## Train Cartoon GAN

Jupyter notebook version of the ```src.networks.train.py``` file.
This file contains all cells to train the model from scratch.

#### Configure drive (only if needed)

In [1]:

from google.colab import drive
import os
drive.mount("/content/drive")
!cd ..
!ls
PROJECT_DIRECTORY = "drive/My Drive/DeepL/"
os.chdir(PROJECT_DIRECTORY)

!ls


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
drive  sample_data
data  log_file_test.log  papers  requirements.txt  src	weights


In [2]:
!ls
!pip install -r requirements.txt

data  log_file_test.log  papers  requirements.txt  src	weights


In [3]:
import torch
import torch.optim as optim


cuda = torch.cuda.is_available()
print(cuda)

if cuda:
  print(torch.cuda.get_device_name(0))
  !nvidia-smi

device = "cuda" if cuda else "cpu"

True
Tesla P100-PCIE-16GB
Mon Dec 27 17:30:30 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    26W / 250W |      2MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-------------------------------------------------------------

In [4]:
import logging

logging.basicConfig(filename="log_file_test.log",
  filemode='a',
  format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
  datefmt='%H:%M:%S',
  level=logging.DEBUG
)

#### Main parameters

In [1]:
from src.models.parameters import CartoonGanParameters

BATCH_SIZE = 8

training_parameters = CartoonGanParameters(
    epochs=10,
    gen_lr=0.0002,
    disc_lr=0.0002,
    input_size=256,
    batch_size=BATCH_SIZE,
    conditional_lambda=10,
    gen_beta1=0.5,
    gen_beta2=0.999,
    disc_beta1=0.5,
    disc_beta2=0.999
)

ModuleNotFoundError: No module named 'src'

### Prepare data

In [6]:
from src.dataset.dataset_cartoon import CartoonDataset
from src.dataset.dataset_pictures import PicturesDataset
from src.preprocessing.filters import Filter
from src.preprocessing.transformations import Transform

from torch.utils.data import DataLoader

filter_data = Filter(new_size=(256, 256))
transform = Transform(new_size=(256, 256), crop_mode="center")

cartoons_dataset = CartoonDataset(
    train=True,
    filter_data = filter_data.cartoon_filter,
    transform = transform.cartoon_transform
)

pictures_dataset = PicturesDataset(
    train=True,
    filter_data = filter_data.picture_filter,
    transform = transform.picture_transform
)


train_pictures_loader = DataLoader(
    dataset=pictures_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    # drop last incomplete batch
    drop_last=True,
    num_workers=2
)

train_cartoons_loader = DataLoader(
    dataset=cartoons_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    num_workers=2
)


torch.Size([3, 256, 256])
torch.Size([3, 256, 256])


### Load model

In [7]:
from src.models.cartoon_gan import CartoonGan
from src.models.utils.parameters import Architecture

cartoon_gan = CartoonGan(
    architecture=Architecture.FIXED
)

In [8]:
pretrained_gen = os.path.join("weights", "pretrained", "pretrained_gen_11.pkl")
pretrained_disc = os.path.join("weights", "pretrained", "pretrained_disc_11.pkl")

cartoon_gan.load_model(pretrained_gen, pretrained_disc)

### Training the model

In [9]:

cartoon_gan.train(
    picture_loader=train_pictures_loader,
    dataset_cartoon=train_cartoons_loader,
    parameters=training_parameters
)

735
735


100%|██████████| 735/735 [16:02<00:00,  1.31s/it]
 37%|███▋      | 273/735 [04:27<07:33,  1.02it/s]
Traceback (most recent call last):
  File "/usr/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe


KeyboardInterrupt: ignored

In [10]:
from src import config

cartoon_gan.save_model(
    os.path.join(config.WEIGHTS_FOLDER, "trained_gen_0.pt"),
    os.path.join(config.WEIGHTS_FOLDER, "trained_disc_0.pt")
)