# BertonGan CelebA training file

In [None]:
# some setup to get colab to work
import sys
IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
	!git clone https://github.com/Herb-Wright/berton-gan/
	!mv berton-gan berton_gan
	import os
	sys.path.append(os.path.abspath('berton_gan'))

import torch
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
from src import download_celeba

# download celeba dataset
train_data = download_celeba()

In [None]:
EXPERIMENT_NAME = 'celeba_experiment_1'

In [None]:
from src import CelebALoader
from torchvision.transforms import Compose, ColorJitter, RandomRotation

# define batch sizes
n, N = 3, 32

# make a transform
transform = Compose([
	ColorJitter(brightness=0.1, contrast=0.1, hue=0.1),
	RandomRotation(4),
])

# make our dataloader
dataloader = CelebALoader(
	n,
	N, 
	transform=transform
)

In [None]:
from src import train_all_at_once
from experiments.utils import load_last_model, save_checkpoint

# hyperparameters
EPOCHS = 50
LR = 1e-2

# load last berton_gan
berton_gan, epoch = load_last_model('celeba', EXPERIMENT_NAME, verbose=True)

# train the gan
train_all_at_once(
	berton_gan,
	dataloader,
	EPOCHS,
	optimizer_options={'lr': LR},
	epochs_start=epoch,
	save_func=(lambda gan, md, epoch: save_checkpoint(gan, md, f'{EXPERIMENT_NAME}/{epoch}')),
	verbose=True
)

In [None]:
# maybe do some things to test how good the model is here??
berton_gan.eval()

In [None]:
# if in colab, download the experiment
if IN_COLAB:
	from experiments.utils import colab_download_experiment
	colab_download_experiment(EXPERIMENT_NAME)