In [1]:
# Import necessary modules from utils and gan_models
import torch
from torch import nn, Tensor
import numpy as np
from torchvision.utils import save_image
import os

import utils
from gan_models import Generator

# Get device from utils
device = utils.get_device()
device


'cpu'

# 1. Dataset

In [2]:
# Define dataset parameters
img_size = 32
BATCH_SIZE = 64

# Get dataloader using utils function
dataloader = utils.get_mnist_dataloader(img_size=img_size, batch_size=BATCH_SIZE)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to mnist_data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:04<00:00, 2.11MB/s]


Extracting mnist_data\MNIST\raw\train-images-idx3-ubyte.gz to mnist_data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to mnist_data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 115kB/s]


Extracting mnist_data\MNIST\raw\train-labels-idx1-ubyte.gz to mnist_data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to mnist_data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:01<00:00, 926kB/s] 


Extracting mnist_data\MNIST\raw\t10k-images-idx3-ubyte.gz to mnist_data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to mnist_data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 1.18MB/s]

Extracting mnist_data\MNIST\raw\t10k-labels-idx1-ubyte.gz to mnist_data\MNIST\raw






# 2. Model

In [3]:
# Define model parameters
channels = 1
img_shape = (channels, img_size, img_size)
latent_dim = 100

In [4]:
# Instantiate generator from imported class
generator = Generator(latent_dim=latent_dim, img_shape=img_shape)
generator.to(device)

Generator(
  (model): Sequential(
    (0): Linear(in_features=100, out_features=128, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Linear(in_features=128, out_features=256, bias=True)
    (3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Linear(in_features=256, out_features=512, bias=True)
    (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Linear(in_features=512, out_features=1024, bias=True)
    (9): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Linear(in_features=1024, out_features=1024, bias=True)
    (12): Tanh()
  )
)

# 3. Training

In [5]:
# Create output directory using utils function
output_dir = "./images_L2"
utils.create_dir(output_dir)
save_interval = 10

In [6]:
EPOCHS = 200
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0001)
criterion = nn.MSELoss()
hist = {
    "train_G_loss": []
}

for epoch in range(EPOCHS):
    total_loss = 0.0

    for i, (imgs, _) in enumerate(dataloader):

        real_imgs = imgs.to(device)

        # --- Train Generator --- 
        optimizer_G.zero_grad()
        
        # Noise input for Generator
        z = torch.randn((imgs.shape[0], latent_dim)).to(device)

        gen_imgs = generator(z)
        G_loss = criterion(gen_imgs, real_imgs)
        total_loss += G_loss.item()

        G_loss.backward()
        optimizer_G.step()

    
    total_loss = total_loss / len(dataloader)    
    print(f"Epoch [{epoch + 1}/{EPOCHS}], total_loss: {total_loss:.4f}")

    hist["train_G_loss"].append(total_loss)

    if epoch % save_interval == 0:
        save_image(gen_imgs.data[:25], f"images_l2/epoch_{epoch}.png", nrow=5, normalize=True)

Epoch [1/200], total_loss: 0.2495
Epoch [2/200], total_loss: 0.2292
Epoch [3/200], total_loss: 0.2279
Epoch [4/200], total_loss: 0.2274
Epoch [5/200], total_loss: 0.2273
Epoch [6/200], total_loss: 0.2272
Epoch [7/200], total_loss: 0.2270
Epoch [8/200], total_loss: 0.2271
Epoch [9/200], total_loss: 0.2270
Epoch [10/200], total_loss: 0.2270
Epoch [11/200], total_loss: 0.2270
Epoch [12/200], total_loss: 0.2269
Epoch [13/200], total_loss: 0.2269
Epoch [14/200], total_loss: 0.2268
Epoch [15/200], total_loss: 0.2268
Epoch [16/200], total_loss: 0.2268
Epoch [17/200], total_loss: 0.2268
Epoch [18/200], total_loss: 0.2268
Epoch [19/200], total_loss: 0.2267
Epoch [20/200], total_loss: 0.2267
Epoch [21/200], total_loss: 0.2267
Epoch [22/200], total_loss: 0.2267


KeyboardInterrupt: 