In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import import_ipynb
from model_conditional import Generator, Discriminator, initialize_weights, gradient_penalty

importing Jupyter notebook from model_conditional.ipynb


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 1e-4
BATCH_SIZE = 64
IMG_SIZE = 64
NUM_CLASSES= 10
GEN_EMBEDDING = 100
Z_DIM = 100
CHANNELS_IMG = 1
Z_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64
CRITIC_ITERATIONS = 5
LAMBDA_GP = 10

  return torch._C._cuda_getDeviceCount() > 0


In [3]:
transforms = transforms.Compose(
    [
        transforms.Resize(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)

In [4]:
dataset = datasets.MNIST(root="dataset/", train=True, transform = transforms, download=True)
loader = DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True)
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN, NUM_CLASSES, IMG_SIZE, GEN_EMBEDDING).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC, NUM_CLASSES, IMG_SIZE).to(device)
initialize_weights(gen)
initialize_weights(disc)

In [5]:
opt_gen = optim.Adam(gen.parameters(), lr = LEARNING_RATE, betas = (0.0, 0.9))
opt_disc = optim.Adam(disc.parameters(), lr = LEARNING_RATE, betas = (0.0, 0.9))

fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

gen.train()
disc.train()

Discriminator(
  (disc): Sequential(
    (0): Conv2d(2, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (4): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (5): Conv2d(512, 1, kernel_size=(4, 4), stride=(2, 2))
  )
  (embed): Embedding(10

In [6]:
for epoch in range (NUM_EPOCHS):
    for batch_idx, (real, labels) in enumerate (loader):
        real = real.to(device)
        cur_batch_size = real.shape[0]
        labels = labels.to(device)
        
        
        for _ in range (CRITIC_ITERATIONS):
            noise = torch.randn((BATCH_SIZE, Z_DIM, 1, 1)).to(device)
            fake = gen(noise, labels)
            disc_real = disc(real, labels).reshape(-1)
            disc_fake = disc(fake, labels).reshape(-1)
            gp = gradient_penalty(disc, labels, real, fake, device = device)
            loss_disc = -(torch.mean(disc_real) - torch.mean(disc_fake)) + LAMBDA_GP*gp
            disc.zero_grad()
            loss_disc.backward(retain_graph = True)
            opt_disc.step()
                
        # TRAIN GENERATOR min -E[critic(gen_fake)]
        output = disc(fake, labels).reshape(-1)
        loss_gen = -torch.mean(output)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()
        
        # Print losses occasionaly and print to tensorboard
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \
                Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
            )
            
            with torch.no_grad():
                fake = gen(noise, labels)
                # take out (up to) 32 exapmples
                img_grid_real = torchvision.utils.make_grid(
                    real[:32], normalize = True
                )
                img_grid_fake = torchvision.utils.make_grid(
                    fake[:32], noramlize = True
                )
                
                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)
                
                step += 1

Epoch [0/5] Batch 0/938                 Loss D: -13.8317, loss G: 10.6475


KeyboardInterrupt: 