# GAN

In [95]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as torch_data
from torchsummary import summary
from torchvision import models
from pycocotools.coco import COCO
import torchvision.transforms as transforms
import json
import skimage.io as io
import matplotlib.pyplot as plt
from tqdm import tqdm

import sys, os
sys.path.append(os.path.abspath("../"))

%load_ext autoreload
%autoreload 2

from models.res_u_net.decoder import UnetDecoder
from models.res_u_net.std_blocks import UpBlock
from models.gan.model import Generator, Discriminator
from trainers.gan import GanTrainer
from datasets.coco_dataset import CocoDataset, CocoPairsDataset
import utils.data
import utils.functionnal

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Dataloader

In [9]:
utils.data.data_files

{'coco_anns_all': 'instances_all2017',
 'all': {'singles': 'imgs', 'pairs': 'pairs'},
 'sport': {'singles': 'imgs_sport', 'pairs': 'pairs_sport'}}

In [10]:
data_dir_pattern = "/Volumes/F_LEDOYEN/ms_coco/annotations/{}.json"
coco = COCO(data_dir_pattern.format(utils.data.data_files["coco_anns_all"]))

loading annotations into memory...
Done (t=91.84s)
creating index...
index created!


In [141]:
dataset = utils.data.coco_dataset(
    coco,
    data_dir_pattern.format(utils.data.data_files["sport"]["singles"]),
    utils.data.transform
)

params = {'batch_size': 64,
          'shuffle': True,
          'num_workers': 0}

dataloader = torch_data.DataLoader(dataset, **params)

## Model

### Generator

In [124]:
generator = UnetDecoder(
    channels, 
    UpBlock.UPSAMPLING_BILINEAR,
    (128, 128),
    copy_n_crop=False
)
discriminator = Discriminator(3, 64)

In [125]:
generator

UnetDecoder(
  (up_blocks): ModuleDict(
    (1024;512): UpBlock(
      (upsample): Sequential(
        (0): Upsample(scale_factor=2.0, mode=bilinear)
        (1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1))
      )
      (conv_blocks): Sequential(
        (0): ConvBlock(
          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU()
        )
        (1): ConvBlock(
          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU()
        )
      )
    )
    (512;256): UpBlock(
      (upsample): Sequential(
        (0): Upsample(scale_factor=2.0, mode=bilinear)
        (1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
      )
      (conv_blocks): Sequential(
        (0): ConvBlock(
      

### Discriminator

In [132]:
discriminator = Discriminator(3, 64)

In [138]:
discriminator

Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 640, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (12): BatchNorm2d(640, eps=1e-05, 

## Trainer

In [None]:
lr = 0.0002

In [143]:
params = {
    "generator" : generator,
    "discriminator" : discriminator,
    "n_epochs" : 1,
    "dataloader" : dataloader,
    "criterion" : nn.BCELoss(),
    "discriminator_optimizer" : optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999)),
    "generator_optimizer" : optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
}

In [150]:
trainer = GanTrainer(
    **params
)

In [151]:
trainer.input_latent_space = torch.rand(64, 1024, 2, 2)

In [154]:
trainer.train()

  0%|          | 0/1 [00:00<?, ?it/s]
0it [00:00, ?it/s][A

torch.Size([64]) torch.Size([64])



1it [01:36, 96.54s/it][A
  0%|          | 0/1 [01:36<?, ?it/s]


KeyboardInterrupt: 