In [1]:
import os
import pandas as pd
from sklearn.model_selection import train_test_split
from torchvision import transforms as T
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch

# Directory and file setup
image_dir = "/kaggle/input/glaucoma-datasets/G1020/Images"
metadata_file = "/kaggle/input/glaucoma-datasets/G1020/G1020.csv"

# Load metadata
metadata = pd.read_csv(metadata_file)
metadata['path'] = metadata['imageID'].apply(lambda x: os.path.join(image_dir, x))

# Separate data by class
glaucoma_df = metadata[metadata['binaryLabels'] == 1]
normal_df = metadata[metadata['binaryLabels'] == 0]

# Ensure equal number of samples from each class
min_samples = min(len(glaucoma_df), len(normal_df))

# Randomly sample without replacement from each class
glaucoma_df = glaucoma_df.sample(n=min_samples, random_state=42)
normal_df = normal_df.sample(n=min_samples, random_state=42)

# Concatenate the balanced data
balanced_df = pd.concat([glaucoma_df, normal_df]).reset_index(drop=True)

# Split the balanced dataset into train and test sets
train_df, test_df = train_test_split(balanced_df, test_size=0.2, random_state=42, stratify=balanced_df['binaryLabels'])

# Image transformations
transformations = T.Compose([
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    #T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Custom dataset class
class GlaucomaDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_path = self.dataframe.iloc[idx]['path']
        image = Image.open(img_path).convert("RGB")
        label = self.dataframe.iloc[idx]['binaryLabels']
        if self.transform:
            image = self.transform(image)
        return image, label

# Create datasets and dataloaders
glaucoma_dataset = GlaucomaDataset(glaucoma_df, transform=transformations)
normal_dataset = GlaucomaDataset(normal_df, transform=transformations)

glaucoma_loader = DataLoader(glaucoma_dataset, batch_size=32, shuffle=True)
normal_loader = DataLoader(normal_dataset, batch_size=32, shuffle=True)
# This setup ensures that both the training and testing sets are balanced regarding class distribution

In [2]:
import torch
import torch.nn as nn

def conv3x3(in_planes, out_planes, stride=1):
    
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.downsample = nn.Sequential(
            # input: 3 x 256 x 256
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1), # output: 64 x 128 x 128
            nn.ReLU(True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # output: 128 x 64 x 64
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), # output: 256 x 32 x 32
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1), # output: 512 x 16 x 16
            nn.BatchNorm2d(512),
            nn.ReLU(True)
        )
        self.upsample = nn.Sequential(
        
            nn.Upsample(scale_factor=2, mode='nearest'), # output: 512 x 32 x 32
            conv3x3(512, 256), 
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.Upsample(scale_factor=2, mode='nearest'), # output: 256 x 64 x 64
            conv3x3(256, 128),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.Upsample(scale_factor=2, mode='nearest'), # output: 128 x 128 x 128
            conv3x3(128, 64),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.Upsample(scale_factor=2, mode='nearest'), # output: 64 x 256 x 256
            conv3x3(64, 3),
            nn.Tanh()  # Final output: 3 x 256 x 256
        )

    def forward(self, x):
        x = self.downsample(x)
        x = self.upsample(x)
        return x

In [3]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # Input: 224 x 224
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),  # Output: 112 x 112
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # Output: 56 x 56
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),  # Output: 28 x 28
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),  # Output: 14 x 14
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1024, kernel_size=4, stride=2, padding=1),  # Output: 7 x 7
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(1024, 2048, kernel_size=4, stride=2, padding=1),  # Output: 3 x 3 (rounded from 3.5 x 3.5)
            nn.BatchNorm2d(2048),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(2048, 1, kernel_size=3, stride=1, padding=0),  # Output: 1 x 1
            nn.Sigmoid()
        )

    def forward(self, input):
        output = self.main(input)
        return output.view(-1)  # Flatten the output to ensure it is [batch_size]

In [4]:
import matplotlib.pyplot as plt
import os
from IPython.display import display

def visualize_images(real_images, transformed_images, epoch, save_dir='/kaggle/working'):
    real_images = real_images.detach().cpu()
    transformed_images = transformed_images.detach().cpu()
    
    fig, axes = plt.subplots(nrows=2, ncols=8, figsize=(16, 4))
    for i in range(8):
        axes[0, i].imshow(real_images[i].permute(1, 2, 0))
        axes[0, i].axis('off')
        axes[1, i].imshow(transformed_images[i].permute(1, 2, 0))
        axes[1, i].axis('off')

    plt.show()  # Display the figure in the output cell
    plt.savefig(os.path.join(save_dir, f'epoch_{epoch+1}.png'))  # Save the figure to the file system
    plt.close(fig)  # Close the figure to free memory

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
lambda1_values = [1.0]
def train_gan_for_lambda(generator, discriminator, glaucoma_loader, normal_loader, lambda1_values, num_epochs=100, save_dir='/kaggle/working'):
    mse_loss = nn.MSELoss()
    l1_loss = nn.L1Loss()  # Adding L1 Loss for perceptual similarity
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    loss_dict = {f'lambda_{lambda1:.1f}': {'D': [], 'G': []} for lambda1 in lambda1_values}

    for lambda1 in lambda1_values:
        print(f"Training with lambda1 = {lambda1}")
        
        # Reinitialize the optimizers and schedulers for each lambda value
        optimizerD = optim.Adam(discriminator.parameters(), lr=0.00002, betas=(0.5, 0.999))
        optimizerG = optim.Adam(generator.parameters(), lr=0.00002, betas=(0.5, 0.999))
        schedulerD = torch.optim.lr_scheduler.StepLR(optimizerD, step_size=10, gamma=0.1)
        schedulerG = torch.optim.lr_scheduler.StepLR(optimizerG, step_size=10, gamma=0.1)

        generator.to(device)
        discriminator.to(device)

        real_label = 1.0
        fake_label = 0.0

        for epoch in range(num_epochs):
            running_lossD = 0.0
            running_lossG = 0.0
            
            for glaucoma_data, normal_data in zip(glaucoma_loader, normal_loader):
                # Unpack the batches correctly
                normal_images, _ = normal_data
                glaucoma_images, _ = glaucoma_data
                normal_images = normal_images.to(device)
                glaucoma_images = glaucoma_images.to(device)
                
                batch_size = normal_images.size(0)

                # Train Discriminator on real images
                real_outputs = discriminator(normal_images)
                lossD_real = mse_loss(real_outputs, torch.full((batch_size,), real_label, dtype=torch.float, device=device))

                # Generate fake images from glaucoma data
                fake_images = generator(glaucoma_images)
                fake_outputs = discriminator(fake_images.detach())
                lossD_fake = mse_loss(fake_outputs, torch.full((batch_size,), fake_label, dtype=torch.float, device=device))

                # Update Discriminator
                discriminator.zero_grad()
                lossD = (lossD_real + lossD_fake) / 2
                lossD.backward()
                optimizerD.step()

                # Update Generator with perceptual loss
                generator.zero_grad()
                output_gen = discriminator(fake_images)
                lossG = mse_loss(output_gen, torch.full((batch_size,), real_label, dtype=torch.float, device=device))
                perceptual_loss = l1_loss(fake_images, glaucoma_images)  # Enforcing similarity
                total_gen_loss = lossG + lambda1 * perceptual_loss  # Combine losses with a weighting factor
                total_gen_loss.backward()
                optimizerG.step()

                running_lossD += lossD.item()
                running_lossG += total_gen_loss.item()

            loss_dict[f'lambda_{lambda1:.1f}']['D'].append(running_lossD / len(glaucoma_loader))
            loss_dict[f'lambda_{lambda1:.1f}']['G'].append(running_lossG / len(glaucoma_loader))

            if (epoch + 1) % 10 == 0:
                print(f'Epoch {epoch+1}/{num_epochs}: '
                      f'Loss_D: {running_lossD / len(glaucoma_loader):.4f}, Loss_G: {running_lossG / len(glaucoma_loader):.4f}, '
                      f'Perceptual Loss: {perceptual_loss.item():.4f}, '
                      f'D(x): {real_outputs.mean().item():.4f}, D(G(z)): {fake_outputs.mean().item():.4f} / {output_gen.mean().item():.4f}')
                
                # Visualize images
                with torch.no_grad():
                    visualize_images(glaucoma_images[:8], fake_images[:8], epoch, save_dir)

            schedulerD.step()
            schedulerG.step()

        # Save model weights after training with the current lambda1 value
        torch.save(generator.state_dict(), os.path.join(save_dir, f'generator_lambda_{lambda1:.1f}.pth'))
        torch.save(discriminator.state_dict(), os.path.join(save_dir, f'discriminator_lambda_{lambda1:.1f}.pth'))

    # Plotting the losses for all lambdas
    plt.figure(figsize=(10, 7))
    for lambda1 in lambda1_values:
        plt.plot(range(1, num_epochs + 1), loss_dict[f'lambda_{lambda1:.1f}']['D'], label=f'D_loss_lambda_{lambda1:.1f}')
        plt.plot(range(1, num_epochs + 1), loss_dict[f'lambda_{lambda1:.1f}']['G'], label=f'G_loss_lambda_{lambda1:.1f}')

    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Loss vs. Epochs for different lambda values')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(save_dir, 'loss_vs_epochs.png'))
    plt.show()
    
generator = Generator().to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
discriminator = Discriminator().to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
# Call the training function
train_gan_for_lambda(generator, discriminator, glaucoma_loader, normal_loader, lambda1_values, 100, '/kaggle/working')

Training with lambda1 = 1.0
