In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from PIL import Image
import os
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import LambdaLR

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),
            nn.ReLU(True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(128, 1, 4, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(128, 1, 4, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

G = Generator().to(device)
F = Generator().to(device)
D_X = Discriminator().to(device)
D_Y = Discriminator().to(device)

# Initialize weights
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.normal_(m.weight.data, 0.0, 0.02)

G.apply(init_weights)
F.apply(init_weights)
D_X.apply(init_weights)
D_Y.apply(init_weights)

# Optimizers
optimizer_G = optim.Adam(list(G.parameters()) + list(F.parameters()), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_X = optim.Adam(D_X.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_Y = optim.Adam(D_Y.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Loss functions
adversarial_loss = nn.MSELoss()
cycle_loss = nn.L1Loss()

class ImageDataset(Dataset):
    def __init__(self, root_dir, transform):
        self.images = [os.path.join(root_dir, f) for f in os.listdir(root_dir)]
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = Image.open(self.images[idx]).convert('RGB')
        return self.transform(img)

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

loader_X = DataLoader(ImageDataset('path/to/horse/images', transform), batch_size=4, shuffle=True)
loader_Y = DataLoader(ImageDataset('path/to/zebra/images', transform), batch_size=4, shuffle=True)

for epoch in range(100):
    for real_X, real_Y in zip(loader_X, loader_Y):
        real_X, real_Y = real_X.to(device), real_Y.to(device)

        # Train Generators
        optimizer_G.zero_grad()
        fake_Y = G(real_X)
        fake_X = F(real_Y)
        loss_G = adversarial_loss(D_Y(fake_Y), torch.ones_like(D_Y(fake_Y)))
        loss_G += adversarial_loss(D_X(fake_X), torch.ones_like(D_X(fake_X)))
        loss_G += cycle_loss(F(fake_Y), real_X) + cycle_loss(G(fake_X), real_Y)
        loss_G.backward()
        optimizer_G.step()

        # Train Discriminators
        optimizer_D_X.zero_grad()
        loss_D_X = (adversarial_loss(D_X(real_X), torch.ones_like(D_X(real_X))) +
                    adversarial_loss(D_X(fake_X.detach()), torch.zeros_like(D_X(fake_X.detach())))) / 2
        loss_D_X.backward()
        optimizer_D_X.step()

        optimizer_D_Y.zero_grad()
        loss_D_Y = (adversarial_loss(D_Y(real_Y), torch.ones_like(D_Y(real_Y))) +
                    adversarial_loss(D_Y(fake_Y.detach()), torch.zeros_like(D_Y(fake_Y.detach())))) / 2
        loss_D_Y.backward()
        optimizer_D_Y.step()

torch.save(G.state_dict(), 'G.pth')
torch.save(F.state_dict(), 'F.pth')
torch.save(D_X.state_dict(), 'D_X.pth')
torch.save(D_Y.state_dict(), 'D_Y.pth')

def show_images(real, fake):
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(real.permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5)
    axes[0].set_title("Real Image")
    axes[1].imshow(fake.permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5)
    axes[1].set_title("Fake Image")
    plt.show()

real_image = next(iter(loader_X)).to(device)
fake_image = G(real_image)
show_images(real_image[0], fake_image[0])

params = {
    'learning_rate': [0.0001, 0.0002, 0.0005],
    'batch_size': [4, 8, 16],
    'lambda_cycle': [5.0, 10.0, 20.0],
    'beta1': [0.5, 0.9],
}

from itertools import product

best_loss = float('inf')
best_params = None

for lr, batch, lam, beta in product(params['learning_rate'],
                                    params['batch_size'],
                                    params['lambda_cycle'],
                                    params['beta1']):
    # Initialize models and optimizers with the current set of parameters
    G = Generator().to(device)
    F = Generator().to(device)
    D_X = Discriminator().to(device)
    D_Y = Discriminator().to(device)

    optimizer_G = optim.Adam(list(G.parameters()) + list(F.parameters()), lr=lr, betas=(beta, 0.999))
    optimizer_D_X = optim.Adam(D_X.parameters(), lr=lr, betas=(beta, 0.999))
    optimizer_D_Y = optim.Adam(D_Y.parameters(), lr=lr, betas=(beta, 0.999))

    # Train for a few epochs and evaluate performance
    loss = train_for_few_epochs(G, F, D_X, D_Y, optimizer_G, optimizer_D_X, optimizer_D_Y, lam)

    if loss < best_loss:
        best_loss = loss
        best_params = (lr, batch, lam, beta)

print(f"Best Parameters: Learning Rate={best_params[0]}, Batch Size={best_params[1]}, "
      f"Lambda Cycle={best_params[2]}, Beta1={best_params[3]}")

import wandb  # Weights & Biases for logging

wandb.init(project="CycleGAN", config=params)

def save_checkpoint(epoch, model, optimizer, filename="checkpoint.pth"):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, filename)

for epoch in range(100):
    total_loss_G, total_loss_D = 0, 0

    for real_X, real_Y in zip(loader_X, loader_Y):
        real_X, real_Y = real_X.to(device), real_Y.to(device)

        # Train Generators
        optimizer_G.zero_grad()
        fake_Y = G(real_X)
        fake_X = F(real_Y)

        loss_G = (adversarial_loss(D_Y(fake_Y), torch.ones_like(D_Y(fake_Y))) +
                  adversarial_loss(D_X(fake_X), torch.ones_like(D_X(fake_X))) +
                  cycle_loss(F(fake_Y), real_X) * params['lambda_cycle'] +
                  cycle_loss(G(fake_X), real_Y) * params['lambda_cycle'])
        loss_G.backward()
        optimizer_G.step()

        # Train Discriminators
        optimizer_D_X.zero_grad()
        loss_D_X = (adversarial_loss(D_X(real_X), torch.ones_like(D_X(real_X))) +
                    adversarial_loss(D_X(fake_X.detach()), torch.zeros_like(D_X(fake_X)))) / 2
        loss_D_X.backward()
        optimizer_D_X.step()

        optimizer_D_Y.zero_grad()
        loss_D_Y = (adversarial_loss(D_Y(real_Y), torch.ones_like(D_Y(real_Y))) +
                    adversarial_loss(D_Y(fake_Y.detach()), torch.zeros_like(D_Y(fake_Y)))) / 2
        loss_D_Y.backward()
        optimizer_D_Y.step()

        total_loss_G += loss_G.item()
        total_loss_D += (loss_D_X.item() + loss_D_Y.item())

    # Log losses to Weights & Biases
    wandb.log({"Generator Loss": total_loss_G / len(loader_X),
               "Discriminator Loss": total_loss_D / len(loader_X)})

    # Save checkpoint every 10 epochs
    if epoch % 10 == 0:
        save_checkpoint(epoch, G, optimizer_G, filename=f"G_epoch_{epoch}.pth")

    print(f"Epoch [{epoch+1}/100], Loss G: {total_loss_G:.4f}, Loss D: {total_loss_D:.4f}")

from pytorch_fid import fid_score

def compute_fid(real_path, fake_path):
    fid_value = fid_score.calculate_fid_given_paths([real_path, fake_path],
                                                    batch_size=50,
                                                    device=device,
                                                    dims=2048)
    print(f"FID Score: {fid_value}")

# Example: After generating images, compute FID
compute_fid('path/to/real/images', 'path/to/generated/images')

import torchvision.utils as vutils

def show_grid(real_images, fake_images, title):
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.axis("off")
    ax.set_title(title)
    ax.imshow(vutils.make_grid(torch.cat((real_images, fake_images)), padding=2, normalize=True).permute(1, 2, 0))
    plt.show()

real_batch = next(iter(loader_X)).to(device)
fake_batch = G(real_batch)
show_grid(real_batch, fake_batch, "Real vs Generated Images")