In [1]:
import numpy as np
import os, math, sys
import glob
import random

import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, ConcatDataset
from torchvision.utils import save_image, make_grid

import matplotlib.pyplot as plt

from PIL import Image
from tqdm import tqdm_notebook as tqdm

random.seed(42)
import warnings
warnings.filterwarnings("ignore")

In [2]:
# number of epochs of training
n_epochs = 200
# size of the batches
batch_size = 64
# name of the dataset
dataset_name = "val_256/"
# adam: learning rate
lr = 0.0002
# adam: decay of first order momentum of gradient
b1 = 0.5
# adam: decay of first order momentum of gradient
b2 = 0.999
# number of cpu threads to use during batch generation
n_cpu = 8
# size of each image dimension
img_size = 128
# size of random mask
mask_size = 16
# number of image channels
channels = 3

# Calculate output dims of image discriminator (PatchGAN)
patch_h, patch_w = int(mask_size / 2 ** 3), int(mask_size / 2 ** 3)
patch = (1, patch_h, patch_w)

### Define Dataset Class

In [3]:
class ImageDataset(Dataset):
    def __init__(self, root, transforms_=None, img_size=128, mask_size=16, mode="train"):
        self.transform = transforms.Compose(transforms_)
        self.img_size = img_size
        self.mask_size = mask_size
        self.mode = mode
        self.files = sorted(glob.glob("%s/*.jpg" % root))
        # self.files = self.files[:-4000] if mode == "train" else self.files[-4000:]

    def apply_random_mask(self, img):
        """Randomly masks image"""
        y1, x1 = np.random.randint(0, self.img_size - self.mask_size, 2)
        y2, x2 = y1 + self.mask_size, x1 + self.mask_size
        masked_part = img[:, y1:y2, x1:x2]
        masked_img = img.clone()
        masked_img[:, y1:y2, x1:x2] = 1

        return masked_img, masked_part

    def __getitem__(self, index):

        img = Image.open(self.files[index % len(self.files)])
        img = self.transform(img)
        if int(transforms.ToTensor()(img).shape[0]) == 1:
            img = transforms.Grayscale(num_output_channels=3)(img)
        
        img = transforms.Compose([transforms.ToTensor(),
                                  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])(img)

        masked_img, aux = self.apply_random_mask(img)

        return img, masked_img, aux

    def __len__(self):
        return len(self.files)

In [4]:
transform = [
     transforms.Resize((128, 128)),
     transforms.ColorJitter(hue=.50, saturation=.50),
     transforms.RandomHorizontalFlip(p=0.5),
     transforms.RandomVerticalFlip(p=0.5),
     ]
os.makedirs("aug_data", exist_ok=True)
augmentation_data = ImageDataset(dataset_name, transforms_=transform)

In [5]:
transform = [
     transforms.Resize((128, 128)),
     transforms.ColorJitter(hue=.20, saturation=.20),
     transforms.RandomHorizontalFlip(p=0.5),
     transforms.RandomVerticalFlip(p=0.5),
     ]
os.makedirs("aug_data", exist_ok=True)
augmentation_data_1 = ImageDataset(dataset_name, transforms_=transform)

### Get Train/Test Dataloaders

In [6]:
transforms_ = [
    transforms.Resize((img_size, img_size), Image.BICUBIC),
    #transforms.ToTensor(),
    #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

dataloader = DataLoader(
    ConcatDataset([ImageDataset(dataset_name, transforms_=transforms_),augmentation_data_1, augmentation_data]),
    batch_size=batch_size,
    shuffle=True,
    num_workers=n_cpu,
)

<h3><center>Model Architecture</center></h3>
<img src="https://miro.medium.com/max/700/1*fJpamgw0yBZZRNEuex07hw.png" width="1000" height="1000"/>
<h4></h4>
<h4><center>Image Source:  <a href="https://arxiv.org/abs/1609.04802">Context Encoders: Feature Learning by Inpainting [Deepak Pathak et al.]</a></center></h4>

In [7]:
class Generator(nn.Module):
    def __init__(self, channels=3):
        super(Generator, self).__init__()

        def encoder(in_feat, out_feat, normalize=True):
            layers = [nn.Conv2d(in_feat, out_feat, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2))
            return layers

        def decoder(in_feat, out_feat, normalize=True):
            layers = [nn.ConvTranspose2d(in_feat, out_feat, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_feat, 0.8))
            layers.append(nn.ReLU())
            return layers

        self.model = nn.Sequential(
            *encoder(channels, 64, normalize=False),
            *encoder(64, 64),
            *encoder(64, 128),
            *encoder(128, 256),
            *encoder(256, 512),
            nn.Conv2d(512, 4000, 1),
            *decoder(4000, 512),
            *decoder(512, 256),
            *decoder(256, 128),
            *decoder(128, 64),
            nn.Conv2d(64, 32, 4, stride=2, padding=1),
            nn.BatchNorm2d(32, 0.8),
            nn.ReLU(),
            nn.Conv2d(32, 16, 4, stride=2, padding=1),
            nn.BatchNorm2d(16, 0.8),
            nn.Conv2d(16, 3, 3, 1, 1),
            nn.Tanh()
        )
    def forward(self, x):
        return self.model(x)
    

class Discriminator(nn.Module):
    def __init__(self, channels=3):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, stride, normalize):
            layers = [nn.Conv2d(in_filters, out_filters, 3, stride, 1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(*discriminator_block(channels, 16, 2, False),
                                   *discriminator_block(16, 128, 2, True),
                                   *discriminator_block(128, 256, 2, True),
                                   *discriminator_block(256, 512, 1, True),
                                   nn.Conv2d(512, 1, 3, 1, 1))

    def forward(self, img):
        return self.model(img)

In [8]:
len(dataloader)

1711

### Train Context-Encoder GAN

In [9]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)
    
# Loss function
adversarial_loss = torch.nn.MSELoss()
pixelwise_loss = torch.nn.L1Loss()

# Initialize generator and discriminator
generator = Generator(channels=channels)
discriminator = Discriminator(channels=channels)

# Load pretrained models
#generator.load_state_dict(torch.load("generator.pth"))
#discriminator.load_state_dict(torch.load("discriminator.pth"))
# print("Using pre-trained Context-Encoder GAN model!")

generator.cuda()
discriminator.cuda()
adversarial_loss.cuda()
pixelwise_loss.cuda()

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

Tensor = torch.cuda.FloatTensor

In [10]:
!nvidia-smi

Tue Jul 13 07:48:06 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.42.01    Driver Version: 470.42.01    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  On   | 00000000:04:00.0  On |                  N/A |
| 31%   54C    P2    31W / 200W |   4816MiB /  6073MiB |     19%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [11]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('runs/context_encoder')
os.makedirs("saved_models", exist_ok=True)


In [None]:

for epoch in range(n_epochs):
    
    ### Training ###
    gen_adv_loss, gen_pixel_loss, disc_loss, gen_loss = 0, 0, 0, 0
    tqdm_bar = tqdm(dataloader, desc=f'Training Epoch {epoch} ', total=int(len(dataloader)))
    for i, (imgs, masked_imgs, masked_parts) in enumerate(tqdm_bar):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.shape[0], *patch).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.shape[0], *patch).fill_(0.0), requires_grad=False)

        # Configure input
        imgs = Variable(imgs.type(Tensor))
        masked_imgs = Variable(masked_imgs.type(Tensor))
        masked_parts = Variable(masked_parts.type(Tensor))

        ## Train Generator ##
        optimizer_G.zero_grad()

        # Generate a batch of images
        gen_parts = generator(masked_imgs)


        # Adversarial and pixelwise loss
        g_adv = adversarial_loss(discriminator(gen_parts), valid)
        g_pixel = pixelwise_loss(gen_parts, masked_parts)
        # Total loss
        g_loss = 0.001 * g_adv + 0.999 * g_pixel

        g_loss.backward()
        optimizer_G.step()

        ## Train Discriminator ##
        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(masked_parts), valid)
        fake_loss = adversarial_loss(discriminator(gen_parts.detach()), fake)
        d_loss = 0.5 * (real_loss + fake_loss)

        d_loss.backward()
        optimizer_D.step()
        

        
        gen_adv_loss += g_adv.item()
        gen_pixel_loss += g_pixel.item()
        disc_loss += d_loss.item()
        gen_loss += g_loss.item()
        tqdm_bar.set_postfix(gen_adv_loss=gen_adv_loss/(i+1), gen_pixel_loss=gen_pixel_loss/(i+1), disc_loss=disc_loss/(i+1))
        if i % 50 == 49:
            writer.add_scalar('generator loss',
                            gen_loss/(i+1),
                            epoch * len(dataloader) + i)
        
            writer.add_scalar('descremenator loss',
                            disc_loss/(i+1),
                            epoch * len(dataloader) + i)
        
         
    torch.save(generator.state_dict(), "saved_models/generator.pth")
    torch.save(discriminator.state_dict(), "saved_models/discriminator.pth")
writer.close()  

Training Epoch 0 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 1 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 2 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 3 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 4 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 5 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 6 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 7 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 8 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 9 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 10 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 11 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 12 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 13 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 14 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 15 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 16 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 17 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 18 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 19 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 20 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 21 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 22 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 23 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 24 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 25 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 26 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 27 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 28 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 29 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 30 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 31 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 32 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 33 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 34 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 35 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 36 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 37 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 38 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 39 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 40 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 41 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 42 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 43 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 44 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 45 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 46 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 47 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 48 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 49 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 50 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 51 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 52 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 53 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 54 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 55 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 56 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 57 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 58 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 59 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 60 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 61 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 62 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 63 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 64 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 65 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 66 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 67 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 68 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 69 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 70 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 71 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 72 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 73 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 74 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 75 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 76 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 77 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 78 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 79 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 80 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 81 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 82 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 83 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 84 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 85 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 86 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 87 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 88 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 89 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 90 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 91 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 92 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 93 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 94 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 95 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 96 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 97 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 98 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 99 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 100 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 101 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 102 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 103 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 104 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 105 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 106 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 107 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 108 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 109 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 110 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 111 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 112 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 113 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 114 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 115 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 116 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 117 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 118 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 119 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 120 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 121 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 122 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 123 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 124 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 125 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 126 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 127 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 128 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 129 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 130 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 131 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 132 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 133 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 134 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 135 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 136 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 137 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 138 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 139 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 140 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 141 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 142 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 143 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 144 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 145 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 146 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 147 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 148 :   0%|          | 0/1711 [00:00<?, ?it/s]

Training Epoch 149 :   0%|          | 0/1711 [00:00<?, ?it/s]