# Global import

In [None]:
!pip install torchsummary

In [None]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torchsummary import summary

# Global config

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

input_dim = 100
batch_size = 128
epochs = 50
g_model_path = 'g_model.pth'
d_model_path = 'd_model.pth'

lr = 0.0002

# Prepare for dataset

In [None]:
train_dataset = datasets.MNIST(root="./data/", train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)

# Model
using DCGAN

## Generator

In [None]:
class Generator(nn.Module):
    def __init__(self, input_dim):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(input_dim, 32 * 32)
        self.br1 = nn.Sequential(
            nn.BatchNorm1d(1024),
            nn.ReLU()
        )
        self.fc2 = nn.Linear(32 * 32, 128 * 7 * 7)
        self.br2 = nn.Sequential(
            nn.BatchNorm1d(128 * 7 * 7),
            nn.ReLU()
        )
        self.conv1 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.conv2 = nn.Sequential(
            nn.ConvTranspose2d(64, 1, 4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.br1(self.fc1(x))
        x = self.br2(self.fc2(x))
        x = x.reshape(-1, 128, 7, 7)
        x = self.conv1(x)
        output = self.conv2(x)
        return output
    
G = Generator(input_dim)
G.to(device)
summary(G, (100,))

## Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, 5, stride=1),
            nn.LeakyReLU(0.2)
        )
        self.pl1 = nn.MaxPool2d(2, stride=2)
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 5, stride=1),
            nn.LeakyReLU(0.2)
        )
        self.pl2 = nn.MaxPool2d(2, stride=2)
        self.fc1 = nn.Sequential(
            nn.Linear(64 * 4 * 4, 1024),
            nn.LeakyReLU(0.2)
        )
        self.fc2 = nn.Sequential(
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.pl1(x)
        x = self.conv2(x)
        x = self.pl2(x)
        x = x.view(x.shape[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        output = x.squeeze(1)
        return output
    
D = Discriminator()
D.to(device)
summary(D, (1, 28, 28))

# Trainning
using **torch.ones_like** && **torch.zeros_like**

In [None]:
optim_G = torch.optim.Adam(G.parameters(), lr=lr)
optim_D = torch.optim.Adam(D.parameters(), lr=lr)
criterion = nn.BCELoss()

for epoch in range(epochs):
    # Train Discriminator
    tot_loss_G = 0
    tot_loss_D = 0
    for i, (x, _) in enumerate(train_loader):
        optim_D.zero_grad()
        real_data = x.to(device)
        real_pred = D(real_data)
        loss_real = criterion(real_pred, torch.ones_like(real_pred).to(device))

        fake_data = G(torch.randn([batch_size, input_dim]).to(device))
        fake_pred = D(fake_data)
        loss_fake = criterion(fake_pred, torch.zeros_like(fake_pred).to(device))

        loss_D = loss_real + loss_fake
        tot_loss_D += loss_D.item()

        loss_D.backward()
        optim_D.step()

        # Train Generator
        optim_G.zero_grad()

        fake_x = G(torch.randn([batch_size, input_dim]).to(device))
        fake_outputs = D(fake_x)
        loss_G = criterion(fake_outputs, torch.ones_like(fake_outputs).to(device))
        tot_loss_G += loss_G.item()

        loss_G.backward()
        optim_G.step()

        if (i + 1) % 50 == 0:
            print("epoch = {}, batch_round = {}/{}, loss_G = {}, loss_D = {}".format(epoch, i, len(train_loader), tot_loss_G / i, tot_loss_D / i))


    torch.save(G.state_dict(), 'g_ckpt_epoch={}.pth'.format(epoch))
    torch.save(D.state_dict(), 'd_ckpt_epoch={}.pth'.format(epoch))

    x = torch.randn(64, input_dim).to(device)
    img = G(x)
    save_image(img, 'epoch_%d.png' % epoch)

# Testing

In [None]:
G.load_state_dict(torch.load('/kaggle/input/ganmnist/model (1).pth', map_location=device))
G.eval()
samples = 10
num = 10
for i in range(num):
    random_a = torch.randn(100).to(device)
    random_b = torch.randn(100).to(device)
    step = ((random_b - random_a) / (samples - 1)).to(device)
    array = [random_a + step * i for i in range(samples)]
    x = torch.stack(array).to(device)
    img = G(x)
    save_image(img, 'test_%d.png' % i, nrow=samples)