In [None]:
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import torch.nn.functional as F
import torch
import torch.nn as nn

In [None]:
class ResNetBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, 3),
            nn.InstanceNorm2d(dim),
            nn.ReLU(True),

            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, 3),
            nn.InstanceNorm2d(dim)
        )

    def forward(self, x):
        return x + self.block(x)

class ResNetGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, n_blocks=6):
        super().__init__()
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, 64, 7),
            nn.InstanceNorm2d(64),
            nn.ReLU(True)
        ]

        # Downsampling
        in_features = 64
        out_features = in_features * 2
        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(True)
            ]
            in_features = out_features
            out_features *= 2

        # Residual blocks
        for _ in range(n_blocks):
            model += [ResNetBlock(in_features)]

        # Upsampling
        out_features = in_features // 2
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(True)
            ]
            in_features = out_features
            out_features //= 2

        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, output_nc, 7),
            nn.Tanh()
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_nc):
        super().__init__()
        model = [
            nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, True)
        ]

        in_features = 64
        out_features = in_features * 2
        for _ in range(3):
            model += [
                nn.Conv2d(in_features, out_features, 4, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.LeakyReLU(0.2, True)
            ]
            in_features = out_features
            out_features *= 2

        model += [nn.Conv2d(in_features, 1, 4, padding=1)]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


In [None]:
def gan_loss(pred, target_is_real):
    target = torch.ones_like(pred) if target_is_real else torch.zeros_like(pred)
    return F.mse_loss(pred, target)

def cycle_loss(reconstructed, original):
    return F.l1_loss(reconstructed, original)

def identity_loss(same, original):
    return F.l1_loss(same, original)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize models
G_AB = ResNetGenerator(3, 3).to(device)
G_BA = ResNetGenerator(3, 3).to(device)
D_A = Discriminator(3).to(device)
D_B = Discriminator(3).to(device)

# Optimizers
g_optimizer = torch.optim.Adam(list(G_AB.parameters()) + list(G_BA.parameters()), lr=0.0002, betas=(0.5, 0.999))
d_a_optimizer = torch.optim.Adam(D_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
d_b_optimizer = torch.optim.Adam(D_B.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Load datasets
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
dataset_A = datasets.ImageFolder("", transform=transform)
dataset_B = datasets.ImageFolder("", transform=transform)
loader_A = DataLoader(dataset_A, batch_size=1, shuffle=True)
loader_B = DataLoader(dataset_B, batch_size=1, shuffle=True)

# Training
lambda_cyc = 10
lambda_id = 5

for epoch in range(100):
    for real_A, _ in loader_A:
        real_B, _ = next(iter(loader_B))

        real_A = real_A.to(device)
        real_B = real_B.to(device)

        #Train Generators
        g_optimizer.zero_grad()

        fake_B = G_AB(real_A)
        rec_A = G_BA(fake_B)

        fake_A = G_BA(real_B)
        rec_B = G_AB(fake_A)

        # Identity loss
        idt_A = G_BA(real_A)
        idt_B = G_AB(real_B)

        loss_id = identity_loss(idt_A, real_A) + identity_loss(idt_B, real_B)

        # GAN loss
        loss_gan_AB = gan_loss(D_B(fake_B), True)
        loss_gan_BA = gan_loss(D_A(fake_A), True)

        # Cycle loss
        loss_cyc = cycle_loss(rec_A, real_A) + cycle_loss(rec_B, real_B)

        loss_G = loss_gan_AB + loss_gan_BA + lambda_cyc * loss_cyc + lambda_id * loss_id
        loss_G.backward()
        g_optimizer.step()

        # Train Discriminator A
        d_a_optimizer.zero_grad()
        loss_real = gan_loss(D_A(real_A), True)
        loss_fake = gan_loss(D_A(fake_A.detach()), False)
        loss_D_A = (loss_real + loss_fake) * 0.5
        loss_D_A.backward()
        d_a_optimizer.step()

        #Train Discriminator B
        d_b_optimizer.zero_grad()
        loss_real = gan_loss(D_B(real_B), True)
        loss_fake = gan_loss(D_B(fake_B.detach()), False)
        loss_D_B = (loss_real + loss_fake) * 0.5
        loss_D_B.backward()
        d_b_optimizer.step()


In [None]:
from torchvision.utils import save_image
G_AB.eval()
with torch.no_grad():
    for real_A, _ in loader_A:
        real_A = real_A.to(device)
        fake_B = G_AB(real_A)
        save_image((fake_B + 1) / 2, 'sample_output.png')  # de-normalize and save
        break
