# Global import

In [1]:
!pip install torchsummary

Collecting torchsummary
  Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1
[0m

In [2]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torchsummary import summary

# Global config

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

input_dim = 100
batch_size = 128
epochs = 50
g_model_path = 'g_model.pth'
d_model_path = 'd_model.pth'

lr = 0.0002

# Prepare for dataset

In [4]:
train_dataset = datasets.MNIST(root="./data/", train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



# Model
using DCGAN

## Generator

In [5]:
class Generator(nn.Module):
    def __init__(self, input_dim):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(input_dim, 32 * 32)
        self.br1 = nn.Sequential(
            nn.BatchNorm1d(1024),
            nn.ReLU()
        )
        self.fc2 = nn.Linear(32 * 32, 128 * 7 * 7)
        self.br2 = nn.Sequential(
            nn.BatchNorm1d(128 * 7 * 7),
            nn.ReLU()
        )
        self.conv1 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.conv2 = nn.Sequential(
            nn.ConvTranspose2d(64, 1, 4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.br1(self.fc1(x))
        x = self.br2(self.fc2(x))
        x = x.reshape(-1, 128, 7, 7)
        x = self.conv1(x)
        output = self.conv2(x)
        return output
    
G = Generator(input_dim)
G.to(device)
summary(G, (100,))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                 [-1, 1024]         103,424
       BatchNorm1d-2                 [-1, 1024]           2,048
              ReLU-3                 [-1, 1024]               0
            Linear-4                 [-1, 6272]       6,428,800
       BatchNorm1d-5                 [-1, 6272]          12,544
              ReLU-6                 [-1, 6272]               0
   ConvTranspose2d-7           [-1, 64, 14, 14]         131,136
       BatchNorm2d-8           [-1, 64, 14, 14]             128
              ReLU-9           [-1, 64, 14, 14]               0
  ConvTranspose2d-10            [-1, 1, 28, 28]           1,025
          Sigmoid-11            [-1, 1, 28, 28]               0
Total params: 6,679,105
Trainable params: 6,679,105
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forw

## Discriminator

In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, 5, stride=1),
            nn.LeakyReLU(0.2)
        )
        self.pl1 = nn.MaxPool2d(2, stride=2)
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 5, stride=1),
            nn.LeakyReLU(0.2)
        )
        self.pl2 = nn.MaxPool2d(2, stride=2)
        self.fc1 = nn.Sequential(
            nn.Linear(64 * 4 * 4, 1024),
            nn.LeakyReLU(0.2)
        )
        self.fc2 = nn.Sequential(
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.pl1(x)
        x = self.conv2(x)
        x = self.pl2(x)
        x = x.view(x.shape[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        output = x.squeeze(1)
        return output
    
D = Discriminator()
D.to(device)
summary(D, (1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 24, 24]             832
         LeakyReLU-2           [-1, 32, 24, 24]               0
         MaxPool2d-3           [-1, 32, 12, 12]               0
            Conv2d-4             [-1, 64, 8, 8]          51,264
         LeakyReLU-5             [-1, 64, 8, 8]               0
         MaxPool2d-6             [-1, 64, 4, 4]               0
            Linear-7                 [-1, 1024]       1,049,600
         LeakyReLU-8                 [-1, 1024]               0
            Linear-9                    [-1, 1]           1,025
          Sigmoid-10                    [-1, 1]               0
Total params: 1,102,721
Trainable params: 1,102,721
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.40
Params size (MB): 4.21
Estimat

# Trainning
using **torch.ones_like** && **torch.zeros_like**

In [7]:
optim_G = torch.optim.Adam(G.parameters(), lr=lr)
optim_D = torch.optim.Adam(D.parameters(), lr=lr)
criterion = nn.BCELoss()

best_loss = float('inf')

for epoch in range(epochs):
    # Train Discriminator
    tot_loss_G = 0
    tot_loss_D = 0
    for i, (x, _) in enumerate(train_loader):
        optim_D.zero_grad()
        real_data = x.to(device)
        real_pred = D(real_data)
        loss_real = criterion(real_pred, torch.ones_like(real_pred).to(device))

        fake_data = G(torch.randn([batch_size, input_dim]).to(device))
        fake_pred = D(fake_data)
        loss_fake = criterion(fake_pred, torch.zeros_like(fake_pred).to(device))

        loss_D = loss_real + loss_fake
        tot_loss_D += loss_D.item()
        
        loss_D.backward()
        optim_D.step()
            
        # Train Generator
        optim_G.zero_grad()
            
        fake_x = G(torch.randn([batch_size, input_dim]).to(device))
        fake_outputs = D(fake_x)
        loss_G = criterion(fake_outputs, torch.ones_like(fake_outputs).to(device))
        tot_loss_G += loss_G.item()
        
        loss_G.backward()
        optim_G.step()
        
        if (i + 1) % 50 == 0:
            print("epoch = {}, batch_round = {}/{}, loss_G = {}, loss_D = {}".format(epoch, i, len(train_loader), tot_loss_G / i, tot_loss_D / i))
            
    if tot_loss_G / len(train_loader) < best_loss:
        print("update best loss:", tot_loss_G / len(train_loader))
        best_loss = loss_G.item()
        torch.save(G.state_dict(), g_model_path)
        torch.save(D.state_dict(), d_model_path)
    
    x = torch.randn(64, input_dim).to(device)
    img = G(x)
    save_image(img, 'epoch_%d.png' % epoch)
        

epoch = 0, batch_round = 49/469, loss_G = 3.6696824017836125, loss_D = 0.33946854944283866
epoch = 0, batch_round = 99/469, loss_G = 4.78378872979771, loss_D = 0.18010431461299609
epoch = 0, batch_round = 149/469, loss_G = 3.943889849537971, loss_D = 0.2906956854245107
epoch = 0, batch_round = 199/469, loss_G = 3.6141315755532615, loss_D = 0.31287916226022355
epoch = 0, batch_round = 249/469, loss_G = 3.407738257364097, loss_D = 0.33088770751508484
epoch = 0, batch_round = 299/469, loss_G = 3.246359002032009, loss_D = 0.35735973170382224
epoch = 0, batch_round = 349/469, loss_G = 3.060054861201256, loss_D = 0.4193537076668317
epoch = 0, batch_round = 399/469, loss_G = 2.90685639748896, loss_D = 0.46183237965851276
epoch = 0, batch_round = 449/469, loss_G = 2.7873918913255022, loss_D = 0.48914301808807326
update best loss: 2.7415952734601525
epoch = 1, batch_round = 49/469, loss_G = 1.768719948067957, loss_D = 0.8024258564929573
epoch = 1, batch_round = 99/469, loss_G = 1.71386056897616