In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import transforms
from torchvision import datasets
from torchvision.utils import save_image

In [3]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda')

In [4]:
transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=0.5, std=0.5)
])

In [5]:
batch_size = 1

In [6]:
mnist_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms, download=True)
data_loader = torch.utils.data.DataLoader(dataset=mnist_dataset, batch_size=batch_size, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw



In [7]:
dataiter = iter(mnist_dataset)
image, _ = next(dataiter)
print(f"Image shape = {image.shape}")
data_length = image.shape[1] * image.shape[2]
print(f"Data Length = {data_length}")

Image shape = torch.Size([1, 28, 28])
Data Length = 784


In [8]:
z_dim = 100 
lr = 0.0002

In [9]:
# https://machinelearningmastery.com/how-to-develop-a-generative-adversarial-network-for-an-mnist-handwritten-digits-from-scratch-in-keras/

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.fc1 = nn.Linear(z_dim, 6272)
        self.leaky_relu1 = nn.LeakyReLU(0.2)
        
        self.convT2 = nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1)
        self.leaky_relu2 = nn.LeakyReLU(0.2)

        self.convT3 = nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1)
        self.leaky_relu3 = nn.LeakyReLU(0.2)
        
        self.conv4 = nn.Conv2d(128, 1, kernel_size=7, stride=1, padding=3)
        self.sigmoid4 = nn.Sigmoid()

    def forward(self, x):
        out = self.fc1(x)
        out = self.leaky_relu1(out)
        out = out.view(batch_size, 128, 7, 7)

        out = self.convT2(out)
        out = self.leaky_relu2(out)

        out = self.convT3(out)
        out = self.leaky_relu3(out)

        out = self.conv4(out)
        out = self.sigmoid4(out)

        return out

In [10]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1)
        self.leaky_relu1 = nn.LeakyReLU(0.2)
        self.dropout1 = nn.Dropout(p=0.4)

        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1)
        self.leaky_relu2 = nn.LeakyReLU(0.2)
        self.dropout2 = nn.Dropout(p=0.4)

        self.fc3 = nn.Linear(64*7*7, 1)

    def forward(self, x):
        out = self.conv1(x)
        out = self.leaky_relu1(out)
        out = self.dropout1(out)
        
        out = self.conv2(out)
        out = self.leaky_relu2(out)
        out = self.dropout2(out)
        out = out.view(batch_size, 7*7*64)

        out = self.fc3(out)

        return out

In [11]:
G = Generator()
D = Discriminator()

G.to(device)
D.to(device)

Discriminator(
  (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (leaky_relu1): LeakyReLU(negative_slope=0.2)
  (dropout1): Dropout(p=0.4, inplace=False)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (leaky_relu2): LeakyReLU(negative_slope=0.2)
  (dropout2): Dropout(p=0.4, inplace=False)
  (fc3): Linear(in_features=3136, out_features=1, bias=True)
)

In [12]:
print(G)

Generator(
  (fc1): Linear(in_features=100, out_features=6272, bias=True)
  (leaky_relu1): LeakyReLU(negative_slope=0.2)
  (convT2): ConvTranspose2d(128, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (leaky_relu2): LeakyReLU(negative_slope=0.2)
  (convT3): ConvTranspose2d(128, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (leaky_relu3): LeakyReLU(negative_slope=0.2)
  (conv4): Conv2d(128, 1, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
  (sigmoid4): Sigmoid()
)


In [13]:
criterion = nn.BCELoss()

G_optimizer = optim.Adam(G.parameters(), lr=lr)
D_optimizer = optim.Adam(D.parameters(), lr=lr)

In [14]:
def train_G():
    G.zero_grad()

    # z = torch.randn(batch_size, z_dim, requires_grad=True).to(device)
    # y = torch.ones(batch_size, 1, requires_grad=True).to(device)
    z = Variable(torch.randn(batch_size, z_dim).to(device))
    y = Variable(torch.ones(batch_size, 1).to(device))

    G_output = G(z)

    D_output = D(G_output)
    G_loss = criterion(D_output, y)
    G_optimizer.step()

    return G_loss.item()

In [15]:
def train_D(x):
    # z = torch.randn(batch_size, z_dim, requires_grad=True).to(device)
    # x_fake = G(z)
    # x_real = x.view(batch_size, data_length).to(device)
    # y_real = torch.ones(batch_size, 1).to(device)
    # y_fake = torch.zeros(batch_size, 0).to(device)

    x_real, y_real = x.view(batch_size, data_length), torch.ones(batch_size, 1)
    x_real, y_real = Variable(x_real.to(device)), Variable(y_real.to(device))
    z = Variable(torch.randn(batch_size, z_dim).to(device))
    x_fake, y_fake = G(z), Variable(torch.zeros(batch_size, 1).to(device))

    D.zero_grad()

    D_real_output = D(x_real.view(batch_size, 1, 28, 28))
    D_real_loss = criterion(D_real_output, y_real)
    
    D_fake_output = D(x_fake)
    D_fake_loss = criterion(D_fake_output, y_fake)
    
    D_total_loss = D_real_loss + D_fake_loss
    D_total_loss.backward()
    D_optimizer.step()
    
    return D_total_loss.item()

In [16]:
filepath = "/content/drive/MyDrive/Colab Notebooks/2022/MNIST GAN/model"

In [17]:
import os.path

if os.path.exists(f"{filepath}/mode/G_model_state_dict.pt"):
    G.load_state_dict(torch.load(f"{filepath}/G_model_state_dict.pt"))
    G_optimizer.load_state_dict(torch.load(f"{filepath}/G_optim_state_dict.pt"))
    D.load_state_dict(torch.load(f"{filepath}/D_model_state_dict.pt"))
    D_optimizer.load_state_Dict(torch.load(f"{filepath}/D_optim_state_dict.pt"))
    print('Load Complete')
else:
    print('Not Loadable')

Not Loadable


In [18]:
torch.version.cuda

'11.6'

In [19]:
n_epoch = 200
best_g_loss = float('inf')
torch.backends.cudnn.enabled = False
for epoch in range(1, n_epoch+1):
    D_losses, G_losses = [], []
    for batch_idx, (x, _) in enumerate(data_loader):
        G_losses.append(train_G())
        D_losses.append(train_D(x))
        
        
    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), n_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))
    if loss_g < best_g_loss:
        best_loss = loss_g
        best_G_model_state = G.state_dict()
        best_G_optim_state = G_optimizer.state_dict()
        torch.save(best_G_model_state, f"{G_modelpath}/G_model_state_dict.pt")
        torch.save(best_G_optim_state, f"{G_modelpath}/G_optim_state_dict.pt")

        torch.save(D.state_dict(), f"{D_modelpath}/D_model_state_dict.pt")
        torch.save(D_optimizer.state_dict(), f"{D_modelpath}/D_optim_state_dict.pt")


RuntimeError: ignored