<a href="https://colab.research.google.com/github/Calcifer777/learn-deep-learning/blob/main/generative-models/samples/GAN_fashion_mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports

In [None]:
import logging
from pathlib import Path
from typing import Tuple

from matplotlib import pyplot as plt
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
from torch import nn
from torch import Tensor
from torch.utils.data import DataLoader
from torch.optim import AdamW, Optimizer
from torchvision import datasets
from torchvision.transforms.functional import pil_to_tensor
from torchvision.transforms import (
    Compose,
    Normalize,
    ToTensor,
)

In [None]:
if torch.cuda.is_available():
    DEVICE = "cuda"
else:
    DEVICE = "cpu"

print(f"{DEVICE =}")

In [None]:
!test -e samples && rm samples/*
!test -e models && rm models/*

# Params

In [None]:
LATENT_DIM = 64 # was 32
INPUT_DIM = 28
LR_G = 0.0002
LR_D = 0.0002
EPOCHS = 50
BATCH_SIZE = 16 # was 32
EPS = 1e-5
BASE_DIM_GEN = 7

PATH_MODELS = Path("./models/")
PATH_MODELS.mkdir(exist_ok=True, parents=True)

PATH_SAMPLES = Path("./samples/")
PATH_SAMPLES.mkdir(exist_ok=True, parents=True)

# Model

In [None]:
class Generator(nn.Module):
    def __init__(self,):
        super(Generator, self).__init__()
        base_ch = 256
        self.net = nn.Sequential(
            nn.Linear(LATENT_DIM, BASE_DIM_GEN * BASE_DIM_GEN * base_ch),
            nn.LeakyReLU(0.2),
            nn.Unflatten(-1, (base_ch, BASE_DIM_GEN, BASE_DIM_GEN)),
            Generator.cnn_block(base_ch, base_ch // 2, upsample=True),
            Generator.cnn_block(base_ch // 2, base_ch // 4, upsample=True),
            Generator.cnn_block(base_ch // 4, 1, final_layer=True),
        )

    @staticmethod
    def cnn_block(
        in_ch: int,
        out_ch: int,
        upsample: bool = False,
        final_layer: bool = False,
    ):
        # Base
        layers = [nn.Conv2d(in_ch, out_ch, kernel_size=3, padding="same")]
        # Upsample
        if upsample:
            layers.append(nn.Upsample(scale_factor=2, mode='nearest'),)
        # BatchNorm
        layers.append(nn.BatchNorm2d(out_ch))
        # Activation
        layers.append(nn.Tanh() if final_layer else nn.LeakyReLU(.2))
        return nn.Sequential(*layers)

    def forward(self, x: Tensor) -> Tensor:
        return self.net(x)

In [None]:
test_generator = Generator().to(DEVICE)

x = torch.rand((BATCH_SIZE, LATENT_DIM)).to(DEVICE)

for i in range(len(test_generator.net)):
    x = test_generator.net[i](x)
    print(x.shape)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, ):
        super(Discriminator, self).__init__()

        self.net = nn.Sequential(
            self.conv_block(1, 64),
            self.conv_block(64, 256),
            nn.Flatten(),
            nn.Linear(4096, 1),
            nn.Sigmoid()
        )

    @staticmethod
    def conv_block(in_ch: int, out_ch: int, final_layer: bool = False):
        block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=5, stride=2),
            nn.BatchNorm2d(out_ch,),
        )
        if not final_layer:
            block.append(nn.LeakyReLU(0.2))
        return block

    def forward(self, x: Tensor):
        return self.net(x)


In [None]:
discriminator_test = Discriminator().to(DEVICE)
x = torch.rand((16, 1, 28, 28)).to(DEVICE)

for i in range(len(discriminator_test.net)):
    x = discriminator_test.net[i](x)
    print(x.shape)

x.mean()

# Dataset

In [None]:
ds_sample = datasets.FashionMNIST(
    "./data/fashion-mnist/",
    download=True,
    train=True,
)

In [None]:
sample_img, sample_label = ds_sample[0]
print(f"{sample_label = }")
plt.imshow(sample_img)

In [None]:
def rescale(img: Image.Image):
    img = pil_to_tensor(img)
    return (img - 127.5) / 127.5

In [None]:
ds = datasets.FashionMNIST(
    "./data/fashion-mnist/",
    train=True,
    download=True,
    transform=Compose([
        rescale
    ]),
)

In [None]:
dl = DataLoader(
    dataset=ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

# Train

In [None]:
prior = torch.distributions.Normal(0, 1)

def sample(sample_shape: torch.Size = torch.Size()):
    return prior.sample(sample_shape)

In [None]:
loss_fn = nn.BCELoss(reduction="mean")

In [None]:
def train_discriminator(
    real_img: Tensor,
    generator: nn.Module,
    discriminator: nn.Module,
    optim_d: Optimizer,
) -> float:
    optim_d.zero_grad()
    latent = sample((real_img.shape[0], LATENT_DIM)).to(DEVICE)
    gen_img = generator(latent).detach()
    gen_img_probs = torch.clamp(discriminator(gen_img), EPS, 1.-EPS)
    real_img_probs = torch.clamp(discriminator(real_img), EPS, 1-EPS)
    loss = (
        loss_fn(real_img_probs, torch.ones_like(real_img_probs))
        + loss_fn(gen_img_probs, torch.zeros_like(gen_img_probs))
    ) / 2
    loss.backward()
    optim_d.step()
    return loss.item()

In [None]:
def train_generator(
    batch: Tensor,
    generator: nn.Module,
    discriminator: nn.Module,
    optim_g: Optimizer,
) -> float:
    optim_g.zero_grad()
    latent = sample((batch.shape[0], LATENT_DIM)).to(DEVICE)
    gen_img = generator(latent)
    gen_img_probs = torch.clamp(discriminator(gen_img), EPS, 1-EPS)
    loss = loss_fn(gen_img_probs, torch.ones(batch.shape[0], 1).to(DEVICE))
    loss.backward()
    optim_g.step()
    return loss.item()

In [None]:
def train_step(
    batch: Tensor,
    generator: nn.Module,
    discriminator: nn.Module,
    optim_g: Optimizer,
    optim_d: Optimizer,
):
    generator.train()
    discriminator.train()
    loss_d = train_discriminator(
        batch,
        generator,
        discriminator,
        optim_d,
    )
    loss_g = train_generator(
        batch,
        generator,
        discriminator,
        optim_g,
    )
    return dict(
        g=loss_g,
        d=loss_d,
    )

In [None]:
sample_batch = next(iter(dl))
sample_batch_img, _ = sample_batch
sample_batch_img.shape

test_generator = Generator().to(DEVICE)
test_discriminator = Discriminator().to(DEVICE)
test_optim_g = AdamW(test_generator.parameters())
test_optim_d = AdamW(test_discriminator.parameters())

train_step(
    sample_batch_img.to(DEVICE),
    test_generator,
    test_discriminator,
    test_optim_d,
    test_optim_d,
)

In [None]:
def train_epoch(
    dl: DataLoader,
    generator: nn.Module,
    discriminator: nn.Module,
    optim_g: Optimizer,
    optim_d: Optimizer,
):
    loss_acc_g, loss_acc_d = 0.0, 0.0
    for batch_idx, (batch, _) in enumerate(tqdm(dl)):
        loss = train_step(
            batch.to(DEVICE),
            generator,
            discriminator,
            optim_g,
            optim_d,
        )
        loss_acc_g += loss["g"]
        loss_acc_d += loss["d"]
        logging.info(
            f"Batch ({batch_idx}/{len(dl)})"
            f" - Loss_G: {loss['g']:.4f}"
            f" - Loss_D: {loss['d']:.4f}"
        )
    loss_avg_g = loss_acc_g / len(dl)
    loss_avg_d = loss_acc_d / len(dl)
    return dict(
        g=loss_avg_g,
        d=loss_avg_d,
    )

In [None]:
# train_epoch(
#     dl,
#     generator,
#     discriminator,
#     optimizer_g,
#     optimizer_d,
# )

In [None]:
@torch.inference_mode
def generate_samples(model: nn.Module, path_out: Path = None):

    rows = 5
    cols = 5

    model.eval()
    latent = sample((rows * cols, LATENT_DIM)).to(DEVICE)
    x = model(latent)
    x = x.detach().cpu().numpy()
    fig, ax = plt.subplots(rows, cols)
    for i, ax in enumerate(ax.flatten()):
        plottable_image = np.reshape(x[i], x.shape[2:])
        plottable_image = (plottable_image * 127.5) + 127.5
        ax.imshow(plottable_image) # , cmap='gray')
        ax.axis('off')

    if path_out is not None:
        plt.savefig(path_out, bbox_inches='tight')
    else:
        plt.show()
    plt.close()

In [None]:
def train(
    dl: DataLoader,
    generator: nn.Module,
    discriminator: nn.Module,
    optim_g: Optimizer,
    optim_d: Optimizer,
):
    losses = dict(
        generator=[],
        discriminator=[],
    )
    for epoch_id in range(EPOCHS):
        loss = train_epoch(
            dl,
            generator,
            discriminator,
            optim_g,
            optim_d,
        )
        logging.warning(
            f"Epoch ({epoch_id}/{EPOCHS})"
            f" - Loss_G: {loss['g']:.4f}"
            f" - Loss_D: {loss['d']:.4f}"
        )
        losses["generator"].append(loss['g'])
        losses["discriminator"].append(loss['d'])
        # Save
        torch.save(generator.state_dict(), PATH_MODELS / 'generator.pt')
        torch.save(discriminator.state_dict(), PATH_MODELS / 'generator.pt')
        # Generate samples
        generate_samples(generator, PATH_SAMPLES / f"epoch_{epoch_id}.png")
    return losses

In [None]:
generator = Generator().to(DEVICE)
discriminator = Discriminator().to(DEVICE)

In [None]:
optimizer_g = AdamW(params=generator.parameters(), lr=LR_G,)
optimizer_d = AdamW(params=discriminator.parameters(), lr=LR_D,)

In [None]:
history = train(
    dl,
    generator,
    discriminator,
    optimizer_g,
    optimizer_d,
)

# Diagnostics

## Losses

In [None]:
plt.plot(range(EPOCHS), history["generator"])
plt.plot(range(EPOCHS), history["discriminator"])
plt.legend(["generator", "discriminator", ])
plt.title("Loss")

## Samples

### Unconditional generation

In [None]:
generate_samples(generator)

### Sample interpolation

In [None]:
# Generate samples
latents = torch.distributions.Normal(0, 1).sample((2, LATENT_DIM))
z1, z2 = torch.chunk(input=latents, chunks=2, dim=0)

# Interpolate
weights = Tensor(np.arange(0, 1, 0.1))
z_inter = torch.lerp(
    z1,
    z2,
    weights.unsqueeze(1),
)

generator.eval()
with torch.inference_mode():
    generated = generator(z_inter.to(DEVICE)).detach().cpu().numpy()

fig, ax = plt.subplots(1, weights.shape[0], figsize=(16,9))

for i, ax in enumerate(ax.flatten()):
    plottable_image = np.reshape(generated[i], generated.shape[2:])
    plottable_image = (plottable_image * 127.5) + 127.5
    ax.imshow(plottable_image)
    ax.axis('off')

plt.show()
plt.close()