In [None]:

import os
import shutil
import random

# Define base directories
BASE_DIR = "/kaggle/working" # Kaggle's default working directory
DATA_INPUT_DIR = "/kaggle/input/bangladeshi-hospitals-eye-dataset/Bangladeshi Hospitals Dataset" # Adjust if your dataset name is different

# Create project subdirectories
os.makedirs(os.path.join(BASE_DIR, 'models'), exist_ok=True)
os.makedirs(os.path.join(BASE_DIR, 'utils'), exist_ok=True)
os.makedirs(os.path.join(BASE_DIR, 'checkpoints'), exist_ok=True)
os.makedirs(os.path.join(BASE_DIR, 'results'), exist_ok=True)
os.makedirs(os.path.join(BASE_DIR, 'data/trainA'), exist_ok=True)
os.makedirs(os.path.join(BASE_DIR, 'data/trainB'), exist_ok=True)
os.makedirs(os.path.join(BASE_DIR, 'data/testA'), exist_ok=True)
os.makedirs(os.path.join(BASE_DIR, 'data/testB'), exist_ok=True)

print("Directories created.")

In [None]:
# Cell 2: Install Dependencies


!pip install torch torchvision numpy matplotlib Pillow

print("Dependencies installed.")

In [None]:
%%writefile utils/image_pool.py
import random
import torch

class ImagePool():
    """This class implements an image buffer that stores previously generated images.
    This buffer allows us to update discriminators using a history of generated images
    rather than only the ones produced by the latest generator.
    This can help stabilize training.
    """
    def __init__(self, pool_size):
        """Initialize the ImagePool class
        Parameters:
            pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be used
        """
        self.pool_size = pool_size
        if self.pool_size > 0:  # create an empty pool
            self.num_imgs = 0
            self.images = []

    def query(self, images):
        """Return an image from the pool.
        Parameters:
            images: the latest generated images from the generator
        Returns images from the buffer.

        By 50/100, the buffer will return previously generated images rather than newly generated ones.
        """
        if self.pool_size == 0:  # if the buffer size is 0, do nothing
            return images
        return_images = []
        for image in images:
            image = torch.unsqueeze(image.data, 0)
            if self.num_imgs < self.pool_size:   # if the buffer is not full; keep inserting current images to the buffer
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:  # the buffer is full; randomly select an image from the buffer or the current image
                p = random.uniform(0, 1)
                if p > 0.5:  # by 50% chance, the buffer will return a previously generated image
                    random_id = random.randint(0, self.pool_size - 1)  # randint is inclusive
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:  # by another 50% chance, do not add the current image to the buffer
                    return_images.append(image)
        return_images = torch.cat(return_images, 0)  # collect all the images and return
        return return_images

In [None]:
%%writefile utils/transforms.py
from torchvision import transforms

def get_transforms(image_size=(256, 256), domain='A'):
    """
    Returns a composed set of transformations for CycleGAN images.
    Includes stronger data augmentation for domain B (healthy images).
    """
    if domain == 'B':
        return transforms.Compose([
            transforms.Resize(image_size),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(15),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
            transforms.ColorJitter(brightness=0.3, contrast=0.3),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
    else:
        return transforms.Compose([
            transforms.Resize(image_size),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

def get_test_transforms(image_size=(256, 256)):
    """
    Returns a composed set of transformations for testing images (no augmentation).
    """
    return transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

# Example Usage (optional, can be removed for cleaner notebook)
if __name__ == '__main__':
    train_transform_A = get_transforms(domain='A')
    train_transform_B = get_transforms(domain='B')
    test_transform = get_test_transforms()
    print("Train transforms (A):", train_transform_A)
    print("Train transforms (B):", train_transform_B)
    print("Test transforms:", test_transform)

In [None]:
%%writefile utils/dataset.py
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import random

class CycleGAN_Dataset(Dataset):
    def __init__(self, root_A, root_B, transform=None, transform_B=None, max_images=None):
        self.root_A = root_A
        self.root_B = root_B
        self.transform = transform
        self.transform_B = transform_B if transform_B is not None else transform

        files_A_all = [os.path.join(root_A, f) for f in os.listdir(root_A) if os.path.isfile(os.path.join(root_A, f))]
        files_B_all = [os.path.join(root_B, f) for f in os.listdir(root_B) if os.path.isfile(os.path.join(root_B, f))]

        # Shuffle and limit files if max_images is specified
        if max_images is not None:
            random.shuffle(files_A_all)
            random.shuffle(files_B_all)
            self.files_A = files_A_all[:min(len(files_A_all), max_images)]
            self.files_B = files_B_all[:min(len(files_B_all), max_images)]
        else:
            self.files_A = files_A_all
            self.files_B = files_B_all

        self.len_A = len(self.files_A)
        self.len_B = len(self.files_B)
        self.length = max(self.len_A, self.len_B)

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        img_A = Image.open(self.files_A[index % self.len_A]).convert("RGB")
        img_B = Image.open(self.files_B[index % self.len_B]).convert("RGB")

        if self.transform:
            img_A = self.transform(img_A)
        if self.transform_B:
            img_B = self.transform_B(img_B)

        return img_A, img_B

# Example Usage (optional, can be removed for cleaner notebook)
if __name__ == '__main__':
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    dataset_limited = CycleGAN_Dataset(root_A='data/trainA', root_B='data/trainB', transform=transform, max_images=100)
    print(f"Dataset length (limited to 100 per domain): {len(dataset_limited)}")
    dataset_full = CycleGAN_Dataset(root_A='data/trainA', root_B='data/trainB', transform=transform)
    print(f"Dataset length (full): {len(dataset_full)}")


In [None]:
%%writefile models/generator_A2B.py
import torch
import torch.nn as nn

class ResnetBlock(nn.Module):
    """Define a Resnet block"""
    def __init__(self, dim):
        super(ResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim)

    def build_conv_block(self, dim):
        conv_block = []
        conv_block += [nn.ReflectionPad2d(1)]
        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True), nn.InstanceNorm2d(dim), nn.ReLU(True)]
        conv_block += [nn.ReflectionPad2d(1)]
        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True), nn.InstanceNorm2d(dim)]
        return nn.Sequential(*conv_block)

    def forward(self, x):
        out = x + self.conv_block(x)
        return out

class Generator(nn.Module):
    """Resnet-based generator"""
    def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.InstanceNorm2d, use_dropout=False, n_blocks=9, padding_type='reflect'):
        assert(n_blocks >= 0)
        super(Generator, self).__init__()
        if type(norm_layer) == nn.BatchNorm2d:
            use_bias = False
        else:
            use_bias = True

        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
                 norm_layer(ngf),
                 nn.ReLU(True)]

        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2 ** i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
                      norm_layer(ngf * mult * 2),
                      nn.ReLU(True)]

        mult = 2 ** n_downsampling
        for i in range(n_blocks):
            model += [ResnetBlock(ngf * mult)]

        for i in range(n_downsampling):
            mult = 2 ** (n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                         kernel_size=3, stride=2,
                                         padding=1, output_padding=1,
                                         bias=use_bias),
                      norm_layer(int(ngf * mult / 2)),
                      nn.ReLU(True)]
        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, input):
        return self.model(input)

# Example Usage (optional, can be removed for cleaner notebook)
if __name__ == '__main__':
    generator = Generator(input_nc=3, output_nc=3)
    dummy_input = torch.randn(1, 3, 256, 256)
    output = generator(dummy_input)
    print(f"Input shape: {dummy_input.shape}")
    print(f"Output shape: {output.shape}")


In [None]:
%%writefile models/generator_B2A.py
import torch
import torch.nn as nn

class ResnetBlock(nn.Module):
    """Define a Resnet block"""
    def __init__(self, dim):
        super(ResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim)

    def build_conv_block(self, dim):
        conv_block = []
        conv_block += [nn.ReflectionPad2d(1)]
        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True), nn.InstanceNorm2d(dim), nn.ReLU(True)]
        conv_block += [nn.ReflectionPad2d(1)]
        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True), nn.InstanceNorm2d(dim)]
        return nn.Sequential(*conv_block)

    def forward(self, x):
        out = x + self.conv_block(x)
        return out

class Generator(nn.Module):
    """Resnet-based generator"""
    def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.InstanceNorm2d, use_dropout=False, n_blocks=9, padding_type='reflect'):
        assert(n_blocks >= 0)
        super(Generator, self).__init__()
        if type(norm_layer) == nn.BatchNorm2d:
            use_bias = False
        else:
            use_bias = True

        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
                 norm_layer(ngf),
                 nn.ReLU(True)]

        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2 ** i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
                      norm_layer(ngf * mult * 2),
                      nn.ReLU(True)]

        mult = 2 ** n_downsampling
        for i in range(n_blocks):
            model += [ResnetBlock(ngf * mult)]

        for i in range(n_downsampling):
            mult = 2 ** (n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                         kernel_size=3, stride=2,
                                         padding=1, output_padding=1,
                                         bias=use_bias),
                      norm_layer(int(ngf * mult / 2)),
                      nn.ReLU(True)]
        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, input):
        return self.model(input)

# Example Usage (optional, can be removed for cleaner notebook)
if __name__ == '__main__':
    generator = Generator(input_nc=3, output_nc=3)
    dummy_input = torch.randn(1, 3, 256, 256)
    output = generator(dummy_input)
    print(f"Input shape: {dummy_input.shape}")
    print(f"Output shape: {output.shape}")

In [None]:
%%writefile models/discriminator_A.py
import torch
import torch.nn as nn

class Discriminator(nn.Module):
    """Defines a PatchGAN discriminator"""
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d):
        super(Discriminator, self).__init__()
        if type(norm_layer) == nn.BatchNorm2d:
            use_bias = False
        else:
            use_bias = True

        kw = 4
        padw = 1
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        return self.model(input)

# Example Usage (optional, can be removed for cleaner notebook)
if __name__ == '__main__':
    discriminator = Discriminator(input_nc=3)
    dummy_input = torch.randn(1, 3, 256, 256)
    output = discriminator(dummy_input)
    print(f"Input shape: {dummy_input.shape}")
    print(f"Output shape: {output.shape}")

In [None]:
%%writefile models/discriminator_B.py
import torch
import torch.nn as nn

class Discriminator(nn.Module):
    """Defines a PatchGAN discriminator"""
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d):
        super(Discriminator, self).__init__()
        if type(norm_layer) == nn.BatchNorm2d:
            use_bias = False
        else:
            use_bias = True

        kw = 4
        padw = 1
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        return self.model(input)

# Example Usage (optional, can be removed for cleaner notebook)
if __name__ == '__main__':
    discriminator = Discriminator(input_nc=3)
    dummy_input = torch.randn(1, 3, 256, 256)
    output = discriminator(dummy_input)
    print(f"Input shape: {dummy_input.shape}")
    print(f"Output shape: {output.shape}")

In [None]:

# Cell 10: Data Organization (CORRECTED)

import os
import shutil
import random

# Ensure all necessary directories exist
os.makedirs(os.path.join(BASE_DIR, 'data/trainA'), exist_ok=True)
os.makedirs(os.path.join(BASE_DIR, 'data/trainB'), exist_ok=True)
os.makedirs(os.path.join(BASE_DIR, 'data/valA'), exist_ok=True)
os.makedirs(os.path.join(BASE_DIR, 'data/valB'), exist_ok=True)
os.makedirs(os.path.join(BASE_DIR, 'data/testA'), exist_ok=True)
os.makedirs(os.path.join(BASE_DIR, 'data/testB'), exist_ok=True)

# Move Healthy images to a temporary 'all_healthy' directory
temp_all_healthy_dir = os.path.join(BASE_DIR, 'data/all_healthy')
os.makedirs(temp_all_healthy_dir, exist_ok=True)
healthy_source_dir = os.path.join(DATA_INPUT_DIR, 'Healthy')
for f in os.listdir(healthy_source_dir):
    if os.path.isfile(os.path.join(healthy_source_dir, f)):
        shutil.copy(os.path.join(healthy_source_dir, f), os.path.join(temp_all_healthy_dir, f))
print("Copied Healthy images to temporary all_healthy directory.")

# Collect Diseased images from all subfolders, grouped by class
temp_all_diseased_dir = os.path.join(BASE_DIR, 'data/all_diseased')
os.makedirs(temp_all_diseased_dir, exist_ok=True)
diseased_dirs = [d for d in os.listdir(DATA_INPUT_DIR) if os.path.isdir(os.path.join(DATA_INPUT_DIR, d)) and d != 'Healthy']
diseased_files_by_class = {}
for d in diseased_dirs:
    class_dir = os.path.join(DATA_INPUT_DIR, d)
    files = [os.path.join(class_dir, f) for f in os.listdir(class_dir) if os.path.isfile(os.path.join(class_dir, f))]
    diseased_files_by_class[d] = files
    for f in files:
        shutil.copy(f, os.path.join(temp_all_diseased_dir, os.path.basename(f)))
print("Copied Diseased images to temporary all_diseased directory.")

# Balance diseased images across classes (9 classes, target 2676 images)
total_diseased_target = 2676  # Match healthy images
num_classes = len(diseased_dirs)  # Should be 9
images_per_class = total_diseased_target // num_classes  # ~297 images per class
remainder = total_diseased_target % num_classes  # 3 extra images
selected_diseased_files = []
for class_name, files in diseased_files_by_class.items():
    random.shuffle(files)
    num_to_select = images_per_class + (1 if remainder > 0 else 0)
    remainder -= 1 if remainder > 0 else 0
    selected_diseased_files.extend(files[:min(len(files), num_to_select)])
    print(f"Selected {min(len(files), num_to_select)} images from {class_name}")

# Move selected diseased files back to temp_all_diseased_dir
shutil.rmtree(temp_all_diseased_dir)
os.makedirs(temp_all_diseased_dir, exist_ok=True)
for f in selected_diseased_files:
    shutil.copy(f, os.path.join(temp_all_diseased_dir, os.path.basename(f)))
print(f"Total selected diseased images: {len(selected_diseased_files)}")

# Split data into train, val, test
def split_dataset(source_dir, train_dir, val_dir, test_dir, val_ratio=0.2, test_ratio=0.2):
    all_files = [f for f in os.listdir(source_dir) if os.path.isfile(os.path.join(source_dir, f))]
    random.shuffle(all_files)
    num_total = len(all_files)
    num_test = int(num_total * test_ratio)
    num_val = int(num_total * val_ratio)
    num_train = num_total - num_test - num_val

    if num_train == 0 and num_total > 0: num_train = 1
    if num_val == 0 and num_total > 1: num_val = 1
    if num_test == 0 and num_total > 2: num_test = 1
    if num_train + num_val + num_test > num_total:
        diff = (num_train + num_val + num_test) - num_total
        if num_train > diff: num_train -= diff
        elif num_val > diff: num_val -= diff
        elif num_test > diff: num_test -= diff

    train_files = all_files[:num_train]
    val_files = all_files[num_train:num_train + num_val]
    test_files = all_files[num_train + num_val:num_train + num_val + num_test]

    print(f"Total files: {num_total}, Train: {len(train_files)}, Val: {len(val_files)}, Test: {len(test_files)}")

    for f in train_files:
        shutil.move(os.path.join(source_dir, f), os.path.join(train_dir, f))
    for f in val_files:
        shutil.move(os.path.join(source_dir, f), os.path.join(val_dir, f))
    for f in test_files:
        shutil.move(os.path.join(source_dir, f), os.path.join(test_dir, f))

    print(f"Split data from {source_dir} into {train_dir}, {val_dir}, {test_dir}")

print("\nSplitting Diseased data into trainA, valA, testA...")
split_dataset(temp_all_diseased_dir, 
              os.path.join(BASE_DIR, 'data/trainA'), 
              os.path.join(BASE_DIR, 'data/valA'), 
              os.path.join(BASE_DIR, 'data/testA'))

print("\nSplitting Healthy data into trainB, valB, testB...")
split_dataset(temp_all_healthy_dir, 
              os.path.join(BASE_DIR, 'data/trainB'), 
              os.path.join(BASE_DIR, 'data/valB'), 
              os.path.join(BASE_DIR, 'data/testB'))

# Clean up temporary directories
shutil.rmtree(temp_all_diseased_dir)
shutil.rmtree(temp_all_healthy_dir)
print("\nTemporary data directories cleaned up.")
print("Final data organization complete.")


In [None]:
%%writefile train.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os
import itertools
import numpy as np
import matplotlib.pyplot as plt

# Import modules from the current working directory
import sys
sys.path.append('/kaggle/working/utils')
sys.path.append('/kaggle/working/models')

from dataset import CycleGAN_Dataset
from transforms import get_transforms
from image_pool import ImagePool
from generator_A2B import Generator as Generator_A2B
from generator_B2A import Generator as Generator_B2A
from discriminator_A import Discriminator as Discriminator_A
from discriminator_B import Discriminator as Discriminator_B

# --- Hyperparameters and Configuration ---
class Config:
    def __init__(self):
        self.dataroot_A = '/kaggle/working/data/trainA'
        self.dataroot_B = '/kaggle/working/data/trainB'
        self.val_dataroot_A = '/kaggle/working/data/valA'
        self.val_dataroot_B = '/kaggle/working/data/valB'
        self.test_dataroot_A = '/kaggle/working/data/testA'
        self.test_dataroot_B = '/kaggle/working/data/testB'
        self.batch_size = 16
        self.image_size = 128  # Adjusted based on your preference
        self.input_nc = 3
        self.output_nc = 3
        self.ngf = 64
        self.ndf = 64
        self.n_resnet_blocks = 9
        self.lr_g = 0.0002  # Generator learning rate
        self.lr_d = 0.0004  # Discriminator learning rate (TTUR)
        self.beta1 = 0.5
        self.num_epochs = 100  # Adjusted based on your preference
        self.decay_epoch = 50
        self.lambda_cycle = 10.0
        self.lambda_identity = 1.0  # Reduced to encourage transformation
        self.pool_size = 50
        self.save_dir = '/kaggle/working/checkpoints'
        self.results_dir = '/kaggle/working/results'
        self.display_freq = 20
        self.save_latest_freq = 100
        self.save_epoch_freq = 5
        self.early_stopping_patience = 20
        self.early_stopping_min_delta = 0.01
        self.max_train_images_per_domain = None  # Use all available images
        self.max_val_images_per_domain = None

opt = Config()

# --- Device Configuration ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# --- Initialize Models ---
netG_A2B = Generator_A2B(opt.input_nc, opt.output_nc, opt.ngf, n_blocks=opt.n_resnet_blocks)
netG_B2A = Generator_B2A(opt.input_nc, opt.output_nc, opt.ngf, n_blocks=opt.n_resnet_blocks)
netD_A = Discriminator_A(opt.input_nc, opt.ndf)
netD_B = Discriminator_B(opt.input_nc, opt.ndf)

if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    netG_A2B = nn.DataParallel(netG_A2B)
    netG_B2A = nn.DataParallel(netG_B2A)
    netD_A = nn.DataParallel(netD_A)
    netD_B = nn.DataParallel(netD_B)

netG_A2B.to(device)
netG_B2A.to(device)
netD_A.to(device)
netD_B.to(device)

# --- Loss Functions ---
criterionGAN = nn.MSELoss()
criterionCycle = nn.L1Loss()
criterionIdentity = nn.L1Loss()

# --- Optimizers ---
optimizer_G = optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), lr=opt.lr_g, betas=(opt.beta1, 0.999))
optimizer_D_A = optim.Adam(netD_A.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999))
optimizer_D_B = optim.Adam(netD_B.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999))

# --- Learning Rate Schedulers ---
def get_scheduler(optimizer, opt):
    def lambda_rule(epoch):
        lr_l = 1.0 - max(0, epoch + 1 - opt.decay_epoch) / (opt.num_epochs - opt.decay_epoch)
        return lr_l
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
    return scheduler

scheduler_G = get_scheduler(optimizer_G, opt)
scheduler_D_A = get_scheduler(optimizer_D_A, opt)
scheduler_D_B = get_scheduler(optimizer_D_B, opt)

# --- Image Pool for Discriminator Training ---
fake_A_pool = ImagePool(opt.pool_size)
fake_B_pool = ImagePool(opt.pool_size)

# --- Data Loaders ---
transform_A = get_transforms(image_size=(opt.image_size, opt.image_size), domain='A')
transform_B = get_transforms(image_size=(opt.image_size, opt.image_size), domain='B')
train_dataset = CycleGAN_Dataset(root_A=opt.dataroot_A, root_B=opt.dataroot_B, transform=transform_A, transform_B=transform_B)
train_dataloader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=os.cpu_count() // 2 if os.cpu_count() else 0)

val_dataset = CycleGAN_Dataset(root_A=opt.val_dataroot_A, root_B=opt.val_dataroot_B, transform=transform_A, transform_B=transform_B)
val_dataloader = DataLoader(val_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=os.cpu_count() // 2 if os.cpu_count() else 0)

# --- Early Stopping ---
best_val_loss = float('inf')
epochs_no_improve = 0
early_stop = False

# --- Loss Tracking ---
train_losses_G, val_losses_G = [], []

# --- Training Loop ---
print("Starting Training Loop...")
for epoch in range(opt.num_epochs):
    if early_stop:
        print("Early stopping triggered.")
        break

    # --- Training Phase ---
    netG_A2B.train()
    netG_B2A.train()
    netD_A.train()
    netD_B.train()

    for i, (real_A, real_B) in enumerate(train_dataloader):
        real_A = real_A.to(device)
        real_B = real_B.to(device)

        # --- Train Generators G_A2B and G_B2A ---
        optimizer_G.zero_grad()

        # Identity loss
        identity_B = netG_A2B(real_B)
        loss_identity_B = criterionIdentity(identity_B, real_B) * opt.lambda_identity
        identity_A = netG_B2A(real_A)
        loss_identity_A = criterionIdentity(identity_A, real_A) * opt.lambda_identity

        # GAN loss D_A(G_A2B(A))
        fake_B = netG_A2B(real_A)
        pred_fake_B = netD_B(fake_B)
        loss_GAN_A2B = criterionGAN(pred_fake_B, torch.ones_like(pred_fake_B))

        # GAN loss D_B(G_B2A(B))
        fake_A = netG_B2A(real_B)
        pred_fake_A = netD_A(fake_A)
        loss_GAN_B2A = criterionGAN(pred_fake_A, torch.ones_like(pred_fake_A))

        # Cycle consistency loss
        reconstructed_A = netG_B2A(fake_B)
        loss_cycle_A = criterionCycle(reconstructed_A, real_A) * opt.lambda_cycle

        reconstructed_B = netG_A2B(fake_A)
        loss_cycle_B = criterionCycle(reconstructed_B, real_B) * opt.lambda_cycle

        # Total generator loss
        loss_G = loss_GAN_A2B + loss_GAN_B2A + loss_cycle_A + loss_cycle_B + loss_identity_A + loss_identity_B
        loss_G.backward()
        torch.nn.utils.clip_grad_norm_(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), max_norm=1.0)
        optimizer_G.step()

        # --- Train Discriminator D_A ---
        optimizer_D_A.zero_grad()
        pred_real_A = netD_A(real_A)
        loss_D_real_A = criterionGAN(pred_real_A, torch.ones_like(pred_real_A))
        fake_A_from_pool = fake_A_pool.query(fake_A)
        pred_fake_A = netD_A(fake_A_from_pool.detach())
        loss_D_fake_A = criterionGAN(pred_fake_A, torch.zeros_like(pred_fake_A))
        loss_D_A = (loss_D_real_A + loss_D_fake_A) * 0.5
        loss_D_A.backward()
        torch.nn.utils.clip_grad_norm_(netD_A.parameters(), max_norm=1.0)
        optimizer_D_A.step()

        # --- Train Discriminator D_B ---
        optimizer_D_B.zero_grad()
        pred_real_B = netD_B(real_B)
        loss_D_real_B = criterionGAN(pred_real_B, torch.ones_like(pred_real_B))
        fake_B_from_pool = fake_B_pool.query(fake_B)
        pred_fake_B = netD_B(fake_B_from_pool.detach())
        loss_D_fake_B = criterionGAN(pred_fake_B, torch.zeros_like(pred_fake_B))
        loss_D_B = (loss_D_real_B + loss_D_fake_B) * 0.5
        loss_D_B.backward()
        torch.nn.utils.clip_grad_norm_(netD_B.parameters(), max_norm=1.0)
        optimizer_D_B.step()

        # --- Print Training Progress ---
        if i % opt.display_freq == 0:
            print(f"Epoch [{epoch+1}/{opt.num_epochs}], Step [{i+1}/{len(train_dataloader)}], "
                  f"Loss_G: {loss_G.item():.4f}, Loss_D_A: {loss_D_A.item():.4f}, Loss_D_B: {loss_D_B.item():.4f}, "
                  f"Loss_GAN_A2B: {loss_GAN_A2B.item():.4f}, Loss_GAN_B2A: {loss_GAN_B2A.item():.4f}, "
                  f"Loss_cycle_A: {loss_cycle_A.item():.4f}, Loss_cycle_B: {loss_cycle_B.item():.4f}, "
                  f"Loss_identity_A: {loss_identity_A.item():.4f}, Loss_identity_B: {loss_identity_B.item():.4f}")
            
            # Save generated images for visualization
            if not os.path.exists(opt.results_dir):
                os.makedirs(opt.results_dir)
            save_image(fake_B.detach(), os.path.join(opt.results_dir, f'fake_B_epoch{epoch+1}_step{i+1}.png'), normalize=True)
            save_image(fake_A.detach(), os.path.join(opt.results_dir, f'fake_A_epoch{epoch+1}_step{i+1}.png'), normalize=True)
            save_image(real_A.detach(), os.path.join(opt.results_dir, f'real_A_epoch{epoch+1}_step{i+1}.png'), normalize=True)
            save_image(real_B.detach(), os.path.join(opt.results_dir, f'real_B_epoch{epoch+1}_step{i+1}.png'), normalize=True)

        # Save more frequent visualizations
        if i % 10 == 0:
            save_image(fake_B.detach(), os.path.join(opt.results_dir, f'fake_B_epoch{epoch+1}_step{i+1}.png'), normalize=True)
            save_image(real_B.detach(), os.path.join(opt.results_dir, f'real_B_epoch{epoch+1}_step{i+1}.png'), normalize=True)

        # Track training loss
        train_losses_G.append(loss_G.item())

    # Update learning rates
    scheduler_G.step()
    scheduler_D_A.step()
    scheduler_D_B.step()

    # --- Validation Phase ---
    netG_A2B.eval()
    netG_B2A.eval()
    val_losses_G = []
    with torch.no_grad():
        for i, (real_A_val, real_B_val) in enumerate(val_dataloader):
            real_A_val = real_A_val.to(device)
            real_B_val = real_B_val.to(device)

            # Calculate generator losses on validation set
            identity_B_val = netG_A2B(real_B_val)
            loss_identity_B_val = criterionIdentity(identity_B_val, real_B_val) * opt.lambda_identity
            identity_A_val = netG_B2A(real_A_val)
            loss_identity_A_val = criterionIdentity(identity_A_val, real_A_val) * opt.lambda_identity

            fake_B_val = netG_A2B(real_A_val)
            pred_fake_B_val = netD_B(fake_B_val)
            loss_GAN_A2B_val = criterionGAN(pred_fake_B_val, torch.ones_like(pred_fake_B_val))

            fake_A_val = netG_B2A(real_B_val)
            pred_fake_A_val = netD_A(fake_A_val)
            loss_GAN_B2A_val = criterionGAN(pred_fake_A_val, torch.ones_like(pred_fake_A_val))

            reconstructed_A_val = netG_B2A(fake_B_val)
            loss_cycle_A_val = criterionCycle(reconstructed_A_val, real_A_val) * opt.lambda_cycle

            reconstructed_B_val = netG_A2B(fake_A_val)
            loss_cycle_B_val = criterionCycle(reconstructed_B_val, real_B_val) * opt.lambda_cycle

            total_val_loss_G = loss_GAN_A2B_val + loss_GAN_B2A_val + loss_cycle_A_val + loss_cycle_B_val + loss_identity_A_val + loss_identity_B_val
            val_losses_G.append(total_val_loss_G.item())
    
    avg_val_loss_G = np.mean(val_losses_G)
    val_losses_G.append(avg_val_loss_G)
    print(f"Epoch [{epoch+1}/{opt.num_epochs}], Average Validation Loss G: {avg_val_loss_G:.4f}")

    # --- Early Stopping Logic ---
    if avg_val_loss_G < best_val_loss - opt.early_stopping_min_delta:
        best_val_loss = avg_val_loss_G
        epochs_no_improve = 0
        if not os.path.exists(opt.save_dir):
            os.makedirs(opt.save_dir)
        torch.save(netG_A2B.state_dict(), os.path.join(opt.save_dir, 'netG_A2B_best.pth'))
        torch.save(netG_B2A.state_dict(), os.path.join(opt.save_dir, 'netG_B2A_best.pth'))
        torch.save(netD_A.state_dict(), os.path.join(opt.save_dir, 'netD_A_best.pth'))
        torch.save(netD_B.state_dict(), os.path.join(opt.save_dir, 'netD_B_best.pth'))
        print(f"Saved best models at epoch {epoch+1} with validation loss: {best_val_loss:.4f}")
    else:
        epochs_no_improve += 1
        print(f"Validation loss did not improve for {epochs_no_improve} epochs.")
        if epochs_no_improve >= opt.early_stopping_patience:
            early_stop = True

    # Save models periodically
    if (epoch + 1) % opt.save_epoch_freq == 0:
        if not os.path.exists(opt.save_dir):
            os.makedirs(opt.save_dir)
        torch.save(netG_A2B.state_dict(), os.path.join(opt.save_dir, f'netG_A2B_epoch_{epoch+1}.pth'))
        torch.save(netG_B2A.state_dict(), os.path.join(opt.save_dir, f'netG_B2A_epoch_{epoch+1}.pth'))
        torch.save(netD_A.state_dict(), os.path.join(opt.save_dir, f'netD_A_epoch_{epoch+1}.pth'))
        torch.save(netD_B.state_dict(), os.path.join(opt.save_dir, f'netD_B_epoch_{epoch+1}.pth'))
        print(f"Models saved for epoch {epoch+1}")

# Save loss plot
plt.plot(train_losses_G, label='Train Loss G')
plt.plot(val_losses_G, label='Val Loss G')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig(os.path.join(opt.results_dir, 'loss_plot.png'))
plt.close()

print("Training complete.")

In [None]:
# Cell 12: Run Training

!python train.py

In [None]:
%%writefile predict.py
import torch
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import os

# Import modules from the current working directory (Kaggle's /kaggle/working/)
import sys
sys.path.append('/kaggle/working/models')
sys.path.append('/kaggle/working/utils')

from generator_A2B import Generator as Generator_A2B
from transforms import get_test_transforms

# --- Configuration ---
class Config:
    def __init__(self):
        self.image_size = 256
        self.input_nc = 3
        self.output_nc = 3
        self.ngf = 64 # number of generator filters
        self.n_resnet_blocks = 9
        self.model_path = '/kaggle/working/checkpoints/netG_A2B_best.pth' # Path to your trained generator model (changed to best model)
        self.input_image_path = '/kaggle/working/data/testA/sample_diseased_eye.jpg' # Path to the diseased eye image you want to translate
        self.output_image_path = '/kaggle/working/results/predicted_healthy_eye.png' # Path to save the translated healthy eye image

opt = Config()

# --- Device Configuration ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# --- Initialize Generator ---
netG_A2B = Generator_A2B(opt.input_nc, opt.output_nc, opt.ngf, n_blocks=opt.n_resnet_blocks).to(device)

# Load trained model weights
if os.path.exists(opt.model_path):
    # Load the state_dict
    state_dict = torch.load(opt.model_path, map_location=device)
    
    # Create a new ordered dict without 'module.' prefix
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('module.'):
            new_state_dict[k[7:]] = v # remove 'module.' prefix
        else:
            new_state_dict[k] = v
            
    netG_A2B.load_state_dict(new_state_dict) # Load the modified state_dict
    netG_A2B.eval() # Set generator to evaluation mode
    print(f"Loaded model from {opt.model_path}")
else:
    print(f"Error: Model not found at {opt.model_path}. Please train the model first.")
    exit()

# --- Image Transformation ---
transform = get_test_transforms(image_size=(opt.image_size, opt.image_size))

# --- Load and Transform Input Image ---
if os.path.exists(opt.input_image_path):
    input_image = Image.open(opt.input_image_path).convert("RGB")
    input_tensor = transform(input_image).unsqueeze(0).to(device) # Add batch dimension
    print(f"Loaded input image from {opt.input_image_path}")
else:
    print(f"Error: Input image not found at {opt.input_image_path}. Please provide a valid image.")
    exit()

# --- Perform Inference ---
with torch.no_grad():
    output_tensor = netG_A2B(input_tensor)

# --- Save Output Image ---
if not os.path.exists(os.path.dirname(opt.output_image_path)):
    os.makedirs(os.path.dirname(opt.output_image_path))
save_image(output_tensor, opt.output_image_path, normalize=True)
print(f"Translated image saved to {opt.output_image_path}")

print("Prediction complete.")


In [None]:
# Cell 15: Move Best Model for Download

import os
import shutil

source_path = '/kaggle/working/checkpoints/netG_A2B_best.pth'
destination_path = '/kaggle/working/netG_A2B_best.pth' # Move to the root of the working directory

if os.path.exists(source_path):
    shutil.copy(source_path, destination_path)
    print(f"Successfully copied {source_path} to {destination_path}")
else:
    print(f"Error: Model not found at {source_path}. Please ensure training completed successfully.")

# You can also move other best models if needed, e.g.:
# shutil.copy('/kaggle/working/checkpoints/netG_B2A_best.pth', '/kaggle/working/netG_B2A_best.pth')
# shutil.copy('/kaggle/working/checkpoints/netD_A_best.pth', '/kaggle/working/netD_A_best.pth')
# shutil.copy('/kaggle/working/checkpoints/netD_B_best.pth', '/kaggle/working/netD_B_best.pth')


In [None]:
# Cell 14: Run Prediction

# Before running this, ensure you have a sample image in /kaggle/working/data/testA/
# For example, you might copy one from your dataset:
# !cp /kaggle/input/bangladeshi-hospitals-eye-dataset/Bangladeshi\ Hospitals\ Dataset/Central\ Serous\ Chorioretinopathy\ \[Color\ Fundus\]/CSCR1.jpg /kaggle/working/data/testA/sample_diseased_eye.jpg

!python predict.py

In [None]:
import shutil
import os

# Define the output directory
output_dir = '/kaggle/working/'

# Remove all files and directories in /kaggle/working/
for item in os.listdir(output_dir):
    item_path = os.path.join(output_dir, item)
    try:
        if os.path.isfile(item_path) or os.path.islink(item_path):
            os.unlink(item_path)  # Remove file or link
        elif os.path.isdir(item_path):
            shutil.rmtree(item_path)  # Remove directory and its contents
    except Exception as e:
        print(f"Error deleting {item_path}: {e}")