In [1]:
import torch 
import torch.nn as nn 
import torch.optim as optim 
from dataloader import *
from _generator import *
from discriminator import *
from losses import * 
import utilities
import random
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

BASE_PRETRAIN_PATH = '/content/drive/MyDrive/trained_models/'
BASE_SAVE_PATH = '/content/drive/MyDrive/adjustParams/'

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Initial Variables etc..


In [3]:
# networks 

generator = Generator().to(device)
generator.load_model(BASE_SAVE_PATH + 'generator_checkpoint_e5.pth')
generator = generator.train()

discriminator = Discriminator().to(device)
discriminator.load_model(BASE_SAVE_PATH + 'discriminator_checkpoint_e5.pth')
discriminator = discriminator.train()

VGG = getVGGConv4_4().to(device)

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


HBox(children=(FloatProgress(value=0.0, max=574673361.0), HTML(value='')))




In [4]:
# losses 

content_loss = ContentLoss(VGG).to(device)
grayscale_loss = GrayscaleStyleLoss(VGG).to(device)
color_recon_loss = ColorReconLoss().to(device)
adversarial_loss = nn.MSELoss().to(device)


In [5]:
# optimizers

# maybe come back and add weight decay 

pre_train_optim = optim.Adam(generator.parameters(), lr=0.0001)

gen_optim = optim.Adam(generator.parameters(), lr=0.00008)
dis_optim = optim.Adam(discriminator.parameters(), lr=0.00016)

In [6]:
ANIME_PATH = '/content/drive/MyDrive/dataset/Shinkai/style/'
SMOOTH_PATH = '/content/drive/MyDrive/dataset/Shinkai/smooth/'
PHOTOS_PATH = '/content/drive/MyDrive/dataset/train_photo/'

photo_dataloader = getPhotoDataloader(PHOTOS_PATH)
anime_dataloader = getAnimeDataloader(ANIME_PATH, grayscale=True)
dis_dataloader = getPhotoAndAnimeDataloader(ANIME_PATH, SMOOTH_PATH, PHOTOS_PATH)

## Train the Network here:

In [None]:
import pickle
from google.colab import files

EPOCHS = 10
START_EPOCH = 6
G_TO_D_RATIO = 5
LAMBDA_ADV = 100.
LAMBDA_CON = 0.01
LAMBDA_GRA = 10.
LAMBDA_COL = 5.
DISCRIMINATOR_FIRST = False
RANDOM_SKIP = 0.3 # roughly 1 to 4 ratio
# Math is: 
# discriminator 6,609 iterations / epoch = 26,438 photos / BATCH_SIZE(4)
# generator: 8,320 iterations / epoch = 6,656 images * G_TO_D_RATIO / BATCH_SIZE 
# discriminator * RANDOM_SKIP / generator is roughly 1 to 4


anime_iter = iter(anime_dataloader)

fake_true_labels = torch.ones((4,1,64,64)).to(device)

content_loss_list = utilities.readListFromPickle(BASE_SAVE_PATH + 'content_loss.pkl')
gray_loss_list = utilities.readListFromPickle(BASE_SAVE_PATH + 'gray_loss.pkl')
color_loss_list = utilities.readListFromPickle(BASE_SAVE_PATH + 'color_loss.pkl')
adv_loss_list = utilities.readListFromPickle(BASE_SAVE_PATH + 'adv_loss.pkl')
discriminator_loss_list = utilities.readListFromPickle(f'{BASE_SAVE_PATH}dis_loss.pkl')


for e in range(START_EPOCH, EPOCHS):  
  print("training generator")
  for r in range(G_TO_D_RATIO):
    if DISCRIMINATOR_FIRST:
      DISCRIMINATOR_FIRST = False
      break
    for p_batch_idx, photo_batch in enumerate(photo_dataloader):
      # train the generator
      
      gen_optim.zero_grad()
      
      anime_batch = next(anime_iter).to(device)
      if anime_batch.shape != photo_batch.shape:
        anime_iter = iter(anime_dataloader)
        continue

      # pass through generator network
      photo_batch = photo_batch.to(device)
      gen_images = generator(photo_batch)
      # gen_image is [4 x 3 x 256 x 256]

      # pass through discriminator network
      pred_labels = discriminator(gen_images)

      # calculate losses
      con_loss = content_loss(gen_images, photo_batch)
      gra_loss = grayscale_loss(gen_images, anime_batch)
      col_loss = color_recon_loss(gen_images, photo_batch)
      adv_loss = adversarial_loss(pred_labels, fake_true_labels)

      loss = LAMBDA_ADV * adv_loss + LAMBDA_CON * con_loss + \
              LAMBDA_GRA * gra_loss + LAMBDA_COL * col_loss
      
      # backpropogate
      loss.backward()
      gen_optim.step()

      # save in list
      content_loss_list.append(con_loss.item())
      gray_loss_list.append(gra_loss.item())
      color_loss_list.append(col_loss.item())
      adv_loss_list.append(adv_loss.item())

      if p_batch_idx % 500 == 499:
        # save model at periodic checkpoints
        print("generator epoch:", e, "r:", r, "p_batch_idx", p_batch_idx, "loss:", loss.item())        
        utilities.saveListToPickle(BASE_SAVE_PATH + 'content_loss.pkl', content_loss_list)
        utilities.saveListToPickle(BASE_SAVE_PATH + 'gray_loss.pkl', gray_loss_list)
        utilities.saveListToPickle(BASE_SAVE_PATH + 'color_loss.pkl', color_loss_list)
        utilities.saveListToPickle(BASE_SAVE_PATH + 'adv_loss.pkl', adv_loss_list)
      
    print("saving sample generated images...")
    unique_identifier = f"e{e}r{r}idx"
    utilities.save_torch_as_images(BASE_SAVE_PATH, gen_images, unique_identifier=f'{unique_identifier}', is_standardized_image=True)
    utilities.save_torch_as_images(BASE_SAVE_PATH, gen_images, unique_identifier=f'_{unique_identifier}', is_standardized_image=True, adjust_brightness=True, imgs=photo_batch)
    print("done!")
    generator.save_model(f"{BASE_SAVE_PATH}generator_checkpoint_e{e}_r{r}.pth")

  generator.save_model(f"{BASE_SAVE_PATH}generator_checkpoint_e{e}.pth")
  
  print("training discriminator")
  for batch_idx, (photo_batch, labels) in enumerate(dis_dataloader):
    # trains the discriminator

    if random.random() < RANDOM_SKIP:
      # discriminator is trained too often, so randomly skip RANDOM_SKIP of images
      # to prevent discriminator from converging too quickly
      continue

    dis_optim.zero_grad()
    
    # send data to cuda if available    
    photo_batch = photo_batch.to(device)
    labels = labels.to(device)

    # pass through discriminator and get loss
    pred_labels = discriminator(photo_batch)
    loss = LAMBDA_ADV * adversarial_loss(pred_labels, labels)

    # backpropogate
    loss.backward()
    dis_optim.step()
    
    discriminator_loss_list.append(loss.item())

    if batch_idx % 500 == 499:
      # save model at periodic checkpoints
      print("discriminator epoch:", e, "batch_idx", batch_idx, "loss:", loss.item())
      utilities.saveListToPickle(f'{BASE_SAVE_PATH}dis_loss.pkl', discriminator_loss_list)

  discriminator.save_model(f"{BASE_SAVE_PATH}discriminator_checkpoint_e{e}.pth")

generator.save_model(f"{BASE_SAVE_PATH}generator_final.pth")
discriminator.save_model(f"{BASE_SAVE_PATH}discriminator_final.pth")

training generator




generator epoch: 6 r: 0 p_batch_idx 499 loss: 1.3788108825683594
generator epoch: 6 r: 0 p_batch_idx 999 loss: 1.024960994720459
generator epoch: 6 r: 0 p_batch_idx 1499 loss: 0.7302474975585938
saving sample generated images...
done!
generator epoch: 6 r: 1 p_batch_idx 499 loss: 1.0336525440216064
generator epoch: 6 r: 1 p_batch_idx 999 loss: 0.9195556640625
generator epoch: 6 r: 1 p_batch_idx 1499 loss: 0.7550180554389954
saving sample generated images...
done!
generator epoch: 6 r: 2 p_batch_idx 499 loss: 0.6827806234359741
generator epoch: 6 r: 2 p_batch_idx 999 loss: 0.744933009147644
generator epoch: 6 r: 2 p_batch_idx 1499 loss: 0.6778417825698853
saving sample generated images...
done!
generator epoch: 6 r: 3 p_batch_idx 499 loss: 0.9143848419189453
generator epoch: 6 r: 3 p_batch_idx 999 loss: 0.7374331951141357
generator epoch: 6 r: 3 p_batch_idx 1499 loss: 0.5849680304527283
saving sample generated images...
done!
generator epoch: 6 r: 4 p_batch_idx 499 loss: 0.8655440211296

  return F.mse_loss(input, target, reduction=self.reduction)


discriminator epoch: 6 batch_idx 499 loss: 0.3091875910758972
discriminator epoch: 6 batch_idx 999 loss: 0.2818688750267029
discriminator epoch: 6 batch_idx 1499 loss: 0.21203672885894775
discriminator epoch: 6 batch_idx 1999 loss: 0.06371842324733734
discriminator epoch: 6 batch_idx 2499 loss: 0.05788814648985863
discriminator epoch: 6 batch_idx 3499 loss: 0.10336916893720627
discriminator epoch: 6 batch_idx 3999 loss: 0.07945617288351059
discriminator epoch: 6 batch_idx 4999 loss: 0.06535618007183075
discriminator epoch: 6 batch_idx 5499 loss: 0.08059097826480865
discriminator epoch: 6 batch_idx 6499 loss: 0.1356891542673111
training generator
generator epoch: 7 r: 0 p_batch_idx 499 loss: 0.8095471858978271
generator epoch: 7 r: 0 p_batch_idx 999 loss: 0.6764535903930664
generator epoch: 7 r: 0 p_batch_idx 1499 loss: 0.8352744579315186
saving sample generated images...
done!
generator epoch: 7 r: 1 p_batch_idx 499 loss: 0.750571608543396
generator epoch: 7 r: 1 p_batch_idx 999 loss: 

# Some quick visual checks

In [None]:
pred_labels = discriminator(gen_images)

In [None]:
print(pred_labels)

In [None]:
discriminator(photo_batch)

In [None]:
print(labels)

In [None]:
import matplotlib.pyplot as plt

# learning curves 

INTERVAL = 25

content_loss_list = utilities.listToAvg(utilities.readListFromPickle(BASE_SAVE_PATH + 'content_loss.pkl'), interval=INTERVAL)
color_loss_list = utilities.listToAvg(utilities.readListFromPickle(BASE_SAVE_PATH + 'color_loss.pkl'), interval=INTERVAL)
gray_loss_list = utilities.listToAvg(utilities.readListFromPickle(BASE_SAVE_PATH + 'gray_loss.pkl'), interval=INTERVAL)
adv_loss_list = utilities.listToAvg(utilities.readListFromPickle(BASE_SAVE_PATH + 'adv_loss.pkl'), interval=INTERVAL)


f1 = plt.figure()
plt.title("content")
plt.plot(content_loss_list)

f2 = plt.figure()
plt.title("color")
plt.plot(color_loss_list)

f3 = plt.figure()
plt.title("gray")
plt.plot(gray_loss_list)

f4 = plt.figure()
plt.title("adv")
plt.plot(adv_loss_list)