In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms

In [2]:
torch.__version__

'1.12.0+cu116'

In [3]:
# data init (-1,1)
transform = transforms.Compose([
    transforms.ToTensor(),  # 0-1:channel, high, witch, tensor
    transforms.Normalize(0.5, 0.5)
])


In [4]:
# import inside dataset
train_ds = torchvision.datasets.MNIST('data',
                                      train=True,
                                      transform=transform,
                                      download=True)

In [5]:
dataloader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)

In [6]:
imgs, _ = next(iter(dataloader))

In [7]:
imgs.shape

torch.Size([64, 1, 28, 28])

In [8]:
# Generator
# input : 100-size noise
# output: (1,28,28) imgs
# linear 1: 100----256  input use ReLU
# linear 2: 256----512
# linear 3: 512----28*28
# reshape : 28*28--1,28,28  output use tanh
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 28 * 28),
            nn.Tanh(),
        )

    def forward(self, x):  # x: 100-size noise input.
        img = self.main(x)
        img = img.view(-1, 28, 28)
        return img

In [9]:
# Discriminator
# input (1,28,28)
# output True or False. use sigmoid
# BCEloss calculate
# hin Discriminator recommend use leakyReLU
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.main = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 1, 2),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        self.main(x)
        return x

In [10]:
# init
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

In [11]:
# init module
gen = Generator().to(device)
dis = Discriminator().to(device)

In [12]:
d_optim = torch.optim.Adam(dis.parameters(), lr=0.0001)
g_optim = torch.optim.Adam(gen.parameters(), lr=0.0001)

In [13]:
loss_func = torch.nn.BCELoss()

In [14]:
# PLT
def gen_img_plot(model, test_input):
    prediction = np.squeeze(model(test_input).detach().cpu().numpy())
    fig = plt.figure(figsize=(4, 4))
    for i in range(prediction.size(0)):
        plt.subplot(4, 4, i + 1)
        plt.imshow((prediction[i] + 1) / 2)
        plt.axis('off')
    plt.show()

In [15]:
test_input = torch.randn(16, 100, device=device)

In [16]:
# GAN train
D_loss = []
G_loss = []

for epoch in range(20):
    d_epoch_loss = 0
    g_epoch_loss = 0
    count = len(dataloader)  # return number of train
    for step, (img, _) in enumerate(dataloader):
        img = img.to(device)
        size = img.size(0)
        random_noise = torch.randn(size, 100, device=device)

        d_optim.zero_grad()
        real_output = dis(img)  # input real img into dis
        d_real_loss = loss_func(real_output, torch.ones_like(real_output))

        d_real_loss.backward()

        gen_img = gen(random_noise)
        fake_output = dis(gen_img.detach())  # dis input gen's img
        d_fake_loss = loss_func(fake_output, torch.zeros_like(fake_output))

        d_fake_loss.backward()

        d_loss = d_real_loss + d_fake_loss
        d_optim.step()

        g_optim.zero_grad()
        fake_output = dis(gen_img)
        g_loss = loss_func(fake_output, torch.ones_like(fake_output))
        g_loss.backward()
        g_optim.step()

        with torch.no_grad():
            d_epoch_loss += d_loss
            g_epoch_loss += g_loss
    with torch.no_grad():
        d_epoch_loss /= count
        g_epoch_loss /= count
        D_loss.append(d_epoch_loss)
        G_loss.append(g_epoch_loss)
        print('Epoch:', epoch)
        gen_img_plot(gen, test_input)


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn