<a href="https://colab.research.google.com/github/Eereenah/deep-learning/blob/master/Deoldify_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# DS-330 Final Project: Deoldify

The conversion of a colored image into gray-scale is very simple. However, the inverse operation is not so trivial, hence, automatic image colorization has been a popular area of research in Deep Learning for quite a while. As the information in the black-and-white domain is relatively limited, the addition of color aspect can provide the image with new semantic meanings. 

The following notebook features the modified implementation of pix2pix network suitable for the colorization task.
![example](https://i.imgur.com/4oB9X0c.jpg)

## Imports:

In [0]:
import scipy
import numpy as np
import cv2
import os
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
# from viz import updatable_display2
import seaborn as sns

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torch import nn, optim
from torchvision import transforms, datasets
import torchvision.utils as vutils
from skimage import color

# use_cuda = True
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Dataset

Because of limited computational resources, we will be using a subset of COCO dataset (the `test2017` subset, which features 40K images for training and `val2017`, 5K images for testing)

In [0]:
! wget images.cocodataset.org/zips/val2017.zip

In [0]:
! wget images.cocodataset.org/zips/test2017.zip

In [0]:
! unzip -qq *.zip

In [0]:
! rm *.zip

Images for training:

In [0]:
! find 'test2017/' -maxdepth 1 -type f -printf "." | wc -c

Images for testing:

In [0]:
! find 'val2017/' -maxdepth 1 -type f -printf "." | wc -c

In [0]:
from google.colab import drive
drive.mount('/content/drive')

## Pre-Processing

Helper fucntions for reading and resizing the images:



In [0]:
import scipy
import numpy as np
import cv2
import os
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

def getPaths(root):
    for file in os.listdir(root):
        if 'jpg' in file:
            yield(os.path.join(root,file))

def readimg(l):
    im = cv2.imread(l)
    return cv2.cvtColor(im, cv2.COLOR_BGR2RGB)

def resize(img, size):
    img = np.array(img, dtype=np.uint8)
    if len(img.shape) == 2:
      img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
    return cv2.resize(img, (size, size)) 

def black_and_white(img):
    return cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)

In [0]:
root = 'test2017/'

Creating a custom data-loader, to create bw-color image pairs:

In [0]:
class Load_Dataset(Dataset):
    """Create Dataloaders"""
    def __init__(self, root_dir, size, transform = None):

        self.paths = list(getPaths(root_dir))
        self.size = size
        self.img_transform = transforms.Compose([
                   transforms.ToTensor(),
                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
               ])
        self.bw_transform = transforms.Compose([
                   transforms.ToTensor(),
               ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = resize(readimg(self.paths[idx]), self.size)
        return self.img_transform(img), self.bw_transform(black_and_white(img)[:,:,None])

In [0]:
dataset = Load_Dataset(root,128)
dataset[0][0].shape,dataset[0][1].shape

In [0]:
dataset[0][0].min(), dataset[0][0].max()

## Model

The pix2pix model uses

* a U-Net model as a generator $G$. This model takes an image $y$ as input and produces another image $x$ 
* a convolutional descriminator model $D$. This model takes both $x$ and $y$ as input and tries to guess if $x \sim p_x$ or $x \sim p_G$

The model is trained using the GAN loss function and an additional $L_1$ loss on the output of the generative model $G$. Both losses are balanced using an additional hyperparameter $\lambda$ (set to 100 in the original paper). 
Additionally, for $L_{GAN}$, we use the gradient penalty to ensure the Lipschitz property (WGAN-GP). The final objective function is defined as:

$$ G^*  = \arg\min_G\max_D \mathcal{L}_{cGAN}(G,D)_{WP} + \lambda \mathcal{L}_{L1}(G) $$ where $\mathcal{L}_{cGAN}(G,D)_{WP}$ is defined as the sum of conditional GAN loss and WGAN Gradient Penalty


As for the network architecture, we adapt that of pix2pix. In particular, the generator borrows from the U-Net architecture with skip-connections, which allow for sharing of low-level information on the input and output levels such as the location of edges, which provide crucial information for the colorization task. We use snippets from the Pytorch implementation of the generator/discriminator in our project:

In [0]:
class gated_resnet(nn.Module):
    """
    Gated Residual Block
    """
    def __init__(self, num_filters, kernel_size, padding, nonlinearity=nn.ReLU, dropout=0.2, dilation=1,batchNormObject=nn.BatchNorm2d):
        super(gated_resnet, self).__init__()
        self.gated = True
        num_hidden_filters =2 * num_filters if gated else num_filters
        self.conv_input = nn.Conv2d(num_filters, num_hidden_filters, kernel_size=kernel_size,stride=1,padding=padding,dilation=dilation )
        self.dropout = nn.Dropout2d(dropout)
        self.nonlinearity = nonlinearity()
        self.batch_norm1 = batchNormObject(num_hidden_filters)
        self.conv_out = nn.Conv2d(num_hidden_filters, num_hidden_filters, kernel_size=kernel_size,stride=1,padding=padding,dilation=dilation )
        self.batch_norm2 = batchNormObject(num_filters)

    def forward(self, og_x):
        x = self.conv_input(og_x)
        x = self.batch_norm1(x)
        x = self.nonlinearity(x)
        x = self.dropout(x)
        x = self.conv_out(x)
        if self.gated:
            a, b = torch.chunk(x, 2, dim=1)
            c3 = a * F.sigmoid(b)
        else:
            c3 = x
        out = og_x + c3
        out = self.batch_norm2(out)
        return out
    
class ResidualBlock(nn.Module):
    """
    Residual Block
    """
    def __init__(self, num_filters, kernel_size, padding, nonlinearity=nn.ReLU, dropout=0.2, dilation=1,batchNormObject=nn.BatchNorm2d):
        super(ResidualBlock, self).__init__()
        num_hidden_filters = num_filters
        self.conv1 = nn.Conv2d(num_filters, num_hidden_filters, kernel_size=kernel_size,stride=1,padding=padding,dilation=dilation )
        self.dropout = nn.Dropout2d(dropout)
        self.nonlinearity = nonlinearity(inplace=False)
        self.batch_norm1 = batchNormObject(num_hidden_filters)
        self.conv2 = nn.Conv2d(num_hidden_filters, num_hidden_filters, kernel_size=kernel_size,stride=1,padding=padding,dilation=dilation )
        self.batch_norm2 = batchNormObject(num_filters)

    def forward(self, og_x):
        x = og_x
        x = self.dropout(x)
        x = self.conv1(og_x)
        x = self.batch_norm1(x)
        x = self.nonlinearity(x)
        x = self.conv2(x)
        out = og_x + x
        out = self.batch_norm2(out)
        out = self.nonlinearity(out)
        return out

In [0]:
class ConvolutionalEncoder(nn.Module):
    """
    Convolutional Encoder providing skip connections
    """
    def __init__(self,n_features_input,num_hidden_features,kernel_size,padding,n_resblocks,dropout_min=0,dropout_max=0.2, blockObject=ResidualBlock,batchNormObject=nn.BatchNorm2d):
        """
        n_features_input (int): number of intput features
        num_hidden_features (list(int)): number of features for each stage
        kernel_size (int): convolution kernel size
        padding (int): convolution padding
        n_resblocks (int): number of residual blocks at each stage
        dropout (float): dropout probability
        blockObject (nn.Module): Residual block to use. Default is ResidualBlock
        batchNormObject (nn.Module): normalization layer. Default is nn.BatchNorm2d
        """
        super(ConvolutionalEncoder,self).__init__()
        self.n_features_input = n_features_input
        self.num_hidden_features = num_hidden_features
        self.stages = nn.ModuleList()
        dropout = iter([(1-t)*dropout_min + t*dropout_max   for t in np.linspace(0,1,(len(num_hidden_features)))])
        dropout = iter(dropout)
        # input convolution block
        block = [nn.Conv2d(n_features_input, num_hidden_features[0], kernel_size=kernel_size,stride=1, padding=padding)]
        for _ in range(n_resblocks):
            p = next(iter(dropout))
            block += [blockObject(num_hidden_features[0], kernel_size, padding, dropout=p,batchNormObject=batchNormObject)]
        self.stages.append(nn.Sequential(*block))
        # layers
        for features_in,features_out in [num_hidden_features[i:i+2] for i in range(0,len(num_hidden_features), 1)][:-1]:
            # downsampling
            block = [nn.MaxPool2d(2),nn.Conv2d(features_in, features_out, kernel_size=1,padding=0 ),batchNormObject(features_out),nn.ReLU()]
            #block = [nn.Conv2d(features_in, features_out, kernel_size=kernel_size,stride=2,padding=padding ),nn.BatchNorm2d(features_out),nn.ReLU()]
            # residual blocks
            p = next(iter(dropout))
            for _ in range(n_resblocks):
                block += [blockObject(features_out, kernel_size, padding, dropout=p,batchNormObject=batchNormObject)]
            self.stages.append(nn.Sequential(*block)) 
            
    def forward(self,x):
        skips = []
        for stage in self.stages:
            x = stage(x)
            skips.append(x)
        return x,skips
    def getInputShape(self):
        return (-1,self.n_features_input,-1,-1)
    def getOutputShape(self):
        return (-1,self.num_hidden_features[-1], -1,-1)

In [0]:
class ConvolutionalDecoder(nn.Module):
    """
    Convolutional Decoder taking skip connections
    """
    def __init__(self,n_features_output,num_hidden_features,kernel_size,padding,n_resblocks,dropout_min=0,dropout_max=0.2,blockObject=ResidualBlock,batchNormObject=nn.BatchNorm2d):
        """
        n_features_output (int): number of output features
        num_hidden_features (list(int)): number of features for each stage
        kernel_size (int): convolution kernel size
        padding (int): convolution padding
        n_resblocks (int): number of residual blocks at each stage
        dropout (float): dropout probability
        blockObject (nn.Module): Residual block to use. Default is ResidualBlock
        batchNormObject (nn.Module): normalization layer. Default is nn.BatchNorm2d
        """
        super(ConvolutionalDecoder,self).__init__()
        self.n_features_output = n_features_output
        self.num_hidden_features = num_hidden_features
        self.upConvolutions = nn.ModuleList()
        self.skipMergers = nn.ModuleList()
        self.residualBlocks = nn.ModuleList()
        dropout = iter([(1-t)*dropout_min + t*dropout_max   for t in np.linspace(0,1,(len(num_hidden_features)))][::-1])
        # input convolution block
        # layers
        for features_in,features_out in [num_hidden_features[i:i+2] for i in range(0,len(num_hidden_features), 1)][:-1]:
            # downsampling
            self.upConvolutions.append(nn.Sequential(nn.ConvTranspose2d(features_in, features_out, kernel_size=3, stride=2,padding=1,output_padding=1),batchNormObject(features_out),nn.ReLU()))
            self.skipMergers.append(nn.Conv2d(2*features_out, features_out, kernel_size=kernel_size,stride=1, padding=padding))
            # residual blocks
            block = []
            p = next(iter(dropout))
            for _ in range(n_resblocks):
                block += [blockObject(features_out, kernel_size, padding, dropout=p,batchNormObject=batchNormObject)]
            self.residualBlocks.append(nn.Sequential(*block))   
        # output convolution block
        block = [nn.Conv2d(num_hidden_features[-1],n_features_output, kernel_size=kernel_size,stride=1, padding=padding)]
        self.output_convolution = nn.Sequential(*block)

    def forward(self,x, skips):
        for up,merge,conv,skip in zip(self.upConvolutions,self.skipMergers, self.residualBlocks,skips):
            x = up(x)
            cat = torch.cat([x,skip],1)
            x = merge(cat)
            x = conv(x)
        return self.output_convolution(x)
    def getInputShape(self):
        return (-1,self.num_hidden_features[0],-1,-1)
    def getOutputShape(self):
        return (-1,self.n_features_output, -1,-1)

In [0]:
class DilatedConvolutions(nn.Module):
    """
    Sequential Dialted convolutions
    """
    def __init__(self, n_channels, n_convolutions, dropout):
        super(DilatedConvolutions, self).__init__()
        kernel_size = 3
        padding = 1
        self.dropout = nn.Dropout2d(dropout)
        self.non_linearity = nn.ReLU(inplace=True)
        self.strides = [2**(k+1) for k in range(n_convolutions)]
        convs = [nn.Conv2d(n_channels, n_channels, kernel_size=kernel_size,dilation=s, padding=s) for s in self.strides ]
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        for c in convs:
            self.convs.append(c)
            self.bns.append(nn.BatchNorm2d(n_channels))
    def forward(self,x):
        skips = []
        for (c,bn,s) in zip(self.convs,self.bns,self.strides):
            x_in = x
            x = c(x)
            x = bn(x)
            x = self.non_linearity(x)
            x = self.dropout(x)
            x = x_in + x
            skips.append(x)
        return x,skips

In [0]:
class DilatedConvolutions2(nn.Module):
    """
    Sequential Dialted convolutions
    """
    def __init__(self, n_channels, n_convolutions,dropout,kernel_size,blockObject=ResidualBlock,batchNormObject=nn.BatchNorm2d):
        super(DilatedConvolutions2, self).__init__()
        self.dilatations = [2**(k+1) for k in range(n_convolutions)]
        self.blocks = nn.ModuleList([blockObject(n_channels, kernel_size, d, dropout=dropout, dilation=d,batchNormObject=batchNormObject) for d in self.dilatations ])
    def forward(self,x):
        skips = []
        for b in self.blocks:
            x = b(x)
            skips.append(x)
        return x, skips

In [0]:
class UNet(nn.Module):
    """
    U-Net model with dynamic number of layers, Residual Blocks, Dilated Convolutions, Dropout and Group Normalization
    """
    def __init__(self, in_channels, out_channels, num_hidden_features,n_resblocks,num_dilated_convs, dropout_min=0, dropout_max=0, gated=False, padding=1, kernel_size=3,group_norm=32):
        """
        initialize the model
        Args:
            in_channels (int): number of input channels (image=3)
            out_channels (int): number of output channels (n_classes)
            num_hidden_features (list(int)): number of hidden features for each layer (the number of layer is the lenght of this list)
            n_resblocks (int): number of residual blocks at each layer 
            num_dilated_convs (int): number of dilated convolutions at the last layer
            dropout (float): float in [0,1]: dropout probability
            gated (bool): use gated Convolutions, default is False
            padding (int): padding for the convolutions
            kernel_size (int): kernel size for the convolutions
            group_norm (bool): number of groups to use for Group Normalization, default is 32, if zero: use nn.BatchNorm2d
        """
        super(UNet, self).__init__()
        if group_norm > 0:
            for h in num_hidden_features:
                assert h%group_norm==0, "Number of features at each layer must be divisible by 'group_norm'"
        blockObject = gated_resnet if gated else ResidualBlock
        batchNormObject = lambda n_features : nn.GroupNorm(group_norm,n_features) if group_norm > 0 else nn.BatchNorm2d
        self.encoder = ConvolutionalEncoder(in_channels,num_hidden_features,kernel_size,padding,n_resblocks,dropout_min=dropout_min,dropout_max=dropout_max,blockObject=blockObject,batchNormObject=batchNormObject)
        if num_dilated_convs > 0:
            #self.dilatedConvs = DilatedConvolutions2(num_hidden_features[-1], num_dilated_convs,dropout_max,kernel_size,blockObject=blockObject,batchNormObject=batchNormObject)
            self.dilatedConvs = DilatedConvolutions(num_hidden_features[-1],num_dilated_convs,dropout_max) # <v11 uses dilatedConvs2
        else:
            self.dilatedConvs = None
        self.decoder = ConvolutionalDecoder(out_channels,num_hidden_features[::-1],kernel_size,padding,n_resblocks,dropout_min=dropout_min,dropout_max=dropout_max,blockObject=blockObject,batchNormObject=batchNormObject)
        
    def forward(self, x):
        x,skips = self.encoder(x)
        if self.dilatedConvs is not None:
            x,dilated_skips = self.dilatedConvs(x)
            for d in dilated_skips:
                x += d
            x += skips[-1]
        x = self.decoder(x,skips[:-1][::-1])
        return x

Define the Generator $G$ and the dicrimonator $D$:

In [0]:
class Generator(nn.Module):
    def __init__(self, kwargs):
        super(Generator, self).__init__()
        self.unet = UNet(**kwargs)
        self.tanh = nn.Tanh()

    def forward(self, x):
        return self.tanh(self.unet(x))

In [0]:
class Discriminator(nn.Module):
    def __init__(self, kwargs):
        super(Discriminator, self).__init__()
        self.encoder = ConvolutionalEncoder(**kwargs)
        self.convout = nn.Conv2d(kwargs['num_hidden_features'][-1],1,kernel_size=3,padding=1)

    def forward(self, input):

        output,skips = self.encoder(input)
        output = self.convout(output)
        return output
    
class Identity(nn.Module):
    def __init__(self,features):
        super(Identity, self).__init__()

    def forward(self, x):
        return x

## Defining the model:

Define hyperparameters:

In [0]:
batch_size = 32
#m_test = int(np.sqrt(batch_size))
m_test = 5
m_test = m_test - 1 if m_test**2 > batch_size else m_test
lr_g = 1e-4
lr_d = 1e-5
lambda_l1 = 100

in_channels = 1
out_channels = 3
n_features_zero = 32
group_norm = 8
n_resblocks = 1
num_dilated_convs = 4
depth = 4
kernel_size = 3
padding = 1

In [0]:
num_hidden_features = [n_features_zero * 2**k for k in range(depth)]

Define the generator and critic networks with the given hyperparameters:

In [0]:
generator = UNet(in_channels, 
                 out_channels, 
                 num_hidden_features,
                 n_resblocks,
                 num_dilated_convs,
                 dropout_min = 0.1, 
                 dropout_max = 0.3, 
                 gated = False, 
                 padding = 1, 
                 kernel_size = 3,
                 group_norm = group_norm)

generator = nn.Sequential(generator, nn.Tanh())

In [0]:
discriminator = Discriminator({'n_features_input':in_channels + out_channels,
                              'num_hidden_features':num_hidden_features,
                              'kernel_size':kernel_size,
                              'padding':padding,
                              'n_resblocks':n_resblocks,
                              'dropout_min':0.1,
                              'dropout_max':0.3, 
                              'blockObject':ResidualBlock,
                              'batchNormObject':Identity})

Initialize the weights (Xavier for weights and 0 for biases):

In [0]:
def weights_init(m):
    if isinstance(m, nn.Conv2d): 
        if m.weight is not None:
            init.xavier_uniform_(m.weight)
        if m.bias is not None:
            init.constant_(m.bias, 0.0)
    elif isinstance(m, nn.Linear):
        if m.weight is not None:
            init.xavier_uniform_(m.weight)
        if m.bias is not None:
            init.constant_(m.bias, 0.0)

### Training

#### WGAN-GP Gradients Penlaty

In [0]:
def calc_gradient_penalty(netD, real_data, fake_data):
    lmbd = 10 
    alpha = torch.rand(real_data.size(0), 1).requires_grad_()
    alpha = alpha[:,:,None,None]
    alpha = alpha.expand(real_data.size())

    interpolates = alpha * real_data + ((1 - alpha) * fake_data)

    disc_interpolates = netD(interpolates)

    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones(disc_interpolates.size()),
                              create_graph = True, retain_graph=True, only_inputs=True)[0]
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lmbd
    return gradient_penalty

### Train Discriminator

In [0]:
def train_discriminator(optimizer, real_data, real_labels, fake_data, fake_labels):
    optimizer.zero_grad()
    # rain on real data
    real_input = torch.cat([real_data,real_labels],1)
    prediction_real = discriminator(real_input).squeeze()
    # train on take Data
    fake_input = torch.cat([fake_data,fake_labels],1)
    prediction_fake = discriminator(fake_input).squeeze()
    # gradients penalty (Lipschitz condition) (WGAN-GP)
    penalty = calc_gradient_penalty(discriminator,real_input, fake_input)
    loss = prediction_fake.mean() - prediction_real.mean() + penalty
    loss.backward()
    optimizer.step()
    return loss

### Train Generator

In [0]:
def train_generator(optimizer, fake_data, fake_labels, true_data):
    optimizer.zero_grad()
    fake_input = torch.cat([fake_data,fake_labels],1)
    prediction = discriminator(fake_input).squeeze() 
    G_loss = - torch.mean(prediction)
    L1 = torch.abs(fake_data - true_data).mean()
    loss = G_loss + lambda_l1 * L1
    loss.backward()
    optimizer.step()
    return G_loss,L1

## Define the training:

Setting parameters for the training:

In [0]:
epochs = 20
starting_epoch = 0
g_error = 0
gen_steps = 1
gen_train_freq = 1
global_step = 0

In [0]:
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
data_loader2 = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [0]:
d_optimizer = optim.Adam(discriminator.parameters(), lr = lr_d)
g_optimizer = optim.Adam(generator.parameters(), lr = lr_g)

Initializing the weights (or loading them from a checkpoint):

In [0]:
generator.apply(weights_init)
discriminator.apply(weights_init)

In [0]:
generator.load_state_dict(torch.load('drive/My Drive/Deoldify/Model_3/Checkpoints/gen_epoch_7,gloss=-1.805633.pth'))
discriminator.load_state_dict(torch.load('drive/My Drive/Deoldify/Model_3/Checkpoints/disc_epoch_7,gloss=0.793554.pth'))

Training:

In [0]:
global_step = 0
for epoch in range(starting_epoch, epochs):
  for n_batch, (real_data,label_batch) in enumerate(data_loader):
      with torch.no_grad():
          fake_data = generator(label_batch)
      d_error = train_discriminator(d_optimizer, real_data,label_batch, fake_data,label_batch)
      if global_step % gen_train_freq == 0:
          for _ in range(gen_steps):
              real_data, label_batch = next(iter(data_loader2))
              fake_data = generator(label_batch)
              g_error,l1_error = train_generator(g_optimizer,fake_data,label_batch,real_data)
              g_error,l1_error = g_error.item(),l1_error.item()
      print('step ', global_step, ' epoch ', epoch, ' d_error ', d_error.item(), ' g_error ', g_error)
      global_step += 1
      if global_step % 5 == 0:
          test_images = fake_data.permute(0,2,3,1).data.cpu().numpy()
          i_ = 0
          plt.figure(figsize=(5*m_test, 5*m_test)) 
          plt.subplots_adjust(wspace=0, hspace=0)
          for l in range(m_test**2):
              tile = test_images[l]
              tile = (tile-tile.min())/(tile.max()-tile.min())
              plt.subplot(m_test, m_test, i_+1) 
              plt.imshow(tile); plt.axis('off')
              i_ += 1
          plt.savefig('drive/My Drive/Deoldify/Model_3/Train/epoch' + str(epoch) + 'step' + str(global_step) + '.png', bbox_inches='tight')
          #plt.show()
      if global_step % 50 == 0:
        torch.save(generator.state_dict(),
                  '%s/gen_epoch_%d,gloss=%.6f.pth' % (
                  'drive/My Drive/Deoldify/Model_3/Checkpoints/', epoch, g_error))
        torch.save(discriminator.state_dict(),
                  '%s/disc_epoch_%d,gloss=%.6f.pth' % (
                  'drive/My Drive/Deoldify/Model_3/Checkpoints/', epoch, d_error))

## Inference

Setting the path for the test images folder and loading the dataset:

In [0]:
test_root = 'val2017/'

In [0]:
test_dataset = Load_Dataset(test_root,128)
test_dataset[0][0].shape,dataset[0][1].shape

Batch-size for testing (default set to 1):

In [0]:
test_batch_size = 1
m_test = int(np.sqrt(test_batch_size))

In [0]:
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size = test_batch_size, shuffle=True)
test_data_loader2 = torch.utils.data.DataLoader(test_dataset, batch_size = test_batch_size, shuffle=True)

Loading the weights from the checkpoints:

In [0]:
generator.load_state_dict(torch.load('drive/My Drive/Deoldify/Model_3/Checkpoints/gen_epoch_6,gloss=0.751362.pth'))
discriminator.load_state_dict(torch.load('drive/My Drive/Deoldify/Model_3/Checkpoints/disc_epoch_6,gloss=0.436743.pth'))

Testing process:

In [0]:
global_test_step = 0
for n_batch, (real_data,label_batch) in enumerate(test_data_loader):
    with torch.no_grad():
        fake_data = generator(label_batch)
    if global_test_step % gen_train_freq == 0:
        for _ in range(gen_steps):
            real_data, label_batch = next(iter(test_data_loader2))
            fake_data = generator(label_batch)
    global_test_step += 1
    if global_test_step % 1 == 0:
        test_images = fake_data.permute(0,2,3,1).data.cpu().numpy()
        real_images = real_data.permute(0,2,3,1).data.cpu().numpy()
        i_ = 0
        plt.figure(figsize=(5*m_test, 5*m_test)) 
        plt.subplots_adjust(wspace=0, hspace=0)
        for l in range(m_test**2):
            tile = test_images[l]
            tile = (tile-tile.min())/(tile.max()-tile.min())
            plt.subplot(m_test, m_test, i_+1) 
            plt.imshow(tile); plt.axis('off')
            i_ += 1
        plt.savefig('drive/My Drive/Deoldify/Model_3/Test/colorized_step' + str(global_test_step) + '.png', bbox_inches='tight')
        plt.show()

        i_ = 0
        plt.figure(figsize=(5*m_test, 5*m_test)) 
        plt.subplots_adjust(wspace=0, hspace=0)
        for l in range(m_test**2):
            tile = real_images[l]
            tile = (tile-tile.min())/(tile.max()-tile.min())
            plt.subplot(m_test, m_test, i_+1) 
            plt.imshow(tile); plt.axis('off')
            i_ += 1
        plt.savefig('drive/My Drive/Deoldify/Model_3/Test/original_step' + str(global_test_step) + '.png', bbox_inches='tight')
        plt.show()