In [1]:
import torch
from torch import nn, Tensor
import numpy as np
from torchvision.utils import save_image

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

# 1. Dataset

In [3]:
import torchvision

img_size = 32
    
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((img_size, img_size)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.5], std=[0.5])
])

images = torchvision.datasets.MNIST(root='mnist_data', train=True, 
                                    download=True, transform=transform)

In [4]:
BATCH_SIZE = 64
dataloader = torch.utils.data.DataLoader(images, batch_size=BATCH_SIZE, shuffle=True)

# 2. Model

In [5]:
channels = 1
img_shape = (channels, img_size, img_size)
latent_dim = 100

In [6]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2, inplace=True),
        
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )
        
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

In [7]:
generator = Generator()
generator.to(device)

Generator(
  (model): Sequential(
    (0): Linear(in_features=100, out_features=128, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Linear(in_features=128, out_features=256, bias=True)
    (3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Linear(in_features=256, out_features=512, bias=True)
    (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Linear(in_features=512, out_features=1024, bias=True)
    (9): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Linear(in_features=1024, out_features=1024, bias=True)
    (12): Tanh()
  )
)

# 3. Training

In [9]:
import os
os.makedirs("images_l2", exist_ok=True)
save_interval = 10

In [12]:
EPOCHS = 200
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0001)
criterion = nn.MSELoss()
hist = {
        "train_G_loss": []
}

for epoch in range(EPOCHS):
    total_loss = 0.0

    for i, (imgs, _) in enumerate(dataloader):

        real_imgs = imgs.to(device)

        # --- Train Generator --- 
        optimizer_G.zero_grad()
        
        # Noise input for Generator
        z = torch.randn((imgs.shape[0], latent_dim)).to(device)

        gen_imgs = generator(z)
        G_loss = criterion(gen_imgs, real_imgs)
        total_loss += G_loss.item()

        G_loss.backward()
        optimizer_G.step()

    
    total_loss = total_loss / len(dataloader)    
    print(f"Epoch [{epoch + 1}/{EPOCHS}], total_loss: {total_loss:.4f}")

    hist["train_G_loss"].append(total_loss)

    if epoch % save_interval == 0:
        save_image(gen_imgs.data[:25], f"images_l2/epoch_{epoch}.png", nrow=5, normalize=True)

Epoch [1/200], total_loss: 0.2267
Epoch [2/200], total_loss: 0.2265
Epoch [3/200], total_loss: 0.2265
Epoch [4/200], total_loss: 0.2264
Epoch [5/200], total_loss: 0.2263
Epoch [6/200], total_loss: 0.2263
Epoch [7/200], total_loss: 0.2263
Epoch [8/200], total_loss: 0.2262
Epoch [9/200], total_loss: 0.2262
Epoch [10/200], total_loss: 0.2262
Epoch [11/200], total_loss: 0.2261
Epoch [12/200], total_loss: 0.2261
Epoch [13/200], total_loss: 0.2261
Epoch [14/200], total_loss: 0.2260
Epoch [15/200], total_loss: 0.2260
Epoch [16/200], total_loss: 0.2260
Epoch [17/200], total_loss: 0.2260
Epoch [18/200], total_loss: 0.2260
Epoch [19/200], total_loss: 0.2260
Epoch [20/200], total_loss: 0.2260
Epoch [21/200], total_loss: 0.2260
Epoch [22/200], total_loss: 0.2260
Epoch [23/200], total_loss: 0.2259
Epoch [24/200], total_loss: 0.2260
Epoch [25/200], total_loss: 0.2260
Epoch [26/200], total_loss: 0.2259
Epoch [27/200], total_loss: 0.2259
Epoch [28/200], total_loss: 0.2259
Epoch [29/200], total_loss: 0