In [None]:
import torch
import torch.nn as nn
import time
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader
from torch.autograd import Variable

In [None]:
train_img = datasets.MNIST(root = 'MNIST/', train = True,
                           download = True, transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize(mean = (0.5,), std = (0.5,))
                           ]))
batch_size = 100
train_data = DataLoader(dataset = train_img, batch_size = 100, shuffle = True, drop_last = True)

In [None]:
def one_hot_vector(labels, C = 10):
    one_hot = torch.FloatTensor(labels.size(0), C).zero_().to(device)
    target = one_hot.scatter_(1, labels.unsqueeze(1), 1) # scatter_(dim, index, src)
    target = Variable(target)
    return target

In [None]:
batch_size = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
G = Generator().to(device)
D = Discriminator().to(device)

optim_G = torch.optim.Adam(G.parameters(), lr = 0.001)
optim_D = torch.optim.Adam(D.parameters(), lr = 0.001)
criterion = torch.nn.BCELoss()

In [None]:
# train
total_batch = len(train_data) # 600

for epoch in range(200):
    avg_cost = [0, 0]
    for x, y in train_data:
        x = x.view(x.size(0), -1).to(device)
        x_oh = one_hot_vector(y.to(device), 10)

        z = torch.randn(batch_size, 100, device = device) # noise (random하게 noise 생성)
        z_label = torch.randint(10, (batch_size,), device = device)
        z_oh = one_hot_vector(z_label, 10)
        fake_img = G(z, z_oh)

        real = (torch.FloatTensor(x.size(0), 1).fill_(1.0)).to(device)
        fake = (torch.FloatTensor(x.size(0), 1).fill_(0.0)).to(device)

        # train Generator
        optim_G.zero_grad()
        g_cost = criterion(D(fake_img, z_oh), real) 
        g_cost.backward()
        optim_G.step()

        fake_img = fake_img.detach().to(device)
        # train Discriminator
        optim_D.zero_grad()
        d_cost = criterion(D(torch.cat((x, z_img)), torch.cat((x_oh, z_oh))), torch.cat((real, fake)))
        d_cost.backward()
        optim_D.step()

        avg_cost[0] += g_cost
        avg_cost[1] += d_cost
    avg_cost[0] /= total_batch
    avg_cost[1] /= total_batch

    if (epoch+1) % 10 == 0 or epoch < 10:
        print(f"Epoch : {epoch + 1}, Generator : {avg_cost[0]}, Discriminator : {avg_cost[1]}")
        z = torch.rand(100, 100, device = device)
        label = torch.Tensor(100).fill_(0).long().to(device)
        for i in range(10):
            for j in range(10):
                label[10*i+j] = j
        z_oh = one_hot_vector(label, 10)
        fake_img = G(z, z_oh)
        fake_img = fake_img.reshape([100, 1, 28, 28])
        img_grid = make_grid(fake_img, nrow = 10, normalize = True)
        save_image(img_grid, "/content/drive/MyDrive/Deep Learning/GAN/GAN Result/%d.png"%(epoch+1))