In [None]:
import torch
import torch.nn as nn
from torchvision import models, utils
import torch.nn.functional as F
import numpy as np
import math
import tqdm
import os

# custom dataloader
from src.dataloader import DIV2KDataModule
# Generator network
from src.generator import Generator
# Discriminator network
from src.discriminator import Discriminator
# module for VGG-based perceptual loss
from src.vgg_wrapper import VGGLoss
# training functions
from src.training_functions import pretrain_generator, train, save_models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [None]:
# set up dataloader
dataloader = DIV2KDataModule()
dataloader.setup()

In [None]:
SCALE_FACTOR = 4 # (64x64 -> 256x256)
LR = 1e-4
NUM_EPOCHS = 7
GEN_PRETRAIN_NUM_EPOCHS = 3

G = Generator(scale_factor=SCALE_FACTOR).to(device)
D = Discriminator().to(device)

# batch_size=64 is good for T4 GPU on colab, worse GPU-s might require much less
# for example 16 is ok for GeForce GTX 1050
train_loader_pre = dataloader.train_dataloader(batch_size=64)

G = pretrain_generator(G, train_loader_pre, pretrain_epochs=GEN_PRETRAIN_NUM_EPOCHS, lr=LR)

# batchsize should be around half than that of pretraining
train_loader = dataloader.train_dataloader(batch_size=32)
val_loader = dataloader.val_dataloader(batch_size=32)

G, D = train(G, D, train_loader, val_loader=val_loader, num_epochs=NUM_EPOCHS, lr=LR)

In [None]:
save_models(G, D)