In [1]:
import os
import time
import sys
import torch
import torchvision
import datetime
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import save_image
import glob
import numpy as np
import random
import itertools
from PIL import Image
from torch.autograd import Variable
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
from tqdm import tqdm_notebook
PATH = '/home/antixk/Anand/' #'/home/robot/Anand/' #'/home/antixk/Anand/' #
sys.path.append(PATH)


from NeuralBlocks.models.cyclegan import ResNetGenerator
from NeuralBlocks.models.cyclegan import Discriminator
from NeuralBlocks.trainers import GANTrainer
from NeuralBlocks.trainers.ganloss import GANLoss

In [2]:
NUM_EPOCH = 200
BATCH_SIZE = 2
IMG_HEIGHT = 256
IMG_WIDTH = 256
NUM_CHANNELS = 3
SAMPLE_INTERVAL = 100
CHECKPOINT_INTERVAL = 10
NUM_RESIDUAL_BLOCKS = 2
LR = 0.0002
NUM_WORKERS = 8
IMG_SIZE = 32
DATA_PATH = PATH+"NeuralBlocks/data_utils/datasets/MNIST/"
SAVE_PATH = PATH+"NeuralBlocks/experiments/MNIST/"

os.makedirs(SAVE_PATH, exist_ok=True)

NORM = 'MSN'

cudnn.benchmark = True

IMG_DIR = 'images_{}/'.format(NORM)
SAVE_DIR = 'saved_models_{}/'.format(NORM)

In [3]:
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(IMG_SIZE), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=64,
    shuffle=True,
)

In [4]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.init_size = IMG_SIZE // 4
        self.l1 = nn.Sequential(nn.Linear(128, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 1, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(1, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = IMG_SIZE // 2 ** 4
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)

        return validity
G = Generator()
D =  Discriminator()
    

In [5]:
optimizer_G = torch.optim.Adam(G.parameters(), lr=LR, betas=(0.5,0.999))
optimizer_D = torch.optim.Adam(D.parameters(), lr=LR, betas=(0.5,0.999))

In [6]:
class WGAN_loss(GANLoss):
    def __init__(self):
        super(WGAN_loss, self).__init__()
        self.adversarial_loss = torch.nn.BCELoss()
        
    def compute_loss(self, images, G_x, D_x, D_G_x):    
        valid = torch.ones(images.shape[0], 1).cuda()
        fake = torch.zeros(images.shape[0], 1).cuda()      


        g_loss = self.adversarial_loss(D_G_x, valid)

        real_loss = self.adversarial_loss(D_x, valid)
        fake_loss = self.adversarial_loss(D_G_x.detach(), fake)
        d_loss = (real_loss + fake_loss) / 2

        return g_loss, d_loss    

In [7]:
loss = WGAN_loss()

trainer = GANTrainer(G_model=G, 
                     D_model=D, 
                     dataloader=dataloader,
                     gan_loss = loss,
                     G_optimizer=optimizer_G, 
                     D_optimizer=optimizer_D,
                     latent_dim = 128)

TypeError: __init__() missing 1 required positional argument: 'losses'

In [None]:
trainer.run(num_epochs=2)