In [None]:
import os
import glob
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.utils import save_image


class UNetGenerator(nn.Module):
    def __init__(self, in_channels=4, out_channels=3):
        super(UNetGenerator, self).__init__()

        def down_block(in_ch, out_ch, normalize=True):
            layers = [nn.Conv2d(in_ch, out_ch, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_ch, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return nn.Sequential(*layers)

        def up_block(in_ch, out_ch, dropout=False):
            layers = [nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)]
            layers.append(nn.BatchNorm2d(out_ch, 0.8))
            layers.append(nn.ReLU(inplace=True))
            if dropout:
                layers.append(nn.Dropout(0.5))
            return nn.Sequential(*layers)

        # Downsampling
        self.d1 = down_block(in_channels, 64, normalize=False)
        self.d2 = down_block(64, 128)
        self.d3 = down_block(128, 256)
        self.d4 = down_block(256, 512)  # Bottom layer

        # Upsampling
        self.u1 = up_block(512, 256)
        self.u2 = up_block(512, 128)  # after concat with d3
        self.u3 = up_block(256, 64)   # after concat with d2
        self.u4 = up_block(128, 64)   # after concat with d1

        self.final = nn.Sequential(
            nn.Conv2d(64, out_channels, 3, 1, 1),
            nn.Tanh()  # Output should be in the range of [-1, 1]
        )

    def forward(self, x):
        # Down
        d1_out = self.d1(x)
        d2_out = self.d2(d1_out)
        d3_out = self.d3(d2_out)
        d4_out = self.d4(d3_out)

        # Up
        u1_out = self.u1(d4_out)
        u1_out = torch.cat([u1_out, d3_out], dim=1)

        u2_out = self.u2(u1_out)
        u2_out = torch.cat([u2_out, d2_out], dim=1)

        u3_out = self.u3(u2_out)
        u3_out = torch.cat([u3_out, d1_out], dim=1)

        u4_out = self.u4(u3_out)

        return self.final(u4_out)


class InferenceDataset(Dataset):
    def __init__(self, img_root, mask_root, transforms_, img_size=128, mask_size=64):
        self.transform = transforms.Compose(transforms_)
        self.img_size = img_size
        self.mask_size = mask_size

        self.files = sorted(glob.glob(os.path.join(img_root, "*.png")))
        self.mask_files = [os.path.join(mask_root, os.path.basename(f)) for f in self.files]

        self.mask_transform = transforms.Compose([
            transforms.Resize((img_size, img_size), Image.NEAREST),
            transforms.ToTensor(),
        ])

    def apply_center_mask(self, img):
        i = (self.img_size - self.mask_size) // 2
        masked_img = img.clone()
        masked_img[:, i : i + self.mask_size, i : i + self.mask_size] = 1
        return masked_img, i

    def __getitem__(self, index):
        img_path = self.files[index]
        mask_path = self.mask_files[index]

        img = Image.open(img_path)
        img = self.transform(img)

        mask_img = Image.open(mask_path).convert("L")
        mask_img = self.mask_transform(mask_img)
        mask_img = (mask_img > 0.5).float()

        masked_img, i = self.apply_center_mask(img)

        return img, masked_img, mask_img, i, mask_path  # Return mask path

    def __len__(self):
        return len(self.files)


def generate_inpainted_images(
    generator_path,
    img_folder,
    mask_folder,
    output_folder,
    img_size=128,
    mask_size=64,
    batch_size=8
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    os.makedirs(output_folder, exist_ok=True)

    generator = UNetGenerator(in_channels=4, out_channels=3).to(device)
    generator.load_state_dict(torch.load(generator_path, map_location=device))
    generator.eval()

    transforms_ = [
        transforms.Resize((img_size, img_size), Image.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]

    dataset = InferenceDataset(
        img_root=img_folder,
        mask_root=mask_folder,
        transforms_=transforms_,
        img_size=img_size,
        mask_size=mask_size,
    )
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=1)

    with torch.no_grad():
        for batch_i, (imgs, masked_imgs, mask_imgs, center_i, mask_paths) in enumerate(dataloader):
            imgs = imgs.to(device)
            masked_imgs = masked_imgs.to(device)
            mask_imgs = mask_imgs.to(device)

            gen_input = torch.cat((masked_imgs, mask_imgs), dim=1)
            gen_output = generator(gen_input)

            filled_imgs = masked_imgs.clone()
            for b in range(imgs.size(0)):
                i_val = center_i[b].item()
                gen_part = gen_output[b, :, i_val : i_val + mask_size, i_val : i_val + mask_size]
                filled_imgs[b, :, i_val : i_val + mask_size, i_val : i_val + mask_size] = gen_part

            for b in range(imgs.size(0)):
                mask_filename = os.path.basename(mask_paths[b])  # Get mask filename
                out_path = os.path.join(output_folder, mask_filename)  # Save with same name

                result_img = (filled_imgs[b] * 0.5) + 0.5
                save_image(result_img, out_path)

                print(f"Saved: {out_path}")


# Paths for inference
generator_path = "/home/MICCAI25/GAN_CMC/generator_context_encoder_GAN_stain_augmentation.pth"
img_folder = "/home/MICCAI25/CMC_context_images"
mask_folder = "/home/MICCAI25/Ami-Br_Train_Augmented_masks"
output_folder = "/home/MICCAI25/GAN_CMC_Synthetic_Images/Image_CMC_Mask_AMI-Br"

img_size = 128
mask_size = 64
batch_size = 8

generate_inpainted_images(
    generator_path=generator_path,
    img_folder=img_folder,
    mask_folder=mask_folder,
    output_folder=output_folder,
    img_size=img_size,
    mask_size=mask_size,
    batch_size=batch_size
)
