In [1]:
import os, torch
from torch import nn
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt

from models.vis_utils import show_img_batch
from models.discriminator_utils import DiscConvBlock
from models.generator_utils import GenConvTransposeBlock
from models.utils import get_noise, weights_init
from models.DCGAN.mnist import (Generator as MnistGenerator, 
                                Discriminator as MnistDiscriminator)
from models.DCGAN.train import train
%matplotlib inline

In [2]:
n_samples = 100
criterion = nn.BCEWithLogitsLoss()
z_dim = 128
batch_size = 128
hidden_dim = 64

lr = 0.0002
beta_1 = 0.5
beta_2 = 0.999

In [3]:
mnist_transforms = transforms.Compose([transforms.ToTensor(), 
                                       transforms.Normalize((0.5,), (0.5,))])
mnist_dt = MNIST(root="../Data/", download=False, transform=mnist_transforms)
mnist_dl = DataLoader(mnist_dt, batch_size=batch_size, shuffle=True)

In [4]:
device = "mps" if torch.backends.mps.is_available() else "cpu"

# Initialize the generator and discriminator
gen  = MnistGenerator(z_dim=z_dim  , hidden_dim=hidden_dim).to(device)
disc = MnistDiscriminator(im_chan=1, hidden_dim=hidden_dim).to(device)

# Initialize the optimizers for generator and discriminator
gen_opt  = torch.optim.Adam(gen.parameters() , lr=lr, betas=(beta_1, beta_2))
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))

gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

noise_input = get_noise(n_samples, z_dim, device)
fake = gen(noise_input)
output = disc(fake)

In [6]:
train(50)

NameError: name 'mnist_dl' is not defined

In [None]:
n_epochs = 100
display_step = 5000
save_step = 10000
cur_step = 0
generator_loss = 0
discriminator_loss = 0

for epoch in range(n_epochs):

    for real, _ in tqdm(mnist_dl):
        cur_batch_size = len(real)
        real = real.to(device)

        ## Update discriminator ##
        disc_opt.zero_grad()
        
        # Get noise corresponding to the current batch_size
        noise = get_noise(cur_batch_size, z_dim, device=device)
                
        fake = gen(noise)
        disc_fake = disc(fake.detach())
        disc_fake_loss = criterion(disc_fake, torch.zeros_like(disc_fake, device=device))
        disc_real = disc(real)
        disc_real_loss = criterion(disc_real, torch.ones_like(disc_fake, device=device))
        disc_loss = (disc_real_loss + disc_fake_loss) / 2

        disc_loss.backward(retain_graph=True)
        disc_opt.step()

        discriminator_loss += disc_loss.item()
        
        ### Update generator ###
        gen_opt.zero_grad()
        noise = get_noise(cur_batch_size, z_dim, device=device)
                
        fake = gen(noise)
        gen_fake = disc(fake)
        gen_fake_loss = criterion(gen_fake, torch.ones_like(gen_fake, device=device))
        gen_fake_loss.backward()
        gen_opt.step()        
        
        # Keep track of the average generator loss
        generator_loss += gen_fake_loss.item()

        ### Visualization code ###
        if cur_step % display_step == 0 and cur_step > 0:
            mean_gen_loss = generator_loss/display_step
            mean_disc_loss = discriminator_loss/display_step
            print(f"Epoch {epoch}, step {cur_step}: Generator loss: {mean_gen_loss:.2f}, discriminator loss: {mean_disc_loss:.2f}")
            generator_loss = 0
            discriminator_loss = 0
            show_img_batch((real + 1) / 2)
            show_img_batch((fake + 1) / 2)
        if cur_step % save_step == 0and cur_step > 0:
            torch.save({
                'generator': gen.state_dict(),
                'discriminator': disc.state_dict(),
                'gen_opt': gen_opt.state_dict(),
                'disc_opt': disc_opt.state_dict(),
                'generator_loss': mean_gen_loss,
                'discriminator_loss': mean_disc_loss
            }, f"./assets/DCGAN/mnist_epoch_{epoch}_step_{cur_step}.pth")
            print(f"Models, optimizers, and mean losses saved at step {cur_step}")
        cur_step += 1


In [21]:
from torchvision.datasets import ImageFolder

In [None]:
# celeba = ImageFolder("../Data/celeba", 
#                      transform=transforms.Compose([transforms.ToTensor(),
                                                #    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))]))
celeba = ImageFolder("../Data/celeba", 
                     transform=transforms.Compose([transforms.ToTensor(),
                                                   transforms.CenterCrop((178, 178)),
                                                   transforms.Resize((64, 64))]))
celeba_dl = DataLoader(celeba, batch_size=128, shuffle=True)
img_batch, labels = next(iter(celeba_dl))
# show_img_batch(img_batch, size=(3,178,178))
show_img_batch(img_batch, size=(3,64,64))

In [29]:
class CelebGenerator(nn.Module):

    def __init__(self, z_dim=256, im_chan=3, hidden_dim=64):
        super().__init__()
        self.z_dim = z_dim
        self.gen = nn.Sequential(
            GenConvTransposeBlock(z_dim,           hidden_dim * 16, kernel_size=4, stride=1, padding = 0),
            GenConvTransposeBlock(hidden_dim * 16, hidden_dim * 8,  kernel_size=4, stride=2, padding = 1),
            GenConvTransposeBlock(hidden_dim * 8,  hidden_dim * 4,  kernel_size=4, stride=2, padding = 1),
            GenConvTransposeBlock(hidden_dim * 4,  hidden_dim * 2,  kernel_size=4, stride=2, padding = 1),
            GenConvTransposeBlock(hidden_dim * 2,  im_chan,         kernel_size=4, stride=2, padding = 1, final=True)
        )
    def unsqueeze_noise(self, x):
        return x.view(len(x), self.z_dim, 1, 1)
    
    def forward(self, x):
        x = self.unsqueeze_noise(x)
        return self.gen(x)
    
class CelebDiscriminator(nn.Module):

    def __init__(self, im_chan=3, hidden_dim=64):
        super().__init__()
        self.z_dim = z_dim
        self.disc = nn.Sequential(
            DiscConvBlock(im_chan,         hidden_dim * 1, kernel_size=4, stride=2, padding=1),
            DiscConvBlock(hidden_dim * 1,  hidden_dim * 2, kernel_size=4, stride=2, padding=1),
            DiscConvBlock(hidden_dim * 2,  hidden_dim * 2, kernel_size=4, stride=2, padding=1), 
            DiscConvBlock(hidden_dim * 2,  hidden_dim * 4, kernel_size=4, stride=2, padding=1),
            DiscConvBlock(hidden_dim * 4,  1, kernel_size=4, stride=1, final=True)
        )

    def forward(self, x):
        return self.disc(x).view(-1, 1)

In [30]:
# Training params
criterion = nn.BCEWithLogitsLoss()
display_step = 2500
batch_size = 128

# Model params
im_chan = 3
z_dim = 128
size = (3,64,64)

# Optimizer params
lr = 0.0002
beta_1 = 0.5 
beta_2 = 0.999

In [31]:
# Initialize the generator and discriminator
gen = CelebGenerator(z_dim=z_dim, im_chan=im_chan).to(device)
disc = CelebDiscriminator(im_chan=im_chan).to(device)

# Initialize the optimizers for generator and discriminator
gen_opt = torch.optim.Adam(gen.parameters(),   lr=lr, betas=(beta_1, beta_2))
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))

gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

In [None]:
n_epochs = 100
cur_step = 0
generator_loss = 0
discriminator_loss = 0

for epoch in range(n_epochs):

    for real, _ in tqdm(celeba_dl):
        cur_batch_size = len(real)
        real = real.to(device)

        ## Update discriminator ##
        disc_opt.zero_grad()
        
        # Get noise corresponding to the current batch_size
        noise = get_noise(cur_batch_size, z_dim, device=device)
                
        fake = gen(noise)
        disc_fake = disc(fake.detach())
        disc_fake_loss = criterion(disc_fake, torch.zeros_like(disc_fake, device=device))
        disc_real = disc(real)
        disc_real_loss = criterion(disc_real, torch.ones_like(disc_fake, device=device))
        disc_loss = (disc_real_loss + disc_fake_loss) / 2

        disc_loss.backward(retain_graph=True)
        disc_opt.step()

        discriminator_loss += disc_loss.item()
        
        ### Update generator ###
        gen_opt.zero_grad()
        noise = get_noise(cur_batch_size, z_dim, device=device)
                
        fake = gen(noise)
        gen_fake = disc(fake)
        gen_fake_loss = criterion(gen_fake, torch.ones_like(gen_fake, device=device))
        gen_fake_loss.backward()
        gen_opt.step()        
        
        # Keep track of the average generator loss
        generator_loss += gen_fake_loss.item()

        ### Visualization code ###
        if cur_step % display_step == 0 and cur_step > 0:
            mean_gen_loss = generator_loss/display_step
            mean_disc_loss = discriminator_loss/display_step
            print(f"Epoch {epoch}, step {cur_step}: Generator loss: {mean_gen_loss:.2f}, discriminator loss: {mean_disc_loss:.2f}")
            generator_loss = 0
            discriminator_loss = 0
            show_img_batch(real, size=size)
            show_img_batch(fake, size=size)
        cur_step += 1
