In [2]:
from __future__ import print_function
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import sys
import argparse
import random
import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.data import DataLoader
import os

In [15]:
LEARNING_RATE, BATCH_SIZE, IMAGE_SIZE, EPOCHS, noise_channels, gen_features, disc_features, image_channels = 0.0005 , 64, 64, 300, 256, 64, 64, 3 

In [16]:
data_transforms = transforms.Compose([
        transforms.Resize((64,64)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)),
])


In [22]:
dataset = datasets.ImageFolder(root="archive", transform=data_transforms)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [23]:
class Generator(nn.Module):
    def __init__(self, noise_channels, image_channels, features):
        super(Generator, self).__init__()
        

        self.model = nn.Sequential(
            
            # Transpose block 1
            nn.ConvTranspose2d(noise_channels, features*16, kernel_size=4, stride=1, padding=0),
            nn.ReLU(),

            # Transpose block 2
            nn.ConvTranspose2d(features*16, features*8, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features*8),
            nn.ReLU(),

            # Transpose block 3
            nn.ConvTranspose2d(features*8, features*4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features*4),
            nn.ReLU(),

            # Transpose block 4
            nn.ConvTranspose2d(features*4, features*2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features*2),
            nn.ReLU(),

            # Last transpose block (different)
            nn.ConvTranspose2d(features*2, image_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )
    
    def forward(self, x):
        return self.model(x)

In [24]:
class Discriminator(nn.Module):
    def __init__(self, image_channels, features):
        super(Discriminator, self).__init__()
        

        self.model = nn.Sequential(
            

            nn.Conv2d(image_channels, features, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),

            # Conv block 2 
            nn.Conv2d(features, features*2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features*2),
            nn.LeakyReLU(0.2),
    
            # Conv block 3
            nn.Conv2d(features*2, features*4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features*4),
            nn.LeakyReLU(0.2),

            # Conv block 4
            nn.Conv2d(features*4, features*8, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features*8),
            nn.LeakyReLU(0.2),

            # Conv block 5 (different)
            nn.Conv2d(features*8, 1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.model(x)

In [26]:
generator       = Generator(noise_channels, image_channels, gen_features).to(device)
discriminator   = Discriminator(image_channels, disc_features).to(device)
gen_optimizer = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
disc_optimizer = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
criterion = nn.BCELoss()
fake_label = 0
real_label = 1

In [27]:
generator.train()
discriminator.train()

Discriminator(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(2, 2))
    (12): Sigmoid()
  )
)

In [28]:
fixed_noise = torch.randn(64, noise_channels, 1, 1).to(device)

In [30]:
import torchvision.utils as vutils
step = 0

for epoch in range(EPOCHS):
    

    for batch_idx, (data, target) in enumerate(dataloader):
        

        data = data.to(device)

        batch_size = data.shape[0]
        
        discriminator.zero_grad()
        label = (torch.ones(batch_size) * 0.9).to(device)
        output = discriminator(data).reshape(-1)
        real_disc_loss = criterion(output, label)
        d_x = output.mean().item()

        noise = torch.randn(batch_size, noise_channels, 1, 1).to(device)
        fake = generator(noise)
        label = (torch.ones(batch_size) * 0.1).to(device)
        output = discriminator(fake.detach()).reshape(-1)
        fake_disc_loss = criterion(output, label)

        disc_loss = real_disc_loss + fake_disc_loss


        disc_loss.backward()
        disc_optimizer.step()

        generator.zero_grad()
        label = torch.ones(batch_size).to(device)
        output = discriminator(fake).reshape(-1)
        gen_loss = criterion(output, label)
        gen_loss.backward()
        gen_optimizer.step()

        if batch_idx % 30 == 0:
            step += 1
            
            print(
                f"Epoch: {epoch} ===== Batch: {batch_idx}/{len(dataloader)} ===== Disc loss: {disc_loss:.4f} ===== Gen loss: {gen_loss:.4f}"
            )

            with torch.no_grad():
                fake_images = generator(fixed_noise)

  
    fake_images = generator(fixed_noise)
    save_path = "generated_images_epoch_{}.png".format(epoch)
    vutils.save_image(fake_images, save_path, normalize=True)
    print("Generated images saved at '{}'".format(save_path))

Epoch: 0 ===== Batch: 0/68 ===== Disc loss: 0.9547 ===== Gen loss: 3.6233


KeyboardInterrupt: 

In [40]:
torch.save(generator, "gen.pt")
