In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import transforms
from torchvision import datasets
from torchvision.utils import save_image

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda')

In [None]:
transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5), std=(0.5))
])

batch = 100

train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms, download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms, download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch, shuffle=False)

In [None]:
class Generator(nn.Module):
    def __init__(self, in_features, out_features):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(in_features, 256)
        self.fc2 = nn.Linear(256, 512)
        self.fc3 = nn.Linear(512, 1024)
        self.fc4 = nn.Linear(1024, out_features)

    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        x = torch.tanh(self.fc4(x))
        return x

class Discriminator(nn.Module):
    def __init__(self, in_features):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(in_features, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 256)
        self.fc4 = nn.Linear(256, 1)

    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        x = torch.sigmoid(self.fc4(x))
        return x

In [None]:
z_features = 100

col, row = train_dataset.data[0].shape
data_features = col * row

G = Generator(in_features=z_features, out_features=data_features)
D = Discriminator(in_features=data_features)

G.to(device)
D.to(device)

Discriminator(
  (fc1): Linear(in_features=784, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=256, bias=True)
  (fc4): Linear(in_features=256, out_features=1, bias=True)
)

In [None]:
criterion = nn.BCELoss()

lr = 0.0002
G_optimizer = optim.Adam(G.parameters(), lr=lr)
D_optimizer = optim.Adam(D.parameters(), lr=lr)

In [None]:
def G_train(x):
    G.zero_grad()

    z = Variable(torch.randn(batch, z_features).to(device))
    y = Variable(torch.ones(batch, 1).to(device))

    G_output = G(z)
    D_output = D(G_output)
    G_loss = criterion(D_output, y)
    
    G_loss.backward()
    G_optimizer.step()

    return G_loss.data.item()

In [None]:
def D_train(x):
    D.zero_grad()

    x_real, y_real = x.view(-1, data_features), torch.ones(batch, 1)
    x_real, y_real = Variable(x_real.to(device)), Variable(y_real.to(device))
    
    D_output = D(x_real)
    D_real_loss = criterion(D_output, y_real)
    D_real_score = D_output

    z = Variable(torch.randn(batch, z_features).to(device))
    x_fake, y_fake = G(z), Variable(torch.zeros(batch, 1).to(device))
    
    D_output = D(x_fake)
    D_fake_loss = criterion(D_output, y_fake)
    D_fake_score = D_output

    D_loss = D_real_loss + D_fake_loss
    D_loss.backward()
    D_optimizer.step()

    return D_loss.data.item()

In [None]:
filepath = "/content/drive/MyDrive/Colab Notebooks/2022/MNIST GAN"

G_modelpath = f"{filepath}/model/G_state"
D_modelpath = f"{filepath}/model/D_state"

In [None]:
G.load_state_dict(torch.load(f"{G_modelpath}/G_model_state_dict.pt"))
G_optimizer.load_state_dict(torch.load(f"{G_modelpath}/G_optim_state_dict.pt"))
D.load_state_dict(torch.load(f"{D_modelpath}/D_model_state_dict.pt"))
D_optimizer.load_state_Dict(torch.load(f"{D_modelpath}/D_optim_state_dict.pt")))

In [None]:
n_epoch = 200
best_g_loss = float('inf')
for epoch in range(1, n_epoch+1):
    D_losses, G_losses = [], []
    for batch_idx, (x, _) in enumerate(train_loader):
        D_losses.append(D_train(x))
        G_losses.append(G_train(x))
    
    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), n_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))
    if loss_g < best_g_loss:
        best_loss = loss_g
        best_G_model_state = G.state_dict()
        best_G_optim_state = G_optimizer.state_dict()
        torch.save(best_G_model_state, f"{G_modelpath}/G_model_state_dict.pt")
        torch.save(best_G_optim_state, f"{G_modelpath}/G_optim_state_dict.pt")

        torch.save(D.state_dict(), f"{D_modelpath}/D_model_state_dict.pt")
        torch.save(D_optimizer.state_dict(), f"{D_modelpath}/D_optim_state_dict.pt")


[1/200]: loss_d: 0.644, loss_g: 4.829
[2/200]: loss_d: 0.026, loss_g: 10.272
[3/200]: loss_d: 0.213, loss_g: 10.108
[4/200]: loss_d: 0.637, loss_g: 4.472
[5/200]: loss_d: 0.417, loss_g: 3.810
[6/200]: loss_d: 0.259, loss_g: 4.424
[7/200]: loss_d: 0.263, loss_g: 4.725
[8/200]: loss_d: 0.343, loss_g: 4.090
[9/200]: loss_d: 0.394, loss_g: 3.455
[10/200]: loss_d: 0.391, loss_g: 3.434
[11/200]: loss_d: 0.394, loss_g: 3.251
[12/200]: loss_d: 0.502, loss_g: 2.827
[13/200]: loss_d: 0.565, loss_g: 2.549
[14/200]: loss_d: 0.545, loss_g: 2.611
[15/200]: loss_d: 0.660, loss_g: 2.220
[16/200]: loss_d: 0.667, loss_g: 2.156
[17/200]: loss_d: 0.710, loss_g: 2.052
[18/200]: loss_d: 0.773, loss_g: 1.942
[19/200]: loss_d: 0.768, loss_g: 1.955
[20/200]: loss_d: 0.785, loss_g: 1.892
[21/200]: loss_d: 0.812, loss_g: 1.815
[22/200]: loss_d: 0.826, loss_g: 1.787
[23/200]: loss_d: 0.834, loss_g: 1.759
[24/200]: loss_d: 0.860, loss_g: 1.733
[25/200]: loss_d: 0.912, loss_g: 1.617
[26/200]: loss_d: 0.888, loss_g:

In [None]:
best_G_model_state = G.state_dict()
best_G_optim_state = G_optimizer.state_dict()
torch.save(best_G_model_state, f"{G_modelpath}/G_model_state_dict.pt")
torch.save(best_G_optim_state, f"{G_modelpath}/G_optim_state_dict.pt")

torch.save(D.state_dict(), f"{D_modelpath}/D_model_state_dict.pt")
torch.save(D_optimizer.state_dict(), f"{D_modelpath}/D_optim_state_dict.pt")

In [None]:
number = 1
with torch.no_grad():
    test_z = Variable(torch.randn(batch, z_features).to(device))
    generated = G(test_z)

    save_image(generated.view(batch, 1, 28, 28), f"{filepath}/samples/sample{number}.png")