In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import os
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

In [None]:
# for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data Loading

In [None]:
class CycleGANDataset(Dataset):
    def __init__(self, root, transform=None):
        self.transform = transform

        # Path to domain A (e.g., Summer) and domain B (e.g., Winter)
        self.dir_A = os.path.join(root, 'trainA')
        self.dir_B = os.path.join(root, 'trainB')

        self.files_A = sorted(os.listdir(self.dir_A))
        self.files_B = sorted(os.listdir(self.dir_B))

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

    def __getitem__(self, index):
        img_A_path = os.path.join(self.dir_A, self.files_A[index % len(self.files_A)])
        img_B_path = os.path.join(self.dir_B, self.files_B[index % len(self.files_B)])

        img_A = Image.open(img_A_path).convert('RGB')
        img_B = Image.open(img_B_path).convert('RGB')

        if self.transform:
            img_A = self.transform(img_A)
            img_B = self.transform(img_B)

        return {'A': img_A, 'B': img_B}

# Define image transformations (resize, normalization, etc.)
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load Dataset
dataset = CycleGANDataset(root='/kaggle/input/summer2winter-yosemite', transform=transform)
dataloader = DataLoader(dataset, batch_size =4, shuffle=True)

# Example check
sample = next(iter(dataloader))
print(sample['A'].shape, sample['B'].shape)

# Model Building

In [None]:
class ResnetBlock(nn.Module):
    def __init__(self, dim):
        super(ResnetBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm2d(dim),
            nn.ReLU(True),
            nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm2d(dim),
        )

    def forward(self, x):
        return x + self.block(x)

class GeneratorResNet(nn.Module):
    def __init__(self, input_channels=3, output_channels=3, n_residual_blocks=9):
        super(GeneratorResNet, self).__init__()
        model = [
            nn.Conv2d(input_channels, 64, kernel_size=7, padding=3, bias=False),
            nn.InstanceNorm2d(64),
            nn.ReLU(True)
        ]

        # Downsampling
        model += [
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(128),
            nn.ReLU(True),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(256),
            nn.ReLU(True)
        ]

        # Residual blocks
        for _ in range(n_residual_blocks):
            model += [ResnetBlock(256)]

        # Upsampling
        model += [
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
            nn.InstanceNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
            nn.InstanceNorm2d(64),
            nn.ReLU(True)
        ]

        model += [nn.Conv2d(64, output_channels, kernel_size=7, padding=3), nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self, input_channels=3):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, x):
        return self.model(x)

# Example usage
generator = GeneratorResNet()
discriminator = Discriminator()

print(generator)
print(discriminator)


# Function to Display Images

In [None]:
def denormalize(img_tensor):
    """Convert a normalized tensor image to a displayable format."""
    img = img_tensor.cpu().detach().numpy()
    img = (img * 0.5) + 0.5  # Denormalize from [-1, 1] to [0, 1]
    img = np.transpose(img, (1, 2, 0))  # C, H, W -> H, W, C
    return img

def show_training_images(real_A, fake_B, real_B, fake_A):
    plt.figure(figsize=(8, 8))

    plt.subplot(2, 2, 1)
    plt.imshow(denormalize(real_A))
    plt.title("Real A (e.g., Summer)")
    plt.axis("off")

    plt.subplot(2, 2, 2)
    plt.imshow(denormalize(fake_B))
    plt.title("Fake B (e.g., Winter)")
    plt.axis("off")

    plt.subplot(2, 2, 3)
    plt.imshow(denormalize(real_B))
    plt.title("Real B (e.g., Winter)")
    plt.axis("off")

    plt.subplot(2, 2, 4)
    plt.imshow(denormalize(fake_A))
    plt.title("Fake A (e.g., Summer)")
    plt.axis("off")

    plt.show()

# Training

In [None]:
generator_A2B = GeneratorResNet().to(device)
generator_B2A = GeneratorResNet().to(device)
discriminator_A = Discriminator().to(device)
discriminator_B = Discriminator().to(device)

# Loss functions
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()

# Optimizers with adjusted learning rates
optimizer_G = optim.Adam(list(generator_A2B.parameters()) + list(generator_B2A.parameters()), lr=0.0001, betas=(0.5, 0.999))
optimizer_D_A = optim.Adam(discriminator_A.parameters(), lr=0.00005, betas=(0.5, 0.999))
optimizer_D_B = optim.Adam(discriminator_B.parameters(), lr=0.00005, betas=(0.5, 0.999))

num_epochs = 100

for epoch in range(num_epochs):
    progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch+1}/{num_epochs}")
    for i, batch in progress_bar :
        # Get one image from the batch
        real_A = batch['A'][0].unsqueeze(0).to(device)  # Add batch dimension
        real_B = batch['B'][0].unsqueeze(0).to(device)  # Add batch dimension

        # Generate fake images
        fake_B = generator_A2B(real_A)
        fake_A = generator_B2A(real_B)

        # Reconstruct images (cycle consistency)
        rec_A = generator_B2A(fake_B)
        rec_B = generator_A2B(fake_A)

        # Losses for the Generators
        loss_GAN_A2B = criterion_GAN(discriminator_B(fake_B), torch.ones_like(discriminator_B(fake_B)))
        loss_GAN_B2A = criterion_GAN(discriminator_A(fake_A), torch.ones_like(discriminator_A(fake_A)))

        loss_cycle_A = criterion_cycle(rec_A, real_A)
        loss_cycle_B = criterion_cycle(rec_B, real_B)

        # Total Generator Loss
        loss_G = loss_GAN_A2B + loss_GAN_B2A + 10.0 * (loss_cycle_A + loss_cycle_B)

        # Update Generators
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

        # Discriminator Losses
        # Discriminator A (real vs fake)
        loss_D_A = criterion_GAN(discriminator_A(real_A), torch.ones_like(discriminator_A(real_A))) + \
                   criterion_GAN(discriminator_A(fake_A.detach()), torch.zeros_like(discriminator_A(fake_A)))
        optimizer_D_A.zero_grad()
        loss_D_A.backward()
        optimizer_D_A.step()

        # Discriminator B (real vs fake)
        loss_D_B = criterion_GAN(discriminator_B(real_B), torch.ones_like(discriminator_B(real_B))) + \
                   criterion_GAN(discriminator_B(fake_B.detach()), torch.zeros_like(discriminator_B(fake_B)))
        optimizer_D_B.zero_grad()
        loss_D_B.backward()
        optimizer_D_B.step()

        if i % 200 == 0:
            # Visualize the results
            show_training_images(real_A[0], fake_B[0], real_B[0], fake_A[0])

    print(f"Epoch [{epoch}/{num_epochs}], Batch [{i}], Loss G: {loss_G.item()}, Loss D_A: {loss_D_A.item()}, Loss D_B: {loss_D_B.item()}")


# Saving Models

In [None]:
torch.save(generator_A2B.state_dict(), 'generator_A2B.pth')
torch.save(generator_B2A.state_dict(), 'generator_B2A.pth')
torch.save(discriminator_A.state_dict(), 'discriminator_A.pth')
torch.save(discriminator_B.state_dict(), 'discriminator_B.pth')