# Setup and Imports 

In [None]:
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch
import numpy as np
from torchvision.utils import save_image
from torch_fidelity import calculate_metrics
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import torchvision

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
batch_size = 128
workers = 8

dataset = dset.ImageFolder(root="data",
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)

In [None]:
#for batch_idx, (data, labels) in enumerate(dataloader):
   # print(f"Batch {batch_idx + 1}:")
   # print(f"  - Data shape (images): {data.shape}")
   # print(f"  - Labels shape: {labels.shape}") 

   # if batch_idx == 2:
   #     break

# Model Architecture 

In [None]:
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self, color_channels, discriminator_features):
        super(Discriminator, self).__init__()
        self.dis = nn.Sequential(
            
            # Input: (input_channels) x 64 x 64
            nn.Conv2d(color_channels, discriminator_features, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2),
            # State size: (discriminator_features) x 32 x 32
            
            nn.Conv2d(discriminator_features, discriminator_features * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(discriminator_features * 2),
            nn.LeakyReLU(0.2),
            # State size: (discriminator_features*2) x 16 x 16
            
            nn.Conv2d(discriminator_features * 2, discriminator_features * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(discriminator_features * 4),
            nn.LeakyReLU(0.2),
            # State size: (discriminator_features*4) x 8 x 8
            
            nn.Conv2d(discriminator_features * 4, discriminator_features * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(discriminator_features * 8),
            nn.LeakyReLU(0.2),
            # State size: (discriminator_features*8) x 4 x 4
            
            nn.Conv2d(discriminator_features * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.dis(input).view(-1)    

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim, generator_features, color_channels):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            # Input: (latent_dim, 1, 1)
            nn.ConvTranspose2d(latent_dim, generator_features * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(generator_features * 8),
            nn.ReLU(),
            # State: (512, 4, 4)

            nn.ConvTranspose2d(generator_features * 8, generator_features * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(generator_features * 4),
            nn.ReLU(),
            # State: (256, 8, 8)

            nn.ConvTranspose2d(generator_features * 4, generator_features * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(generator_features * 2),
            nn.ReLU(),
            # State: (128, 16, 16)

            nn.ConvTranspose2d(generator_features * 2, generator_features, 4, 2, 1, bias=False),
            nn.BatchNorm2d(generator_features),
            nn.ReLU(),
            # State: (64, 32, 32)

            nn.ConvTranspose2d(generator_features, color_channels, 4, 2, 1, bias=False),
            nn.Tanh()
            # Output: (3, 64, 64)
        )

    def forward(self, x):
        return self.gen(x)

In [None]:
def generate_and_save_fid_samples(generator, num_samples=2000, dir="fid_samples"):
    os.makedirs(dir, exist_ok=True)
    generator.eval()
    with torch.no_grad():
        for i in range(num_samples // batch_size + 1):
            z = torch.randn(batch_size, 100, 1, 1).to(device)
            fake_images = generator(z)
            for j, img in enumerate(fake_images):
                idx = i * batch_size + j
                if idx >= num_samples:
                    break
                save_image(img, f"{dir}/sample_{idx}.png", normalize=True)
    generator.train()

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
def gradient_penalty(critic, real, fake, device):
    batch_size = real.shape[0]
    epsilon = torch.rand(batch_size, 1, 1, 1, device=device)
    interpolated = epsilon * real + (1-epsilon) * fake
    
    critic_scores = critic(interpolated)
    
    gradients = torch.autograd.grad(
        outputs=critic_scores,
        inputs=interpolated,
        grad_outputs=torch.ones_like(critic_scores),
        create_graph=True,
        retain_graph=True
    )[0]
    
    gradients = gradients.view(gradients.shape[0], -1)
    penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return penalty

# Training 

In [None]:
gnet = Generator(latent_dim=100, generator_features=64, color_channels=3).to(device)
dnet = Discriminator(color_channels=3, discriminator_features=64).to(device)

gnet.apply(weights_init)
dnet.apply(weights_init)

g_optim = torch.optim.Adam(gnet.parameters(), lr=1e-4, betas=(0.5, 0.999))
d_optim = torch.optim.Adam(dnet.parameters(), lr=4e-4, betas=(0.5, 0.999))

g_scheduler = torch.optim.lr_scheduler.ExponentialLR(g_optim, gamma=0.99)
d_scheduler = torch.optim.lr_scheduler.ExponentialLR(d_optim, gamma=0.99)

lossFn = nn.BCELoss()

In [None]:
import os
os.makedirs("samples", exist_ok=True)
os.makedirs("generated_fid_samples", exist_ok=True)

from tqdm.auto import tqdm

epochs = 100

losses = np.zeros((epochs,2))
decisions = np.zeros((epochs,2))

fixed_noise = torch.randn(64, 100, 1, 1).to(device)

for epoch in tqdm(range(epochs)):

    d_loss_epoch = 0
    g_loss_epoch = 0
    d_real_epoch = 0
    d_fake_epoch = 0
    
    for i, (X,_) in enumerate(dataloader):

        real_images = X.to(device)
        batch_size = real_images.size(0)
        fake_images = gnet(torch.randn(batch_size, 100, 1, 1).to(device))

        real = torch.full((batch_size,), 0.9, device=device)
        fake = torch.full((batch_size,), 0.1, device=device)

        # ---- Disctriminator step ----

        # True images
        pred_real = dnet(real_images).view(-1)
        d_loss_real = lossFn(real, pred_real)

        # False images
        pred_fake = dnet(fake_images).view(-1)
        d_loss_fake = lossFn(fake, pred_fake)

        # Combine the losses
        d_loss_combined = d_loss_real + d_loss_fake

        d_optim.zero_grad()
        d_loss_combined.backward()
        d_optim.step()


        # ---- Generator step ----
        fake_images = gnet( torch.randn(batch_size,100,1,1).to(device) )
        pred_fake   = dnet(fake_images)
      
        # compute and collect loss and accuracy
        g_loss = lossFn(pred_fake.squeeze(),real)
         
        g_optim.zero_grad()
        g_loss.backward()
        g_optim.step()

        d_loss_epoch += d_loss_combined.item()
        g_loss_epoch += g_loss.item()
        d_real_epoch += torch.mean((pred_real>.5).float()).item()
        d_fake_epoch += torch.mean((pred_fake<.5).float()).item()

    g_scheduler.step()
    d_scheduler.step()

    losses[epoch, 0] = d_loss_epoch/len(dataloader)
    losses[epoch, 1] = g_loss_epoch/len(dataloader)
    decisions[epoch, 0] = d_real_epoch/len(dataloader)
    decisions[epoch, 1] = d_fake_epoch/len(dataloader)

    with torch.no_grad():
        fake = gnet(fixed_noise).detach().cpu()
        save_image(fake, f'samples/epoch_{epoch+1}.png', nrow=8, normalize=True)
    
    if epoch % 5 == 0:
        generate_and_save_fid_samples(gnet, num_samples=2000, dir="generated_fid_samples")
        
        metrics_dict = calculate_metrics(
            input1="data/data",
            input2="generated_fid_samples",
            cuda=True,
            fid=True,
            isc=False,
            kid=False,
            verbose=False
        )
        
        fid = metrics_dict['frechet_inception_distance']
        print(f"Epoch {epoch}: FID={fid:.2f}, LR={g_optim.param_groups[0]['lr']:.2e}")

In [None]:
import matplotlib.pyplot as plt

fig,ax = plt.subplots(1,3,figsize=(18,5))

ax[0].plot(losses)
ax[0].set_xlabel('Epochs')
ax[0].set_ylabel('Loss')
ax[0].set_title('Model loss')
ax[0].legend(['Discrimator','Generator'])

ax[1].plot(losses[::5,0],losses[::5,1],'k.',alpha=.1)
ax[1].set_xlabel('Discriminator loss')
ax[1].set_ylabel('Generator loss')

ax[2].plot(decisions)
ax[2].set_xlabel('Epochs')
ax[2].set_ylabel('Probablity ("real")')
ax[2].set_title('Discriminator output')
ax[2].legend(['Real','Fake'])

plt.show()

In [None]:
gnet.eval()
with torch.no_grad():
    fake_data = gnet(torch.randn(12, 100, 1, 1).to(device)).cpu()
    
    fake_data = fake_data.permute(0, 2, 3, 1)  # Change from (N,C,H,W) to (N,H,W,C)
    fake_data = (fake_data + 1) / 2  # Scale from [-1,1] to [0,1]

    fig, axs = plt.subplots(3, 4, figsize=(8, 6))
    for i, ax in enumerate(axs.flatten()):
        ax.imshow(fake_data[i])
        ax.axis('off')
    plt.show()