In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch import optim
import torchvision
from torchvision import transforms
from torchinfo import summary
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt


In [5]:
writer = SummaryWriter('./logs')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 32

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.5, ), std=(0.5, ))])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=False, transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

In [6]:
class Generator(nn.Module):
    def __init__(self, g_input_dim, g_output_dim):
        super(Generator, self).__init__()       
        self.fc1 = nn.Linear(g_input_dim, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)
    
    # forward method
    def forward(self, x): 
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))
    
class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)
    
    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return torch.sigmoid(self.fc4(x))

In [7]:
z_dim = 100
mnist_dim = train_dataset.train_data.size(1) * train_dataset.train_data.size(2)

# build network
G = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device)
D = Discriminator(mnist_dim).to(device)

writer.add_graph(G, input_to_model=torch.randn(batch_size, z_dim))
writer.add_graph(D, input_to_model=torch.randn(batch_size, mnist_dim))

# optimizer
lr = 0.0002
g_optimizer = optim.Adam(G.parameters(), lr = lr)
d_optimizer = optim.Adam(D.parameters(), lr = lr)

# loss
criterion = nn.BCELoss() 

	%input.5 : Float(32, 1024, strides=[1024, 1], requires_grad=1, device=cpu) = aten::dropout(%input.3, %27, %28) # d:\development\miniconda\lib\site-packages\torch\nn\functional.py:1252:0
	%input.11 : Float(32, 512, strides=[512, 1], requires_grad=1, device=cpu) = aten::dropout(%input.9, %33, %34) # d:\development\miniconda\lib\site-packages\torch\nn\functional.py:1252:0
	%input : Float(32, 256, strides=[256, 1], requires_grad=1, device=cpu) = aten::dropout(%input.15, %39, %40) # d:\development\miniconda\lib\site-packages\torch\nn\functional.py:1252:0
This may cause errors in trace checking. To disable trace checking, pass check_trace=False to torch.jit.trace()
  _check_trace(
Tensor-likes are not close!

Mismatched elements: 32 / 32 (100.0%)
Greatest absolute difference: 0.06438559293746948 at index (23, 0) (up to 1e-05 allowed)
Greatest relative difference: 0.1335170647211262 at index (23, 0) (up to 1e-05 allowed)
  _check_trace(


In [8]:
summary(G, input_size=(batch_size, z_dim))

Layer (type:depth-idx)                   Output Shape              Param #
Generator                                [32, 784]                 --
├─Linear: 1-1                            [32, 256]                 25,856
├─Linear: 1-2                            [32, 512]                 131,584
├─Linear: 1-3                            [32, 1024]                525,312
├─Linear: 1-4                            [32, 784]                 803,600
Total params: 1,486,352
Trainable params: 1,486,352
Non-trainable params: 0
Total mult-adds (M): 47.56
Input size (MB): 0.01
Forward/backward pass size (MB): 0.66
Params size (MB): 5.95
Estimated Total Size (MB): 6.62

In [9]:
summary(D, input_size=(batch_size, mnist_dim))

Layer (type:depth-idx)                   Output Shape              Param #
Discriminator                            [32, 1]                   --
├─Linear: 1-1                            [32, 1024]                803,840
├─Linear: 1-2                            [32, 512]                 524,800
├─Linear: 1-3                            [32, 256]                 131,328
├─Linear: 1-4                            [32, 1]                   257
Total params: 1,460,225
Trainable params: 1,460,225
Non-trainable params: 0
Total mult-adds (M): 46.73
Input size (MB): 0.10
Forward/backward pass size (MB): 0.46
Params size (MB): 5.84
Estimated Total Size (MB): 6.40

In [10]:
def d_train(x):
    D.zero_grad()

    x_real, y_real = x.view(-1, mnist_dim).to(device), torch.ones(batch_size, 1).to(device)

    print(x_real.shape, y_real.shape)
    d_output = D(x_real)

    print(d_output.shape, y_real.shape)
    d_real_loss = criterion(d_output, y_real)
    d_real_score = d_output

    z = torch.randn(batch_size, z_dim).to(device)
    x_fake, y_fake = G(z), torch.zeros(batch_size, 1).to(device)

    d_output = D(x_fake)
    d_fake_loss = criterion(d_output, y_fake)
    d_fake_score = d_output

    d_loss = d_real_loss + d_fake_loss
    d_loss.backward()
    d_optimizer.step()

    return d_loss.item()


def g_train(x):
    G.zero_grad()
    z = torch.randn(batch_size, z_dim).to(device)
    y = torch.ones(batch_size, 1).to(device)

    g_output = G(z)
    d_output =  D(g_output)
    g_loss = criterion(d_output, y)

    g_loss.backward()
    g_optimizer.step()

    return g_loss.item()

In [11]:
epochs = 10
step = 0
for epoch in range(epochs):
    d_losses, g_losses = [], []
    for batch_idx, (x, _) in enumerate(train_loader):
        step += 1
        d_losses.append(d_train(x))
        g_losses.append(g_train(x))
        print('[%d/%d]: [%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
        epoch, epochs,batch_idx, len(train_loader), torch.mean(torch.FloatTensor(d_losses)), torch.mean(torch.FloatTensor(g_losses))))
        writer.add_scalar('g_loss', torch.mean(torch.FloatTensor(g_losses)), step)
        writer.add_scalar('d_loss', torch.mean(torch.FloatTensor(d_losses)), step)
        if batch_idx % 10 == 0:
            with torch.no_grad():
                test_z = torch.randn(batch_size, z_dim).to(device)
                generated = G(test_z)
                img = img = torchvision.utils.make_grid(generated.view(generated.size(0), 1, 28, 28))
                writer.add_image(f'mnist_{epoch}_{batch_idx}', img, global_step=step)
    
    if epoch % 10 == 0:
        D.eval()
        G.eval()
        torch.save({
        'epoch': epoch,
        'd_model_state_dict': D.state_dict(),
        'g_model_state_dict': G.state_dict(),
        'd_optimizer_state_dict': d_optimizer.state_dict(),
        'd_loss': d_losses,
        'g_optimizer_state_dict': g_optimizer.state_dict(),
        'g_loss': g_losses,
        }, f'./checkpoint/epoch{epoch}_weight.pth')
        D.train()
        G.train()

writer.close()                

torch.Size([32, 784]) torch.Size([32, 1])
torch.Size([32, 1]) torch.Size([32, 1])
[0/10]: [0/1875]: loss_d: 1.391, loss_g: 0.691
torch.Size([32, 784]) torch.Size([32, 1])
torch.Size([32, 1]) torch.Size([32, 1])
[0/10]: [1/1875]: loss_d: 1.325, loss_g: 0.688
torch.Size([32, 784]) torch.Size([32, 1])
torch.Size([32, 1]) torch.Size([32, 1])
[0/10]: [2/1875]: loss_d: 1.265, loss_g: 0.686
torch.Size([32, 784]) torch.Size([32, 1])
torch.Size([32, 1]) torch.Size([32, 1])
[0/10]: [3/1875]: loss_d: 1.209, loss_g: 0.682
torch.Size([32, 784]) torch.Size([32, 1])
torch.Size([32, 1]) torch.Size([32, 1])
[0/10]: [4/1875]: loss_d: 1.157, loss_g: 0.677
torch.Size([32, 784]) torch.Size([32, 1])
torch.Size([32, 1]) torch.Size([32, 1])
[0/10]: [5/1875]: loss_d: 1.112, loss_g: 0.670
torch.Size([32, 784]) torch.Size([32, 1])
torch.Size([32, 1]) torch.Size([32, 1])
[0/10]: [6/1875]: loss_d: 1.076, loss_g: 0.661
torch.Size([32, 784]) torch.Size([32, 1])
torch.Size([32, 1]) torch.Size([32, 1])
[0/10]: [7/1875

In [None]:
torch.save(D, './model/discriminator.pt')
torch.save(G, './model/generator.pt')