In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import glob
import itertools
import os
import time
from tqdm import tqdm
import torchvision.transforms as transforms
from PIL import Image


IMAGE_SIZE = 128
NUM_RESIDUAL_BLOCKS = 6
BATCH_SIZE = 1
LR = 0.0002
BETA1 = 0.5
LAMBDA_CYCLE = 10.0
LAMBDA_IDENTITY = 0.5 * LAMBDA_CYCLE
NUM_EPOCHS = 20
DECAY_EPOCH = 100
CHECKPOINT_INTERVAL = 5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        conv_block = [nn.ReflectionPad2d(1),
                      nn.Conv2d(in_features, in_features, 3),
                      nn.InstanceNorm2d(in_features),
                      nn.ReLU(inplace=True),
                      nn.ReflectionPad2d(1),
                      nn.Conv2d(in_features, in_features, 3),
                      nn.InstanceNorm2d(in_features)]
        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)

class ImageBuffer():
    def __init__(self, max_size=50):
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if np.random.uniform(0, 1) > 0.5:
                    i = np.random.randint(0, self.max_size)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return torch.cat(to_return)

In [2]:
class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, num_residual_blocks=6):
        super().__init__()

        # Initial Convolutional Layers
        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, 64, 7),
                 nn.InstanceNorm2d(64),
                 nn.ReLU(inplace=True)]

        # Downsampling Layers (3 convolutions)
        in_features = 64
        out_features = in_features * 2
        for _ in range(2):
            model += [nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                      nn.InstanceNorm2d(out_features),
                      nn.ReLU(inplace=True)]
            in_features = out_features
            out_features = in_features * 2

        # Residual Blocks (6 blocks for 128x128 images)
        for _ in range(num_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling Layers (2 fractionally-strided convolutions with stride 1/2)
        out_features = in_features // 2
        for _ in range(2):
            model += [nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                      nn.InstanceNorm2d(out_features),
                      nn.ReLU(inplace=True)]
            in_features = out_features
            out_features = in_features // 2

        # Output Layer (1 convolution that maps features to RGB)
        model += [nn.ReflectionPad2d(3),
                  nn.Conv2d(64, output_nc, 7),
                  nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


class Discriminator(nn.Module):
    def __init__(self, input_nc):
        super().__init__()

        # A: C64-C128-C256-C512
        # PatchGAN uses LeakyReLU and no InstanceNorm on the first layer

        def discriminator_block(in_filters, out_filters, normalize=True):
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(input_nc, 64, normalize=False), # C64 (No Norm)
            *discriminator_block(64, 128),                       # C128
            *discriminator_block(128, 256),                      # C256
            *discriminator_block(256, 512),                      # C512 (Stride 1 on last layer)
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1)                      # Output (Maps to 1 channel)
        )

    def forward(self, x):
        # The output size for a 128x128 input will be 14x14.
        # For a 256x256 input, the output size is 30x30.
        # The PatchGAN concept is achieved by the output shape.
        return self.model(x)

In [3]:
import numpy as np
# =======================================================
# الف. Loss Functions (LSGAN Loss)
# =======================================================

class LSGANLoss(nn.Module):
    """Least Squares GAN Loss"""
    def __init__(self, target_real_label=1.0, target_fake_label=0.0):
        super().__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        self.loss = nn.MSELoss()

    def get_target_tensor(self, prediction, target_is_real):
        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(prediction)

    def __call__(self, prediction, target_is_real):
        target_tensor = self.get_target_tensor(prediction, target_is_real)
        return self.loss(prediction, target_tensor)

class L1CycleLoss(nn.Module):
    """L1 Loss for Cycle Consistency and Identity Loss"""
    def __init__(self):
        super().__init__()
        self.loss = nn.L1Loss()

    def __call__(self, input, target):
        return self.loss(input, target)

import os
import glob
from torch.utils.data import Dataset

class UnpairedDataset(Dataset):
    def __init__(self, root, dataset_name, transform=None):

        self.root_path = root
        self.dataset_name = dataset_name
        self.transform = transform

        path_A = os.path.join(self.root_path, self.dataset_name, 'trainA')
        path_B = os.path.join(self.root_path, self.dataset_name, 'trainB')

        file_search_patterns = ['*.jpg', '*.jpeg', '*.png', '*', '*.webp']

        self.files_A = []
        self.files_B = []

        for pattern in file_search_patterns:
            self.files_A.extend(glob.glob(os.path.join(path_A, pattern)))
        self.files_A = sorted(list(set(self.files_A)))

        for pattern in file_search_patterns:
            self.files_B.extend(glob.glob(os.path.join(path_B, pattern)))
        self.files_B = sorted(list(set(self.files_B)))


    def __getitem__(self, index):
        from PIL import Image

        if not self.files_A or not self.files_B:
            raise IndexError("Dataset lists are empty. Check file paths.")

        index_B = index % len(self.files_B)

        item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]).convert('RGB'))
        item_B = self.transform(Image.open(self.files_B[index_B]).convert('RGB'))

        return {'A': item_A, 'B': item_B}

    def __len__(self):
        if not self.files_A and not self.files_B:
            print(f"Error: No files found in {os.path.join(self.root_path, self.dataset_name)}/trainA or trainB.")
            return 0
        return max(len(self.files_A), len(self.files_B))



class UnpairedTestDataset(Dataset):
    def __init__(self, root, dataset_name, transform=None):
        self.transform = transform
        path_A = os.path.join(root, dataset_name, 'testA')
        path_B = os.path.join(root, dataset_name, 'testB')

        file_search_patterns = ['*.jpg', '*.jpeg', '*.png', '*', '*.webp']

        self.files_A = []
        for pattern in file_search_patterns:
            self.files_A.extend(glob.glob(os.path.join(path_A, pattern)))
        self.files_A = sorted(list(set(self.files_A)))

        self.files_B = []
        for pattern in file_search_patterns:
            self.files_B.extend(glob.glob(os.path.join(path_B, pattern)))
        self.files_B = sorted(list(set(self.files_B)))

    def __getitem__(self, index):
        from PIL import Image
        index_B = index % len(self.files_B)

        item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]).convert('RGB'))
        item_B = self.transform(Image.open(self.files_B[index_B]).convert('RGB'))
        return {'A': item_A, 'B': item_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

In [4]:
def save_checkpoint(epoch, netG_A2B, netG_B2A, netD_A, netD_B, loss_history, checkpoint_dir='checkpoints_apple'):
    os.makedirs(checkpoint_dir, exist_ok=True)

    state = {
        'epoch': epoch,
        'netG_A2B_state_dict': netG_A2B.state_dict(),
        'netG_B2A_state_dict': netG_B2A.state_dict(),
        'netD_A_state_dict': netD_A.state_dict(),
        'netD_B_state_dict': netD_B.state_dict(),
        'loss_history': loss_history
    }
    filename = os.path.join(checkpoint_dir, f'cyclegan_checkpoint_epoch_{epoch}.pth')
    torch.save(state, filename)
    print(f"\n---> Checkpoint saved successfully at: {filename}")


def train_cyclegan(netG_A2B, netG_B2A, netD_A, netD_B, dataloader):

    optimizer_G = optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), lr=LR, betas=(BETA1, 0.999))
    optimizer_D_A = optim.Adam(netD_A.parameters(), lr=LR, betas=(BETA1, 0.999))
    optimizer_D_B = optim.Adam(netD_B.parameters(), lr=LR, betas=(BETA1, 0.999))

    criterion_GAN = LSGANLoss().to(DEVICE)
    criterion_cycle = L1CycleLoss().to(DEVICE)
    criterion_identity = L1CycleLoss().to(DEVICE) # L1 Loss for identity

    fake_A_buffer = ImageBuffer()
    fake_B_buffer = ImageBuffer()
    n_total_batches = len(dataloader)

    def lr_lambda(epoch):
        return 1.0 if epoch < DECAY_EPOCH else 1.0 - (epoch - DECAY_EPOCH) / (NUM_EPOCHS - DECAY_EPOCH)

    scheduler_G = optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=lr_lambda)
    scheduler_D_A = optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=lr_lambda)
    scheduler_D_B = optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=lr_lambda)

    loss_history = {
        'G_total': [],
        'D_A': [],
        'D_B': [],
        'Cycle': [],
        'G_A2B_GAN': [],
        'G_B2A_GAN': []
    }

    for epoch in range(1, NUM_EPOCHS + 1):
        for i, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch}/{NUM_EPOCHS}")):
            real_A = batch['A'].to(DEVICE)
            real_B = batch['B'].to(DEVICE)

            optimizer_G.zero_grad()

            loss_identity_A = criterion_identity(netG_B2A(real_A), real_A)
            loss_identity_B = criterion_identity(netG_A2B(real_B), real_B)
            loss_identity = (loss_identity_A + loss_identity_B) * LAMBDA_IDENTITY

            fake_B = netG_A2B(real_A)
            loss_GAN_A2B = criterion_GAN(netD_B(fake_B), True) # G_A2B minimizes Ex[(D(G(x)) - 1)^2]

            fake_A = netG_B2A(real_B)
            loss_GAN_B2A = criterion_GAN(netD_A(fake_A), True) # G_B2A minimizes Ey[(D(F(y)) - 1)^2]
            loss_GAN = loss_GAN_A2B + loss_GAN_B2A

            # Cycle Loss
            reconstructed_A = netG_B2A(fake_B)
            loss_cycle_ABA = criterion_cycle(reconstructed_A, real_A)

            reconstructed_B = netG_A2B(fake_A)
            loss_cycle_BAB = criterion_cycle(reconstructed_B, real_B)
            loss_cycle = (loss_cycle_ABA + loss_cycle_BAB) * LAMBDA_CYCLE

            # Total Generator Loss
            loss_G = loss_GAN + loss_cycle + loss_identity
            loss_G.backward()
            optimizer_G.step()


            # D_A Loss (D_A minimizes E_y[(D(y)-1)^2] + E_x[D(G(x))^2])
            optimizer_D_A.zero_grad()

            # Real Loss
            loss_D_real_A = criterion_GAN(netD_A(real_A), True)

            # Fake Loss (Use Buffer)
            fake_A_pop = fake_A_buffer.push_and_pop(fake_A.detach())
            loss_D_fake_A = criterion_GAN(netD_A(fake_A_pop), False)

            loss_D_A = (loss_D_real_A + loss_D_fake_A) * 0.5
            loss_D_A.backward()
            optimizer_D_A.step()

            # D_B Loss
            optimizer_D_B.zero_grad()

            # Real Loss
            loss_D_real_B = criterion_GAN(netD_B(real_B), True)

            # Fake Loss (Use Buffer)
            fake_B_pop = fake_B_buffer.push_and_pop(fake_B.detach())
            loss_D_fake_B = criterion_GAN(netD_B(fake_B_pop), False)

            loss_D_B = (loss_D_real_B + loss_D_fake_B) * 0.5
            loss_D_B.backward()
            optimizer_D_B.step()

        scheduler_G.step()
        scheduler_D_A.step()
        scheduler_D_B.step()

        loss_history['G_total'].append(loss_G.item())
        loss_history['D_A'].append(loss_D_A.item())
        loss_history['D_B'].append(loss_D_B.item())
        loss_history['Cycle'].append(loss_cycle.item())
        loss_history['G_A2B_GAN'].append(loss_GAN_A2B.item())
        loss_history['G_B2A_GAN'].append(loss_GAN_B2A.item())

        # 9. Checkpointing
        if epoch % CHECKPOINT_INTERVAL == 0:
            save_checkpoint(epoch, netG_A2B, netG_B2A, netD_A, netD_B, loss_history)

In [5]:
transform = transforms.Compose([
    transforms.Resize(int(IMAGE_SIZE * 1.12), Image.Resampling.BICUBIC),
    transforms.RandomCrop(IMAGE_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])




if __name__ == '__main__':
    netG_A2B = Generator(3, 3, NUM_RESIDUAL_BLOCKS).to(DEVICE)
    netG_B2A = Generator(3, 3, NUM_RESIDUAL_BLOCKS).to(DEVICE)
    netD_A = Discriminator(3).to(DEVICE)
    netD_B = Discriminator(3).to(DEVICE)


    ROOT_PATH = '/content/drive/MyDrive'

    dataloader = DataLoader(
        UnpairedDataset(ROOT_PATH, 'apple2orange', transform),
        batch_size=BATCH_SIZE, shuffle=True, pin_memory=True
    )

    print(f"--- Starting CycleGAN Training on {DEVICE} ---")
    print(f"Hyperparameters: LR={LR}, Lambda_Cycle={LAMBDA_CYCLE}, ImageSize={IMAGE_SIZE}")

    start_time = time.time()
    train_cyclegan(netG_A2B, netG_B2A, netD_A, netD_B, dataloader)
    end_time = time.time()

    print(f"Training finished. Total time: {end_time - start_time:.2f} seconds.")

--- Starting CycleGAN Training on cuda ---
Hyperparameters: LR=0.0002, Lambda_Cycle=10.0, ImageSize=128


Epoch 1/20: 100%|██████████| 1019/1019 [13:30<00:00,  1.26it/s]
Epoch 2/20: 100%|██████████| 1019/1019 [02:06<00:00,  8.08it/s]
Epoch 3/20: 100%|██████████| 1019/1019 [02:05<00:00,  8.10it/s]
Epoch 4/20: 100%|██████████| 1019/1019 [02:05<00:00,  8.09it/s]
Epoch 5/20: 100%|██████████| 1019/1019 [02:05<00:00,  8.10it/s]



---> Checkpoint saved successfully at: checkpoints_apple/cyclegan_checkpoint_epoch_5.pth


Epoch 6/20: 100%|██████████| 1019/1019 [02:05<00:00,  8.11it/s]
Epoch 7/20: 100%|██████████| 1019/1019 [02:05<00:00,  8.11it/s]
Epoch 8/20: 100%|██████████| 1019/1019 [02:05<00:00,  8.10it/s]
Epoch 9/20: 100%|██████████| 1019/1019 [02:05<00:00,  8.09it/s]
Epoch 10/20: 100%|██████████| 1019/1019 [02:06<00:00,  8.09it/s]



---> Checkpoint saved successfully at: checkpoints_apple/cyclegan_checkpoint_epoch_10.pth


Epoch 11/20: 100%|██████████| 1019/1019 [02:06<00:00,  8.09it/s]
Epoch 12/20: 100%|██████████| 1019/1019 [02:05<00:00,  8.09it/s]
Epoch 13/20: 100%|██████████| 1019/1019 [02:06<00:00,  8.08it/s]
Epoch 14/20: 100%|██████████| 1019/1019 [02:06<00:00,  8.08it/s]
Epoch 15/20: 100%|██████████| 1019/1019 [02:05<00:00,  8.09it/s]



---> Checkpoint saved successfully at: checkpoints_apple/cyclegan_checkpoint_epoch_15.pth


Epoch 16/20: 100%|██████████| 1019/1019 [02:05<00:00,  8.09it/s]
Epoch 17/20: 100%|██████████| 1019/1019 [02:05<00:00,  8.10it/s]
Epoch 18/20: 100%|██████████| 1019/1019 [02:05<00:00,  8.09it/s]
Epoch 19/20: 100%|██████████| 1019/1019 [02:05<00:00,  8.10it/s]
Epoch 20/20: 100%|██████████| 1019/1019 [02:05<00:00,  8.09it/s]



---> Checkpoint saved successfully at: checkpoints_apple/cyclegan_checkpoint_epoch_20.pth
Training finished. Total time: 3204.18 seconds.


In [6]:
import torch
import os
import matplotlib.pyplot as plt
import numpy as np

def load_checkpoint(filename, netG_A2B, netG_B2A, netD_A, netD_B):
    if os.path.exists(filename):
        print(f"Loading checkpoint from: {filename}")
        checkpoint = torch.load(filename, map_location='cpu')

        netG_A2B.load_state_dict(checkpoint['netG_A2B_state_dict'])
        netG_B2A.load_state_dict(checkpoint['netG_B2A_state_dict'])
        netD_A.load_state_dict(checkpoint['netD_A_state_dict'])
        netD_B.load_state_dict(checkpoint['netD_B_state_dict'])

        return checkpoint['loss_history']
    else:
        print(f"Error: Checkpoint file not found: {filename}")
        return None

def plot_loss_history(loss_history, num_epochs):
    epochs = range(1, num_epochs + 1)

    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    axes[0].plot(epochs, loss_history['G_total'][:num_epochs], label='Total Generator Loss')
    axes[0].plot(epochs, loss_history['Cycle'][:num_epochs], label='Cycle Consistency Loss')
    axes[0].set_title('Generator Losses (Total & Cycle)')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].legend()
    axes[0].grid(True)

    axes[1].plot(epochs, loss_history['G_A2B_GAN'][:num_epochs], label='G_A2B GAN Loss (A to B)')
    axes[1].plot(epochs, loss_history['G_B2A_GAN'][:num_epochs], label='G_B2A GAN Loss (B to A)')
    axes[1].set_title('Separate Generator GAN Losses')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Loss')
    axes[1].legend()
    axes[1].grid(True)

    axes[2].plot(epochs, loss_history['D_A'][:num_epochs], label='Discriminator A Loss (A domain)')
    axes[2].plot(epochs, loss_history['D_B'][:num_epochs], label='Discriminator B Loss (B domain)')
    axes[2].set_title('Discriminator Losses')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('Loss')
    axes[2].legend()
    axes[2].grid(True)

    plt.tight_layout()
    plt.savefig('cyclegan_loss_plot_modified.png')
    plt.close()
    print("Loss plot saved as 'cyclegan_loss_plot_modified.png'")

def generate_and_plot_samples(netG_A2B, netG_B2A, test_dataloader, epoch, filename):
    netG_A2B.eval()
    netG_B2A.eval()

    data = next(iter(test_dataloader))
    real_A = data['A'].to(DEVICE)
    real_B = data['B'].to(DEVICE)

    with torch.no_grad():
        fake_B = netG_A2B(real_A)
        cycle_A = netG_B2A(fake_B)

        fake_A = netG_B2A(real_B)
        cycle_B = netG_A2B(fake_A)

    def to_img_np(tensor):
        # Denormalize: [-1, 1] -> [0, 1]
        img = (tensor.data.cpu().squeeze().permute(1, 2, 0) + 1) / 2.0
        return np.clip(img.numpy(), 0, 1)

    fig, axes = plt.subplots(2, 3, figsize=(10, 7))
    fig.suptitle(f"CycleGAN Results - Epoch {epoch}", fontsize=16)

    axes[0, 0].imshow(to_img_np(real_A))
    axes[0, 0].set_title('Real A (Apple)')
    axes[0, 1].imshow(to_img_np(fake_B))
    axes[0, 1].set_title('Fake B (Orange)')
    axes[0, 2].imshow(to_img_np(cycle_A))
    axes[0, 2].set_title('Cycle A (Reconstructed)')

    axes[1, 0].imshow(to_img_np(real_B))
    axes[1, 0].set_title('Real B (Orange)')
    axes[1, 1].imshow(to_img_np(fake_A))
    axes[1, 1].set_title('Fake A (Apple)')
    axes[1, 2].imshow(to_img_np(cycle_B))
    axes[1, 2].set_title('Cycle B (Reconstructed)')

    for ax in axes.flat:
        ax.axis('off')

    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.savefig(filename)
    plt.close()
    print(f"Sample images saved as '{filename}'")

In [8]:
    netG_A2B = Generator(3, 3, NUM_RESIDUAL_BLOCKS).to(DEVICE)
    netG_B2A = Generator(3, 3, NUM_RESIDUAL_BLOCKS).to(DEVICE)
    netD_A = Discriminator(3).to(DEVICE)
    netD_B = Discriminator(3).to(DEVICE)

    ROOT_PATH = '/content/drive/MyDrive'
    test_dataset = UnpairedTestDataset(ROOT_PATH, 'apple2orange', transform)
    test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)

    epochs_to_plot = [5, 10, 20]
    full_loss_history = None

    for epoch in epochs_to_plot:
        checkpoint_file = f'/content/checkpoints_apple/cyclegan_checkpoint_epoch_{epoch}.pth'

        loss_history = load_checkpoint(checkpoint_file, netG_A2B, netG_B2A, netD_A, netD_B)

        if loss_history is not None:
            if epoch == 20:
                full_loss_history = loss_history

            generate_and_plot_samples(
                netG_A2B, netG_B2A, test_dataloader, epoch,
                filename=f'cyclegan_samples_epoch_{epoch}.png'
            )

    if full_loss_history is not None:
        plot_loss_history(full_loss_history, num_epochs=20)
    else:
        print("\nCould not find Epoch 20 checkpoint to plot the full loss history.")

Loading checkpoint from: /content/checkpoints_apple/cyclegan_checkpoint_epoch_5.pth
Sample images saved as 'cyclegan_samples_epoch_5.png'
Loading checkpoint from: /content/checkpoints_apple/cyclegan_checkpoint_epoch_10.pth
Sample images saved as 'cyclegan_samples_epoch_10.png'
Loading checkpoint from: /content/checkpoints_apple/cyclegan_checkpoint_epoch_20.pth
Sample images saved as 'cyclegan_samples_epoch_20.png'
Loss plot saved as 'cyclegan_loss_plot_modified.png'
