In [1]:
import os
import time
import sys
import torch
import datetime
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import save_image
import glob
import numpy as np
import random
import itertools
from PIL import Image
from torch.autograd import Variable
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

sys.path.append('/home/antixk/Anand/')
print(sys.path)


from NeuralBlocks.models.cyclegan import ResNetGenerator
from NeuralBlocks.models.cyclegan import Discriminator

['/home/antixk/Anand/NeuralBlocks/notebooks', '/home/antixk/miniconda3/envs/main/lib/python37.zip', '/home/antixk/miniconda3/envs/main/lib/python3.7', '/home/antixk/miniconda3/envs/main/lib/python3.7/lib-dynload', '', '/home/antixk/miniconda3/envs/main/lib/python3.7/site-packages', '/home/antixk/miniconda3/envs/main/lib/python3.7/site-packages/IPython/extensions', '/home/antixk/.ipython', '/home/antixk/Anand/']


In [2]:
NUM_EPOCH = 200
BATCH_SIZE = 1
IMG_HEIGHT = 256
IMG_WIDTH = 256
NUM_CHANNELS = 3
SAMPLE_INTERVAL = 100
CHECKPOINT_INTERVAL = 10
NUM_RESIDUAL_BLOCKS = 9
LR = 0.0002
NUM_WORKERS = 8
DATA_PATH = "/home/antixk/Anand/NeuralBlocks/data_utils/datasets/monet2photo/"
NORM = 'BN'

In [3]:
class ImageDataset(Dataset):
    def __init__(self, root, transforms_=None, unaligned=False, mode='train'):
        self.transform = transforms.Compose(transforms_)
        self.unaligned = unaligned

        self.files_A = sorted(glob.glob(os.path.join(root, '%s/A' % mode) + '/*.*'))
        self.files_B = sorted(glob.glob(os.path.join(root, '%s/B' % mode) + '/*.*'))

    def __getitem__(self, index):
        item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]))

        if self.unaligned:
            item_B = self.transform(Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)]))
        else:
            item_B = self.transform(Image.open(self.files_B[index % len(self.files_B)]))

        return {'A': item_A, 'B': item_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

In [4]:
# Image transformations
transforms_ = [ transforms.Resize(int(IMG_HEIGHT*1.12), Image.BICUBIC),
                transforms.RandomCrop((IMG_HEIGHT, IMG_WIDTH)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]

# Training data loader
dataloader = DataLoader(ImageDataset(DATA_PATH, transforms_=transforms_, unaligned=True),
                        batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
# Test data loader
val_dataloader = DataLoader(ImageDataset(DATA_PATH, transforms_=transforms_, unaligned=True, mode='test'),
                        batch_size=5, shuffle=True, num_workers=1)


In [5]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        conv_block = [  nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features),
                        nn.ReLU(inplace=True),
                        nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features)  ]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)

class ResNetGenerator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, num_res_blocks=9, norm='BN'):
        super(ResNetGenerator, self).__init__()

        # Initial convolution block
        model = [   nn.ReflectionPad2d(3),
                    nn.Conv2d(in_channels, 64, 7),
                    nn.InstanceNorm2d(64),
                    nn.ReLU(inplace=True) ]

        # Downsampling
        in_features = 64
        out_features = in_features*2
        for _ in range(2):
            model += [  nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features*2

        # Residual blocks
        for _ in range(num_res_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling
        out_features = in_features//2
        for _ in range(2):
            model += [  nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features//2

        # Output layer
        model += [  nn.ReflectionPad2d(3),
                    nn.Conv2d(64, out_channels, 7),
                    nn.Tanh() ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

In [6]:

class Discriminator(nn.Module):
    def __init__(self, in_channels=3, norm='BN'):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, normalize=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels, 64, normalize=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1)
        )

    def forward(self, img):
        return self.model(img)

In [5]:
G_AB = ResNetGenerator(in_channels=3, num_res_blocks=NUM_RESIDUAL_BLOCKS, out_channels=3, norm=NORM).cuda()
G_BA = ResNetGenerator(in_channels=3, num_res_blocks=NUM_RESIDUAL_BLOCKS, out_channels=3, norm=NORM).cuda()
D_A = Discriminator(in_channels=3, norm=NORM).cuda()
D_B = Discriminator(in_channels=3, norm=NORM).cuda()

In [6]:
# Optimizers
optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()),lr=LR, betas=(0.5,0.999))
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=LR, betas=(0.5,0.999))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=LR, betas=(0.5,0.999))

In [7]:
# Loss weights
lambda_cyc = 10
lambda_id = 0.5 * lambda_cyc

# Losses
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

In [8]:
# Create sample and checkpoint directories
os.makedirs('images/', exist_ok=True)
os.makedirs('saved_models/', exist_ok=True)

In [8]:
class ReplayBuffer():
    def __init__(self, max_size=50):
        assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.'
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0,1) > 0.5:
                    i = random.randint(0, self.max_size-1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))

In [10]:
def sample_images(batches_done):
    """Saves a generated sample from the test set"""
    imgs = next(iter(val_dataloader))
    real_A = Variable(imgs['A'].type(Tensor))
    fake_B = G_AB(real_A)
    real_B = Variable(imgs['B'].type(Tensor))
    fake_A = G_BA(real_B)
    img_sample = torch.cat((real_A.data, fake_B.data,
                            real_B.data, fake_A.data), 0)
    save_image(img_sample, 'images/%s.png' % (batches_done), nrow=5, normalize=True)

In [9]:
Tensor = torch.cuda.FloatTensor

# Buffers of previously generated samples
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

# Calculate output of image discriminator (PatchGAN)
patch = (1, IMG_HEIGHT // 2**4, IMG_WIDTH // 2**4)

prev_time = time.time()
for epoch in range(NUM_EPOCH):
    for i, batch in enumerate(dataloader):

        # Set model input
        real_A = Variable(batch['A'].type(Tensor))
        real_B = Variable(batch['B'].type(Tensor))

        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((real_A.size(0), *patch))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((real_A.size(0), *patch))), requires_grad=False)

        # ------------------
        #  Train Generators
        # ------------------

        optimizer_G.zero_grad()
        
#         print(real_A.size())
        # Identity loss
        loss_id_A = criterion_identity(G_BA(real_A), real_A)
        loss_id_B = criterion_identity(G_AB(real_B), real_B)

        loss_identity = (loss_id_A + loss_id_B) / 2

        # GAN loss
        fake_B = G_AB(real_A)
        loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
        fake_A = G_BA(real_B)
        loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)

        loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

        # Cycle loss
        recov_A = G_BA(fake_B)
        loss_cycle_A = criterion_cycle(recov_A, real_A)
        recov_B = G_AB(fake_A)
        loss_cycle_B = criterion_cycle(recov_B, real_B)

        loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

        # Total loss
        loss_G =    loss_GAN + \
                    lambda_cyc * loss_cycle + \
                    lambda_id * loss_identity

        loss_G.backward()
        optimizer_G.step()

        # -----------------------
        #  Train Discriminator A
        # -----------------------

        optimizer_D_A.zero_grad()

        # Real loss
        loss_real = criterion_GAN(D_A(real_A), valid)
        # Fake loss (on batch of previously generated samples)
        fake_A_ = fake_A_buffer.push_and_pop(fake_A)
        loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
        # Total loss
        loss_D_A = (loss_real + loss_fake) / 2

        loss_D_A.backward()
        optimizer_D_A.step()

        # -----------------------
        #  Train Discriminator B
        # -----------------------

        optimizer_D_B.zero_grad()

        # Real loss
        loss_real = criterion_GAN(D_B(real_B), valid)
        # Fake loss (on batch of previously generated samples)
        fake_B_ = fake_B_buffer.push_and_pop(fake_B)
        loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
        # Total loss
        loss_D_B = (loss_real + loss_fake) / 2

        loss_D_B.backward()
        optimizer_D_B.step()

        loss_D = (loss_D_A + loss_D_B) / 2

        # --------------
        #  Log Progress
        # --------------

        # Determine approximate time left
        batches_done = epoch * len(dataloader) + i
        batches_left = NUM_EPOCH * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()

        # Print log
        print("\r[Epoch %d/%d] [Batch %d/%d] [D loss: %.3f] [G loss: %.3f, adv: %.3f, cycle: %.3f, identity: %f.3] ETA: %s" %
                                                        (epoch, NUM_EPOCH,
                                                        i, len(dataloader),
                                                        loss_D.item(), loss_G.item(),
                                                        loss_GAN.item(), loss_cycle.item(),
                                                        loss_identity.item(), time_left))

        # If at sample interval save image
        if batches_done % SAMPLE_INTERVAL == 0:
            sample_images(batches_done)


    # Update learning rates
#     lr_scheduler_G.step()
#     lr_scheduler_D_A.step()
#     lr_scheduler_D_B.step()

    if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
        # Save model checkpoints
        torch.save(G_AB.state_dict(), 'saved_models/G_AB_%d.pth' % (epoch))
        torch.save(G_BA.state_dict(), 'saved_models/G_BA_%d.pth' % (epoch))
        torch.save(D_A.state_dict(), 'saved_models/D_A_%d.pth' % ( epoch))
        torch.save(D_B.state_dict(), 'saved_models/D_B_%d.pth' % ( epoch))


torch.Size([1, 3, 256, 256])
[Epoch 0/200] [Batch 0/6287] [D loss: 0.560] [G loss: 7.478, adv: 1.011, cycle: 0.430, identity: 0.433184.3] ETA: 15 days, 11:47:50.340252


NameError: name 'sample_images' is not defined