In [1]:
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from datetime import datetime
import matplotlib.pyplot as plt
import os

  from .autonotebook import tqdm as notebook_tqdm


# **GANs**

## **Data**

In [2]:

transform = transforms.Compose(
    [transforms.ToTensor(), # ToTensor scales the image to [0,1]
    transforms.Normalize( # converts pixels to [-1;1], min_value = (0-0.5)/0.5 = -1, max_value = (1-0.5)/0.5 = 1 for pixels in range(0,1)
        mean=(0.5),
        std=(0.5)
    )]
)

train_data = torchvision.datasets.MNIST(
    root='.',
    train=True,
    download=True,
    transform=transform
)

print(len(train_data))

60000


In [3]:
batch_size = 128
train_loader = torch.utils.data.DataLoader(
    train_data,
    batch_size=batch_size,
    shuffle=True)

for x, y in train_loader:
    print(x.shape, y.shape)
    print(torch.min(x), torch.max(x))
    break 

torch.Size([128, 1, 28, 28]) torch.Size([128])
tensor(-1.) tensor(1.)


## **Discriminator**

In [4]:
D = nn.Sequential(
    nn.Linear(784, 512),
    nn.LeakyReLU(0.2),
    nn.Linear(512, 256),
    nn.LeakyReLU(0.2),
    nn.Linear(256, 1),    
    )

## **Generator**

In [5]:
latent_dim = 100
G =  nn.Sequential(
    nn.Linear(latent_dim, 256),
    nn.LeakyReLU(0.2),
    nn.BatchNorm1d(256, momentum=0.7),

    nn.Linear(256, 512),
    nn.LeakyReLU(0.2),
    nn.BatchNorm1d(512, momentum=0.7),

    nn.Linear(512, 1024),
    nn.LeakyReLU(0.2),
    nn.BatchNorm1d(1024, momentum=0.7),

    nn.Linear(1024, 784),
    nn.Tanh() # keeps output values in the [-1,1] range
)

## **Device, loss, optimizer**

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
print (device)
D.to(device)
G.to(device)

cpu


Sequential(
  (0): Linear(in_features=100, out_features=256, bias=True)
  (1): LeakyReLU(negative_slope=0.2)
  (2): BatchNorm1d(256, eps=1e-05, momentum=0.7, affine=True, track_running_stats=True)
  (3): Linear(in_features=256, out_features=512, bias=True)
  (4): LeakyReLU(negative_slope=0.2)
  (5): BatchNorm1d(512, eps=1e-05, momentum=0.7, affine=True, track_running_stats=True)
  (6): Linear(in_features=512, out_features=1024, bias=True)
  (7): LeakyReLU(negative_slope=0.2)
  (8): BatchNorm1d(1024, eps=1e-05, momentum=0.7, affine=True, track_running_stats=True)
  (9): Linear(in_features=1024, out_features=784, bias=True)
  (10): Tanh()
)

In [7]:
criterion = nn.BCEWithLogitsLoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [8]:
# to save G outputs
if not os.path.exists('gan_images'):
    os.makedirs('gan_images')

In [9]:
def rescale_image(img):
    """Scale the image back to [0;1]"""
    return (img+1)/2

## **Training**

In [12]:
ones_ = torch.ones(batch_size, 1).to(device)
zeros_ = torch.zeros(batch_size, 1).to(device)

d_losses, g_losses = [], []

for epoch in range(200):
    for inputs, _ in train_loader: # targets are not used
        # reshaping in case the batch size doesn't divide evenly 
        # and is not equal to the specified one
        n = inputs.size(0)
        inputs.reshape(n, 784).to(device) # reshape to N*D

        ones= ones_[:n]
        zeros = zeros_[:n]

        # --- train discriminator ---
        # real images
        real_outputs = D(inputs)
        d_loss_real = criterion(real_outputs, ones)

        # fake images
        noise = torch.randn(n, latent_dim).to(device) 
        fake_images = G(noise)
        fake_outputs = D(fake_images)
        d_loss_fake = criterion(fake_outputs, zeros)

        d_loss = np.mean(d_loss_real, d_loss_fake)
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        d_loss.backward()  
        d_optimizer.step()  

        # --- train generator ---
        for _ in range(2): # we train G twice as much as D
            noise = torch.randn(n, latent_dim).to(device) 
            fake_images = G(noise)
            fake_outputs = D(fake_images)
            g_loss = criterion(fake_outputs, ones_) # /!\ this time we reverse the labels

            d_optimizer.zero_grad()
            g_optimizer.zero_grad()
            g_loss.backward()  
            g_optimizer.step() 
            
        d_losses.append(d_loss.item())
        g_losses.append(g_loss.item())

    print(f'Epochs: {epoch+1}, G loss: {g_loss.item()}, D loss: {d_loss.item()}')
    # torchvision.utils.save_image(rescale_image(fake_images), f"gan_images/{epoch+1}.png")

0


RuntimeError: mat1 and mat2 shapes cannot be multiplied (3584x28 and 784x512)