In [None]:
import os
import random
import itertools
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import torchvision


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, kernel_size=3),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, kernel_size=3),
            nn.InstanceNorm2d(channels),
        )
        
    def forward(self, x):
        return x + self.block(x)

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, n_residual_blocks=9):
        super().__init__()

        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels, 64, kernel_size=7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        ]

        # Downsampling
        in_features = 64
        out_features = in_features * 2
        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features
            out_features *= 2

        # Residual blocks
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling
        out_features = in_features // 2
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features
            out_features //= 2

        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, out_channels, kernel_size=7),
            nn.Tanh()
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()

        def discriminator_block(in_filters, out_filters, normalize=True):
            layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels, 64, normalize=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.Conv2d(512, 1, kernel_size=4, padding=1)  # Output: patch of predictions
        )

    def forward(self, x):
        return self.model(x)

In [None]:
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

G_AB = Generator().to(device)  # Monet → Photo
G_BA = Generator().to(device)  # Photo → Monet
D_A = Discriminator().to(device)  # Discriminate real Monet
D_B = Discriminator().to(device)  # Discriminate real Photo

lr = 0.0002

optimizer_G = optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=lr, betas=(0.5, 0.999))
optimizer_D_A = optim.Adam(D_A.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D_B = optim.Adam(D_B.parameters(), lr=lr, betas=(0.5, 0.999))

In [None]:
class UnpairedImageDataset(Dataset):
    def __init__(self, root_dir, transforms_A=None, transforms_B=None, mode='train'):
        super().__init__()
        self.dir_A = os.path.join(root_dir, mode + 'A')
        self.dir_B = os.path.join(root_dir, mode + 'B')

        self.images_A = sorted(os.listdir(self.dir_A))
        self.images_B = sorted(os.listdir(self.dir_B))

        self.transforms_A = transforms_A
        self.transforms_B = transforms_B

    def __len__(self):
        return max(len(self.images_A), len(self.images_B))

    def __getitem__(self, idx):
        img_A_path = os.path.join(self.dir_A, self.images_A[idx % len(self.images_A)])
        img_B_path = os.path.join(self.dir_B, self.images_B[idx % len(self.images_B)])

        img_A = Image.open(img_A_path).convert('RGB')
        img_B = Image.open(img_B_path).convert('RGB')

        if self.transforms_A:
            img_A = self.transforms_A(img_A)
        if self.transforms_B:
            img_B = self.transforms_B(img_B)

        return {'A': img_A, 'B': img_B}

transform_train = T.Compose([
    T.Resize(286, interpolation=Image.BICUBIC),
    T.RandomCrop(256),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
transform_test = T.Compose([
    T.Resize(256, interpolation=Image.BICUBIC),
    T.ToTensor(),
    T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [None]:
# Assuming the dataset is structured with 'trainA', 'trainB', 'testA', 'testB' directories

dataset = UnpairedImageDataset(root_dir='C:/DEEP_LEARNING/ProjectMonet/datasets/Monet2photo', transforms_A=transform_train, transforms_B=transform_train, mode='train')
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0)


In [None]:
import time

def sample_images(epoch, batches_done, dataloader, save_dir='images'):
    os.makedirs(save_dir, exist_ok=True)
    G_AB.eval()
    G_BA.eval()

    with torch.no_grad():
        data = next(iter(dataloader))
        real_A = data['A'].to(device)
        real_B = data['B'].to(device)

        fake_B = G_AB(real_A)
        fake_A = G_BA(real_B)

        recov_A = G_BA(fake_B)
        recov_B = G_AB(fake_A)

        imgs = torch.cat((real_A, fake_B, recov_A, real_B, fake_A, recov_B), 0)
        imgs = (imgs + 1) / 2  # [-1,1] -> [0,1]

        grid = torchvision.utils.make_grid(imgs, nrow=3)
        torchvision.utils.save_image(grid, f'{save_dir}/epoch{epoch}_batch{batches_done}.png')

    G_AB.train()
    G_BA.train()


In [None]:
import matplotlib.pyplot as plt

def plot_loss_progression(G_losses, D_A_losses, D_B_losses, cycle_losses=None, identity_losses=None, GEN_A_losses=None, GEN_B_losses=None):
    plt.figure(figsize=(14, 6))

    plt.subplot(1, 2, 1)
    plt.plot(G_losses, label='G Loss')
    plt.plot(D_A_losses, label='D_A Loss')
    plt.plot(D_B_losses, label='D_B Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Generator and Discriminator Loss')
    plt.legend()
    plt.grid(True)

    plt.subplot(1, 2, 2)
    if cycle_losses:
        plt.plot(cycle_losses, label='Cycle Loss')
    if identity_losses:
        plt.plot(identity_losses, label='Identity Loss')
    if GEN_A_losses:
        plt.plot(GEN_A_losses, label='GEN_A (A→B) Loss', linestyle='--')
    if GEN_B_losses:
        plt.plot(GEN_B_losses, label='GEN_B (B→A) Loss', linestyle='--')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Cycle & Identity Loss')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.show()


In [None]:
import torchvision.utils as vutils

def show_sample_images(dataloader, G_AB, G_BA, device):
    G_AB.eval()
    G_BA.eval()

    with torch.no_grad():
        batch = next(iter(dataloader))
        real_A = batch['A'].to(device)
        real_B = batch['B'].to(device)

        fake_B = G_AB(real_A)
        recov_A = G_BA(fake_B)

        fake_A = G_BA(real_B)
        recov_B = G_AB(fake_A)

        # Denormalize images from [-1,1] to [0,1]
        def denorm(x):
            return (x + 1) / 2

        images = torch.cat([
            denorm(real_A), denorm(fake_B), denorm(recov_A),
            denorm(real_B), denorm(fake_A), denorm(recov_B)
        ], 0)

        grid_img = vutils.make_grid(images, nrow=3)

        plt.figure(figsize=(12,8))
        plt.axis('off')
        plt.title('Real A | Fake B | Recov A || Real B | Fake A | Recov B')
        plt.imshow(grid_img.permute(1, 2, 0).cpu())
        plt.show()

    G_AB.train()
    G_BA.train()


In [None]:
# Labels
real_label = 1.0
fake_label = 0.0

n_epochs = 25
lambda_cycle = 10.0
lambda_identity = 5.0

In [None]:
def train_cycleGAN(n_epochs, dataloader, device):

    G_losses = []
    D_A_losses = []
    D_B_losses = []
    cycle_losses = []
    identity_losses = []
    GEN_A_losses = []
    GEN_B_losses = []


    for epoch in range(1, n_epochs+1):
        start_time = time.time()
        for i, batch in enumerate(tqdm(dataloader)):
            real_A = batch['A'].to(device)
            real_B = batch['B'].to(device)

        # ------------------
        #  Train Generators
        # ------------------
            optimizer_G.zero_grad()

        # Identity loss
            same_B = G_AB(real_B)
            loss_identity_B = criterion_identity(same_B, real_B) * lambda_identity

            same_A = G_BA(real_A)
            loss_identity_A = criterion_identity(same_A, real_A) * lambda_identity

        # GAN loss
            fake_B = G_AB(real_A)
            pred_fake = D_B(fake_B)
            valid = torch.ones_like(pred_fake).to(device)
            loss_GAN_AB = criterion_GAN(pred_fake, valid)

            fake_A = G_BA(real_B)
            pred_fake = D_A(fake_A)
            valid = torch.ones_like(pred_fake).to(device)
            loss_GAN_BA = criterion_GAN(pred_fake, valid)

        # Cycle loss
            recov_A = G_BA(fake_B)
            loss_cycle_A = criterion_cycle(recov_A, real_A) * lambda_cycle

            recov_B = G_AB(fake_A)
            loss_cycle_B = criterion_cycle(recov_B, real_B) * lambda_cycle

        # Total generator loss
            loss_G = loss_GAN_AB + loss_GAN_BA + loss_cycle_A + loss_cycle_B + loss_identity_A + loss_identity_B
            loss_GEN_A = loss_GAN_AB + loss_cycle_B + loss_identity_B
            loss_GEN_B = loss_GAN_BA + loss_cycle_A + loss_identity_A
            loss_G.backward()
            optimizer_G.step()

        # -----------------------
        #  Train Discriminator A
        # -----------------------
            optimizer_D_A.zero_grad()

        # Real loss
            pred_real = D_A(real_A)
            valid = torch.ones_like(pred_real).to(device)
            loss_D_real = criterion_GAN(pred_real, valid)

        # Fake loss
            fake_A_detach = fake_A.detach()
            pred_fake = D_A(fake_A_detach)
            fake = torch.zeros_like(pred_fake).to(device)
            loss_D_fake = criterion_GAN(pred_fake, fake)

        # Total loss
            loss_D_A = (loss_D_real + loss_D_fake) * 0.5
            loss_D_A.backward()
            optimizer_D_A.step()

        # -----------------------
        #  Train Discriminator B
        # -----------------------
            optimizer_D_B.zero_grad()

        # Real loss
            pred_real = D_B(real_B)
            valid = torch.ones_like(pred_real).to(device)
            loss_D_real = criterion_GAN(pred_real, valid)

        # Fake loss
            fake_B_detach = fake_B.detach()
            pred_fake = D_B(fake_B_detach)
            fake = torch.zeros_like(pred_fake).to(device)
            loss_D_fake = criterion_GAN(pred_fake, fake)

        # Total loss
            loss_D_B = (loss_D_real + loss_D_fake) * 0.5
            loss_D_B.backward()
            optimizer_D_B.step()

        # Print losses occasionally
            if i % 200 == 0:
                print(f"[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] "
                    f"[D_A loss: {loss_D_A.item():.4f}] [D_B loss: {loss_D_B.item():.4f}] "
                    f"[G loss: {loss_G.item():.4f}]")

        # Save sample images every 2200 batches
            if i % 2200 == 0:
                sample_images(epoch, i, dataloader)
                show_sample_images(dataloader, G_AB, G_BA, device)

            G_losses.append(loss_G.item())
            D_A_losses.append(loss_D_A.item())
            D_B_losses.append(loss_D_B.item())
            cycle_loss_total = (loss_cycle_A + loss_cycle_B).item()
            cycle_losses.append(cycle_loss_total)
            identity_loss_total = (loss_identity_A + loss_identity_B).item()
            identity_losses.append(identity_loss_total)
            GEN_A_losses.append(loss_GEN_A.item())
            GEN_B_losses.append(loss_GEN_B.item())

        print(f"Epoch {epoch} finished in {(time.time() - start_time):.2f} seconds.")

    # After training loop ends
    plot_loss_progression(G_losses, D_A_losses, D_B_losses,
    cycle_losses, identity_losses,
    GEN_A_losses=GEN_A_losses,
    GEN_B_losses=GEN_B_losses)

    torch.save(G_AB.state_dict(), "G_AB_photo.pth")  # monet → Photo
    torch.save(G_BA.state_dict(), "G_BA_monet.pth")    # photo → Monet



In [None]:
n_epochs = 15
lambda_cycle = 10.0
lambda_identity = 5.0

In [None]:
import torch
print("CUDA available:", torch.cuda.is_available())
print("Number of GPUs:", torch.cuda.device_count())
print("Using GPU:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")

In [None]:
import torch
print(torch.version.cuda)
print(torch.backends.cudnn.version())

In [None]:
!nvidia-smi


In [None]:
train_cycleGAN(n_epochs, dataloader, device)