### Imports

In [None]:
import torch
from torch import nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device, torch.cuda.get_device_name(0))

### Loading in data and definig utils

In [None]:
data_dir = '../data/cats_GAN'
batch_size = 512
im_size = 64

from torchvision import transforms, datasets


def load_transformed_dataset(im_size: int=64):
    transform = transforms.Compose([
            transforms.Resize((im_size, im_size)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x * 2 - 1)
            ])

    return datasets.ImageFolder(data_dir, transform=transform)

def denorm(image):
    return (image + 1) / 2

def show_sample_images(images):
    n_samples = min(images.size(0), 64)
    n_row = 8
    n_col = n_samples // n_row

    images = denorm(images[:n_samples])
    _, axes = plt.subplots(n_row, n_col, figsize=(8, 8))
    plt.suptitle("Sample cats from the dataset", fontsize=16)

    for i, ax in enumerate(axes.flat):
        image = images[i]
        ax.imshow(image.permute(1, 2, 0))
        ax.axis('off')
    
    plt.subplots_adjust(wspace=0, hspace=0)
    plt.show()

def cross_entropy_loss(logits, labels):
    return F.binary_cross_entropy_with_logits(logits, labels)

def get_noise(n_batch, latent_size):
    return torch.randn(n_batch, latent_size, device=device)

def visualize_images(images, n_rows=8, n_cols=8, title=None):
    images = (images / 2 + 0.5).clamp(0, 1)  # Rescale to [0, 1] range and clamp
    cmap = 'viridis'

    _, axes = plt.subplots(n_rows, n_cols, figsize=(n_rows, n_cols))
    for i in range(n_rows):
        for j in range(n_cols):
            ax = axes[i, j]
            image = images[i * n_cols + j]
            image = image.permute(1, 2, 0)
            ax.imshow(image.cpu().detach().numpy(), cmap=cmap)
            ax.axis("off")
    plt.subplots_adjust(wspace=0, hspace=0)
    if title:
        plt.suptitle(title, fontsize=16)
    plt.show()

def plot_learning_curve(g_loss_history, d_g_z_loss_history, d_x_loss_history, d_loss_history):
    plt.figure(figsize=(15, 8))
    plt.plot(g_loss_history, label='Generator', alpha=0.7)
    plt.plot(d_g_z_loss_history, label='Discriminator (Generated)', alpha=0.7)
    plt.plot(d_x_loss_history, label='Discriminator (Real)', alpha=0.7)
    plt.plot(d_loss_history, label='Discriminator Total', alpha=0.7)
    plt.xlabel('Ilość iteracji')
    plt.ylabel('Wartość funkcji straty')
    plt.legend()
    plt.title('Krzywe uczenia')
    plt.grid(True)
    plt.show()

data = load_transformed_dataset()
dataloader = DataLoader(data, batch_size=batch_size, shuffle=True)
for batch in dataloader:
    images, _ = batch
    show_sample_images(images)
    break

### Generator and Discriminator models

In [None]:
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, latent_size=128, n_c=3):
        super(Generator, self).__init__()
        # in: [{batch_size}, {latent_size}]
        self.linear = nn.Linear(latent_size, 512 * 4 * 4)
        self.norm_1 = nn.BatchNorm2d(512)
        self.relu_1 = nn.LeakyReLU(0.2)
        # out: [{batch_size}, 512, 4, 4]

        self.conv_layer_2 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
        self.norm_2 = nn.BatchNorm2d(256)
        self.relu_2 = nn.LeakyReLU(0.2)
        # out: [{batch_size}, 256, 8, 8]

        self.conv_layer_3 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)
        self.norm_3 = nn.BatchNorm2d(128)
        self.relu_3 = nn.LeakyReLU(0.2)
        # out: [{batch_size}, 128, 16, 16]

        self.conv_layer_4 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
        self.norm_4 = nn.BatchNorm2d(64)
        self.relu_4 = nn.LeakyReLU(0.2)
        # out: [{batch_size}, 64, 32, 32]

        self.output_layer = nn.ConvTranspose2d(64, n_c, kernel_size=4, stride=2, padding=1)
        # out: [{batch_size}, 3, 64, 64]

    def forward(self, input_layer):
        h = self.linear(input_layer)
        h = h.view(-1, 512, 4, 4)
        h = self.norm_1(h)
        h = self.relu_1(h)

        h = self.conv_layer_2(h)
        h = self.norm_2(h)
        h = self.relu_2(h)

        h = self.conv_layer_3(h)
        h = self.norm_3(h)
        h = self.relu_3(h)

        h = self.conv_layer_4(h)
        h = self.norm_4(h)
        h = self.relu_4(h)

        x = self.output_layer(h)
        return x


class Discriminator(nn.Module):
    def __init__(self, image_size):
        super(Discriminator, self).__init__()
        # in: [{batch_size}, 3, 64, 64]
        self.conv_1 = nn.Conv2d(3, image_size, kernel_size=4, stride=2, padding=1)
        self.norm_1 = nn.BatchNorm2d(image_size)
        self.leaky_relu_1 = nn.LeakyReLU(0.2)
        # out [{batch_size}, image_size, 32, 32]

        self.conv_2 = nn.Conv2d(image_size, 2 * image_size, kernel_size=4, stride=2, padding=1)
        self.norm_2 = nn.BatchNorm2d(2 * image_size)
        self.leaky_relu_2 = nn.LeakyReLU(0.2)
        # out [{batch_size}, 2 * image_size, 16, 16]

        self.conv_3 = nn.Conv2d(2 * image_size, 4 * image_size, kernel_size=4, stride=2, padding=1)
        self.norm_3 = nn.BatchNorm2d(4 * image_size)
        self.leaky_relu_3 = nn.LeakyReLU(0.2)
        # out [{batch_size}, 4 * image_size, 8, 8]

        self.conv_4 = nn.Conv2d(4 * image_size, 1, kernel_size=8, stride=1, padding=0)
        self.norm_4 = nn.BatchNorm2d(8 * image_size)
        self.leaky_relu_4 = nn.LeakyReLU(0.2)
        # out [{batch_size}, 1, 1, 1]

        self.flatten = nn.Flatten()
        self.output_layer = nn.Linear(1, 1)

    def forward(self, input_layer):
        h = self.conv_1(input_layer)
        h = self.leaky_relu_1(h)

        h = self.conv_2(h)
        h = self.norm_2(h)
        h = self.leaky_relu_2(h)

        h = self.conv_3(h)
        h = self.norm_3(h)
        h = self.leaky_relu_3(h)

        h = self.conv_4(h)
        h = self.leaky_relu_4(h)

        h = self.flatten(h)
        x = self.output_layer(h)
        return x

### Initialize parameters for GAN model training

In [None]:
n_h, n_w, n_c = [im_size, im_size, 3]
latent_size = 128
n_batch = 1000
lr = 2.5e-4
epochs = 10000
noise_vector = get_noise(n_batch, latent_size).to(device)
device = "cuda" if torch.cuda.is_available() else "cpu"

#### Training loop

In [None]:
generator = Generator(latent_size=latent_size).to(device)
discriminator = Discriminator(im_size).to(device)

g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

g_loss_history = []
d_g_z_loss_history = []
d_x_loss_history = []
d_loss_history = []

for epoch in range(epochs):
    noise_vector = get_noise(n_batch, latent_size).to(device)
    if epoch % 100 == 0:
        generated_images = generator(noise_vector)
        visualize_images(generated_images, title=f"Wygenerowane obrazy po {epoch} epokach")
        if epoch != 0:
            plot_learning_curve(g_loss_history, d_g_z_loss_history, d_x_loss_history, d_loss_history)
            print(f"Epoka [{epoch + 1}/{epochs}], g_loss: {g_loss.item()}, d_g_z_loss: {d_g_z_loss.item()},  d_x_loss: {d_x_loss.item()}, d_loss: {d_loss.item()}")


    for batch in tqdm(dataloader, ncols=80, leave=False):
        data, _ = batch
        data = data.to(device)
        
        d_optimizer.zero_grad()
        d_g_z = discriminator(generator(noise_vector))
        d_x = discriminator(data)
  
        d_g_z_loss = cross_entropy_loss(d_g_z, torch.zeros_like(d_g_z))
        d_x_loss = cross_entropy_loss(d_x, torch.ones_like(d_x))
        d_loss = (d_g_z_loss + d_x_loss) / 2
        
        d_loss.backward()
        d_optimizer.step()
        
        g_optimizer.zero_grad()
        g_z = generator(noise_vector)
        d_g_z = discriminator(g_z)
        g_loss = cross_entropy_loss(d_g_z, torch.ones_like(d_g_z))
        
        g_loss.backward()
        g_optimizer.step()

        g_loss_history.append(g_loss.item())
        d_g_z_loss_history.append(d_g_z_loss.item())
        d_x_loss_history.append(d_x_loss.item())
        d_loss_history.append(d_loss.item())


In [None]:
images = generator(get_noise(n_batch, latent_size))
visualize_images(images, title=f"Wygenerowana seria obrazów po {epochs} epokach")