In [None]:
!pip install scikit-image
!pip install jupyter

In [None]:
import torch.utils.data as data

from python.data.dataset import *
from python.models.generator import UNet
from python.models.discriminator import PatchGAN
from python.models.utils import init_weights
from python.train.trainer import *

import sys

SEED = 42
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)

# if needed
# log = open("train.log", "a") 
# sys.stdout = log

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
dataset = "data/Coco"
version = "2017"

dataset_train = CocoLab(dataset, version=version, size=256, train=True)
train_loader = data.DataLoader(dataset_train, batch_size=4, shuffle=True, num_workers=4)

dataset_test = CocoLab(dataset, version=version, size=256, train=False)
test_loader = data.DataLoader(dataset_test, batch_size=4, shuffle=True, num_workers=4)

In [None]:
generator = UNet(1, 2).to(device)
discriminator = PatchGAN(3).to(device)

generator.apply(init_weights) # init weights with a gaussian distribution centered at 0, and std=0.02
discriminator.apply(init_weights) # init weights with a gaussian distribution centered at 0, and std=0.02

### Pretrain

In [None]:
trainer = Pretrain(generator, test_loader, train_loader)

In [None]:
trainer.train(2, "l1", generator_file="test")

In [None]:
trainer.make_plot("test")

### Train with GAN

In [None]:
trainer = GanTrain(generator, discriminator, test_loader, train_loader, reg_R1=True)

In [None]:
trainer.train(2, generator_file="generator", discriminator_file="discriminator", file_name="gan_no_r1",)

In [None]:
trainer.make_plot("test")