In [None]:
import os
import argparse
import glob
import random
import numpy as np
from PIL import Image
from math import log10
from skimage.metrics import structural_similarity as ssim

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image

import albumentations as A
from albumentations.pytorch import ToTensorV2
from tiatoolbox.tools.stainaugment import StainAugmentor
import wandb
import matplotlib.pyplot as plt
import random
from torchvision.utils import make_grid

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

wandb.init(project="context_encoder_GAN_stain_augmentation_Ami-Br_new")
os.makedirs("context_encoder_GAN_stain_augmentation_AMI-Br", exist_ok=True)

class UNetGenerator(nn.Module):
    def __init__(self, in_channels=4, out_channels=3):
        super(UNetGenerator, self).__init__()

        def down_block(in_ch, out_ch, normalize=True):
            layers = [nn.Conv2d(in_ch, out_ch, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_ch, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return nn.Sequential(*layers)

        def up_block(in_ch, out_ch, dropout=False):
            layers = [nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)]
            layers.append(nn.BatchNorm2d(out_ch, 0.8))
            layers.append(nn.ReLU(inplace=True))
            if dropout:
                layers.append(nn.Dropout(0.5))
            return nn.Sequential(*layers)

        self.d1 = down_block(in_channels, 64, normalize=False)
        self.d2 = down_block(64, 128)
        self.d3 = down_block(128, 256)
        self.d4 = down_block(256, 512)

        self.u1 = up_block(512, 256)
        self.u2 = up_block(512, 128)
        self.u3 = up_block(256, 64)
        self.u4 = up_block(128, 64)

        self.final = nn.Sequential(
            nn.Conv2d(64, out_channels, 3, 1, 1),
            nn.Tanh()
        )

    def forward(self, x):
        d1_out = self.d1(x)
        d2_out = self.d2(d1_out)
        d3_out = self.d3(d2_out)
        d4_out = self.d4(d3_out)

        u1_out = self.u1(d4_out)
        u1_out = torch.cat([u1_out, d3_out], dim=1)

        u2_out = self.u2(u1_out)
        u2_out = torch.cat([u2_out, d2_out], dim=1)

        u3_out = self.u3(u2_out)
        u3_out = torch.cat([u3_out, d1_out], dim=1)

        u4_out = self.u4(u3_out)
        return self.final(u4_out)



class Discriminator(nn.Module):
    def __init__(self, channels=4):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, stride, normalize, dilation=1):
            layers = [nn.Conv2d(in_filters, out_filters, 3, stride, 1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        layers = []
        in_filters = channels
        for out_filters, stride, norm, dil in [(64, 2, False,1), (128, 2, True,1), (256, 2, True,1), (512, 2, True, 2),(512, 2, True, 4)]:
            layers.extend(discriminator_block(in_filters, out_filters, stride, norm,dil))
            in_filters = out_filters

        layers.append(nn.AvgPool2d(4))
        
        layers.append(nn.Conv2d(out_filters, 1, 3, 1, 1))
        self.model = nn.Sequential(*layers)

    def forward(self, img):
        return self.model(img)




class ImageDataset(Dataset):
    def __init__(self, root, mask_root, img_size=128, mask_size=64, mode="train"):
        self.img_size = img_size
        self.mask_size = mask_size
        self.mode = mode

        self.files = sorted(glob.glob(f"{root}/*.png"))
        if mode == "train":
            self.files = self.files[:-500]
        else:
            self.files = self.files[-500:]

        self.mask_files = [os.path.join(mask_root, os.path.basename(f)) for f in self.files]

        # Stain augmentation parameters
        stain_matrix = np.array([
            [0.91633014, -0.20408072, -0.34451435],
            [0.17669817, 0.92528011, 0.33561059]
        ])

        if mode == "train":
            self.augmentation = A.Compose([
                A.Resize(img_size, img_size),
                A.RandomRotate90(),
                A.HorizontalFlip(0.5),
                A.VerticalFlip(0.5),
                StainAugmentor(method="macenko", stain_matrix=stain_matrix),
                A.RandomBrightnessContrast(p=0.2),
                A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
                ToTensorV2()
            ])
        else:
            self.augmentation = A.Compose([
                A.Resize(img_size, img_size),
                A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
                ToTensorV2()
            ])

        self.mask_transform = A.Compose([
            A.Resize(img_size, img_size, interpolation=1),
            ToTensorV2()
        ])

    def apply_random_mask(self, img):
        # Finding 0: Might make sense to add border here so MF is not completely visible
        border=16
        y1, x1 = np.random.randint(border, self.img_size - self.mask_size - border, 2)
        y2, x2 = y1 + self.mask_size, x1 + self.mask_size

        masked_part = img[:, y1:y2, x1:x2].clone()
        masked_img = img.clone()
        masked_img[:, y1:y2, x1:x2] = 1
        return masked_img, masked_part, y1, x1

    def apply_center_mask(self, img):
        i = (self.img_size - self.mask_size) // 2
        y1, x1 = i, i
        y2, x2 = y1 + self.mask_size, x1 + self.mask_size

        masked_part = img[:, y1:y2, x1:x2].clone()
        masked_img = img.clone()
        masked_img[:, y1:y2, x1:x2] = 1
        return masked_img, masked_part, y1, x1

    def __getitem__(self, index):
        img_path = self.files[index]
        mask_path = self.mask_files[index]

        img = np.array(Image.open(img_path).convert("RGB"))
        mask_img = np.array(Image.open(mask_path).convert("L"))

        # Finding 2: The augmentation is independent but should be together
        #old:
        #augmented = self.augmentation(image=img)
        #img = augmented["image"]

        #mask_aug = self.mask_transform(image=mask_img)
        #mask_img = (mask_aug["image"] > 0.5).float()

        #new:
        augmented = self.augmentation(image=img, mask=mask_img)
        img = augmented["image"]
        mask_img = (augmented["mask"] > 0.5).float()[None,:,:]
        
        if self.mode == "train":
            masked_img, masked_part, y1, x1 = self.apply_random_mask(img)
        else:
            masked_img, masked_part, y1, x1 = self.apply_center_mask(img)

        return img, masked_img, masked_part, mask_img, y1, x1

    def __len__(self):
        return len(self.files)

opt = argparse.Namespace(
    n_epochs=500,
    batch_size=8,
    # dataset_name="img_align_celeba",
    lr=0.0002,
    b1=0.5,
    b2=0.999,
    n_cpu=4,
    # latent_dim=100,
    img_size=128,
    mask_size=64,
    channels=3,
    mask_channels=1,
    sample_interval=100
)

wandb.config.update(vars(opt))

cuda = torch.cuda.is_available

generator = UNetGenerator(in_channels=opt.channels + opt.mask_channels, out_channels=opt.channels)
discriminator = Discriminator(channels=opt.channels + opt.mask_channels)
#discriminator = ResNet18Discriminator(pretrained=False)

if cuda:
    generator.cuda()
    discriminator.cuda()

# Finding 1: Was L2 loss but should be BCE loss IMHO (it's a classification)
adversarial_loss = nn.BCEWithLogitsLoss()
pixelwise_loss = nn.L1Loss()

if cuda:
    adversarial_loss.cuda()
    pixelwise_loss.cuda()

def weights_init_normal(m):
    if isinstance(m, nn.Conv2d):
        nn.init.normal_(m.weight.data, 0.0, 0.02)

generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

train_dataset = ImageDataset(
    root="/data/Mitosis_Detection/MICCAI25/AMI-Br_dataset/Train_All/Images",
    mask_root="/data/Mitosis_Detection/MICCAI25/AMI-Br_dataset/Train_All/Masks",
    img_size=opt.img_size,
    mask_size=opt.mask_size,
    mode="train"
)

val_dataset = ImageDataset(
    root="/data/Mitosis_Detection/MICCAI25/AMI-Br_dataset/Train_All/Images",
    mask_root="/data/Mitosis_Detection/MICCAI25/AMI-Br_dataset/Train_All/Masks",
    img_size=opt.img_size,
    mask_size=opt.mask_size,
    mode="val"
)

train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=opt.n_cpu)


optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

wandb.watch(generator, discriminator, log="all")

def save_sample(batches_done):
    """Save and display a random grid of images, including masks, masked samples, and filled outputs."""
    val_data = list(val_loader) 
    random_batch = random.choice(val_data)  
    samples, masked_samples, masked_parts, mask_img, y1, x1 = random_batch

    samples = samples.type(Tensor)
    masked_samples = masked_samples.type(Tensor)
    mask_img = mask_img.type(Tensor)
    gen_input = torch.cat((masked_samples, mask_img), dim=1)

    gen_output = generator(gen_input)

    filled_samples = masked_samples.clone()
    for b in range(samples.size(0)):
        yy1 = y1[b].item()
        xx1 = x1[b].item()
        filled_samples[b, :, yy1:yy1 + opt.mask_size, xx1:xx1 + opt.mask_size] = \
            gen_output[b, :, yy1:yy1 + opt.mask_size, xx1:xx1 + opt.mask_size]

    mask_img_3c = mask_img.repeat(1, 3, 1, 1)

    output_grid = torch.cat((mask_img_3c, masked_samples, filled_samples, samples), dim=-2)

    grid = make_grid(output_grid, nrow=4, normalize=True, padding=2, pad_value=1)

    np_grid = grid.permute(1, 2, 0).cpu().numpy()

    save_path = f"context_encoder_GAN_stain_augmentation_AMI-Br/{batches_done}_random.png"
    plt.figure(figsize=(12, 12))
    plt.imshow(np_grid)
    plt.axis("off")
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()

    wandb.log({"inferred_images": [wandb.Image(save_path, caption=f"Step {batches_done} (Random Sample)")]})

def compute_psnr(img1, img2):
    img1 = (img1 + 1) / 2
    img2 = (img2 + 1) / 2
    mse = torch.mean((img1 - img2) ** 2)
    if mse.item() == 0:
        return 100.0
    return 20 * torch.log10(1.0 / torch.sqrt(mse)).item()

def compute_ssim(img1, img2):
    img1 = (img1.detach().cpu().numpy() + 1) / 2
    img2 = (img2.detach().cpu().numpy() + 1) / 2
    N, C, H, W = img1.shape
    ssim_val = 0.0
    for n in range(N):
        ssim_c = 0.0
        for c in range(C):
            ssim_c += ssim(img1[n, c, :, :], img2[n, c, :, :], data_range=1.0)
        ssim_val += ssim_c / C
    return ssim_val / N




In [None]:
imgs, masked_imgs, masked_parts, mask_img, y1, x1 = next(iter(train_loader))
fig,ax = plt.subplots(8,4, figsize=(10,10))
for k in range(imgs.size(0)):
    s=torch.Tensor([0.5,0.5,0.5])
    m=torch.Tensor([0.5,0.5,0.5])
    ax[k,0].imshow(imgs[k].permute(1,2,0).mul_(s).add_(m))
    ax[k,1].imshow(masked_imgs[k].permute(1,2,0).mul_(s).add_(m))
    ax[k,2].imshow(masked_parts[k].permute(1,2,0).mul_(s).add_(m))
    ax[k,3].imshow(mask_img[k].permute(1,2,0))
ax[0,0].set_title('imgs')
ax[0,1].set_title('masked_imgs')
ax[0,2].set_title('masked_parts')
ax[0,3].set_title('mask_img')


    

In [None]:
imgs, masked_imgs, masked_parts, mask_img, y1, x1 = next(iter(train_loader))

gen_input = torch.cat((masked_imgs, mask_img), dim=1)
#gen_output = generator(gen_input)
gen_input.shape

real_filled = masked_imgs.clone()
for b in range(imgs.size(0)):
    yy1 = y1[b].item()
    xx1 = x1[b].item()
    real_filled[b, :, yy1:yy1 + opt.mask_size, xx1:xx1 + opt.mask_size] = masked_parts[b]


fig,ax = plt.subplots(8,6, figsize=(10,10))
for k in range(imgs.size(0)):
    s=torch.Tensor([0.5,0.5,0.5])
    m=torch.Tensor([0.5,0.5,0.5])
    ax[k,0].imshow(gen_input[k][0:3,:,:,].permute(1,2,0).mul_(s).add_(m))
    ax[k,1].imshow(gen_input[k][3,:,:,])

    yy1 = y1[k].item()
    xx1 = x1[k].item()

    gen_mask = torch.zeros(size=imgs.shape[2:4])
    gen_mask[yy1:yy1 + opt.mask_size, xx1:xx1 + opt.mask_size] = 1
    ax[k,2].imshow(gen_mask)
    ax[k,3].imshow(masked_parts[k].permute(1,2,0).mul_(s).add_(m))
    ax[k,4].imshow(real_filled[k].permute(1,2,0).mul_(s).add_(m))
    ax[k,5].imshow(imgs[k].permute(1,2,0).mul_(s).add_(m))


    


ax[0,0].set_title('gen_input [0:3]')
ax[0,1].set_title('gen_input [3]')
ax[0,2].set_title('mask for inpainting ')
ax[0,3].set_title('GT for inpainting ')
ax[0,4].set_title('real filled ')
ax[0,5].set_title('ref img ')



    

In [None]:
epoch_d_losses = []
epoch_g_adv_losses = []
epoch_g_pixel_losses = []
epoch_psnrs = []
epoch_ssims = []

for epoch in range(opt.n_epochs):
    generator.train()
    discriminator.train()

    epoch_d_loss_sum = 0.0
    epoch_g_adv_sum = 0.0
    epoch_g_pixel_sum = 0.0
    n_batches = 0

    for i, batch in enumerate(train_loader):
        imgs, masked_imgs, masked_parts, mask_img, y1, x1 = batch
        imgs = imgs.type(Tensor)
        masked_imgs = masked_imgs.type(Tensor)
        masked_parts = masked_parts.type(Tensor)
        mask_img = mask_img.type(Tensor)
        y1 = y1.type(torch.int)
        x1 = x1.type(torch.int)

#        valid = Tensor(imgs.shape[0], 1, 1, 16).fill_(1.0)
#        fake = Tensor(imgs.shape[0], 1, 16, 16).fill_(0.0)
        valid = Tensor(imgs.shape[0], 1, 1, 1).fill_(1.0)
        fake = Tensor(imgs.shape[0], 1, 1, 1).fill_(0.0)
        
        optimizer_G.zero_grad()
        gen_input = torch.cat((masked_imgs, mask_img), dim=1)
        gen_output = generator(gen_input)

        g_pixel_loss = 0.0
        for b in range(imgs.size(0)):
            yy1 = y1[b].item()
            xx1 = x1[b].item()
            gen_part = gen_output[b:b + 1, :, yy1:yy1 + opt.mask_size, xx1:xx1 + opt.mask_size]
            masked_part = masked_parts[b:b + 1]
            g_pixel_loss += pixelwise_loss(gen_part, masked_part)
        g_pixel_loss /= imgs.size(0)

        # Images labeled as real as generator tries to fool discriminator. Alternatively, we could formulate this as maximization.
        g_adv_loss = adversarial_loss(discriminator(torch.cat((gen_output, mask_img), dim=1)), valid)

        g_loss = 0.1 * g_adv_loss + 1.0 * g_pixel_loss
        g_loss.backward()
        optimizer_G.step()

      
        optimizer_D.zero_grad()

        # Question 4: Why real_filled, why not using the original image?
        real_filled = masked_imgs.clone()
        for b in range(imgs.size(0)):
            yy1 = y1[b].item()
            xx1 = x1[b].item()
            real_filled[b, :, yy1:yy1 + opt.mask_size, xx1:xx1 + opt.mask_size] = masked_parts[b]

        real_loss = adversarial_loss(discriminator(torch.cat((imgs, mask_img), dim=1)), valid)
        fake_loss = adversarial_loss(discriminator(torch.cat((gen_output.detach(), mask_img), dim=1)), fake)

        d_loss = 0.5 * (real_loss + fake_loss)
        d_loss.backward()
        optimizer_D.step()

        print(
            f"[Epoch {epoch}/{opt.n_epochs}] [Batch {i}/{len(train_loader)}] "
            f"[D loss: {d_loss.item():.6f}] [G adv: {g_adv_loss.item():.6f}, pixel: {g_pixel_loss.item():.6f}]"
        )

        wandb.log({
            "D_loss": d_loss.item(),
            "G_adv": g_adv_loss.item(),
            "G_pixel": g_pixel_loss.item(),
            "epoch": epoch
        })

        batches_done = epoch * len(train_loader) + i
        if batches_done % opt.sample_interval == 0:
            save_sample(batches_done)

        epoch_d_loss_sum += d_loss.item()
        epoch_g_adv_sum += g_adv_loss.item()
        epoch_g_pixel_sum += g_pixel_loss.item()
        n_batches += 1

    generator.eval()
    val_imgs, val_masked_imgs, val_masked_parts, val_mask_img, val_y1, val_x1 = next(iter(val_loader))
    val_imgs = val_imgs.type(Tensor)
    val_masked_imgs = val_masked_imgs.type(Tensor)
    val_mask_img = val_mask_img.type(Tensor)

    with torch.no_grad():
        gen_input = torch.cat((val_masked_imgs, val_mask_img), dim=1)
        gen_output = generator(gen_input)

        val_filled = val_masked_imgs.clone()
        for b in range(val_imgs.size(0)):
            yy1 = val_y1[b].item()
            xx1 = val_x1[b].item()
            val_filled[b, :, yy1:yy1 + opt.mask_size, xx1:xx1 + opt.mask_size] = gen_output[b, :, yy1:yy1 + opt.mask_size, xx1:xx1 + opt.mask_size]

        current_psnr = compute_psnr(val_filled, val_imgs)
        current_ssim = compute_ssim(val_filled, val_imgs)

    wandb.log({
        "PSNR": current_psnr,
        "SSIM": current_ssim,
        "epoch": epoch
    })

    epoch_d_losses.append(epoch_d_loss_sum / n_batches)
    epoch_g_adv_losses.append(epoch_g_adv_sum / n_batches)
    epoch_g_pixel_losses.append(epoch_g_pixel_sum / n_batches)
    epoch_psnrs.append(current_psnr)
    epoch_ssims.append(current_ssim)

torch.save(generator.state_dict(), "/home/MICCAI25/GAN_AMI-Br/generator_context_encoder_GAN_stain_augmentation.pth")
torch.save(discriminator.state_dict(), "/home/MICCAI25/GAN_AMI-Br/discriminator_context_encoder_GAN_stain_augmentation.pth")

avg_d_loss = np.mean(epoch_d_losses)
avg_g_adv_loss = np.mean(epoch_g_adv_losses)
avg_g_pixel_loss = np.mean(epoch_g_pixel_losses)
avg_psnr = np.mean(epoch_psnrs)
avg_ssim = np.mean(epoch_ssims)

print("Training Complete!")
print(f"Average D Loss: {avg_d_loss:.6f}")
print(f"Average G Adv Loss: {avg_g_adv_loss:.6f}")
print(f"Average G Pixel Loss: {avg_g_pixel_loss:.6f}")
print(f"Average PSNR: {avg_psnr:.4f}")
print(f"Average SSIM: {avg_ssim:.4f}")

wandb.finish()