# Training 

In [2]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np
import torch.utils
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import Resize, RandomCrop, ToTensor, Compose
from torchvision.utils import save_image
import torch
import torch.nn as nn
import torch.optim as optim
from util import get_hole, get_mask, crop
from generator import Generator
from discriminator import Discriminator
from tqdm import tqdm
import random

MIN_HOLEW, MAX_HOLEW = 96, 128
MIN_HOLEH, MAX_HOLEH = 96, 128
EPOCH_G = 20
EPOCH_D = 15
EPOCH_M = 100
batch_size = 64
cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
device

device(type='cuda')

In [4]:

tsfm = Compose([Resize(256), RandomCrop((256, 256)), ToTensor()])

training_data = datasets.CelebA(
    root="../data",
    split='train',
    download=False,
    transform=tsfm
)
test_data = datasets.CelebA(
    root="../data",
    split='test',
    download=False,
    transform=tsfm
)

RuntimeError: Dataset not found or corrupted. You can use download=True to download it

In [None]:

# calculating mean pixel value of the training set
mpv = torch.tensor((0.50925811, 0.42336759, 0.37791181)).view(1, 3, 1, 1).to(device) #precomputed mean value
# mpv = np.zeros((3,))
# for x in training_data:
#    r = x[0][0]
#    g = x[0][1]
#    b = x[0][2]
#   mpv += (torch.mean(r), torch.mean(g), torch.mean(b))
# mpv /= len(training_data)

def collate_fn(batch):
    batch = torch.cat([sample[0].unsqueeze(0) for sample in batch], dim=0)
    return batch

train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

generator = Generator().to(device)
optimizer = optim.Adadelta(generator.parameters(), lr=0.1)

GPATH = "../model/gen-mutual-20-loss13.787893346045166.pth"
if GPATH is not None:
  checkpoint = torch.load(GPATH)
  generator.load_state_dict(checkpoint['model_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

discriminator = Discriminator().to(device)
OptimizerD = optim.Adadelta(discriminator.parameters(), lr=0.01)
DPATH = "../model/dis-14-loss-2.8606153897929893e-06.pth"
if DPATH is not None:
  checkpoint = torch.load(DPATH)
  discriminator.load_state_dict(checkpoint['model_state_dict'])
  OptimizerD.load_state_dict(checkpoint['optimizer_state_dict'])


The above code imports all the necessary modules, defining the constants and load the data/model using pytorch. Note that the mean pixel value of the training set is pre-calculated to save time. The training uses standard CelebA dataset with its pre-defined training and test set. In total, during the training phase the generator is trained 20 times alone, discriminator is trained 15 times alone, and they are both trained 21 times mutually. 

In [None]:
for epoch in range(EPOCH_G):
  generator.train()
  loop = tqdm(train_dataloader, desc="Generator")
  loss_tot = 0
  for x in loop:
      x = x.to(device)
      shape = (x.shape[0], 1, x.shape[2], x.shape[3])
      hole = get_hole((random.randint(MIN_HOLEW, MAX_HOLEW),
                        random.randint(MIN_HOLEH, MAX_HOLEH)))
      mask = get_mask(shape, hole).to(device)
      net_in = x - x * mask + mpv * mask
      input = torch.cat((net_in, mask), dim=1)
      out = generator(input)
      loss = nn.functional.mse_loss(x*mask, out*mask)
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()
      loop.set_description(f"Epoch {epoch}/{EPOCH_G}")
      loop.set_postfix(loss=loss.item())
      loop.update()
      loss_tot += loss.item()
  torch.save({'epoch': epoch,
        'model_state_dict': generator.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
            }, "/content/model/regen-" + str(epoch) + "-loss-"+str(loss_tot/len(loop))+".pth")
  loop.close()

Epoch 0/20: 100%|██████████| 2544/2544 [18:15<00:00,  2.32it/s, loss=0.00398]
Epoch 1/20: 100%|██████████| 2544/2544 [18:17<00:00,  2.32it/s, loss=0.00342]
Epoch 2/20: 100%|██████████| 2544/2544 [18:16<00:00,  2.32it/s, loss=0.00616]
Epoch 3/20: 100%|██████████| 2544/2544 [18:19<00:00,  2.31it/s, loss=0.0033]
Epoch 4/20: 100%|██████████| 2544/2544 [18:18<00:00,  2.32it/s, loss=0.00453]
Epoch 5/20: 100%|██████████| 2544/2544 [18:19<00:00,  2.31it/s, loss=0.00363]
Epoch 6/20: 100%|██████████| 2544/2544 [18:17<00:00,  2.32it/s, loss=0.00373]
Epoch 7/20: 100%|██████████| 2544/2544 [18:21<00:00,  2.31it/s, loss=0.00344]
Epoch 8/20: 100%|██████████| 2544/2544 [18:21<00:00,  2.31it/s, loss=0.00303]
Epoch 9/20: 100%|██████████| 2544/2544 [18:21<00:00,  2.31it/s, loss=0.0041]
Epoch 10/20: 100%|██████████| 2544/2544 [18:21<00:00,  2.31it/s, loss=0.0025]
Epoch 11/20:   0%|          | 10/2544 [00:04<20:00,  2.11it/s, loss=0.00295]


KeyboardInterrupt: 

In [None]:
GEN_TEST_PATH = "../model/regen-10-loss-0.003284191969820919.pth"
generator_test = Generator().to(device)
checkpoint_test = torch.load(GEN_TEST_PATH)
generator_test.load_state_dict(checkpoint_test['model_state_dict'])
with torch.no_grad():
  x = next(iter(test_dataloader)).to(device)
  shape = (x.shape[0], 1, x.shape[2], x.shape[3])
  hole = get_hole((random.randint(MIN_HOLEW, MAX_HOLEW),
                    random.randint(MIN_HOLEH, MAX_HOLEH)))
  mask = get_mask(shape, hole).to(device)
  x = x - x * mask + mpv * mask
  input = torch.cat((x, mask), dim=1)
  out = generator_test(input)
  imgs = torch.cat((x.cpu(), out.cpu()), dim=0)
  save_image(imgs, "../result/test1.jpg", nrow=len(x))

We show the result of training generator only. We can see it's pretty blurred since only mse loss is used
![result of generator](../result/test1.jpg)

In [None]:
BCEloss = nn.BCELoss()

for epoch in range(EPOCH_D):
  discriminator.train()
  loop = tqdm(train_dataloader, desc="Discriminator")
  loss_tot = 0
  for x in loop:
    shape = (x.shape[0], 1, x.shape[2], x.shape[3])
    x = x.to(device)
    holeC = get_hole((128, 128))
    maskC = get_mask(shape, holeC).to(device)
    net_input = x - x * maskC + mpv * maskC
    inputC = torch.cat((net_input, maskC), dim=1)
    outC = generator(inputC)
    global_inputC = outC.detach()
    local_inputC = crop(global_inputC, holeC)
    resultC = discriminator((local_inputC.to(device), global_inputC.to(device)))
    lossC = BCEloss(resultC, torch.zeros((len(x), 1), dtype=torch.float).to(device))

    holeD = get_hole((128,128))
    local_inputD = crop(x, holeD)
    resultD = discriminator((local_inputD.to(device), x))
    lossD = BCEloss(resultD, torch.ones((len(x), 1), dtype=torch.float).to(device))

    loss_overall =  (lossC +  lossD)/2
    loss_overall.backward()
    loss_tot += loss_overall.item()
    OptimizerD.step()
    OptimizerD.zero_grad()
    loop.set_description(f"Epoch {epoch}/{EPOCH_D}")
    loop.set_postfix({'loss':loss_tot, 'lossC': lossC.item(), 'lossD': lossD.item()})
    loop.update()
  torch.save({'epoch': epoch,
        'model_state_dict': discriminator.state_dict(),
        'optimizer_state_dict': OptimizerD.state_dict(),
        'loss': loss_overall,
            }, "/content/model/dis-" + str(epoch) + "-loss-"+str(loss_tot/len(loop))+".pth")
  loop.close()


Epoch 0/15: 100%|██████████| 2544/2544 [13:35<00:00,  3.12it/s, loss=9.44, lossC=1.44e-5, lossD=9.31e-5]
Epoch 1/15: 100%|██████████| 2544/2544 [13:36<00:00,  3.11it/s, loss=0.101, lossC=1.25e-5, lossD=6.66e-5]
Epoch 2/15: 100%|██████████| 2544/2544 [13:37<00:00,  3.11it/s, loss=0.0538, lossC=3.53e-5, lossD=2.69e-5]
Epoch 3/15: 100%|██████████| 2544/2544 [13:38<00:00,  3.11it/s, loss=0.0361, lossC=5.59e-6, lossD=1.2e-5]
Epoch 4/15: 100%|██████████| 2544/2544 [13:37<00:00,  3.11it/s, loss=0.0276, lossC=7.95e-6, lossD=8.68e-6]
Epoch 5/15: 100%|██████████| 2544/2544 [13:38<00:00,  3.11it/s, loss=0.0213, lossC=2.51e-6, lossD=2e-6]
Epoch 6/15: 100%|██████████| 2544/2544 [13:39<00:00,  3.11it/s, loss=0.0174, lossC=1.69e-6, lossD=7.26e-6]
Epoch 7/15: 100%|██████████| 2544/2544 [13:37<00:00,  3.11it/s, loss=0.0152, lossC=3.55e-6, lossD=7.1e-6]
Epoch 8/15: 100%|██████████| 2544/2544 [13:39<00:00,  3.11it/s, loss=0.0134, lossC=1.09e-6, lossD=5.55e-6]
Epoch 9/15: 100%|██████████| 2544/2544 [13:37

In [None]:
BCEloss = nn.BCELoss()
for epoch in range(EPOCH_M):
  lossG_tot = 0
  generator.train()
  discriminator.train()
  alpha = torch.tensor(0.004, dtype=torch.float32).to(device)
  loop = tqdm(train_dataloader, desc="Mutual training")
  for x in loop:
    x = x.to(device)
    shape = (x.shape[0], 1, x.shape[2], x.shape[3])
    holeC = get_hole((128, 128))
    maskC = get_mask(shape, holeC).to(device)
    net_input = x - x * maskC + mpv * maskC
    inputC = torch.cat((net_input, maskC), dim=1)
    outC = generator(inputC)
    global_inputC = outC.detach()
    local_inputC = crop(global_inputC, holeC)
    resultC = discriminator((local_inputC.to(device), global_inputC.to(device)))
    lossC = BCEloss(resultC, torch.zeros((len(x), 1), dtype=torch.float).to(device))

    holeD = get_hole((128,128))
    local_inputD = crop(x, holeD)
    resultD = discriminator((local_inputD.to(device), x))
    lossD = BCEloss(resultD, torch.ones((len(x), 1), dtype=torch.float).to(device))
    lossD_overall =  (lossC +  lossD) * alpha / 2
    lossD_overall.backward()
    OptimizerD.step()
    OptimizerD.zero_grad()


    lossG = nn.functional.mse_loss(x*maskC, outC*maskC)
    lossG_tot += lossG.item()
    outputD1 = discriminator((crop(outC, holeC).to(device), outC.to(device)))
    lossG_overall = (lossG + alpha * BCEloss(outputD1, torch.ones((len(x), 1), dtype=torch.float).to(device)))/2
    lossG_overall.backward()
    optimizer.step()
    optimizer.zero_grad()
    loop.set_description(f"Epoch {epoch}/{EPOCH_M}")
    loop.set_postfix({'lossG': lossG_tot, 'lossD': lossD_overall.item()})
    loop.update()
  torch.save({'epoch': epoch,
        'model_state_dict': discriminator.state_dict(),
        'optimizer_state_dict': OptimizerD.state_dict(),
        'loss': lossD_overall,
            }, "/content/model/dis-mutual-" + str(epoch)+".pth")
  torch.save({'epoch': epoch,
        'model_state_dict': generator.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': lossG_overall,
            }, "/content/model/gen-mutual-" + str(epoch)+ "-loss" + str(lossG_tot) +".pth")
  loop.close()


Epoch 0/100: 100%|██████████| 2544/2544 [23:24<00:00,  1.81it/s, lossG=13.3, lossD=0.00387]
Epoch 1/100: 100%|██████████| 2544/2544 [23:19<00:00,  1.82it/s, lossG=12.8, lossD=0.00314]
Epoch 2/100:  67%|██████▋   | 1695/2544 [15:33<07:47,  1.82it/s, lossG=8.53, lossD=0.00331]


KeyboardInterrupt: 

# **For inference**

After training them mutually, we can do inference. The test result will be shown in the main jupyter notebook called project.ipynb, please refer to that notebook for inference.