<a href="https://colab.research.google.com/github/MathBorgess/into_pytorch/blob/master/generative_models/gans_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch import nn, optim
import torchvision
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [12]:
class Discriminator(nn.Module):
    def __init__(self, input_dim, dense_units):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(input_dim, dense_units),
            nn.LeakyReLU(0.1), # a default choice for GANs
            nn.Linear(dense_units, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.disc(x)

In [17]:
class Generator(nn.Module):
    def __init__(self, latent_dim, img_dim, dense_units):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(latent_dim, dense_units),
            nn.LeakyReLU(0,1),
            nn.Linear(dense_units, img_dim),
            nn.Tanh()
        )
    def forward(self, x):
        return self.gen(x)

In [15]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr = 1e-4
latent_dim = 8*8
image_dim = 28 * 28 * 1
dense_units = 252
batch_size = 32
epochs = 50

In [18]:
disc = Discriminator(image_dim, dense_units).to(device)
gen = Generator(latent_dim, image_dim, dense_units).to(device)

fixed_noise = torch.randn((batch_size, latent_dim)).to(device)
transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

dataset = datasets.MNIST(root='dataset/', transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

disc_optimizer = optim.Adam(disc.parameters(), lr=lr)
gen_optimizer = optim.Adam(gen.parameters(), lr=lr)

criterion = nn.BCELoss()

writer_fake = SummaryWriter(f'runs/GAN/fake')
writer_real = SummaryWriter(f'runs/GAN/real')
step = 0

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataset/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 74426547.21it/s]


Extracting dataset/MNIST/raw/train-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 36268171.80it/s]


Extracting dataset/MNIST/raw/train-labels-idx1-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 66375140.57it/s]

Extracting dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 18023206.02it/s]


Extracting dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw



In [19]:
for epoch in range(epochs):
    for batch_idx, (x, _) in enumerate(loader):
        x = x.to(device)
        x = x.view(-1, image_dim)

        # Disc training: max log(D(real)) + log(1 - D(G(x)))
        # BCELoss = -w_n [y_n log(x_n) + (1-y_n) log(1 - x_n)]
        noise = torch.randn(batch_size, latent_dim).to(device)
        fake = gen(noise)
        disc_real = disc(x).view(-1)

        # only y_n log(x_n)
        loss_real = criterion(disc_real, torch.ones_like(disc_real))

        disc_fake = disc(fake.detach()).view(-1)
        # only (1 - y_n)log(1 - x_n)
        loss_fake = criterion(disc_fake, torch.zeros_like(disc_fake))

        lossD = (loss_fake + loss_real) / 2
        disc.zero_grad()
        # if the retain is not true, the nodes of the variable graph will be
        # removed from to save memory, so we wouldn't be able to use the fake again
        lossD.backward(retain_graph=True)
        disc_optimizer.step()

        # Train Generator: min log(1-D(G(z))) <-> max log(D(G(z)))
        output_disc = disc(fake).view(-1).to(device)
        loss_gen = criterion(output_disc, torch.ones_like(output_disc))
        gen.zero_grad()
        loss_gen.backward()
        gen_optimizer.step()


        with torch.no_grad():
            if batch_idx == 0:
                print(f' epoch {epoch}/{epochs} gen_loss: {loss_gen} disc_loss: {lossD}')

                fake_grid = torchvision.utils.make_grid(gen(noise).reshape(-1, 1, 28, 28), normalize=True)
                real_grid = torchvision.utils.make_grid(x.reshape(-1, 1, 28, 28), normalize=True)


                writer_fake.add_image(f'fake_',fake_grid, global_step=step)
                writer_real.add_image(f'real_', real_grid, global_step=step)

                step += 1

 epoch 0/50 gen_loss: 0.6919145584106445 disc_loss: 0.6896236538887024
 epoch 1/50 gen_loss: 1.248140811920166 disc_loss: 0.38303080201148987
 epoch 2/50 gen_loss: 0.927990734577179 disc_loss: 0.6187950372695923
 epoch 3/50 gen_loss: 1.0287690162658691 disc_loss: 0.44701051712036133
 epoch 4/50 gen_loss: 0.9006974697113037 disc_loss: 0.5167999863624573
 epoch 5/50 gen_loss: 1.4555137157440186 disc_loss: 0.35877102613449097
 epoch 6/50 gen_loss: 1.2283070087432861 disc_loss: 0.36854785680770874
 epoch 7/50 gen_loss: 1.281175136566162 disc_loss: 0.49029773473739624
 epoch 8/50 gen_loss: 1.6893389225006104 disc_loss: 0.2908696234226227
 epoch 9/50 gen_loss: 1.689479947090149 disc_loss: 0.39602190256118774
 epoch 10/50 gen_loss: 1.3565142154693604 disc_loss: 0.4471641182899475
 epoch 11/50 gen_loss: 1.4847629070281982 disc_loss: 0.47024962306022644
 epoch 12/50 gen_loss: 3.22965669631958 disc_loss: 0.1002495288848877
 epoch 13/50 gen_loss: 2.212141990661621 disc_loss: 0.22362785041332245
 

KeyboardInterrupt: ignored

In [20]:
%tensorboad --logdir runs/GAN

UsageError: Line magic function `%tensorboad` not found.
