In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import transforms
from torchvision import datasets
from torchvision.utils import save_image

In [3]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cpu')

In [4]:
transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5), std=(0.5))
])

In [5]:
batch_size = 100

In [6]:
mnist_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms, download=True)
data_loader = torch.utils.data.DataLoader(dataset=mnist_dataset, batch_size=batch_size, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw



In [7]:
dataiter = iter(mnist_dataset)
image, _ = next(dataiter)
print(f"Image shape = {image.shape}")
mnist_length = image.shape[1] * image.shape[2]
print(f"Data Length = {mnist_length}")

Image shape = torch.Size([1, 28, 28])
Data Length = 784


In [8]:
z_dim = 100 
lr = 0.0002

In [9]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()       
        self.fc1 = nn.Linear(z_dim, 256)
        self.fc2 = nn.Linear(256, 512)
        self.fc3 = nn.Linear(512, 1024)
        self.fc4 = nn.Linear(1024, mnist_length)
    
    # forward method
    def forward(self, x): 
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))

In [10]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(mnist_length, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 256)
        self.fc4 = nn.Linear(256, 1)
    
    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return torch.sigmoid(self.fc4(x))

In [11]:
G = Generator()
D = Discriminator()

G.to(device)
D.to(device)

Discriminator(
  (fc1): Linear(in_features=784, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=256, bias=True)
  (fc4): Linear(in_features=256, out_features=1, bias=True)
)

In [12]:
criterion = nn.BCELoss()

G_optimizer = optim.Adam(G.parameters(), lr=lr)
D_optimizer = optim.Adam(D.parameters(), lr=lr)

In [13]:
def G_train():
    G.zero_grad()

    z = torch.randn(batch_size, z_dim)
    label = torch.ones(batch_size, 1)

    z = z.to(device)
    label = label.to(device)

    G_output = G(z)
    D_output = D(G_output)
    
    loss = criterion(D_output, label)
    loss.backward()
    G_optimizer.step()

    return loss.data.item()

In [14]:
def D_train(images):
    D.zero_grad()

    images = images.view(-1, mnist_length)
    labels_real = torch.ones(batch_size, 1)
    labels_real = labels_real.to(device)
    
    D_output = D(images)
    loss_real = criterion(D_output, labels_real)

    z = torch.randn(batch_size, z_dim)
    labels_fake = torch.zeros(batch_size, 1)

    z = z.to(device)
    labels_fake = labels_fake.to(device)

    G_output = G(z)
    D_output = D(G_output)
    loss_fake = criterion(D_output, labels_fake)
    
    loss_total = loss_real + loss_fake
    loss_total.backward()
    D_optimizer.step()
    

    return loss_total.data.item()

In [15]:
filepath = "/content/drive/MyDrive/Colab Notebooks/2022/MNIST GAN"

In [16]:
import os.path

if os.path.exists(f"{filepath}/model/G_model_state_dict.pt"):
    G.load_state_dict(torch.load(f"{filepath}/model/G_model_state_dict.pt"))
    G_optimizer.load_state_dict(torch.load(f"{filepath}/model/G_optim_state_dict.pt"))
    D.load_state_dict(torch.load(f"{filepath}/model/D_model_state_dict.pt"))
    D_optimizer.load_state_dict(torch.load(f"{filepath}/model/D_optim_state_dict.pt"))
    print('Load Complete')
else:
    print('Load Fail')

Load Complete


In [17]:
n_epoch = 200
best_g_loss = float('inf')
for epoch in range(n_epoch):
    D_loss_list, G_loss_list = [], []
    for batch_idx, (images, _) in enumerate(data_loader):
        images = images.to(device)
        D_loss = D_train(images)
        D_loss_list.append(D_loss)

        G_loss = G_train()
        G_loss_list.append(G_loss)

        if batch_idx % 100 == 0:
          print(f"Iter {batch_idx}/600 Complete")
    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
        (epoch), n_epoch, torch.mean(torch.FloatTensor(D_loss_list)), torch.mean(torch.FloatTensor(G_loss_list))))
    loss_g = torch.mean(torch.FloatTensor(G_loss_list))
    if loss_g < best_g_loss:
        best_loss = loss_g
        best_G_model_state = G.state_dict()
        best_G_optim_state = G_optimizer.state_dict()
        torch.save(best_G_model_state, f"{filepath}/model/G_model_state_dict.pt")
        torch.save(best_G_optim_state, f"{filepath}/model/G_optim_state_dict.pt")

        torch.save(D.state_dict(), f"{filepath}/model/D_model_state_dict.pt")
        torch.save(D_optimizer.state_dict(), f"{filepath}/model/D_optim_state_dict.pt")


Iter 0/600 Complete


KeyboardInterrupt: ignored

In [21]:
with torch.no_grad():
    z = torch.randn(batch_size, z_dim)
    z.to(device)
    sample = G(z)
    save_image(sample.view(sample.size(0), 1, 28, 28), f"{filepath}/samples/sample.png")