In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image
import os
import itertools
from torch.autograd import Variable

In [None]:
# Clear CUDA cache
torch.cuda.empty_cache()

In [None]:
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

In [None]:
# Load Data
#Sequence of transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),  # Resize images to 224x224 pixels
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
new_size = 1 #batch size
dataset_A = datasets.ImageFolder(root='/content/drive/MyDrive/my_data/train/benign', transform=transform) #load benign dataset
dataset_B = datasets.ImageFolder(root='/content/drive/MyDrive/my_data/train/malignant', transform=transform) #load malignant dataset
#Data loader for each dataset
loader_A = torch.utils.data.DataLoader(dataset_A, batch_size=new_size, shuffle=True)
loader_B = torch.utils.data.DataLoader(dataset_B, batch_size=new_size, shuffle=True)

In [None]:
import matplotlib.pyplot as plt

def imshow(img):
    img = img.numpy().transpose((1, 2, 0))
    plt.imshow(img)
    plt.axis('off')

dataiter_A = iter(loader_A)
images_A, _ = next(dataiter_A)

dataiter_B = iter(loader_B)
images_B, _ = next(dataiter_B)

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title('Sample from benign ()')
imshow(images_A[0])

plt.subplot(1, 2, 2)
plt.title('Sample from malignant ()')
imshow(images_B[0])
plt.show()

In [None]:
#######Generator##########
#########################

class ResnetGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, n_blocks=9, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
        assert(n_blocks >= 0)
        super(ResnetGenerator, self).__init__()
        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf #no of generator filters
        #n_blocks = resnet blocks

        #Initial convlutional block
        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),
                 norm_layer(ngf),
                 nn.ReLU(True)]

        # Downsample
        #reducing spatial dimensions
        n_downsampling = 2 #no. of downsampling layers
        for i in range(n_downsampling):
            mult = 2**i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=True),
                      norm_layer(ngf * mult * 2),
                      nn.ReLU(True)]

        # Resnet blocks
        #using shortcut connections bypassing few layers
        mult = 2**n_downsampling
        for i in range(n_blocks):
            model += [ResnetBlock(ngf * mult, padding_type='reflect', norm_layer=norm_layer, use_dropout=use_dropout, use_bias=True)]

        # Upsample
        for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                         kernel_size=3, stride=2,
                                         padding=1, output_padding=1,
                                         bias=True),
                      norm_layer(int(ngf * mult / 2)),
                      nn.ReLU(True)]

        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, input):
        return self.model(input)
#resnet block
class ResnetBlock(nn.Module):
    def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        super(ResnetBlock, self).__init__()
        # Create the convolutional block
        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
    #function to build convolution block
    def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        conv_block = [] #to hold layers
        p = 0

        #determining padding type for first layer
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        else:
            p = 1  # 'zero' padding
        #first convolution layer
        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
                       norm_layer(dim),
                       nn.ReLU(True)]
        #opyional dropout layer
        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        #determining padding type for 2nd layer
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        else:
            p = 1  # 'zero' padding
        #second convolutional block
        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
                       norm_layer(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

class CustomDiscriminator(nn.Module):
    def __init__(self, num_classes=2):
        super(CustomDiscriminator, self).__init__()
        # Load the pre-trained VGG16 model
        vgg16 = models.vgg16(pretrained=True)

        # Remove the classifier part of VGG16
        self.features = vgg16.features

        # Calculate the size of the feature map after VGG16 features
        # Assuming input image size of (3, 224, 224)
        self.feature_map_size = 512 * 7 * 7

        # Define the classifier for fake/real
        self.fake_real_classifier = nn.Sequential(
            nn.Linear(self.feature_map_size, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 1),  # Binary classification (fake/real)
            nn.Sigmoid()
        )

        # Define the classifier for benign/malignant
        self.benign_malignant_classifier = nn.Sequential(
            nn.Linear(self.feature_map_size, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes)  # Multi-class classification (benign/malignant)
        )

    def forward(self, x):
        # Extract features using VGG16
        features = self.features(x)
        features = features.view(features.size(0), -1)  # Flatten the feature map

        # Fake/Real classification
        fake_real_output = self.fake_real_classifier(features)

        # Benign/Malignant classification
        benign_malignant_output = self.benign_malignant_classifier(features)

        return fake_real_output, benign_malignant_output


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
input_nc = 3  # number of channels in the input images
output_nc = 3  # number of channels in the output images
n_residual_blocks = 9  # typical number for a CycleGAN

lr = 0.0002
beta1 = 0.5

In [None]:
# Generators
netG_A2B = ResnetGenerator(input_nc, output_nc, n_blocks=n_residual_blocks).to(device)
netG_B2A = ResnetGenerator(input_nc, output_nc, n_blocks=n_residual_blocks).to(device)

# Discriminators
D_A = CustomDiscriminator().to(device)
D_B = CustomDiscriminator().to(device)

In [None]:
from torch.optim import Adam

# Optimizers
# Define optimizers
optimizer_G = optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), lr=lr, betas=(beta1, 0.999))
optimizer_D_A = optim.Adam(D_A.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D_B = optim.Adam(D_B.parameters(), lr=lr, betas=(beta1, 0.999))

In [None]:
# Define loss functions
criterion_GAN = nn.MSELoss().to(device)
criterion_cycle = nn.L1Loss().to(device)
criterion_identity = nn.L1Loss().to(device)
criterion_classification = nn.CrossEntropyLoss().to(device)

In [None]:
import matplotlib.pyplot as plt
import torchvision.utils as vutils

def plot_single_real_and_fake_image(real_image, fake_image):
    """
    Plots a comparison of a single real and a single generated (fake) image.

    Parameters:
    - real_image: a single Tensor image (C, H, W).
    - fake_image: a single Tensor image (C, H, W).
    """
    plt.figure(figsize=(8, 4))

    # Display the real image
    plt.subplot(1, 2, 1)
    plt.axis("off")
    plt.title("Real Image")
    real_image = vutils.make_grid(real_image, normalize=True).permute(1, 2, 0).cpu().numpy()
    plt.imshow(real_image)

    # Display the fake image
    plt.subplot(1, 2, 2)
    plt.axis("off")
    plt.title("Generated Image")
    fake_image = vutils.make_grid(fake_image, normalize=True).permute(1, 2, 0).cpu().numpy()
    plt.imshow(fake_image)

    plt.show()

In [None]:
import os
import torch

def save_models_to_drive(epoch, netG_A2B, netG_B2A, D_A, D_B, drive_path='/content/drive/MyDrive/vgg_16_CycleGAN_Models'):
    """
    Save model parameters to Google Drive.

    Parameters:
        epoch (int): The current epoch number.
        netG_A2B (nn.Module): Generator model from domain A to B.
        netG_B2A (nn.Module): Generator model from domain B to A.
        netD_A (nn.Module): Discriminator model for domain A.
        netD_B (nn.Module): Discriminator model for domain B.
        drive_path (str): The path in Google Drive to save the models.
    """
    if not os.path.exists(drive_path):
        os.makedirs(drive_path)

    # Define file paths for saving
    path_G_A2B = os.path.join(drive_path, f'netG_A2B_epoch_{epoch}.pth')
    path_G_B2A = os.path.join(drive_path, f'netG_B2A_epoch_{epoch}.pth')
    path_D_A = os.path.join(drive_path, f'netD_A_epoch_{epoch}.pth')
    path_D_B = os.path.join(drive_path, f'netD_B_epoch_{epoch}.pth')

    # Save the models
    torch.save(netG_A2B.state_dict(), path_G_A2B)
    torch.save(netG_B2A.state_dict(), path_G_B2A)
    torch.save(D_A.state_dict(), path_D_A)
    torch.save(D_B.state_dict(), path_D_B)

    print(f"Saved models at epoch {epoch} to {drive_path}")

# Example usage within the training loop:
# save_models_to_drive(epoch, netG_A2B, netG_B2A, netD_A, netD_B)


In [None]:
# Clear CUDA cache
torch.cuda.empty_cache()

In [None]:
num_epochs = 25

import torch
from torch import nn, optim
from torch.autograd import Variable
import itertools
import time

import os

# Record the total training start time
total_training_start_time = time.time()

# Training Loop
for epoch in range(num_epochs):

    # Record the start time of the epoch
    epoch_start_time = time.time()

    for i, (real_A, real_B) in enumerate(zip(loader_A, loader_B)):
        # Set model input
        real_A = Variable(real_A[0].to(device)) # moving images from domain A to CUDA
        real_B = Variable(real_B[0].to(device)) # moving images from domain B to CUDA

        # -------------------------------
        #  Train Generators A2B and B2A
        # -------------------------------
        optimizer_G.zero_grad()

        # Identity loss
        loss_id_A = criterion_identity(netG_B2A(real_A), real_A)
        loss_id_B = criterion_identity(netG_A2B(real_B), real_B)

        # GAN loss
        fake_B = netG_A2B(real_A) # generating images from A to B domain
        pred_fake, class_fake_B = D_B(fake_B)  # predicting real/fake and class
        loss_GAN_A2B = criterion_GAN(pred_fake, torch.ones(pred_fake.size(), device=device))

        fake_A = netG_B2A(real_B) # generating images from B to A domain
        pred_fake, class_fake_A = D_A(fake_A)  # predicting real/fake and class
        loss_GAN_B2A = criterion_GAN(pred_fake, torch.ones(pred_fake.size(), device=device))

        # Cycle loss
        recovered_A = netG_B2A(fake_B)
        loss_cycle_ABA = criterion_cycle(recovered_A, real_A)

        recovered_B = netG_A2B(fake_A)
        loss_cycle_BAB = criterion_cycle(recovered_B, real_B)

        # Classification loss for fake images (benign/malignant)
        target_fake_B = torch.full((class_fake_B.size(0),), 1, device=device, dtype=torch.long)  # All fake_B are malignant (label 1)
        target_fake_A = torch.full((class_fake_A.size(0),), 0, device=device, dtype=torch.long)  # All fake_A are benign (label 0)
        loss_class_fake_B = criterion_classification(class_fake_B, target_fake_B)
        loss_class_fake_A = criterion_classification(class_fake_A, target_fake_A)

        # Total loss for Generators
        loss_G = (loss_id_A + loss_id_B + loss_GAN_A2B + loss_GAN_B2A +
                  loss_cycle_ABA + loss_cycle_BAB + loss_class_fake_B + loss_class_fake_A)
        loss_G.backward()
        optimizer_G.step()

        # -----------------------
        #  Train Discriminator D_A
        # -----------------------
        optimizer_D_A.zero_grad()

        # Real loss
        pred_real_A, class_real_A = D_A(real_A)
        loss_D_real_A = criterion_GAN(pred_real_A, torch.ones(pred_real_A.size(), device=device))
        target_real_A = torch.full((class_real_A.size(0),), 0, device=device, dtype=torch.long)  # All real_A are benign (label 0)
        loss_class_real_A = criterion_classification(class_real_A, target_real_A)

        # Fake loss (detach to avoid training G on these labels)
        pred_fake_A, class_fake_A = D_A(fake_A.detach())
        loss_D_fake_A = criterion_GAN(pred_fake_A, torch.zeros(pred_fake_A.size(), device=device))
        loss_class_fake_A = criterion_classification(class_fake_A, target_fake_A)

        # Total loss for Discriminator A
        loss_D_A = (loss_D_real_A + loss_D_fake_A + loss_class_real_A + loss_class_fake_A) / 2
        loss_D_A.backward()
        optimizer_D_A.step()

        # -----------------------
        #  Train Discriminator D_B
        # -----------------------
        optimizer_D_B.zero_grad()

        # Real loss
        pred_real_B, class_real_B = D_B(real_B)
        loss_D_real_B = criterion_GAN(pred_real_B, torch.ones(pred_real_B.size(), device=device))
        target_real_B = torch.full((class_real_B.size(0),), 1, device=device, dtype=torch.long)  # All real_B are malignant (label 1)
        loss_class_real_B = criterion_classification(class_real_B, target_real_B)

        # Fake loss (detach to avoid training G on these labels)
        pred_fake_B, class_fake_B = D_B(fake_B.detach())
        loss_D_fake_B = criterion_GAN(pred_fake_B, torch.zeros(pred_fake_B.size(), device=device))
        loss_class_fake_B = criterion_classification(class_fake_B, target_fake_B)

        # Total loss for Discriminator B
        loss_D_B = (loss_D_real_B + loss_D_fake_B + loss_class_real_B + loss_class_fake_B) / 2
        loss_D_B.backward()
        optimizer_D_B.step()

        # ---------------------
        #  Log Progress
        # ---------------------
        print(f"Epoch [{epoch}/{num_epochs}] Batch {i}/{len(loader_A)} \
              Loss D_A: {loss_D_A.item()}, Loss D_B: {loss_D_B.item()} \
              Loss G: {loss_G.item()}")

        # If at save interval => save generated image samples
        if i % 20 == 0:  # For example, visualize every 20 batches
            plot_single_real_and_fake_image(real_A[0], fake_B[0])  # Pass the first image of the batch
            plot_single_real_and_fake_image(real_B[0], fake_A[0])

            # Predict class of images and print
            _, class_real_A = D_A(real_A)
            _, class_real_B = D_B(real_B)
            _, class_fake_A = D_A(fake_A)
            _, class_fake_B = D_B(fake_B)

            pred_class_real_A = class_real_A.argmax(dim=1).item()
            pred_class_real_B = class_real_B.argmax(dim=1).item()
            pred_class_fake_A = class_fake_A.argmax(dim=1).item()
            pred_class_fake_B = class_fake_B.argmax(dim=1).item()

            print(f"Predicted class for real_A: {'Benign' if pred_class_real_A == 0 else 'Malignant'}")
            print(f"Predicted class for real_B: {'Benign' if pred_class_real_B == 0 else 'Malignant'}")
            print(f"Predicted class for fake_A: {'Benign' if pred_class_fake_A == 0 else 'Malignant'}")
            print(f"Predicted class for fake_B: {'Benign' if pred_class_fake_B == 0 else 'Malignant'}")

    # Record the end time of the epoch
    epoch_end_time = time.time()
    epoch_duration = epoch_end_time - epoch_start_time
    print(f"Epoch {epoch} completed in {epoch_duration:.2f} seconds")

    # Update learning rates
    #lr_scheduler_G.step()
    #lr_scheduler_D_A.step()
    #lr_scheduler_D_B.step()

    if (epoch + 1) % 1 == 0:  # Every 10 epochs
       save_models_to_drive(epoch, netG_A2B, netG_B2A, D_A, D_B)

# After the final epoch, save the refined generated images
#save_final_generated_images(G, dataloader, classifier, epoch=num_epochs, base_directory="output_breakhis/final_images", device=device)

# Record the total training end time
total_training_end_time = time.time()
total_training_duration = total_training_end_time - total_training_start_time
print(f"Total training time: {total_training_duration:.2f} seconds")