In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as torch_data
from torchsummary import summary
from torchvision import models
from pycocotools.coco import COCO
import torchvision.transforms as transforms
import json
import skimage.io as io
import matplotlib.pyplot as plt
from tqdm import tqdm

import sys, os
sys.path.append(os.path.abspath("../"))

%load_ext autoreload
%autoreload 2

from models.gan.model import Generator, Discriminator
from trainers.gan import GanTrainer
from datasets.coco_dataset import CocoDataset, CocoPairsDataset
import utils.data
import utils.functionnal

## Data loading

In [2]:
data_dir_pattern = "/Volumes/F_LEDOYEN/ms_coco/annotations/{}.json"

In [3]:
data_files = utils.functionnal.map_nested_dicts(data_files, lambda k, v : data_dir_pattern.format(v))

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
    transforms.Lambda(lambda x : x.repeat(3, 1, 1) if x.shape[0] == 1 else x),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

In [5]:
data = {
    "coco_anns_all" : COCO(data_files["coco_anns_all"]),
}

loading annotations into memory...
Done (t=60.37s)
creating index...
index created!


In [9]:
data.update(utils.functionnal.map_nested_dicts(
    data_files,
    func=(
        lambda k,v : 
        utils.data.coco_pairs_dataset(data["coco_anns_all"], v, transform) if k == "pairs" 
        else ( utils.data.coco_dataset(data["coco_anns_all"], v, transform) if k == "singles" else data[k]) 
    )
))

In [12]:
datasets = {k:v for k,v in zip(["train", "test"], utils.data.split_dataset(data["sport"]["singles"], .8, .2))}

params = {'batch_size': 1,
          'shuffle': False,
          'num_workers': 0}
max_epochs = 100

dataloaders = {k: torch_data.DataLoader(v, **params) for k,v in datasets.items()}