<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/CycleGAN_for_Image_to_Image_Translation_Without_Paired_Data.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import itertools

class UNetGenerator(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(UNetGenerator, self).__init__()
        # Encoder
        self.enc1 = nn.Conv2d(input_channels, 64, 4, 2, 1, bias=False)
        self.enc2 = nn.Conv2d(64, 128, 4, 2, 1, bias=False)
        self.enc3 = nn.Conv2d(128, 256, 4, 2, 1, bias=False)
        self.enc4 = nn.Conv2d(256, 512, 4, 2, 1, bias=False)

        # Decoder
        self.dec1 = nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False)
        self.dec2 = nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False)
        self.dec3 = nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False)
        self.dec4 = nn.ConvTranspose2d(64, output_channels, 4, 2, 1, bias=False)

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(F.relu(e1))
        e3 = self.enc3(F.relu(e2))
        e4 = self.enc4(F.relu(e3))

        # Decoder
        d1 = self.dec1(F.relu(e4))
        d2 = self.dec2(F.relu(d1 + e3))  # Skip connection
        d3 = self.dec3(F.relu(d2 + e2))  # Skip connection
        d4 = self.dec4(F.relu(d3 + e1))  # Skip connection

        return torch.tanh(d4)

class PatchGANDiscriminator(nn.Module):
    def __init__(self, input_channels):
        super(PatchGANDiscriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(input_channels, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 1, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 1, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x)

# Initialize the networks
generator_A2B = UNetGenerator(input_channels=3, output_channels=3)
generator_B2A = UNetGenerator(input_channels=3, output_channels=3)
discriminator_A = PatchGANDiscriminator(input_channels=3)
discriminator_B = PatchGANDiscriminator(input_channels=3)

# Loss functions
adversarial_loss = nn.BCELoss()
cycle_loss = nn.L1Loss()

# Optimizers
optimizer_G = torch.optim.Adam(itertools.chain(generator_A2B.parameters(), generator_B2A.parameters()), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_A = torch.optim.Adam(discriminator_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(discriminator_B.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Load the datasets and create the dataloaders
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
dataset_A = datasets.ImageFolder(root='path_to_dataset_A', transform=transform)
dataset_B = datasets.ImageFolder(root='path_to_dataset_B', transform=transform)
dataloader_A = DataLoader(dataset_A, batch_size=1, shuffle=True)
dataloader_B = DataLoader(dataset_B, batch_size=1, shuffle=True)

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    for i, (real_A, real_B) in enumerate(zip(dataloader_A, dataloader_B)):

        # Train Generators
        optimizer_G.zero_grad()

        fake_B = generator_A2B(real_A)
        recon_A = generator_B2A(fake_B)
        fake_A = generator_B2A(real_B)
        recon_B = generator_A2B(fake_A)

        loss_cycle_A = cycle_loss(recon_A, real_A)
        loss_cycle_B = cycle_loss(recon_B, real_B)
        loss_G_A2B = adversarial_loss(discriminator_B(fake_B), torch.ones_like(discriminator_B(fake_B)))
        loss_G_B2A = adversarial_loss(discriminator_A(fake_A), torch.ones_like(discriminator_A(fake_A)))

        loss_G = loss_G_A2B + loss_G_B2A + 10 * (loss_cycle_A + loss_cycle_B)
        loss_G.backward()
        optimizer_G.step()

        # Train Discriminator A
        optimizer_D_A.zero_grad()
        loss_D_A_real = adversarial_loss(discriminator_A(real_A), torch.ones_like(discriminator_A(real_A)))
        loss_D_A_fake = adversarial_loss(discriminator_A(fake_A.detach()), torch.zeros_like(discriminator_A(fake_A)))
        loss_D_A = (loss_D_A_real + loss_D_A_fake) * 0.5
        loss_D_A.backward()
        optimizer_D_A.step()

        # Train Discriminator B
        optimizer_D_B.zero_grad()
        loss_D_B_real = adversarial_loss(discriminator_B(real_B), torch.ones_like(discriminator_B(real_B)))
        loss_D_B_fake = adversarial_loss(discriminator_B(fake_B.detach()), torch.zeros_like(discriminator_B(fake_B)))
        loss_D_B = (loss_D_B_real + loss_D_B_fake) * 0.5
        loss_D_B.backward()
        optimizer_D_B.step()

        if i % 100 == 0:
            print(f"Epoch [{epoch}/{num_epochs}], Step [{i}/{len(dataloader_A)}], Loss G: {loss_G.item()}, Loss D_A: {loss_D_A.item()}, Loss D_B: {loss_D_B.item()}")