In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision as tv
from torch.utils.data import DataLoader
import numpy as np
import random
from tqdm import tqdm
from matplotlib import pyplot as plt

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

set_seed(0)

In [2]:
transform = tv.transforms.Compose([
    tv.transforms.ToTensor(),
    tv.transforms.Normalize(mean=[0.5,], std=[0.5,])
])

trainset = tv.datasets.MNIST("MNIST/", train=True, transform=transform, download=True)
validset = tv.datasets.MNIST("MNIST/", train=False, transform=transform, download=True)
dataset = trainset + validset
train_loader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=0, pin_memory=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to MNIST/MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:05<00:00, 1760158.13it/s]


Extracting MNIST/MNIST\raw\train-images-idx3-ubyte.gz to MNIST/MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to MNIST/MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 141395.72it/s]


Extracting MNIST/MNIST\raw\train-labels-idx1-ubyte.gz to MNIST/MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to MNIST/MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:03<00:00, 425842.03it/s] 


Extracting MNIST/MNIST\raw\t10k-images-idx3-ubyte.gz to MNIST/MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to MNIST/MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4543412.54it/s]

Extracting MNIST/MNIST\raw\t10k-labels-idx1-ubyte.gz to MNIST/MNIST\raw






In [3]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.D = nn.Sequential(
            # input is (1) x 28 x 28
            nn.Conv2d(1, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (64) x 14 x 14
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (128) x 7 x 7
            nn.Conv2d(128, 256, 3, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (256) x 4 x 4
            nn.Conv2d(256, 1, 4, 1, 0, bias=False),
            # state size. (1) x 1 x 1
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.D(x)

In [4]:
class Generator(nn.Module):
    def __init__(self, noize_dim):

        super(Generator, self).__init__()

        self.G = nn.Sequential(
            # input is (100) x 1 x 1
            nn.ConvTranspose2d( noize_dim, 256, 4, 1, 0, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # state size. (256) x 4 x 4
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # state size. (128) x 8 x 8
            nn.ConvTranspose2d( 128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # state size. (64) x 16 x 16
            nn.ConvTranspose2d( 64, 1, 4, 2, 3, bias=False),
            # state size. (1) x 28 x 28
            nn.Tanh()
        )

    def forward(self, x):
        return self.G(x)
        
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
noize_dim = 100
G = Generator(noize_dim).to(device)
D = Discriminator().to(device)
criterion = nn.BCELoss()
G_optimizer = optim.Adam(G.parameters(), lr = 1e-3)
D_optimizer = optim.Adam(D.parameters(), lr = 1e-3)

In [None]:
def D_train():
    D_optimizer.zero_grad()
    x_real = x.to(device)
    y_real = torch.ones(x.size(0)).to(device)
    x_real_predict = D(x_real)
    D_real_loss = criterion(x_real_predict.view(-1), y_real)
    D_real_loss.backward()

    noise = torch.tensor(torch.randn(x.size(0), noize_dim, 1, 1)).to(device)
    y_fake = torch.zeros(x.size(0)).to(device)
    x_fake = G(noise)
    x_fake_predict = D(x_fake)
    D_fake_loss = criterion(x_fake_predict.view(-1), y_fake)
    D_fake_loss.backward()

    D_total_loss = D_real_loss + D_fake_loss
    D_optimizer.step()

    return D_total_loss.item()

def G_train():
    G_optimizer.zero_grad()
    noise = torch.tensor(torch.randn(x.size(0), noize_dim, 1, 1)).to(device)
    y_target = torch.ones(x.size(0)).to(device)
    x_fake = G(noise)
    y_fake = D(x_fake)
    G_loss = criterion(y_fake.view(-1), y_target)
    G_loss.backward()
    G_optimizer.step()

    return G_loss.item()

epochs = 1000
early_stopping = 100
stop_cnt = 0
show_loss = True
best_loss = float('inf')
loss_record = {'Discriminator': [], 'Generator': []}

for epoch in range(epochs):
    train_pbar = tqdm(train_loader, position=0, leave=True)
    D_record, G_record = [], []
    for idx, (x, _) in enumerate(train_pbar):
        D_loss = D_train()
        G_loss = G_train()

        D_record.append(D_loss)
        G_record.append(G_loss)
        
        train_pbar.set_description(f'Train Epoch {epoch}')
        train_pbar.set_postfix({'D_loss': f'{D_loss:.3f}', 'G_loss': f'{G_loss:.3f}'})
    
    D_loss = sum(D_record) / len(D_record)
    G_loss = sum(G_record) / len(G_record)

    loss_record['Discriminator'].append(D_loss)
    loss_record['Generator'].append(G_loss)

    if G_loss < best_loss:
        best_loss = G_loss
        torch.save(D.state_dict(), 'D_model.ckpt')
        torch.save(G.state_dict(), 'G_model.ckpt')
        print(f'Saving Model With Loss {best_loss:.5f}')
        stop_cnt = 0
    else:
        stop_cnt += 1

    if stop_cnt == early_stopping:
        output = "Model can't improve, stop training"
        print('-' * (len(output) + 2))
        print(f'|{output}|')
        print('-' * (len(output) + 2))
        break

    print(f'D_Loss: {D_loss:.5f} G_Loss: {G_loss:.5f}', end='| ')
    print(f'Best Loss: {best_loss:.5f}', end='\n\n')

  noise = torch.tensor(torch.randn(x.size(0), noize_dim, 1, 1)).to(device)
  noise = torch.tensor(torch.randn(x.size(0), noize_dim, 1, 1)).to(device)
Train Epoch 0:  19%|█▉        | 105/547 [00:23<01:42,  4.31it/s, D_loss=0.001, G_loss=11.049]

In [None]:
import cv2

G = Generator(noize_dim)
G.load_state_dict(torch.load('G_model.ckpt'))
G.eval().to(device)
noize = torch.tensor(torch.randn(1, noize_dim, 1, 1)).to(device) 
fake = G(noize)
fake = np.array(fake.detach().cpu())
for cnt, img in enumerate(fake):
    npimg = (img/2+0.5)*255        
    npimg = np.transpose(npimg, (1, 2, 0))      
    #cv2.imwrite(f'fake_image/fake_{cnt}.png', npimg.astype('uint8'))
plt.imshow(npimg)