In [1]:
import torch
from torch import nn, Tensor
import numpy as np
from torchvision.utils import save_image

device = "cuda" if torch.cuda.is_available() else "cpu"
device

device(type='cuda')

# 1. Dataset

In [2]:
import torchvision

img_size = 32

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((img_size, img_size)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.5],
                                     std=[0.5])
])

images = torchvision.datasets.MNIST(root='./mnist_data', train=True, download=True, transform=transform)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz


3.0%

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz


100.0%


Extracting ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz


100.0%

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz



100.0%
100.0%


Extracting ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw



In [3]:
BATCH_SIZE = 64
dataloader = torch.utils.data.DataLoader(images, batch_size=BATCH_SIZE, shuffle=True)

# 2. Model

In [4]:
channels = 1
img_shape = (channels, img_size, img_size)
latent_dim = 100

In [5]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.init_size = img_size // 4
        self.fc = nn.Linear(latent_dim, 128*self.init_size**2)
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),

            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),

            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, channels, kernel_size=3, stride=1, padding=1),
            
            nn.Tanh(),
        )
        
    def forward(self, z):
        x = self.fc(z)
        x = x.view(x.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(x)
        return img

In [6]:
class Descriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(channels, 16, kernel_size=3, stride=2, padding=1), 
            nn.LeakyReLU(0.2, inplace=True), 
            nn.Dropout2d(0.25),

            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), 
            nn.LeakyReLU(0.2, inplace=True), 
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(32, 0.8),

            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), 
            nn.LeakyReLU(0.2, inplace=True), 
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(64, 0.8),

            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), 
            nn.LeakyReLU(0.2, inplace=True), 
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(128, 0.8),
        )

        ds_size = img_size // 2 ** 4
        self.adv_layer = nn.Sequential(
            nn.Linear(128 * ds_size ** 2, 1), 
            nn.Sigmoid()
        )

        
    def forward(self, img):
        x = self.model(img)
        x = x.view(x.shape[0], -1)
        validity = self.adv_layer(x)
        return validity

In [7]:
generator = Generator()
discriminator = Descriminator()

In [8]:
generator.to(device)

Generator(
  (fc): Linear(in_features=100, out_features=8192, bias=True)
  (conv_blocks): Sequential(
    (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): Upsample(scale_factor=2.0, mode='nearest')
    (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Upsample(scale_factor=2.0, mode='nearest')
    (6): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.2, inplace=True)
    (9): Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (10): Tanh()
  )
)

In [9]:
discriminator.to(device)

Descriminator(
  (model): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Dropout2d(p=0.25, inplace=False)
    (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Dropout2d(p=0.25, inplace=False)
    (6): BatchNorm2d(32, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (7): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): LeakyReLU(negative_slope=0.2, inplace=True)
    (9): Dropout2d(p=0.25, inplace=False)
    (10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (11): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (12): LeakyReLU(negative_slope=0.2, inplace=True)
    (13): Dropout2d(p=0.25, inplace=False)
    (14): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
  )
  (adv_laye

# 3. Training

In [10]:
import os
os.makedirs("images", exist_ok=True)

save_interval = 10

In [11]:
EPOCHS = 200

optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

criterion = nn.BCELoss()

hist = {
        "train_G_loss": [],
        "train_D_loss": [],
    }

for epoch in range(EPOCHS):
    running_G_loss = 0.0
    running_D_loss = 0.0

    for i, (imgs, _) in enumerate(dataloader):

        real_imgs = imgs.to(device)
        valid = torch.ones(imgs.shape[0], 1).to(device)
        fake = torch.zeros(imgs.shape[0], 1).to(device)

        # --- Train Generator --- 
        optimizer_G.zero_grad()
        # Noise input for Generator
        z = Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))).to(device)

        gen_imgs = generator(z)
        G_loss = criterion(discriminator(gen_imgs), valid)
        running_G_loss += G_loss.item()

        G_loss.backward()
        optimizer_G.step()

        # --- Train Descriminator --- 
        optimizer_D.zero_grad()
        real_loss = criterion(discriminator(real_imgs), valid)
        fake_loss = criterion(discriminator(gen_imgs.detach()), fake)
        D_loss = (real_loss + fake_loss) / 2
        running_D_loss += D_loss.item()

        D_loss.backward()
        optimizer_D.step()
    
    epoch_G_loss = running_G_loss / len(dataloader)
    epoch_D_loss = running_D_loss / len(dataloader)
    
    print(f"Epoch [{epoch + 1}/{EPOCHS}], Train G Loss: {epoch_G_loss:.4f}, Train D Loss: {epoch_D_loss:.4f}")

    hist["train_G_loss"].append(epoch_G_loss)
    hist["train_D_loss"].append(epoch_D_loss)

    if epoch % save_interval == 0:
        save_image(gen_imgs.data[:25], f"images/epoch_{epoch}.png", nrow=5, normalize=True)


Epoch [1/200], Train G Loss: 0.7535, Train D Loss: 0.6503
Epoch [2/200], Train G Loss: 0.9038, Train D Loss: 0.5882
Epoch [3/200], Train G Loss: 0.9960, Train D Loss: 0.5532
Epoch [4/200], Train G Loss: 1.0860, Train D Loss: 0.5264
Epoch [5/200], Train G Loss: 1.1196, Train D Loss: 0.5185
Epoch [6/200], Train G Loss: 1.1534, Train D Loss: 0.5080
Epoch [7/200], Train G Loss: 1.2054, Train D Loss: 0.4951
Epoch [8/200], Train G Loss: 1.2627, Train D Loss: 0.4848
Epoch [9/200], Train G Loss: 1.2725, Train D Loss: 0.4715
Epoch [10/200], Train G Loss: 1.3162, Train D Loss: 0.4731
Epoch [11/200], Train G Loss: 1.3545, Train D Loss: 0.4612
Epoch [12/200], Train G Loss: 1.3612, Train D Loss: 0.4628
Epoch [13/200], Train G Loss: 1.3908, Train D Loss: 0.4605
Epoch [14/200], Train G Loss: 1.4238, Train D Loss: 0.4378
Epoch [15/200], Train G Loss: 1.4731, Train D Loss: 0.4335
Epoch [16/200], Train G Loss: 1.4985, Train D Loss: 0.4264
Epoch [17/200], Train G Loss: 1.5483, Train D Loss: 0.4145
Epoch 