## 🔧 Library Imports

We begin by importing essential libraries required for training a CycleGAN using PyTorch. These include modules for neural network construction (`torch.nn`), optimization (`torch.optim`), image preprocessing (`torchvision.transforms`), dataset loading, and image manipulation via PIL. We also import helper modules for dynamic computation graphs and data iteration.


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image
import os
import itertools
from torch.autograd import Variable

## 📁 Mount Google Drive

To access the dataset stored in Google Drive, we mount the drive into the Colab runtime. This allows for persistent access to image folders and model checkpoints.


In [None]:
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

## 📁Load and Preprocess the Datasets

In this step, we load two image datasets — one for **benign** and one for **malignant** cases — directly from Google Drive.  
We apply a basic transformation pipeline using `transforms.Compose` that:
- Converts each image into a PyTorch tensor
- Normalizes the pixel values to have a mean and standard deviation of 0.5 for each channel, scaling them into the range `[-1, 1]`

Finally, we wrap the datasets into `DataLoader` objects for batch-wise access during training.


In [None]:
# Load Data
#Sequence of transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
new_size = 1 #batch size
dataset_A = datasets.ImageFolder(root='/content/drive/MyDrive/my_data/train/benign', transform=transform) #load benign dataset
dataset_B = datasets.ImageFolder(root='/content/drive/MyDrive/my_data/train/malignant', transform=transform) #load malignant dataset
#Data loader for each dataset
loader_A = torch.utils.data.DataLoader(dataset_A, batch_size=new_size, shuffle=True)
loader_B = torch.utils.data.DataLoader(dataset_B, batch_size=new_size, shuffle=True)

## 🖼️ Visualize Sample Images from Each Class

Before we begin training, it’s important to visualize a few images from both classes (benign and malignant) to:
- Confirm the data was loaded and preprocessed correctly
- Gain an intuitive understanding of the visual patterns our model will try to learn

We’ll display one image from each domain using `matplotlib`.


In [None]:
import matplotlib.pyplot as plt

def imshow(img):
    img = img.numpy().transpose((1, 2, 0))
    plt.imshow(img)
    plt.axis('off')

dataiter_A = iter(loader_A)
images_A, _ = next(dataiter_A)

dataiter_B = iter(loader_B)
images_B, _ = next(dataiter_B)

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title('Sample from benign ()')
imshow(images_A[0])

plt.subplot(1, 2, 2)
plt.title('Sample from malignant ()')
imshow(images_B[0])
plt.show()

## ⚙️ Define the ResNet-based Generator

This step defines the **generator model** based on a ResNet architecture. The generator takes images from one domain (e.g., benign) and transforms them into the other domain (e.g., malignant).

The model includes:
- Initial convolution block with reflection padding
- Downsampling layers to compress spatial information
- A series of ResNet blocks with skip connections
- Upsampling layers to restore original image resolution
- A final output layer with `Tanh` activation to map pixel values to `[-1, 1]`

We also define the reusable ResNet block used within the generator.


In [None]:
#######Generator##########
#########################

class ResnetGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, n_blocks=9, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
        assert(n_blocks >= 0)
        super(ResnetGenerator, self).__init__()
        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf #no of generator filters
        #n_blocks = resnet blocks

        #Initial convlutional block
        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),
                 norm_layer(ngf),
                 nn.ReLU(True)]

        # Downsample
        #reducing spatial dimensions
        n_downsampling = 2 #no. of downsampling layers
        for i in range(n_downsampling):
            mult = 2**i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=True),
                      norm_layer(ngf * mult * 2),
                      nn.ReLU(True)]

        # Resnet blocks
        #using shortcut connections bypassing few layers
        mult = 2**n_downsampling
        for i in range(n_blocks):
            model += [ResnetBlock(ngf * mult, padding_type='reflect', norm_layer=norm_layer, use_dropout=use_dropout, use_bias=True)]

        # Upsample
        for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                         kernel_size=3, stride=2,
                                         padding=1, output_padding=1,
                                         bias=True),
                      norm_layer(int(ngf * mult / 2)),
                      nn.ReLU(True)]

        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, input):
        return self.model(input)
#resnet block
class ResnetBlock(nn.Module):
    def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        super(ResnetBlock, self).__init__()
        # Create the convolutional block
        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
    #function to build convolution block
    def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        conv_block = [] #to hold layers
        p = 0

        #determining padding type for first layer
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        else:
            p = 1  # 'zero' padding
        #first convolution layer
        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
                       norm_layer(dim),
                       nn.ReLU(True)]
        #opyional dropout layer
        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        #determining padding type for 2nd layer
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        else:
            p = 1  # 'zero' padding
        #second convolutional block
        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
                       norm_layer(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)

## 🧠 Define the Dual Output Discriminator (ResNet50)

We now define a **dual-purpose discriminator** built on top of a pre-trained ResNet50 model. This discriminator does two things:
1. **Discriminates real vs. fake** images (for adversarial training)
2. **Classifies images as benign or malignant** (for medical relevance)

We remove the final fully connected layer of ResNet50 and replace it with two custom branches:
- One for binary real/fake classification
- One for benign/malignant classification


In [None]:
from torchvision.models import resnet50, ResNet50_Weights #pre trained resnet 50 with weights
class ResNet50DualOutputDiscriminator(nn.Module):
    def __init__(self, num_classes=2):
        super(ResNet50DualOutputDiscriminator, self).__init__()
        self.resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
        self.resnet.fc = nn.Identity()  # Removing the fully connected layer

        # Real/Fake classifier
        self.rf_classifier = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )

        # Benign/Malignant classifier
        self.class_classifier = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        features = self.resnet(x)
        real_fake_output = self.rf_classifier(features)
        class_output = self.class_classifier(features)
        return real_fake_output, class_output


## 💻 Select Compute Device

We define the computation device, defaulting to GPU (`cuda`) if available, otherwise falling back to CPU. This enables seamless training on Colab’s hardware accelerators.


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

## 🔧 Define Training Hyperparameters

Here, we specify the core hyperparameters used across the training pipeline:
- `input_nc` and `output_nc`: Number of channels in the input and output images (3 for RGB)
- `n_residual_blocks`: Number of ResNet blocks inside the generator
- `lr` and `beta1`: Learning rate and momentum term for Adam optimizer


In [None]:
input_nc = 3  # number of channels in the input images
output_nc = 3  # number of channels in the output images
n_residual_blocks = 9  # typical number for a CycleGAN

lr = 0.0002
beta1 = 0.5

## 🔄 Initialize Generators and Discriminators

We now create two generators and two discriminators:
- `netG_A2B`: Converts images from domain A (benign) to domain B (malignant)
- `netG_B2A`: Converts images from malignant back to benign
- `netD_A`: Discriminator for domain A (benign)
- `netD_B`: Discriminator for domain B (malignant)

These models will be trained adversarially in the CycleGAN setup.


In [None]:
# Generators
netG_A2B = ResnetGenerator(input_nc, output_nc, n_blocks=n_residual_blocks).to(device)
netG_B2A = ResnetGenerator(input_nc, output_nc, n_blocks=n_residual_blocks).to(device)

# Discriminators
netD_A = ResNet50DualOutputDiscriminator().to(device)
netD_B = ResNet50DualOutputDiscriminator().to(device)

## 🚀 Define Optimizers for All Networks

In this step, we define the optimizers that will be used to update the weights of:
- Both generators (`optimizer_G`)
- Both discriminators (`optimizer_D_A` and `optimizer_D_B`)

We use the Adam optimizer with a learning rate of `0.0002` and `beta1=0.5`, which is common for GAN training.


In [None]:
from torch.optim import Adam

# Optimizers
# Define optimizers
optimizer_G = optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), lr=lr, betas=(beta1, 0.999))
optimizer_D_A = optim.Adam(netD_A.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D_B = optim.Adam(netD_B.parameters(), lr=lr, betas=(beta1, 0.999))

## 📉 Define Loss Functions

To guide the training process, we define several loss functions:
- **Adversarial loss** (`MSELoss`): Encourages realistic image generation
- **Cycle-consistency loss** (`L1Loss`): Ensures translation from A → B → A brings back the original image
- **Identity loss** (`L1Loss`): Helps maintain color/structure when translating similar images
- **Classification loss** (`CrossEntropyLoss`): Ensures the fake images belong to the correct medical class (benign/malignant)


In [None]:
# Define loss functions
criterion_GAN = nn.MSELoss().to(device)
criterion_cycle = nn.L1Loss().to(device)
criterion_identity = nn.L1Loss().to(device)
criterion_classification = nn.CrossEntropyLoss().to(device)

## 🧪 Define Image Visualization Function

This utility function displays a **real image vs. its generated (fake) counterpart** side-by-side using `matplotlib`.

It helps monitor the quality of generated images during training and visually assess how well the generator is performing.


In [None]:
import matplotlib.pyplot as plt
import torchvision.utils as vutils

def plot_single_real_and_fake_image(real_image, fake_image):
    """
    Plots a comparison of a single real and a single generated (fake) image.

    Parameters:
    - real_image: a single Tensor image (C, H, W).
    - fake_image: a single Tensor image (C, H, W).
    """
    plt.figure(figsize=(8, 4))

    # Display the real image
    plt.subplot(1, 2, 1)
    plt.axis("off")
    plt.title("Real Image")
    real_image = vutils.make_grid(real_image, normalize=True).permute(1, 2, 0).cpu().numpy()
    plt.imshow(real_image)

    # Display the fake image
    plt.subplot(1, 2, 2)
    plt.axis("off")
    plt.title("Generated Image")
    fake_image = vutils.make_grid(fake_image, normalize=True).permute(1, 2, 0).cpu().numpy()
    plt.imshow(fake_image)

    plt.show()

## 💾 Save Trained Models to Google Drive

To avoid losing progress during training, we define a function that saves the current state of all models (`netG_A2B`, `netG_B2A`, `netD_A`, `netD_B`) to Google Drive.

Model checkpoints are saved by epoch, allowing us to resume or analyze training at different stages.


In [None]:
import os
import torch

def save_models_to_drive(epoch, netG_A2B, netG_B2A, netD_A, netD_B, drive_path='/content/drive/MyDrive/resnet_50_CycleGAN_Models'):
    """
    Save model parameters to Google Drive.

    Parameters:
        epoch (int): The current epoch number.
        netG_A2B (nn.Module): Generator model from domain A to B.
        netG_B2A (nn.Module): Generator model from domain B to A.
        netD_A (nn.Module): Discriminator model for domain A.
        netD_B (nn.Module): Discriminator model for domain B.
        drive_path (str): The path in Google Drive to save the models.
    """
    if not os.path.exists(drive_path):
        os.makedirs(drive_path)

    # Define file paths for saving
    path_G_A2B = os.path.join(drive_path, f'netG_A2B_epoch_{epoch}.pth')
    path_G_B2A = os.path.join(drive_path, f'netG_B2A_epoch_{epoch}.pth')
    path_D_A = os.path.join(drive_path, f'netD_A_epoch_{epoch}.pth')
    path_D_B = os.path.join(drive_path, f'netD_B_epoch_{epoch}.pth')

    # Save the models
    torch.save(netG_A2B.state_dict(), path_G_A2B)
    torch.save(netG_B2A.state_dict(), path_G_B2A)
    torch.save(netD_A.state_dict(), path_D_A)
    torch.save(netD_B.state_dict(), path_D_B)

    print(f"Saved models at epoch {epoch} to {drive_path}")

# Example usage within the training loop:
# save_models_to_drive(epoch, netG_A2B, netG_B2A, netD_A, netD_B)


## 🧹 Clear CUDA Cache

Before training begins, we clear the CUDA memory cache to prevent potential memory issues on the GPU.


In [None]:
# Clear CUDA cache
torch.cuda.empty_cache()

## 🧠 Train the CycleGAN with Classification

This is the main training loop where the entire CycleGAN is trained over multiple epochs.

For each epoch and batch:
- The generators are trained to fool the discriminators and reconstruct the original images via cycle-consistency
- The discriminators are trained to distinguish between real and fake images
- Classification loss is also applied to encourage the generated images to have the correct medical label

Every few batches, we:
- Visualize real vs. fake images
- Print out the predicted classes of real and generated samples
- Save model checkpoints after each epoch

This completes the core CycleGAN training process for benign ↔ malignant image translation.


In [None]:
num_epochs = 25

import torch
from torch import nn, optim
from torch.autograd import Variable
import itertools
import time

import os

# Record the total training start time
total_training_start_time = time.time()

# Training Loop
for epoch in range(num_epochs):

    # Record the start time of the epoch
    epoch_start_time = time.time()

    for i, (real_A, real_B) in enumerate(zip(loader_A, loader_B)):
        # Set model input
        real_A = Variable(real_A[0].to(device)) # moving images from domain A to CUDA
        real_B = Variable(real_B[0].to(device)) # moving images from domain B to CUDA

        # -------------------------------
        #  Train Generators A2B and B2A
        # -------------------------------
        optimizer_G.zero_grad()

        # Identity loss
        loss_id_A = criterion_identity(netG_B2A(real_A), real_A)
        loss_id_B = criterion_identity(netG_A2B(real_B), real_B)

        # GAN loss
        fake_B = netG_A2B(real_A) # generating images from A to B domain
        pred_fake, class_fake_B = netD_B(fake_B)  # predicting real/fake and class
        loss_GAN_A2B = criterion_GAN(pred_fake, torch.ones(pred_fake.size(), device=device))

        fake_A = netG_B2A(real_B) # generating images from B to A domain
        pred_fake, class_fake_A = netD_A(fake_A)  # predicting real/fake and class
        loss_GAN_B2A = criterion_GAN(pred_fake, torch.ones(pred_fake.size(), device=device))

        # Cycle loss
        recovered_A = netG_B2A(fake_B)
        loss_cycle_ABA = criterion_cycle(recovered_A, real_A)

        recovered_B = netG_A2B(fake_A)
        loss_cycle_BAB = criterion_cycle(recovered_B, real_B)

        # Classification loss for fake images (benign/malignant)
        target_fake_B = torch.full((class_fake_B.size(0),), 1, device=device, dtype=torch.long)  # All fake_B are malignant (label 1)
        target_fake_A = torch.full((class_fake_A.size(0),), 0, device=device, dtype=torch.long)  # All fake_A are benign (label 0)
        loss_class_fake_B = criterion_classification(class_fake_B, target_fake_B)
        loss_class_fake_A = criterion_classification(class_fake_A, target_fake_A)

        # Total loss for Generators
        loss_G = (loss_id_A + loss_id_B + loss_GAN_A2B + loss_GAN_B2A +
                  loss_cycle_ABA + loss_cycle_BAB + loss_class_fake_B + loss_class_fake_A)
        loss_G.backward()
        optimizer_G.step()

        # -----------------------
        #  Train Discriminator D_A
        # -----------------------
        optimizer_D_A.zero_grad()

        # Real loss
        pred_real_A, class_real_A = netD_A(real_A)
        loss_D_real_A = criterion_GAN(pred_real_A, torch.ones(pred_real_A.size(), device=device))
        target_real_A = torch.full((class_real_A.size(0),), 0, device=device, dtype=torch.long)  # All real_A are benign (label 0)
        loss_class_real_A = criterion_classification(class_real_A, target_real_A)

        # Fake loss (detach to avoid training G on these labels)
        pred_fake_A, class_fake_A = netD_A(fake_A.detach())
        loss_D_fake_A = criterion_GAN(pred_fake_A, torch.zeros(pred_fake_A.size(), device=device))
        loss_class_fake_A = criterion_classification(class_fake_A, target_fake_A)

        # Total loss for Discriminator A
        loss_D_A = (loss_D_real_A + loss_D_fake_A + loss_class_real_A + loss_class_fake_A) / 2
        loss_D_A.backward()
        optimizer_D_A.step()

        # -----------------------
        #  Train Discriminator D_B
        # -----------------------
        optimizer_D_B.zero_grad()

        # Real loss
        pred_real_B, class_real_B = netD_B(real_B)
        loss_D_real_B = criterion_GAN(pred_real_B, torch.ones(pred_real_B.size(), device=device))
        target_real_B = torch.full((class_real_B.size(0),), 1, device=device, dtype=torch.long)  # All real_B are malignant (label 1)
        loss_class_real_B = criterion_classification(class_real_B, target_real_B)

        # Fake loss (detach to avoid training G on these labels)
        pred_fake_B, class_fake_B = netD_B(fake_B.detach())
        loss_D_fake_B = criterion_GAN(pred_fake_B, torch.zeros(pred_fake_B.size(), device=device))
        loss_class_fake_B = criterion_classification(class_fake_B, target_fake_B)

        # Total loss for Discriminator B
        loss_D_B = (loss_D_real_B + loss_D_fake_B + loss_class_real_B + loss_class_fake_B) / 2
        loss_D_B.backward()
        optimizer_D_B.step()

        # ---------------------
        #  Log Progress
        # ---------------------
        print(f"Epoch [{epoch}/{num_epochs}] Batch {i}/{len(loader_A)} \
              Loss D_A: {loss_D_A.item()}, Loss D_B: {loss_D_B.item()} \
              Loss G: {loss_G.item()}")

        # If at save interval => save generated image samples
        if i % 20 == 0:  # For example, visualize every 20 batches
            plot_single_real_and_fake_image(real_A[0], fake_B[0])  # Pass the first image of the batch
            plot_single_real_and_fake_image(real_B[0], fake_A[0])

            # Predict class of images and print
            _, class_real_A = netD_A(real_A)
            _, class_real_B = netD_B(real_B)
            _, class_fake_A = netD_A(fake_A)
            _, class_fake_B = netD_B(fake_B)

            pred_class_real_A = class_real_A.argmax(dim=1).item()
            pred_class_real_B = class_real_B.argmax(dim=1).item()
            pred_class_fake_A = class_fake_A.argmax(dim=1).item()
            pred_class_fake_B = class_fake_B.argmax(dim=1).item()

            print(f"Predicted class for real_A: {'Benign' if pred_class_real_A == 0 else 'Malignant'}")
            print(f"Predicted class for real_B: {'Benign' if pred_class_real_B == 0 else 'Malignant'}")
            print(f"Predicted class for fake_A: {'Benign' if pred_class_fake_A == 0 else 'Malignant'}")
            print(f"Predicted class for fake_B: {'Benign' if pred_class_fake_B == 0 else 'Malignant'}")

    # Record the end time of the epoch
    epoch_end_time = time.time()
    epoch_duration = epoch_end_time - epoch_start_time
    print(f"Epoch {epoch} completed in {epoch_duration:.2f} seconds")

    # Update learning rates
    # lr_scheduler_G.step()
    # lr_scheduler_D_A.step()
    # lr_scheduler_D_B.step()

    # Save models at the end of each epoch or at specific intervals
    if (epoch + 1) % 1 == 0:  # Every epoch
       save_models_to_drive(epoch, netG_A2B, netG_B2A, netD_A, netD_B)

# After the final epoch, save the refined generated images
# save_final_generated_images(G, dataloader, classifier, epoch=num_epochs, base_directory="output_breakhis/final_images", device=device)

# Record the total training end time
total_training_end_time = time.time()
total_training_duration = total_training_end_time - total_training_start_time
print(f"Total training time: {total_training_duration:.2f} seconds")