In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [34]:
import torch
from torch import nn, optim
from torchvision import transforms as T
from torchvision.utils import make_grid, save_image
from torch.utils.data import Dataset

from gan import models, build_cycle_gan_trainer, kl_cycle_gan_loss_step
from utils.benchmark import train
from utils.display import display_images
from utils.checkpoints import load_checkpoint
from utils.datasets import DomainDataset
from __datasets__ import VINSDataset

In [4]:
config = models.CycleGanConfig(
    "../../pytorch/datasets/vins",
    "CycleGan-VINS",

    batch_size=8,
    norm=nn.InstanceNorm2d,
    writer=True,
    lr=2e-4,
    p=0,

    inp_channels=3,
    hidden_channels=64,
    out_channels=3,
    downsample=3,
    residuals=7,
    n=0,
    blocks=(64, 128, 256, 512),
)

In [39]:
class VINSGanDataset(Dataset):
    def __init__(self, dataset: "VINSDataset"):
        self.dataset = dataset

    def __getitem__(self, item):
        out = self.dataset[item]
        print(out.keys())
        image1 = out["image"]
        bboxes = out["bboxes"]
        bbox = bboxes[torch.randint(len(bboxes), (1,))]
        image2 = image1.clone()
        image2[:, bbox[1]:bbox[3], bbox[0]:bbox[2]] = -1
        return {
            "image1": image1,
            "image2": image2,
        }

    def __len__(self):
        return len(self.dataset)


ds = VINSGanDataset(VINSDataset(DIR=config.dataset_path, SET="Android", download=True, sub_sample=1,
                                image_transform=T.Compose([
                                    T.Resize((64, 64)),
                                    T.ToTensor(),
                                    T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                    lambda x: x.to(config.device),
                                ])))
len(ds)

740

In [40]:
generatorA, generatorB, discriminatorA, discriminatorB = models.build_CycleGan(config)
optimizerG = optim.Adam(list(generatorA.parameters()) + list(generatorB.parameters()), lr=config.lr, betas=config.betas)
optimizerD = optim.Adam(list(discriminatorA.parameters()) + list(discriminatorB.parameters()), lr=config.lr,
                        betas=config.betas)

In [46]:
def data_extractor(DATA):
    return DATA["image1"], DATA["image2"]


fixed_inp = ds[0:9]["image1"], ds[0:9]["image2"]
trainer = build_cycle_gan_trainer(
    generatorA, generatorB, discriminatorA, discriminatorB,
    optimizerG, optimizerD,
    kl_cycle_gan_loss_step,
    data_extractor,
    writer=config.writer, writer_period=100, fixed_inp=fixed_inp,
    save_path=None, save_period=500,
)

TypeError: expected Tensor as element 0 in argument 0, but got list