## Train Cartoon GAN

Jupyter notebook version of the ```src.networks.train.py``` file.
This file contains all cells to train the model from scratch.

#### Configure drive (only if needed)

In [1]:

from google.colab import drive
import os
drive.mount("/content/drive")

PROJECT_DIRECTORY = "drive/MyDrive/DeepL"
os.chdir(PROJECT_DIRECTORY)

!ls
!pip install -r requirements.txt

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
data  log_file_test.log  papers  requirements.txt  src	weights


In [2]:
import torch
import torch.optim as optim


cuda = torch.cuda.is_available()
print(cuda)

if cuda:
  print(torch.cuda.get_device_name(0))
  !nvidia-smi

device = "cuda" if cuda else "cpu"

True
Tesla K80
Sat Dec 25 19:19:00 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   73C    P8    33W / 149W |      3MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+------------------------------------------------------------------------

#### Main parameters

Main (pre-)training parameters are defined here.
It is easier to change some parameters in `.ipynb` file instead of `config.py` while using Google colaboratory.

In [3]:
from src.models.parameters import CartoonGanParameters

BATCH_SIZE = 8

pretraining_parameters = CartoonGanParameters(
    epochs=10,
    gen_lr=0.0002,
    disc_lr=0.0002,
    batch_size=BATCH_SIZE,
    conditional_lambda=10,
    gen_beta1=0.5,
    gen_beta2=0.999,
    disc_beta1=0.5,
    disc_beta2=0.999
)

training_parameters = CartoonGanParameters(
    epochs=10,
    gen_lr=0.0002,
    disc_lr=0.0002,
    batch_size=BATCH_SIZE,
    conditional_lambda=10,
    gen_beta1=0.5,
    gen_beta2=0.999,
    disc_beta1=0.5,
    disc_beta2=0.999
)


In [4]:
import logging

logging.basicConfig(filename="log_file_test.log",
  filemode='a',
  format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
  datefmt='%H:%M:%S',
  level=logging.DEBUG
)

### Prepare data

In [5]:
from src.extraction.__main__ import *

# Extract the data if needed

"""create_all_frames_csv()
create_all_images_csv()
create_train_test_frames()
create_train_test_images()"""



'create_all_frames_csv()\ncreate_all_images_csv()\ncreate_train_test_frames()\ncreate_train_test_images()'

In [6]:
from src.dataset.dataset_cartoon import CartoonDataset
from src.dataset.dataset_pictures import PicturesDataset
from src.preprocessing.preprocessor import Preprocessor

from torch.utils.data import DataLoader

preprocessor = Preprocessor(size=256)

pictures_dataset = PicturesDataset(
    train=True,
    transform=preprocessor.picture_preprocessor(),
    size=None # keep all data
)

cartoons_dataset = CartoonDataset(
    train=True,
    transform=preprocessor.cartoon_preprocessor(),
    size=None # keep all data
)

train_pictures_loader = DataLoader(
    dataset=pictures_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    # drop last incomplete batch
    drop_last=True,
    num_workers=2
)

train_cartoons_loader = DataLoader(
    dataset=cartoons_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    num_workers=2
)



### Load model

In [7]:
from src.models.cartoon_gan import CartoonGan

cartoon_gan = CartoonGan(
    nb_resnet_blocks=8,
    nb_channels_picture=3,
    nb_channels_cartoon=3,
    nb_channels_1_h_l_gen=64,
    nb_channels_1_h_l_disc=32
)

In [8]:
pretrained_gen = os.path.join("weights", "pretrained", "pretrained_gen_1.pkl")
pretrained_disc = os.path.join("weights", "pretrained", "pretrained_disc_1.pkl")

cartoon_gan.load_model(pretrained_gen, pretrained_disc)

### Pretrain model

In [None]:

cartoon_gan.pretrain(
    pictures_loader=train_pictures_loader,
    parameters=pretraining_parameters
)

  9%|▉         | 376/4045 [15:25<2:09:51,  2.12s/it]

### Training
