In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchinfo import summary
from torch.utils.tensorboard import SummaryWriter


In [2]:
writer = SummaryWriter('./logs')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 32

transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: (x[0] / 255.0, x[1]))])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=False, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=False, transform=transform)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=32, shuffle=True)

z_dim = 100
mnist_dim = train_dataset.train_data.size(1) * train_dataset.train_data.size(2)



In [3]:
class Generator(nn.Module):

    def __init__(self, g_input_dim, g_output_dim) -> None:
        super().__init__()
        self.fc1 = nn.Linear(g_input_dim, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features * 2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features * 2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)

    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)

        return torch.sigmoid(self.fc4(x))


class Discriminator(nn.Module):

    def __init__(self, d_input_dim) -> None:
        super().__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features // 2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features // 2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)

    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)

        return torch.sigmoid(self.fc4(x))    

In [4]:
G = Generator(g_input_dim=z_dim, g_output_dim=mnist_dim).to(device)
D = Discriminator(d_input_dim=mnist_dim).to(device)
criterion = nn.BCELoss()

writer.add_graph(G, input_to_model=torch.randn(batch_size, z_dim))
writer.add_graph(D, input_to_model=torch.randn(mnist_dim))

# optimizer
lr = 0.0002
g_optimizer = torch.optim.Adam(G.parameters(), lr = lr)
d_optimizer = torch.optim.Adam(D.parameters(), lr = lr)

	%input.5 : Float(32, 256, strides=[256, 1], requires_grad=1, device=cpu) = aten::dropout(%input.3, %27, %28) # d:\development\miniconda\lib\site-packages\torch\nn\functional.py:1252:0
	%input.11 : Float(32, 512, strides=[512, 1], requires_grad=1, device=cpu) = aten::dropout(%input.9, %33, %34) # d:\development\miniconda\lib\site-packages\torch\nn\functional.py:1252:0
	%input : Float(32, 1024, strides=[1024, 1], requires_grad=1, device=cpu) = aten::dropout(%input.15, %39, %40) # d:\development\miniconda\lib\site-packages\torch\nn\functional.py:1252:0
This may cause errors in trace checking. To disable trace checking, pass check_trace=False to torch.jit.trace()
  _check_trace(
Tensor-likes are not close!

Mismatched elements: 25080 / 25088 (100.0%)
Greatest absolute difference: 0.08442860841751099 at index (20, 333) (up to 1e-05 allowed)
Greatest relative difference: 0.18322071385465546 at index (0, 17) (up to 1e-05 allowed)
  _check_trace(
	%input.5 : Float(1024, strides=[1], requires_

In [5]:
def d_train(x):
    D.zero_grad()

    x_real, y_real = x.view(-1, mnist_dim).to(device), torch.ones(batch_size, 1).to(device)

    d_output = D(x_real)
    d_real_loss = criterion(d_output, y_real)
    d_real_score = d_output

    z = torch.randn(batch_size, z_dim).to(device)
    x_fake, y_fake = G(z), torch.zeros(batch_size, 1).to(device)

    d_output = D(x_fake)
    d_fake_loss = criterion(d_output, y_fake)
    d_fake_score = d_output

    d_loss = d_real_loss + d_fake_loss
    d_loss.backward()
    d_optimizer.step()

    return d_loss.item()


def g_train(x):
    G.zero_grad()
    z = torch.randn(batch_size, z_dim).to(device)
    y = torch.randn(batch_size, 1).to(device)

    g_output = G(z)
    d_output =  D(g_output)
    g_loss = criterion(d_output, y)

    g_loss.backward()
    g_optimizer.step()

    return g_loss.item()

In [6]:
epochs = 10
step = 0
for epoch in range(epochs):
    d_losses, g_losses = [], []
    for batch_idx, (x, _) in enumerate(train_loader):
        step += 1
        d_losses.append(d_train(x))
        g_losses.append(g_train(x))
        print('[%d/%d]: [%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
        epoch, epochs,batch_idx, len(train_loader), torch.mean(torch.FloatTensor(d_losses)), torch.mean(torch.FloatTensor(g_losses))))
        writer.add_scalar('g_loss', torch.mean(torch.FloatTensor(g_losses)), step)
        writer.add_scalar('d_loss', torch.mean(torch.FloatTensor(d_losses)), step)
        if batch_idx % 10 == 0:
            with torch.no_grad():
                test_z = torch.randn(batch_size, z_dim).to(device)
                generated = G(test_z)
                img = img = torchvision.utils.make_grid(generated.view(generated.size(0), 1, 28, 28))
                writer.add_image(f'mnist_{epoch}_{batch_idx}', img, global_step=step)

writer.close()                

[0/10]: [0/1875]: loss_d: 1.386, loss_g: 0.652
[0/10]: [1/1875]: loss_d: 1.352, loss_g: 0.570
[0/10]: [2/1875]: loss_d: 1.319, loss_g: 0.561
[0/10]: [3/1875]: loss_d: 1.286, loss_g: 0.563
[0/10]: [4/1875]: loss_d: 1.253, loss_g: 0.546
[0/10]: [5/1875]: loss_d: 1.215, loss_g: 0.509
[0/10]: [6/1875]: loss_d: 1.178, loss_g: 0.548
[0/10]: [7/1875]: loss_d: 1.143, loss_g: 0.500
[0/10]: [8/1875]: loss_d: 1.105, loss_g: 0.487
[0/10]: [9/1875]: loss_d: 1.066, loss_g: 0.432
[0/10]: [10/1875]: loss_d: 1.026, loss_g: 0.505
[0/10]: [11/1875]: loss_d: 0.986, loss_g: 0.472
[0/10]: [12/1875]: loss_d: 0.945, loss_g: 0.480
[0/10]: [13/1875]: loss_d: 0.903, loss_g: 0.445
[0/10]: [14/1875]: loss_d: 0.863, loss_g: 0.341
[0/10]: [15/1875]: loss_d: 0.825, loss_g: 0.319
[0/10]: [16/1875]: loss_d: 0.789, loss_g: 0.271
[0/10]: [17/1875]: loss_d: 0.752, loss_g: 0.328
[0/10]: [18/1875]: loss_d: 0.718, loss_g: 0.292
[0/10]: [19/1875]: loss_d: 0.686, loss_g: 0.261
[0/10]: [20/1875]: loss_d: 0.657, loss_g: 0.240
[0