# Model training



Dataset: https://www.kaggle.com/datasets/gpiosenka/100-bird-species

In [None]:
!pip install piq

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!ls -lha kaggle.json

In [None]:
!pip install -q kaggle

In [None]:
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/

In [None]:
!chmod 600 /root/.kaggle/kaggle.json

In [None]:
!pwd

In [None]:
!kaggle datasets download -d "gpiosenka/100-bird-species"

In [None]:
!mv "100-bird-species.zip" drive/MyDrive/bhw2images/

In [None]:
!unzip "drive/MyDrive/bhw2images/100-bird-species.zip" -d "drive/MyDrive/bhw2images/bird-species/"

In [None]:
!ls 'drive/MyDrive/bhw2images/bird-species/train' | wc -l

In [None]:
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

manualSeed = 2007
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.use_deterministic_algorithms(True)

In [None]:
dataroot = "drive/MyDrive/bhw2images/bird-species/train/"
workers = 2
batch_size = 128
image_size = 64
nc = 3
nz = 100
ngf = 64
ndf = 64
num_epochs = 1000
lr = 0.0002
beta1 = 0.5

In [None]:
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers, drop_last=True)

device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")

real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

In [None]:
netG = Generator().to(device)
netG.apply(weights_init)
print(netG)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

In [None]:
netD = Discriminator().to(device)
netD.apply(weights_init)
print(netD)

In [None]:
def plot_losses_and_samples(G_losses, D_losses, dataloader, img_list, ssim_list, fid_list, iters, plot_every):
  plt.figure(figsize=(40,10))
  plt.subplot(1,3,1)
  plt.title("Generator and Discriminator Loss During Training")
  st = len(G_losses) // 200 + 1
  plt.plot(np.arange(0, len(G_losses), st), G_losses[::st],label="G")
  plt.plot(np.arange(0, len(D_losses), st), D_losses[::st],label="D")
  plt.xlabel("iterations")
  plt.ylabel("Loss")
  plt.legend()

  plt.subplot(1,3,2)
  plt.title("SSIM")
  plt.plot(np.arange(0, iters + plot_every, plot_every), ssim_list)
  plt.xlabel("iterations")
  plt.ylabel("SSIM")

  plt.subplot(1,3,3)
  plt.title("FID")
  plt.plot(np.arange(0, iters + plot_every, plot_every), fid_list)
  plt.xlabel("iterations")
  plt.ylabel("FID")
  plt.show()

  real_batch = next(iter(dataloader))

  plt.figure(figsize=(15,15))
  plt.subplot(1,2,1)
  plt.axis("off")
  plt.title("Real Images")
  plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

  plt.subplot(1,2,2)
  plt.axis("off")
  plt.title("Fake Images (iter " + str(iters) + ")")
  plt.imshow(np.transpose(img_list[-1],(1,2,0)))
  plt.show()

In [None]:
criterion = nn.BCELoss()

fixed_noise = torch.randn(64, nz, 1, 1, device=device)

real_label = 1.
fake_label = 0.

optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

In [None]:
from piq import FID
from piq import ssim, SSIMLoss

def collate_fn(data):
  return {'images' : torch.stack([data[0][0]])}

def calc_ssim_and_fid(dataset, n_pics=1000):
  real_pics = []
  for _ in range(n_pics):
    idx = torch.randint(len(dataset), (1,))
    real_pics.append(dataset[idx][0])
  real_pics = torch.stack(real_pics)

  random_noise = torch.randn(n_pics, nz, 1, 1, device=device)
  with torch.no_grad():
      fake_pics = netG(random_noise).detach().cpu()

  fake_pics += 1
  fake_pics /= 2
  real_pics += 1
  real_pics /= 2

  fake_dataset = torch.utils.data.TensorDataset(fake_pics)
  real_dataset = torch.utils.data.TensorDataset(real_pics)

  fake_dataloader = torch.utils.data.DataLoader(dataset = fake_dataset, batch_size=1, collate_fn=collate_fn)
  real_dataloader = torch.utils.data.DataLoader(dataset = real_dataset, batch_size=1, collate_fn=collate_fn)

  fid_metric = FID()
  fake_feats = fid_metric.compute_feats(fake_dataloader)
  real_feats = fid_metric.compute_feats(real_dataloader)
  fid = fid_metric(fake_feats, real_feats)

  ssim_index = ssim(real_pics, fake_pics, data_range=1.)
  return ssim_index, fid

In [None]:
from tqdm import tqdm

img_list = []
G_losses = []
D_losses = []
ssim_list = []
fid_list = []
plot_every = 1000
iters = 0

for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        netD.zero_grad()
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        output = netD(real_cpu).view(-1)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        noise = torch.randn(b_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach()).view(-1)
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()

        errD = errD_real + errD_fake
        optimizerD.step()


        netG.zero_grad()
        label.fill_(real_label)
        output = netD(fake).view(-1)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # if i % 50 == 0:
        #     print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
        #           % (epoch, num_epochs, i, len(dataloader),
        #              errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        if (iters % plot_every == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
            ssim_index, fid = calc_ssim_and_fid(dataset)
            ssim_list.append(ssim_index)
            fid_list.append(fid)
            plot_losses_and_samples(G_losses, D_losses, dataloader, img_list, ssim_list, fid_list, iters, plot_every)
            torch.save({'epoch': epoch, 'netD_state_dict': netD.state_dict(), 'netG_state_dict': netG.state_dict(), 'optimizerD_state_dict': optimizerD.state_dict(), 'optimizerG_state_dict': optimizerG.state_dict()}, "drive/MyDrive/bhw2images/checkpoint_" + str(iters)+".tar")
        iters += 1

In [None]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers, drop_last=True)

In [None]:
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        netD.zero_grad()
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        output = netD(real_cpu).view(-1)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        noise = torch.randn(b_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach()).view(-1)
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()

        errD = errD_real + errD_fake
        optimizerD.step()


        netG.zero_grad()
        label.fill_(real_label)
        output = netD(fake).view(-1)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # if i % 50 == 0:
        #     print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
        #           % (epoch, num_epochs, i, len(dataloader),
        #              errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        if (iters % plot_every == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
            ssim_index, fid = calc_ssim_and_fid(dataset)
            ssim_list.append(ssim_index)
            fid_list.append(fid)
            plot_losses_and_samples(G_losses, D_losses, dataloader, img_list, ssim_list, fid_list, iters, plot_every)
            torch.save({'epoch': epoch, 'netD_state_dict': netD.state_dict(), 'netG_state_dict': netG.state_dict(), 'optimizerD_state_dict': optimizerD.state_dict(), 'optimizerG_state_dict': optimizerG.state_dict()}, "drive/MyDrive/bhw2images/checkpoint_" + str(iters)+".tar")
        iters += 1

In [None]:
def calc_ssim_and_fid_on_real_data(dataset, n_pics=1000):
  real_pics = []
  for _ in range(n_pics):
    idx = torch.randint(len(dataset), (1,))
    real_pics.append(dataset[idx][0])
  real_pics = torch.stack(real_pics)

  fake_pics = []
  for _ in range(n_pics):
    idx = torch.randint(len(dataset), (1,))
    fake_pics.append(dataset[idx][0])
  fake_pics = torch.stack(fake_pics)

  fake_pics += 1
  fake_pics /= 2
  real_pics += 1
  real_pics /= 2

  fake_dataset = torch.utils.data.TensorDataset(fake_pics)
  real_dataset = torch.utils.data.TensorDataset(real_pics)

  fake_dataloader = torch.utils.data.DataLoader(dataset = fake_dataset, batch_size=1, collate_fn=collate_fn)
  real_dataloader = torch.utils.data.DataLoader(dataset = real_dataset, batch_size=1, collate_fn=collate_fn)

  fid_metric = FID()
  fake_feats = fid_metric.compute_feats(fake_dataloader)
  real_feats = fid_metric.compute_feats(real_dataloader)
  fid = fid_metric(fake_feats, real_feats)

  ssim_index = ssim(real_pics, fake_pics, data_range=1.)
  return ssim_index, fid

In [None]:
ssim_index, fid = calc_ssim_and_fid_on_real_data(dataset)
print("SSIM:", ssim_index)
print("FID:", fid)

# Model inference



In [None]:
def inference_model(netG):
  real_batch = next(iter(dataloader))

  plt.figure(figsize=(15,15))
  plt.subplot(1,2,1)
  plt.axis("off")
  plt.title("Real Images")
  plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

  random_noise = torch.randn(64, nz, 1, 1, device=device)
  with torch.no_grad():
      fake = netG(random_noise).detach().cpu()
  fake = vutils.make_grid(fake, padding=2, normalize=True)

  plt.subplot(1,2,2)
  plt.axis("off")
  plt.title("Fake Images")
  plt.imshow(np.transpose(fake,(1,2,0)))
  plt.show()

In [None]:
path = "/content/drive/MyDrive/bhw2images/checkpoint_222000.tar"
checkpoint = torch.load(path)
netG.load_state_dict(checkpoint["netG_state_dict"])

In [None]:
inference_model(netG)

In [None]:
ssim_index, fid = calc_ssim_and_fid(dataset)
print("SSIM:", ssim_index)
print("FID:", fid)

In [None]:
path = "/content/drive/MyDrive/bhw2images/checkpoint_107000.tar"
checkpoint = torch.load(path)
netG.load_state_dict(checkpoint["netG_state_dict"])

In [None]:
inference_model(netG)

In [None]:
ssim_index, fid = calc_ssim_and_fid(dataset)
print("SSIM:", ssim_index)
print("FID:", fid)