In [None]:
import os
import glob
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torchvision.utils import make_grid
from torchvision.models import vgg16
import torch.nn.functional as F
from skimage.metrics import structural_similarity as ssim_metric
from skimage.metrics import peak_signal_noise_ratio as psnr_metric
import matplotlib.pyplot as plt

In [None]:
##########################################################
# TODO: Implement the T1T2Dataset class:
# 1. Initialize the dataset by loading the T1 and T2 image paths.
# 2. Implement the __len__ method to return the number of samples.
# 3. Implement the __getitem__ method to load and transform images.
##########################################################

class T1T2Dataset(Dataset):
    def __init__(self, data_dir, transform=None):
        # Step 1: Load and sort image paths
        # Replace "pass" with your code
        #pass
        self.t1_images = #pass
        self.t2_images = #pass
        self.transform = transform

    def __len__(self):
        # Step 2: Return the number of samples
        #pass
        return #pass


    def __getitem__(self, idx):
        # Step 3: Load images and apply transformations
        #pass
        t1_image = #pass
        t2_image = #pass

        if self.transform:
            t1_image = #pass
            t2_image = #pass

        return t1_image, t2_image

#############################################################
# END OF YOUR CODE
##########################################################

In [None]:
# Transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Create Dataset and DataLoaders
dataset = T1T2Dataset(
    data_dir='../img/pair_slices/',
    transform=transform
)

In [None]:
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size]
)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)

In [None]:
##########################################################
# TODO: Implement the UNetGenerator class:
# 1. Define the encoder (downsampling) layers.
# 2. Define the decoder (upsampling) layers with skip connections.
# 3. Implement the forward method to connect the layers.
##########################################################

class UNetGenerator(nn.Module):
    def __init__(self, input_channels=1, output_channels=1, num_filters=64):
        super(UNetGenerator, self).__init__()
        # Step 1: Define the encoder layers
        #pass
        self.down1 = #pass
        self.down2 = #pass
        self.down3 = #pass
        self.down4 = #pass
        self.down5 = #pass
        self.down6 = #pass
        self.down7 = #pass
        self.down8 = #pass

        # Step 2: Define the decoder layers
        #pass
        self.up1 = #pass
        self.up2 = #pass
        self.up3 = #pass
        self.up4 = #pass
        self.up5 = #pass
        self.up6 = #pass
        self.up7 = #pass
        self.up8 = #pass

    def conv_block(self, in_channels, out_channels, normalize=True):
        layers = [nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)]
        if normalize:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2))
        return nn.Sequential(*layers)

    def deconv_block(self, in_channels, out_channels, dropout=0.0):
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        ]
        if dropout:
            layers.append(nn.Dropout(dropout))
        return nn.Sequential(*layers)

    def forward(self, x):
        # Step 3: Implement the forward pass with skip connections
        #pass
        d1 = #pass
        d2 = #pass
        d3 = #pass
        d4 = #pass
        d5 = #pass
        d6 = #pass
        d7 = #pass
        d8 = #pass

        u1 = #pass
        u2 = #pass
        u3 = #pass
        u4 = #pass
        u5 = #pass
        u6 = #pass
        u7 = #pass
        u8 = #pass

        return u8

#############################################################
# END OF YOUR CODE
##########################################################

In [None]:
##########################################################
# TODO: Implement the Discriminator class:
# 1. Define the convolutional layers of the discriminator.
# 2. Implement the forward method.
##########################################################

class Discriminator(nn.Module):
    def __init__(self, input_channels=2, num_filters=64):
        super(Discriminator, self).__init__()
        # Step 1: Define the convolutional layers
        #pass
        self.model = #pass

    def forward(self, x):
        # Step 2: Implement the forward pass
        #pass
        return #pass

#############################################################
# END OF YOUR CODE
##########################################################

In [None]:
##########################################################
# TODO: Implement the Pix2PixModel class:
# 1. Implement the training_step for both generator and discriminator.
# 2. Implement the validation_step to compute evaluation metrics.
##########################################################
class Pix2PixModel(pl.LightningModule):
    def __init__(self, lr=2e-4, beta1=0.5):
        super(Pix2PixModel, self).__init__()
        self.generator = UNetGenerator()
        self.discriminator = Discriminator()
        self.lr = lr
        self.beta1 = beta1
        self.loss_gan = nn.BCELoss()
        self.loss_l1 = nn.L1Loss()
        # Load pre-trained VGG16 for perceptual loss
        self.vgg = vgg16(pretrained=True).features[:16].eval()
        for param in self.vgg.parameters():
            param.requires_grad = False
        self.save_hyperparameters()

        # Enable manual optimization
        self.automatic_optimization = False
        
    def perceptual_loss(self, gen, target):
        # gen and target are [batch_size, 1, H, W]
        # Duplicate channels to convert to 3-channel images
        gen_rgb = gen.repeat(1, 3, 1, 1)  # Now gen_rgb is [batch_size, 3, H, W]
        target_rgb = target.repeat(1, 3, 1, 1)

        # Normalize using ImageNet mean and std
        mean = torch.tensor([0.485, 0.456, 0.406], device=gen.device).view(1, 3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225], device=gen.device).view(1, 3, 1, 1)

        gen_rgb = (gen_rgb - mean) / std
        target_rgb = (target_rgb - mean) / std

        gen_features = self.vgg(gen_rgb)
        target_features = self.vgg(target_rgb)
        return F.l1_loss(gen_features, target_features)

    def forward(self, x):
        return self.generator(x)

    def configure_optimizers(self):
        optimizer_g = torch.optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, 0.999))
        optimizer_d = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, 0.999))
        return [optimizer_g, optimizer_d]

    def training_step(self, batch, batch_idx):
        # Step 1: Implement the training step for generator and discriminator
        real_A, real_B = batch  # real_A is T1, real_B is T2
        opt_g, opt_d = self.optimizers()
        # Adversarial ground truths
        valid = torch.ones(real_B.size(0), 1, 30, 30).type_as(real_B)
        fake = torch.zeros(real_B.size(0), 1, 30, 30).type_as(real_B)

        # -----------------
        #  Train Generator
        # -----------------
        # Generate fake images
        fake_B = #pass
        # Discriminator's opinion on the generated images
        pred_fake = #pass
        # Calculate generator loss
        loss_gan = #pass
        # L1 loss
        loss_l1 = #pass
        # Perceptual loss
        loss_perceptual = #pass
        # Total loss
        loss_G = loss_gan + 100 * loss_l1 + 10 * loss_perceptual

        # Optimize generator
        self.manual_backward(loss_G)
        opt_g.step()
        opt_g.zero_grad()

        # ---------------------
        #  Train Discriminator
        # ---------------------
        # Real loss
        pred_real = #pass
        loss_real = #pass

        # Fake loss
        # Detach to avoid updating generator parameters
        fake_B_detached = fake_B.detach()
        pred_fake = #pass
        loss_fake = #pass

        # Total discriminator loss
        loss_D = (loss_real + loss_fake) * 0.5

        # Optimize discriminator
        self.manual_backward(loss_D)
        opt_d.step()
        opt_d.zero_grad()

        # Logging losses
        self.log('loss_G', loss_G, prog_bar=True, on_step=True, on_epoch=True)
        self.log('loss_D', loss_D, prog_bar=True, on_step=True, on_epoch=True)

    def validation_step(self, batch, batch_idx):
        # Step 2: Implement the validation step to compute PSNR and SSIM
        real_A, real_B = batch
        fake_B = self.generator(real_A)

        # Denormalize images
        real_B_denorm = real_B * 0.5 + 0.5
        fake_B_denorm = fake_B * 0.5 + 0.5

        real_B_np = real_B_denorm.cpu().numpy()
        fake_B_np = fake_B_denorm.cpu().numpy()

        batch_psnr = 0
        batch_ssim = 0

        for i in range(real_B_np.shape[0]):
            real_img = #pass
            fake_img = #pass

            psnr_value = #pass
            ssim_value = #pass

            batch_psnr += psnr_value
            batch_ssim += ssim_value

        avg_psnr = #pass
        avg_ssim = #pass

        self.log('val_psnr', avg_psnr, prog_bar=True)
        self.log('val_ssim', avg_ssim, prog_bar=True)

        if batch_idx == 0:
            grid = make_grid(
                torch.cat([real_A[:4], fake_B[:4], real_B[:4]], dim=0),
                nrow=4, normalize=True
            )
            self.logger.experiment.add_image('val_images', grid, self.current_epoch)


    def test_step(self, batch, batch_idx):
        real_A, real_B = batch
        fake_B = self.generator(real_A)

        # Denormalize images for metric computation
        real_B_denorm = real_B * 0.5 + 0.5  # Convert from [-1,1] to [0,1]
        fake_B_denorm = fake_B * 0.5 + 0.5

        # Convert tensors to numpy arrays
        real_B_np = real_B_denorm.cpu().numpy()
        fake_B_np = fake_B_denorm.cpu().numpy()

        # Initialize metric accumulators
        batch_psnr = 0
        batch_ssim = 0

        for i in range(real_B_np.shape[0]):
            real_img = real_B_np[i, 0] 
            fake_img = fake_B_np[i, 0]

            # Compute PSNR and SSIM
            psnr_value = psnr_metric(real_img, fake_img, data_range=1.0)
            ssim_value = ssim_metric(real_img, fake_img, data_range=1.0)

            batch_psnr += psnr_value
            batch_ssim += ssim_value

        # Average over batch
        avg_psnr = batch_psnr / real_B_np.shape[0]
        avg_ssim = batch_ssim / real_B_np.shape[0]

        self.log('test_psnr', avg_psnr, prog_bar=True)
        self.log('test_ssim', avg_ssim, prog_bar=True)

        # Optionally log images
        if batch_idx == 0:
            grid = make_grid(
                torch.cat([real_A[:4], fake_B[:4], real_B[:4]], dim=0),
                nrow=4, normalize=True
            )
            self.logger.experiment.add_image('test_images', grid, self.current_epoch)
#############################################################
# END OF YOUR CODE
##########################################################

In [None]:
# Initialize Model and Trainer
model = Pix2PixModel()

##########################################################
# TODO: Set up the training process:
# 1. Initialize the Trainer with appropriate parameters.
# 2. Start the training process.
##########################################################

early_stop_callback = EarlyStopping(
    monitor='val_psnr',
    min_delta=0.00,
    patience=5,
    verbose=True,
    mode='max'
)
logger = TensorBoardLogger("lightning_logs", name="GAN")

trainer = #pass

# Train the Model
trainer.fit('''pass''')

In [None]:
# If you encountered error: ModuleNotFoundError: Neither `tensorboard` nor `tensorboardX` is available. Try `pip install`ing either.
# Please make sure you have tensorboard and tensorboardX installed as listed in the requirements.txt 
# Please try:
'''
from lightning_utilities.core.imports import RequirementCache

print(RequirementCache("tensorboard"))
print(RequirementCache("tensorboardx"))
'''
# Try restarting session if 
'''
Requirement 'tensorboard' met
Requirement 'tensorboardx' met
'''

In [None]:
print('model training is done.')
trainer.save_checkpoint("GAN.pth")

In [None]:
# Test the Model
trainer.test(model, test_loader)

In [None]:
def visualize_results(model, dataloader, num_images=5):
    model.eval()
    with torch.no_grad():
        for idx, batch in enumerate(dataloader):
            real_A, real_B = batch
            fake_B = model.generator(real_A)

            # Denormalize images
            real_A = real_A * 0.5 + 0.5
            fake_B = fake_B * 0.5 + 0.5
            real_B = real_B * 0.5 + 0.5

            # Plot images
            for i in range(min(num_images, real_A.size(0))):
                fig, axs = plt.subplots(1, 3, figsize=(12, 4))
                axs[0].imshow(real_A[i, 0].cpu(), cmap='gray')
                axs[0].set_title('Input T1')
                axs[1].imshow(fake_B[i, 0].cpu(), cmap='gray')
                axs[1].set_title('Generated T2')
                axs[2].imshow(real_B[i, 0].cpu(), cmap='gray')
                axs[2].set_title('Ground Truth T2')
                for ax in axs:
                    ax.axis('off')
                plt.show()
            break  # Only visualize one batch

# Visualize the results
visualize_results(model, test_loader)