## Train Cartoon GAN

Jupyter notebook version of the ```src.networks.train.py``` file.
This file contains all cells to train the model from scratch.

#### Configure drive (only if needed)

In [None]:
"""
from google.colab import drive
import os
drive.mount("/content/drive")

PROJECT_DIRECTORY = "drive/MyDrive/DeepL"
os.chdir(PROJECT_DIRECTORY)

!ls
# !pip install -r requirements.txt
"""

In [1]:
import torch
import torch.optim as optim


cuda = torch.cuda.is_available()
print(cuda)

if cuda:
  print(torch.cuda.get_device_name(0))
  !nvidia-smi

device = "cuda" if cuda else "cpu"

False


#### Main parameters

Main (pre-)training parameters are defined here.
It is easier to change some parameters in `.ipynb` file instead of `config.py` while using Google colaboratory.

In [2]:
from src.models.parameters import CartoonGanParameters

BATCH_SIZE = 16

pretraining_parameters = CartoonGanParameters(
    epochs=10,
    gen_lr=1e-3,
    disc_lr=1e-2,
    nb_resnet_blocks=3,
    batch_size=BATCH_SIZE,
    conditionnal_lambda=2
)

training_parameters = CartoonGanParameters(
    epochs=10,
    gen_lr=1e-3,
    disc_lr=1e-2,
    nb_resnet_blocks=3,
    batch_size=BATCH_SIZE,
    conditionnal_lambda=2
)


### Prepare data

In [None]:
from src.dataset.cartoon_loader import CartoonDatasetLoader
from src.dataset.pictures_loader import PicturesDatasetLoader
from src.preprocessing.preprocessor import Preprocessor
from src.dataset.pictures_loader import PicturesDatasetLoader
from src.dataset.cartoon_loader import CartoonDatasetLoader

from torch.utils.data import DataLoader

preprocessor = Preprocessor(size=256)

pictures_dataset = PicturesDatasetLoader(
    train=True,
    transform=preprocessor.picture_preprocessor
)
cartoons_dataset = CartoonDatasetLoader(
    train=True,
    transform=preprocessor.cartoon_preprocessor
)

train_pictures_loader = DataLoader(
    dataset=pictures_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    # drop last incomplete batch
    drop_last=True,
    num_workers=2
)

train_cartoons_loader = DataLoader(
    dataset=cartoons_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    num_workers=2
)

### Pretraining

In [None]:
from src.models.cartoon_gan import CartoonGan

cartoon_gan = CartoonGan()

In [None]:
cartoon_gan.pretrain(
    pictures_loader=train_pictures_loader,
    parameters=pretraining_parameters
)

### Training


cartoon_gan.train(

)