In [None]:
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import numpy as np
import model
import os

# Parse mask for saving
def mask_parse(mask):
    mask = np.expand_dims(mask, axis=-1)  # (512, 512, 1)
    mask = np.concatenate([mask, mask, mask], axis=-1)  # (512, 512, 3)
    return mask

# Load model and move to CUDA if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.build_unet()
model.load_state_dict(torch.load(r'files\ridge_gabor.pth', map_location=device))
model.eval().to(device)  # Move model to GPU

transform = transforms.Compose([
    transforms.ToTensor()
])

# Load and resize the image
def load_image(image_path):
    image = Image.open(image_path).convert('RGB')
    image = image.resize((512, 512))
    image = transform(image).unsqueeze(0).to(device)  # Move to GPU
    return image

# Get mask prediction
def get_prediction(model, image_tensor):
    with torch.no_grad():
        output = model(image_tensor)
        output = torch.sigmoid(output)
        output = output[0].cpu().numpy()  # Move back to CPU for processing
        output = np.squeeze(output, axis=0)
        output = output > 0.5
        output = np.array(output, dtype=np.uint8)
        output = mask_parse(output)
    return output

# Superimpose the mask on the original image
def superimpose_mask(original_image, mask):
    # Convert the original image from tensor to numpy array
    original_image = original_image.squeeze(0).permute(1, 2, 0).cpu().numpy()  # Convert to (H, W, C)
    original_image = (original_image * 255).astype(np.uint8)  # Scale back to 0-255
    original_image_pil = Image.fromarray(original_image)

    # Convert mask to PIL Image
    mask_image = Image.fromarray((mask * 255).astype(np.uint8))

    # Blend the original image with the mask
    blended_image = Image.blend(original_image_pil, mask_image, alpha=0.5)  # Adjust alpha for transparency
    return blended_image

# Process images in the folder and subfolders
def process_images_in_folder(image_folder, mask_output_folder):
    # Create output folder if it doesn't exist
    if not os.path.exists(mask_output_folder):
        os.makedirs(mask_output_folder)

    for root, dirs, files in os.walk(image_folder):
        # Create corresponding subfolders in the output directory
        for dir_name in dirs:
            output_subfolder = os.path.join(mask_output_folder, os.path.relpath(os.path.join(root, dir_name), image_folder))
            if not os.path.exists(output_subfolder):
                os.makedirs(output_subfolder)

        # Process image files in the current folder
        for file_name in files:
            image_path = os.path.join(root, file_name)
            
            # Load the image and predict the mask
            original_image_tensor = load_image(image_path)
            predicted_mask = get_prediction(model, original_image_tensor)
            
            # Superimpose the mask on the original image
            blended_image = superimpose_mask(original_image_tensor, predicted_mask)
            
            # Save the blended image
            rel_path = os.path.relpath(image_path, image_folder)
            blended_save_path = os.path.join(mask_output_folder, f"blended_{rel_path}")
            
            blended_save_dir = os.path.dirname(blended_save_path)
            if not os.path.exists(blended_save_dir):
                os.makedirs(blended_save_dir)
            
            blended_image.save(blended_save_path)
            
            print(f"Saved blended image for {file_name} at {blended_save_path}")

# Main folder paths
image_folder = r'new_data_ridge\test\to_send'  # Root folder with subfolders and images
mask_output_folder = r'output'  # Root output folder where predicted masks will be saved

# Process all images in the folder and subfolders
process_images_in_folder(image_folder, mask_output_folder)

print("Mask generation, superimposition, and saving complete.")
