In [2]:
try:
    import torchinfo
except:
    %pip install torchinfo jaxtyping einops datasets
    %pip install -U datasets

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Collecting jaxtyping
  Downloading jaxtyping-0.3.3-py3-none-any.whl.metadata (7.8 kB)
Collecting wadler-lindig>=0.1.3 (from jaxtyping)
  Downloading wadler_lindig-0.1.7-py3-none-any.whl.metadata (17 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Downloading jaxtyping-0.3.3-py3-none-any.whl (55 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.9/55.9 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading wadler_lindig-0.1.7-py3-none-any.whl (20 kB)
Installing collected packages: wadler-lindig, torchinfo, jaxtyping
Successfully installed jaxtyping-0.3.3 torchinfo-1.8.0 wadler-lindig-0.1.7
Collecting datasets
  Downloading datasets-4.4.1-py3-none-any.whl.metadata (19 kB)
Collecting pyarrow>=21.0.0 (from datasets)
  Downloading pyarrow-22.0.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (3.2 kB)
Downloading datasets-4.4.1-py3-none-any.whl (511 kB)
[2K   [90m━━━━━━━

In [3]:
import os
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Literal

import einops
import torch as t
from torch.nn import BatchNorm2d, Conv2d, Linear, ReLU, Sequential, ConvTranspose2d
import torchinfo
import wandb
from datasets import load_dataset
from einops.layers.torch import Rearrange
from jaxtyping import Float, Int
from torch import Tensor, nn
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import datasets, transforms
from tqdm import tqdm

device = t.device(
    "mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu"
)

In [4]:
def get_dataset() -> Dataset:
    image_size = 64
    transform = transforms.Compose(
        [
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )
    trainset = datasets.CelebA(
        root="./data",
        split="train",
        download=True,
        transform=transform
    )

    return trainset

In [5]:
class Tanh(nn.Module):
    def forward(self, x: Tensor) -> Tensor:
        return (t.exp(x) - t.exp(-x)) / (t.exp(x) + t.exp(-x))


class LeakyReLU(nn.Module):
    def __init__(self, negative_slope: float = 0.01):
        super().__init__()
        self.negative_slope = negative_slope

    def forward(self, x: Tensor) -> Tensor:
        return t.where(x > 0, x, self.negative_slope * x)

    def extra_repr(self) -> str:
        return f"negative_slope={self.negative_slope}"


class Sigmoid(nn.Module):
    def forward(self, x: Tensor) -> Tensor:
        return 1 / (1 + t.exp(-x))

In [6]:
class Generator(nn.Module):
    def __init__(
        self,
        latent_dim_size: int = 100,
        img_size: int = 64,
        img_channels: int = 3,
        hidden_channels: list[int] = [128, 256, 512],
    ):
        """
        Implements the generator architecture from the DCGAN paper (the diagram at the top
        of page 4). We assume the size of the activations doubles at each layer (so image
        size has to be divisible by 2 ** len(hidden_channels)).

        Args:
            latent_dim_size:
                the size of the latent dimension, i.e. the input to the generator
            img_size:
                the size of the image, i.e. the output of the generator
            img_channels:
                the number of channels in the image (3 for RGB, 1 for grayscale)
            hidden_channels:
                the number of channels in the hidden layers of the generator (starting closest
                to the middle of the DCGAN and going outward, i.e. in chronological order for
                the generator)
        """
        n_layers = len(hidden_channels)
        assert img_size % (2**n_layers) == 0, "activation size must double at each layer"

        super().__init__()

        hidden_channels = hidden_channels[::-1]
        n_layers = len(hidden_channels)

        self.latent_dim_size = latent_dim_size
        self.img_size = img_size
        self.img_channels = img_channels
        self.hidden_channels = hidden_channels

        height = img_size // 2 ** n_layers
        size = hidden_channels[0] * height**2

        self.project_and_reshape = Sequential(
            Linear(latent_dim_size, size, bias=False),
            Rearrange('b (c h w) -> b c h w', h=height, w=height),
            BatchNorm2d(hidden_channels[0]),
            ReLU(),
        )

        in_channels = hidden_channels
        out_channels = hidden_channels[1:] + [img_channels]

        conv_layers = []
        for idx, (i, o) in enumerate(zip(in_channels, out_channels[:-1])):
          conv_layers += [
              ConvTranspose2d(i, o, 4, 2, 1),
              BatchNorm2d(o),
              ReLU(),
          ]

        conv_layers += [
            ConvTranspose2d(in_channels[-1], out_channels[-1], 4, 2, 1),
            Tanh(),
        ]

        self.hidden_layers = Sequential(*conv_layers)

    def forward(self, x: Tensor) -> Tensor:
        x = self.project_and_reshape(x)
        x = self.hidden_layers(x)
        return x


class Discriminator(nn.Module):
    def __init__(
        self,
        img_size: int = 64,
        img_channels: int = 3,
        hidden_channels: list[int] = [128, 256, 512],
    ):
        """
        Implements the discriminator architecture from the DCGAN paper (the mirror image of
        the diagram at the top of page 4). We assume the size of the activations doubles at
        each layer (so image size has to be divisible by 2 ** len(hidden_channels)).

        Args:
            img_size:
                the size of the image, i.e. the input of the discriminator
            img_channels:
                the number of channels in the image (3 for RGB, 1 for grayscale)
            hidden_channels:
                the number of channels in the hidden layers of the discriminator (starting
                closest to the middle of the DCGAN and going outward, i.e. in reverse-
                chronological order for the discriminator)
        """
        n_layers = len(hidden_channels)
        assert img_size % (2**n_layers) == 0, "activation size must double at each layer"

        super().__init__()
        n_layers = len(hidden_channels)

        self.img_size = img_size
        self.img_channels = img_channels
        self.hidden_channels = hidden_channels

        in_channels = [img_channels] + hidden_channels[:-1]
        out_channels = hidden_channels

        conv_layers = []
        for idx, (i, o) in enumerate(zip(in_channels, out_channels)):
          conv_layers.append(Conv2d(i, o, kernel_size=4, stride=2, padding=1)),
          if (idx != 0):
            conv_layers.append(BatchNorm2d(o))

          conv_layers.append(LeakyReLU(0.2))

        self.hidden_layers = Sequential(*conv_layers)

        final_height = img_size // (2**n_layers)
        final_size = hidden_channels[-1] * (final_height**2)
        self.classifier = Sequential(
            Rearrange('b c h w -> b (c h w)'),
            Linear(final_size, 1, bias=False),
            # Sigmoid()
        )

    def forward(self, x: Tensor) -> Tensor:
        x = self.hidden_layers(x)
        x = self.classifier(x)
        return x.squeeze()  # remove dummy `out_channels` dimension

class DCGAN(nn.Module):
    netD: Discriminator
    netG: Generator

    def __init__(
        self,
        latent_dim_size: int = 100,
        img_size: int = 64,
        img_channels: int = 3,
        hidden_channels: list[int] = [128, 256, 512],
    ):
        super().__init__()
        self.latent_dim_size = latent_dim_size
        self.img_size = img_size
        self.img_channels = img_channels
        self.hidden_channels = hidden_channels
        self.netD = Discriminator(img_size, img_channels, hidden_channels)
        self.netG = Generator(latent_dim_size, img_size, img_channels, hidden_channels)

In [7]:
def initialize_weights(model: nn.Module) -> None:
    """
    Initializes weights according to the DCGAN paper (details at the end of page 3 of the DCGAN
    paper), by modifying the weights of the model in place.
    """
    for module in model.modules():
        if isinstance(module, (ConvTranspose2d, Conv2d, Linear)):
            nn.init.normal_(module.weight.data, 0.0, 0.02)
        elif isinstance(module, BatchNorm2d):
            nn.init.normal_(module.weight.data, 1.0, 0.02)
            nn.init.constant_(module.bias.data, 0.0)

In [8]:
model = DCGAN().to(device)
x = t.randn(3, 100).to(device)
print(torchinfo.summary(model.netG, input_data=x), end="\n\n")
print(torchinfo.summary(model.netD, input_data=model.netG(x)))

Layer (type:depth-idx)                   Output Shape              Param #
Generator                                [3, 3, 64, 64]            --
├─Sequential: 1-1                        [3, 512, 8, 8]            --
│    └─Linear: 2-1                       [3, 32768]                3,276,800
│    └─Rearrange: 2-2                    [3, 512, 8, 8]            --
│    └─BatchNorm2d: 2-3                  [3, 512, 8, 8]            1,024
│    └─ReLU: 2-4                         [3, 512, 8, 8]            --
├─Sequential: 1-2                        [3, 3, 64, 64]            --
│    └─ConvTranspose2d: 2-5              [3, 256, 16, 16]          2,097,408
│    └─BatchNorm2d: 2-6                  [3, 256, 16, 16]          512
│    └─ReLU: 2-7                         [3, 256, 16, 16]          --
│    └─ConvTranspose2d: 2-8              [3, 128, 32, 32]          524,416
│    └─BatchNorm2d: 2-9                  [3, 128, 32, 32]          256
│    └─ReLU: 2-10                        [3, 128, 32, 32]    

In [9]:
@dataclass
class DCGANArgs:
    """
    Class for the arguments to the DCGAN (training and architecture).
    Note, we use field(defaultfactory(...)) when our default value is a mutable object.
    """

    # architecture
    latent_dim_size: int = 100
    hidden_channels: list[int] = field(default_factory=lambda: [128, 256, 512])

    # data & training
    batch_size: int = 64
    epochs: int = 3
    lr: float = 0.0002
    betas: tuple[float, float] = (0.5, 0.999)
    clip_grad_norm: float | None = None

    # logging
    use_wandb: bool = True
    wandb_project: str | None = "day5-gan"
    wandb_name: str | None = None
    log_every_n_steps: int = 250


class DCGANTrainer:
    def __init__(self, args: DCGANArgs):
        self.args = args
        self.trainset = get_dataset()
        self.trainloader = DataLoader(
            self.trainset, batch_size=args.batch_size, shuffle=True, num_workers=8
        )

        batch, img_channels, img_height, img_width = next(iter(self.trainloader))[0].shape
        assert img_height == img_width

        self.model = (
            DCGAN(args.latent_dim_size, img_height, img_channels, args.hidden_channels)
            .to(device)
            .train()
        )

        self.loss_fn = nn.BCEWithLogitsLoss()

        self.optG = t.optim.Adam(self.model.netG.parameters(), lr=args.lr, betas=args.betas)
        self.optD = t.optim.Adam(self.model.netD.parameters(), lr=args.lr, betas=args.betas)

    def training_step_discriminator(
        self,
        img_real: Float[Tensor, "batch channels height width"],
        img_fake: Float[Tensor, "batch channels height width"],
    ) -> Float[Tensor, ""]:
        """
        Generates a real and fake image, and performs a gradient step on the discriminator to
        maximize log(D(x)) + log(1-D(G(z))). Logs to wandb if enabled.
        """
        self.optD.zero_grad()

        d_g_z = self.model.netD(img_fake.detach())
        labels_zeros = t.zeros_like(d_g_z)
        loss_fake = self.loss_fn(d_g_z, labels_zeros) # -log(1 - d_g_z)

        d_x = self.model.netD(img_real)
        labels_ones = t.ones_like(d_x)
        loss_real = self.loss_fn(d_x, labels_ones) # -log(d_x)

        loss = loss_fake + loss_real

        loss.backward()
        if self.args.clip_grad_norm is not None:
          nn.utils.clip_grad_norm_(self.model.netD.parameters(), self.args.clip_grad_norm)
        self.optD.step()

        if self.args.use_wandb:
            wandb.log(dict(lossD=loss), step=self.step)

        return loss

    def training_step_generator(
        self, img_fake: Float[Tensor, "batch channels height width"]
    ) -> Float[Tensor, ""]:
        """
        Performs a gradient step on the generator to maximize log(D(G(z))). Logs to wandb if enabled.
        """
        self.optG.zero_grad()

        d_g_z = self.model.netD(img_fake)
        labels_ones = t.ones_like(d_g_z)
        loss = self.loss_fn(d_g_z, labels_ones) # non saturating version - min log(D(G(z)))

        loss.backward()
        if self.args.clip_grad_norm is not None:
          nn.utils.clip_grad_norm_(self.model.netG.parameters(), self.args.clip_grad_norm)
        self.optG.step()

        if self.args.use_wandb:
            wandb.log(dict(lossG=loss), step=self.step)

        return loss

    @t.inference_mode()
    def log_samples(self) -> None:
        """
        Performs evaluation by generating 8 instances of random noise and passing them through the
        generator, then optionally logging the results to Weights & Biases.
        """
        assert self.step > 0, (
            "First call should come after a training step. Remember to increment `self.step`."
        )
        self.model.netG.eval()

        # Generate random noise
        t.manual_seed(42)
        noise = t.randn(10, self.model.latent_dim_size).to(device)
        # Get generator output
        output = self.model.netG(noise)
        # Clip values to make the visualization clearer
        output = output.clamp(output.quantile(0.01), output.quantile(0.99))
        # Log to weights and biases
        if self.args.use_wandb:
            output = einops.rearrange(output, "b c h w -> b h w c").cpu().numpy()
            wandb.log({"images": [wandb.Image(arr) for arr in output]}, step=self.step)

        self.model.netG.train()

    def train(self) -> DCGAN:
        """Performs a full training run."""
        self.step = 0
        if self.args.use_wandb:
            wandb.init(project=self.args.wandb_project, name=self.args.wandb_name)

        for epoch in range(self.args.epochs):
            progress_bar = tqdm(self.trainloader, total=len(self.trainloader), ascii=True)

            for img_real, label in progress_bar:
                img_real = img_real.to(device)
                z = t.randn(self.args.batch_size, self.model.latent_dim_size).to(device)
                img_fake = self.model.netG(z)
                lossD = self.training_step_discriminator(img_real, img_fake.detach())

                lossG = self.training_step_generator(img_fake)

                self.step += 1
                progress_bar.set_description(f"{epoch=}, {lossD=:.4f}, {lossG=:.4f}, batches={self.step}")

                if self.step % self.args.log_every_n_steps == 0:
                    self.log_samples()

            gen_path, disc_path = self.save_checkpoint(epoch)

            if self.args.use_wandb:
                self.log_artifact(gen_path, disc_path, epoch)

        if self.args.use_wandb:
            wandb.finish()

        return self.model

    def save_checkpoint(self, epoch: int):
      os.makedirs("checkpoints", exist_ok=True)

      gen_path = f"checkpoints/generator_epoch{epoch}.pt"
      disc_path = f"checkpoints/discriminator_epoch{epoch}.pt"

      t.save(self.model.netG.state_dict(), gen_path)
      t.save(self.model.netD.state_dict(), disc_path)

      return gen_path, disc_path

    def log_artifact(self, gen_path: str, disc_path: str, epoch: int):
      artifact = wandb.Artifact(f"dcgan_epoch_{epoch}", type="model")
      artifact.add_file(gen_path)
      artifact.add_file(disc_path)
      wandb.log_artifact(artifact)



In [11]:
# Arguments for CelebA
args = DCGANArgs(
    hidden_channels=[128, 256, 512],
    batch_size=32,  # if you get OOM errors, reduce this!
    epochs=5,
    use_wandb=True,
)
trainer = DCGANTrainer(args)
# dcgan = trainer.train()

Downloading...
From (original): https://drive.google.com/uc?id=0B7EVK8r0v71pZjFTYXZWM3FlRnM
From (redirected): https://drive.usercontent.google.com/download?id=0B7EVK8r0v71pZjFTYXZWM3FlRnM&confirm=t&uuid=a010d80a-8361-46c6-8451-f76a1a018116
To: /content/data/celeba/img_align_celeba.zip

  0%|          | 0.00/1.44G [00:00<?, ?B/s][A
  0%|          | 1.57M/1.44G [00:00<04:24, 5.45MB/s][A
  0%|          | 2.62M/1.44G [00:00<03:23, 7.09MB/s][A
  1%|          | 7.34M/1.44G [00:00<01:22, 17.5MB/s][A
  1%|          | 9.44M/1.44G [00:00<01:28, 16.2MB/s][A
  2%|▏         | 30.4M/1.44G [00:01<00:37, 37.9MB/s][A
  4%|▎         | 53.5M/1.44G [00:01<00:19, 71.5MB/s][A
  4%|▍         | 63.4M/1.44G [00:01<00:29, 47.4MB/s][A
  6%|▌         | 80.7M/1.44G [00:01<00:28, 47.8MB/s][A
  7%|▋         | 101M/1.44G [00:02<00:22, 60.1MB/s] [A
  8%|▊         | 114M/1.44G [00:02<00:20, 64.0MB/s][A
  8%|▊         | 122M/1.44G [00:02<00:21, 62.5MB/s][A
 10%|▉         | 143M/1.44G [00:02<00:15, 82.0MB/s]

In [None]:
args = DCGANArgs(
    hidden_channels=[128, 256, 512],
    batch_size=32,
    epochs=5,
    use_wandb=False,
)
model = DCGAN(args)
model.netG.load_state_dict(t.load("checkpoints/generator_epoch2.pt"))