In [1]:
pip install torchvision -y -c pytorch


Usage:   
  /opt/conda/bin/python3.7 -m pip install [options] <requirement specifier> [package-index-options] ...
  /opt/conda/bin/python3.7 -m pip install [options] -r <requirements file> [package-index-options] ...
  /opt/conda/bin/python3.7 -m pip install [options] [-e] <vcs project url> ...
  /opt/conda/bin/python3.7 -m pip install [options] [-e] <local project path> ...
  /opt/conda/bin/python3.7 -m pip install [options] <archive url/path> ...

no such option: -y
Note: you may need to restart the kernel to use updated packages.


In [2]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --apt-packages libomp5 libopenblas-dev

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  5116  100  5116    0     0   9387      0 --:--:-- --:--:-- --:--:--  9369
Updating... This may take around 2 minutes.
Updating TPU runtime to pytorch-dev20200515 ...
Found existing installation: torch 1.5.0
Uninstalling torch-1.5.0:
  Successfully uninstalled torch-1.5.0
Found existing installation: torchvision 0.6.0a0+35d732a
Uninstalling torchvision-0.6.0a0+35d732a:
Done updating TPU runtime
  Successfully uninstalled torchvision-0.6.0a0+35d732a
Copying gs://tpu-pytorch/wheels/torch-nightly+20200515-cp37-cp37m-linux_x86_64.whl...
\ [1 files][ 91.0 MiB/ 91.0 MiB]                                                
Operation completed over 1 objects/91.0 MiB.                                     
Copying gs://tpu-pytorch/wheels/torch_xla-nightly+20200515-cp37-cp37m-linux_x86_64.whl...
| [1 files][119.5 MiB/119.5 MiB]              

In [3]:
USE_TPU = True
MULTI_CORE = False

import os
import torch

 
DATA_DIR = '../input/imagenet-mni/imagenet/'
OUT_DIR = './result/'
MODEL_DIR = './models/'
CHECKPOINT_DIR = './'

TRAIN_DIR = DATA_DIR+"train/"  # UPDATE
TEST_DIR = DATA_DIR+"test/" # UPDATE

os.makedirs(TRAIN_DIR, exist_ok=True)
os.makedirs(TEST_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(OUT_DIR, exist_ok=True)

# DATA INFORMATION
IMAGE_SIZE = 224
BATCH_SIZE = 10
GRADIENT_PENALTY_WEIGHT = 10
NUM_EPOCHS = 1
KEEP_CKPT = 2
# save_model_path = MODEL_DIR
DEVICE=0
if USE_TPU:
    import torch_xla.core.xla_model as xm
    if not MULTI_CORE:
        DEVICE = xm.xla_device()

if not USE_TPU:
    if torch.cuda.is_available():
        DEVICE = torch.device('cuda')
    else:
        DEVICE = 'cpu'




In [4]:
import os
import cv2 
import numpy as np


class DATA():
    def __init__(self, dirname, max_len=None):
        self.dir_path = dirname
        self.filelist = os.listdir(self.dir_path)[:max_len]
        self.batch_size =  BATCH_SIZE
        self.size = len(self.filelist)
        self.data_index = 0
    def __len__(self):
        return len(self.filelist)
    
    def __getitem__(self, item):
        img = []
        label = []
        itemfilelist = ''
        filename = os.path.join(self.dir_path, self.filelist[item])
        itemfilelist = self.filelist[item]
        greyimg, colorimg = self.read_img(filename)
        img = greyimg
        label = colorimg
        img = np.asarray(img)/255 # values between 0 and 1
        label = np.asarray(label)/255 # values between 0 and 1
        return img, label, itemfilelist

    def read_img(self, filename):
        img = cv2.imread(filename, 3)
        height, width, channels = img.shape
        min_hw = int(min(height,width)/2)
        img = img[int(height/2)-min_hw:int(height/2)+min_hw,int(width/2)-min_hw:int(width/2)+min_hw,:]
        labimg = cv2.cvtColor(cv2.resize(img, ( IMAGE_SIZE,  IMAGE_SIZE)), cv2.COLOR_RGB2Lab) ## Changed BGR to RGB
        return np.reshape(labimg[:,:,0], (1,  IMAGE_SIZE,  IMAGE_SIZE)), np.reshape(labimg[:, :, 1:], (2, IMAGE_SIZE,  IMAGE_SIZE))

    def generate_batch(self):
        batch = []
        labels = []
        filelist = []
        for i in range(self.batch_size):
            filename = os.path.join(self.dir_path, self.filelist[self.data_index])
            filelist.append(self.filelist[self.data_index])
            greyimg, colorimg = self.read_img(filename)
            batch.append(greyimg)
            labels.append(colorimg)
            self.data_index = (self.data_index + 1) % self.size
        batch = np.asarray(batch)/255 # values between 0 and 1
        labels = np.asarray(labels)/255 # values between 0 and 1
        return batch, labels, filelist


In [14]:

import os
import cv2
import torch 
import glob
import numpy as np
import matplotlib.pyplot as plt

if  USE_TPU:
    import torch_xla.core.xla_model as xm

def preprocess(imgs):
    try:
        imgs = imgs.detach().numpy()
    except:
        pass
    imgs = imgs * 255
    imgs[imgs>255] = 255
    imgs[imgs<0] = 0 
    return imgs.astype(np.uint8) # torch.unit8

def reconstruct(batchX, predictedY, filelist):

    batchX = batchX.reshape(224,224,1) 
    predictedY = predictedY.reshape(224,224,2)
    result = np.concatenate((batchX, predictedY), axis=2)
    result = cv2.cvtColor(result, cv2.COLOR_Lab2RGB)
    save_results_path =  OUT_DIR
    if not os.path.exists(save_results_path):
        os.makedirs(save_results_path)
    save_path = os.path.join(save_results_path, filelist +  "_reconstructed.jpg" )
    cv2.imwrite(save_path, result)
    return result
    
def reconstruct_no(batchX, predictedY):

    batchX = batchX.reshape(224,224,1) 
    predictedY = predictedY.reshape(224,224,2)
    
    result = np.concatenate((batchX, predictedY), axis=2)
    result = cv2.cvtColor(result, cv2.COLOR_Lab2RGB)
    return result


def imag_gird(axrow, orig, batchL, preds, epoch):
    fig , ax = plt.subplots(1,3, figsize=(15,15))
    ax[0].imshow(orig)
    ax[0].set_title('Original Image')

    ax[1].imshow(np.tile(batchL,(1,1,3)))
    ax[1].set_title('L Image with Channels reapeated(Input)') 

    ax[2].imshow(preds)
    ax[2].set_title('Pred Image')
    plt.savefig(f'sample_preds_{epoch}')
    plt.close()
  # plt.show()

def plot_some(test_data, colorization_model, device, epoch):
    with torch.no_grad():
        indexes = [0, 2, 9]
        for idx in indexes: 
            batchL, realAB, filename = test_data[idx]
            filepath =  TRAIN_DIR+filename
            batchL = batchL.reshape(1,1,224,224)
            realAB = realAB.reshape(1,2,224,224)
            batchL_3 = torch.tensor(np.tile(batchL, [1, 3, 1, 1]))
            batchL_3 = batchL_3.to(device)
            batchL = torch.tensor(batchL).to(device).double()
            realAB = torch.tensor(realAB).to(device).double()

            colorization_model.eval()
            batch_predAB, _ = colorization_model(batchL_3)
            img = cv2.imread(filepath)
            batch_predAB = batch_predAB.cpu().numpy().reshape((224,224,2))
            batchL = batchL.cpu().numpy().reshape((224,224,1))
            realAB = realAB.cpu().numpy().reshape((224,224,2))
            orig = cv2.imread(filepath)
            orig = cv2.resize(cv2.cvtColor(orig, cv2.COLOR_BGR2RGB), (224,224))
            # orig = reconstruct_no(preprocess(batchL), preprocess(realAB))
            preds = reconstruct_no(preprocess(batchL), preprocess(batch_predAB))
            imag_gird(0, orig, batchL, preds, epoch)
            plt.show()

def create_checkpoint(epoch, netG, optG, netD, optD, max_checkpoint, save_path= CHECKPOINT_DIR):
    print('Saving Model and Optimizer weights.....')
    checkpoint = {
        'epoch' : epoch,
        'generator_state_dict' :netG.state_dict(),
        'generator_optimizer': optG.state_dict(),
        'discriminator_state_dict': netD.state_dict(),
        'discriminator_optimizer': optD.state_dict()
    }
    if  USE_TPU:
        xm.save(checkpoint, '2_checkpoint.pt')
    else:
        torch.save(checkpoint, '2_checkpoint.pt')
    print('Weights Saved !!')
    del checkpoint
    files = glob.glob(os.path.expanduser(f"{save_path}*"))
    sorted_files = sorted(files, key=lambda t: -os.stat(t).st_mtime)
    if len(sorted_files) > max_checkpoint:
        os.remove(sorted_files[-1])



def load_checkpoint(checkpoint_directory, netG, optG, netD, optD, device):
    load_from_checkpoint = False
    files = glob.glob(os.path.expanduser(f"{checkpoint_directory}*"))
    for file in files:
        if file.endswith('.pt'):
            load_from_checkpoint=True
            break

    if load_from_checkpoint:
        print('Loading Model and optimizer states from checkpoint....')
        sorted_files = sorted(files, key=lambda t: -os.stat(t).st_mtime)
        checkpoint = torch.load(f'{sorted_files[0]}')
        epoch_checkpoint = checkpoint['epoch'] + 1
        netG.load_state_dict(checkpoint['generator_state_dict'])
        netG.to(device)

        optG.load_state_dict(checkpoint['generator_optimizer'])

        netD.load_state_dict(checkpoint['discriminator_state_dict'])
        netD.to(device)

        optD.load_state_dict(checkpoint['discriminator_optimizer'])
        print('Loaded States !!!')
        print(f'It looks like the this states belong to epoch {epoch_checkpoint-1}.')
        print(f'so the model will train for { NUM_EPOCHS - (epoch_checkpoint-1)} more epochs.')
         

        return netG, optG, netD, optD, epoch_checkpoint
    else:
        print('There are no checkpoints in the mentioned directoy, the Model will train from scratch.')
        epoch_checkpoint = 1
        return netG, optG, netD, optD, epoch_checkpoint
    


def plot_gan_loss(G_losses, D_losses):
    plt.figure(figsize=(10,5))
    plt.title(f"Generator and Discriminator Loss During Training ")
    plt.plot(G_losses,label="G")
    plt.plot(D_losses,label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(f'GANLOSS{epoch}.pdf',figsize=(30,30))


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision


bias=True

class discriminator_model(nn.Module):

    def __init__(self):
        super(discriminator_model, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=(4,4),padding=1,stride=(2,2),bias=bias) # 64, 112, 112
        self.conv2 = nn.Conv2d(64, 128, kernel_size=(4,4), padding=1, stride=(2,2), bias=bias) # 128, 56, 56
        self.conv3 = nn.Conv2d(128,256, kernel_size=(4,4), padding=1, stride=(2,2), bias=bias) # 256, 28, 28, 2
        self.conv4 = nn.Conv2d(256,512, kernel_size=(4,4), padding=3, stride=(1,1), bias=bias) # 512, 28, 28
        self.conv5 = nn.Conv2d(512,1, kernel_size=(4,4), padding=3, stride=(1,1), bias=bias) # 1, 
        self.leaky_relu = nn.LeakyReLU(0.3)

    def forward(self,input):

        net = self.conv1(input)               #[-1, 64, 112, 112]
        net = self.leaky_relu(net)          #[-1, 64, 112, 112]    
        net = self.conv2(net)               #[-1, 128, 56, 56] 
        net = self.leaky_relu(net)          #[-1, 128, 56, 56] 
        net = self.conv3(net)               #[-1, 256, 28, 28]
        net = self.leaky_relu(net)          #[-1, 256, 28, 28]   
        net = self.conv4(net)               #[-1, 512, 27, 27]
        net = self.leaky_relu(net)          #[-1, 512, 27, 27]
        net = self.conv5(net)               #[-1, 1, 26, 26]
        return net

class colorization_model(nn.Module):
    def __init__(self):
        super(colorization_model, self).__init__()

        self.VGG_model = torchvision.models.vgg16(pretrained=True)
        self.VGG_model = nn.Sequential(*list(self.VGG_model.features.children())[:-8]) #[None, 512, 28, 28]
        self.VGG_model = self.VGG_model.double()
        self.relu = nn.ReLU()
        self.lrelu = nn.LeakyReLU(0.3)
        self.global_features_conv1 = nn.Conv2d(512, 512, kernel_size=(3,3), padding=1, stride=(2,2), bias=bias) #[None, 512, 14, 14]
        self.global_features_bn1 = nn.BatchNorm2d(512,eps=0.001,momentum=0.99)
        self.global_features_conv2 = nn.Conv2d(512, 512, kernel_size=(3,3), padding=1, stride=(1,1), bias=bias) #[None, 512, 14, 14]
        self.global_features_bn2 = nn.BatchNorm2d(512,eps=0.001,momentum=0.99)
        self.global_features_conv3 = nn.Conv2d(512, 512, kernel_size=(3,3), padding=1, stride=(2,2), bias=bias) #[None, 512, 7, 7]
        self.global_features_bn3 = nn.BatchNorm2d(512,eps=0.001,momentum=0.99)
        self.global_features_conv4 = nn.Conv2d(512, 512, kernel_size=(3,3), padding=1, stride=(1,1), bias=bias) #[None, 512, 7, 7]
        self.global_features_bn4 = nn.BatchNorm2d(512,eps=0.001,momentum=0.99)

        self.global_features2_flatten = nn.Flatten()
        self.global_features2_dense1 = nn.Linear(512*7*7,1024)
        self.midlevel_conv1 = nn.Conv2d(512,512, kernel_size=(3,3), padding=1, stride=(1,1), bias=bias) #[None, 512, 28, 28]
        self.global_features2_dense2 = nn.Linear(1024,512)
        self.midlevel_bn1 = nn.BatchNorm2d(512, eps=0.001,momentum=0.99)
        self.global_features2_dense3 = nn.Linear(512,256)
        self.midlevel_conv2 = nn.Conv2d(512,256, kernel_size=(3,3), padding=1, stride=(1,1), bias=bias)

        self.midlevel_bn2 = nn.BatchNorm2d(256,eps=0.001,momentum=0.99)

         #[None, 256, 28, 28]
        # self.midlevel_bn2 = nn.BatchNorm2d(256)#,,eps=0.001,momentum=0.99)

        self.global_featuresClass_flatten = nn.Flatten()
        self.global_featuresClass_dense1 = nn.Linear(512*7*7, 4096)
        self.global_featuresClass_dense2 = nn.Linear(4096, 4096)
        self.global_featuresClass_dense3 = nn.Linear(4096, 1000)
        self.softmax = nn.Softmax()

        self.outputmodel_conv1 = nn.Conv2d(512, 256, kernel_size=(1,1), padding=0, stride=(1,1),  bias=bias) 
        self.outputmodel_conv2 = nn.Conv2d(256, 128, kernel_size=(3,3), padding=1, stride=(1,1), bias=bias)
        self.outputmodel_conv3 = nn.Conv2d(128, 64, kernel_size=(3,3), padding=1, stride=(1,1), bias=bias)
        self.outputmodel_conv4 = nn.Conv2d(64, 64, kernel_size=(3,3), padding=1, stride=(1,1), bias=bias)
        self.outputmodel_conv5 = nn.Conv2d(64, 32, kernel_size=(3,3), padding=1, stride=(1,1), bias=bias)
        self.outputmodel_conv6 = nn.Conv2d(32, 2, kernel_size=(3,3), padding=1, stride=(1,1), bias=bias)
        self.outputmodel_upsample = nn.Upsample(scale_factor=(2,2))
        self.outputmodel_bn1 = nn.BatchNorm2d(128)
        self.outputmodel_bn2 = nn.BatchNorm2d(64)
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()

    def forward(self,input_img):

        # VGG Without Top Layers

        vgg_out = self.VGG_model(torch.tensor(input_img).double())

        #Global Features

        global_features = self.relu(self.global_features_conv1(vgg_out))  #[None, 512, 14, 14]
        global_features = self.global_features_bn1(global_features) #[None, 512, 14, 14]
        global_features = self.relu(self.global_features_conv2(global_features)) #[None, 512, 14, 14]
        global_features = self.global_features_bn2(global_features) #[None, 512, 14, 14]

        global_features = self.relu(self.global_features_conv3(global_features)) #[None, 512, 7, 7]
        global_features = self.global_features_bn3(global_features)  #[None, 512, 7, 7]
        global_features = self.relu(self.global_features_conv4(global_features)) #[None, 512, 7, 7]
        global_features = self.global_features_bn4(global_features) #[None, 512, 7, 7]

        global_features2 = self.global_features2_flatten(global_features) #[None, 512*7*7]

        global_features2 = self.global_features2_dense1(global_features2) #[None, 1024]
        global_features2 = self.global_features2_dense2(global_features2) #[None, 512]
        global_features2 = self.global_features2_dense3(global_features2) #[None, 256]
        global_features2 = global_features2.unsqueeze(2).expand(-1,256,28*28) #[None, 256, 784]
        global_features2 = global_features2.view((-1,256,28,28)) #[None, 256, 28, 28]

        global_featureClass = self.global_featuresClass_flatten(global_features) #[None, 512*7*7]
        global_featureClass = self.global_featuresClass_dense1(global_featureClass) #[None, 4096]
        global_featureClass = self.global_featuresClass_dense2(global_featureClass) #[None, 4096]
        global_featureClass = self.softmax(self.global_featuresClass_dense3(global_featureClass))#[None, 1000]

        # Mid Level Features
        midlevel_features = self.midlevel_conv1(vgg_out.double()) #[None, 512, 28, 28]
        midlevel_features = self.midlevel_bn1(midlevel_features) #[None, 512, 28, 28]
        midlevel_features = self.midlevel_conv2(midlevel_features) #[None, 256, 28, 28]
        midlevel_features = self.midlevel_bn2(midlevel_features) #[None, 256, 28, 28]

        # Fusion of (VGG16 + MidLevel) + (VGG16 + Global)

        modelFusion = torch.cat([midlevel_features, global_features2],dim=1)

        # Fusion Colorization

        outputmodel = self.relu(self.outputmodel_conv1(modelFusion)) # None, 256, 28, 28
        outputmodel = self.relu(self.outputmodel_conv2(outputmodel)) # None, 128, 28, 28

        outputmodel = self.outputmodel_upsample(outputmodel) # None, 128, 56, 56
        outputmodel = self.outputmodel_bn1(outputmodel) # None, 128, 56, 56
        outputmodel = self.relu(self.outputmodel_conv3(outputmodel)) # None, 64, 56, 56
        outputmodel = self.relu(self.outputmodel_conv4(outputmodel)) # None, 64, 56, 56 

        outputmodel = self.outputmodel_upsample(outputmodel) # None, 64, 112, 112
        outputmodel = self.outputmodel_bn2(outputmodel) # None, 64, 112, 112
        outputmodel = self.relu(self.outputmodel_conv5(outputmodel)) # None, 32, 112, 112
        outputmodel = self.sigmoid(self.outputmodel_conv6(outputmodel)) # None, 2, 112, 112
        outputmodel = self.outputmodel_upsample(outputmodel) # None, 2, 224, 224

        return outputmodel, global_featureClass


class GAN(nn.Module):
    def __init__(self, netG, netD):
        super(GAN, self).__init__()

        self.netG = netG
        self.netD = netD

    def forward(self, trainL, trainL_3):

        for param in self.netD.parameters():
            param.requires_grad= False

        predAB, classVector = self.netG(trainL_3)
        predLAB = torch.cat([trainL, predAB], dim=1)
        discpred = self.netD(predLAB)

        return predAB, classVector, discpred



[('__call__', <function LevelMapper.__call__ at 0x7fcf936e1050>), ('__init__', <function LevelMapper.__init__ at 0x7fcf936dbf80>)]
[('__call__', <function BalancedPositiveNegativeSampler.__call__ at 0x7fcf935787a0>), ('__init__', <function BalancedPositiveNegativeSampler.__init__ at 0x7fcf93578710>)]
[('__init__', <function BoxCoder.__init__ at 0x7fcf93585f80>), ('decode', <function BoxCoder.decode at 0x7fcf9358a170>), ('decode_single', <function BoxCoder.decode_single at 0x7fcf9358a200>), ('encode', <function BoxCoder.encode at 0x7fcf9358a050>), ('encode_single', <function BoxCoder.encode_single at 0x7fcf9358a0e0>)]
[('__call__', <function Matcher.__call__ at 0x7fcf93585ef0>), ('__init__', <function Matcher.__init__ at 0x7fcf93585d40>), ('set_low_quality_matches_', <function Matcher.set_low_quality_matches_ at 0x7fcf93585e60>)]
[('__init__', <function ImageList.__init__ at 0x7fcf9358a4d0>), ('to', <function ImageList.to at 0x7fcf9358a440>)]
[('__init__', <function Timebase.__init__ at

In [7]:
 
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F


if  USE_TPU:
    import torch_xla
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl
    import torch_xla.distributed.xla_multiprocessing as xmp



def train(train_loader, GAN_Model, netD, VGG_MODEL, optG, optD, device, losses):
    
    batch = 0
    
    def wgan_loss(prediction, real_or_not):
        if real_or_not:
            return -torch.mean(prediction.float())
        else:
            return torch.mean(prediction.float())

    def gp_loss(y_pred, averaged_samples, gradient_penalty_weight):

        gradients = torch.autograd.grad(y_pred,averaged_samples,
                                  grad_outputs=torch.ones(y_pred.size(), device=device),
                                  create_graph=True, retain_graph=True, only_inputs=True)[0]
        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = (((gradients+1e-16).norm(2, dim=1) - 1) ** 2).mean() * gradient_penalty_weight
        return gradient_penalty
    for trainL, trainAB, _ in tqdm(iter(train_loader)):
        batch += 1  

        trainL_3 = torch.tensor(np.tile(trainL, [1,3,1,1]), device=device)

        trainL = torch.tensor(trainL, device=device).double()
        trainAB = torch.tensor(trainAB, device=device).double()
        # trainL_3 = trainL_3.to(device).double()

        predictVGG = F.softmax(VGG_MODEL(trainL_3))

        ############ GAN MODEL ( Training Generator) ###################
        optG.zero_grad()
        predAB, classVector, discpred = GAN_Model(trainL, trainL_3)
        D_G_z1 = discpred.mean().item()
        Loss_KLD = nn.KLDivLoss(size_average='False')(classVector.log().float(), predictVGG.detach().float()) * 0.003
        Loss_MSE = nn.MSELoss()(predAB.float(), trainAB.float())
        Loss_WL = wgan_loss(discpred.float(), True) * 0.1 
        Loss_G = Loss_KLD + Loss_MSE + Loss_WL
        Loss_G.backward()

        if  USE_TPU:
            if  MULTI_CORE:
                xm.optimizer_step(optG)
            else:
                xm.optimizer_step(optG, barrier=True)
        else:
            optG.step()

        losses['G_losses'].append(Loss_G.item())
        losses['EPOCH_G_losses'].append(Loss_G.item())



        ################################################################

        ############### Discriminator Training #########################

        for param in netD.parameters():
            param.requires_grad = True

        optD.zero_grad()
        predLAB = torch.cat([trainL, predAB], dim=1)
        discpred = netD(predLAB.detach())
        D_G_z2 = discpred.mean().item()
        realLAB = torch.cat([trainL, trainAB], dim=1)
        discreal = netD(realLAB)
        D_x = discreal.mean().item()

        weights = torch.randn((trainAB.size(0),1,1,1), device=device)          
        averaged_samples = (weights * trainAB ) + ((1 - weights) * predAB.detach())
        averaged_samples = torch.autograd.Variable(averaged_samples, requires_grad=True)
        avg_img = torch.cat([trainL, averaged_samples], dim=1)
        discavg = netD(avg_img)

        Loss_D_Fake = wgan_loss(discpred, False)
        Loss_D_Real = wgan_loss(discreal, True)
        Loss_D_avg = gp_loss(discavg, averaged_samples,  GRADIENT_PENALTY_WEIGHT)

        Loss_D = Loss_D_Fake + Loss_D_Real + Loss_D_avg
        Loss_D.backward()
        if  USE_TPU:
            if  MULTI_CORE:
                xm.optimzer_step(optD)
            else:
                xm.optimizer_step(optD, barrier=True)
        else:
            optD.step()

        losses['D_losses'].append(Loss_D.item())
        losses['EPOCH_D_losses'].append(Loss_D.item())
        # Output training stats
        if batch % 100 == 0:
            print('Loss_D: %.8f | Loss_G: %.8f | D(x): %.8f | D(G(z)): %.8f / %.8f | MSE: %.8f | KLD: %.8f | WGAN_F(G): %.8f | WGAN_F(D): %.8f | WGAN_R(D): %.8f | WGAN_A(D): %.8f'
                % (Loss_D.item(), Loss_G.item(), D_x, D_G_z1, D_G_z2,Loss_MSE.item(),Loss_KLD.item(),Loss_WL.item(), Loss_D_Fake.item(), Loss_D_Real.item(), Loss_D_avg.item()))



In [None]:
 
import time
import torch
import torchvision
import warnings
warnings.filterwarnings('ignore')

import gc

if  USE_TPU:
    import torch_xla
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl
    import torch_xla.distributed.xla_multiprocessing as xmp


def map_fn(index=None, flags=None):
    global DEVICE
    torch.set_default_tensor_type('torch.FloatTensor')
    torch.manual_seed(1234)

    train_data =  DATA( TRAIN_DIR) 

    if  MULTI_CORE:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
          train_data,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=True)
    else:
        train_sampler = torch.utils.data.RandomSampler(train_data)

    train_loader = torch.utils.data.DataLoader(
      train_data,
      batch_size=flags['batch_size'] if  MULTI_CORE else  BATCH_SIZE,
      sampler=train_sampler,
      num_workers=flags['num_workers'] if  MULTI_CORE else 4,
      drop_last=True,
      pin_memory=True)

    if  MULTI_CORE:
        DEVICE = xm.xla_device()
    else:
        DEVICE =  DEVICE


    netG =  colorization_model().double()
    netD =  discriminator_model().double()

    VGG_modelF = torchvision.models.vgg16(pretrained=True).double()
    VGG_modelF.requires_grad_(False)

    netG = netG.to(DEVICE)
    netD = netD.to(DEVICE)

    VGG_modelF = VGG_modelF.to(DEVICE)

    optD = torch.optim.Adam(netD.parameters(), lr=2e-4, betas=(0.5, 0.999))
    optG = torch.optim.Adam(netG.parameters(), lr=2e-4, betas=(0.5, 0.999))
    
    ## Trains
    train_start = time.time()
    losses = {
      'G_losses' : [],
      'D_losses' : [],
      'EPOCH_G_losses' : [],
      'EPOCH_D_losses' : [],
      'G_losses_eval' : []
    }

    netG, optG, netD, optD, epoch_checkpoint =  load_checkpoint( CHECKPOINT_DIR, netG, optG, netD, optD, DEVICE)
    netGAN = GAN(netG, netD)
    for epoch in range(NUM_EPOCHS):#(epoch_checkpoint,flags['num_epochs']+1 if  MULTI_CORE else  NUM_EPOCHS+1):
        print('\n')
        print('#'*8,f'EPOCH-{epoch}','#'*8)
        losses['EPOCH_G_losses'] = []
        losses['EPOCH_D_losses'] = []
        if  MULTI_CORE:
            para_train_loader = pl.ParallelLoader(train_loader, [DEVICE]).per_device_loader(DEVICE)
            train(para_train_loader, netGAN, netD, VGG_modelF, optG, optD, device=DEVICE, losses=losses)
            elapsed_train_time = time.time() - train_start
            print("Process", index, "finished training. Train time was:", elapsed_train_time) 
        else:
            train(train_loader, netGAN, netD, VGG_modelF, optG, optD, device=DEVICE, losses=losses)
        #########################CHECKPOINTING#################################
        create_checkpoint(epoch, netG, optG, netD, optD, max_checkpoint= KEEP_CKPT, save_path =  CHECKPOINT_DIR)
        ########################################################################
        plot_some(train_data, netG, DEVICE, epoch)
        gc.collect()
# Configures training (and evaluation) parameters

def run():
    if  MULTI_CORE:
        flags = {}
        flags['batch_size'] =  BATCH_SIZE
        flags['num_workers'] = 8
        flags['num_epochs'] =  NUM_EPOCHS
        flags['seed'] = 1234
        xmp.spawn(map_fn, args=(flags,), nprocs=8, start_method='fork')
    else:
        map_fn()
    # print(flags)
if __name__=='__main__':
    run()

Loading Model and optimizer states from checkpoint....
Loaded States !!!
It looks like the this states belong to epoch 1.
so the model will train for 0 more epochs.


######## EPOCH-0 ########


  3%|▎         | 100/3472 [01:18<17:51,  3.15it/s] 

Loss_D: 0.01443609 | Loss_G: 0.00559527 | D(x): -0.00064835 | D(G(z)): -0.00329137 / -0.00329142 | MSE: 0.00525244 | KLD: 0.00001369 | WGAN_F(G): 0.00032914 | WGAN_F(D): -0.00329142 | WGAN_R(D): 0.00064835 | WGAN_A(D): 0.01707916


  6%|▌         | 200/3472 [01:49<16:47,  3.25it/s]

Loss_D: 0.06543993 | Loss_G: 0.00205643 | D(x): -0.00423222 | D(G(z)): -0.00548476 / -0.00548473 | MSE: 0.00149599 | KLD: 0.00001197 | WGAN_F(G): 0.00054847 | WGAN_F(D): -0.00548473 | WGAN_R(D): 0.00423222 | WGAN_A(D): 0.06669244


  9%|▊         | 300/3472 [02:20<16:31,  3.20it/s]

Loss_D: 0.03462139 | Loss_G: 0.00321000 | D(x): -0.00436684 | D(G(z)): -0.00636351 / -0.00636350 | MSE: 0.00256231 | KLD: 0.00001133 | WGAN_F(G): 0.00063635 | WGAN_F(D): -0.00636350 | WGAN_R(D): 0.00436684 | WGAN_A(D): 0.03661805


 12%|█▏        | 400/3472 [02:51<15:51,  3.23it/s]

Loss_D: 0.03301423 | Loss_G: 0.00297843 | D(x): -0.00701017 | D(G(z)): -0.00893800 / -0.00893778 | MSE: 0.00207334 | KLD: 0.00001132 | WGAN_F(G): 0.00089378 | WGAN_F(D): -0.00893778 | WGAN_R(D): 0.00701017 | WGAN_A(D): 0.03494184


 14%|█▍        | 500/3472 [03:21<15:26,  3.21it/s]

Loss_D: 0.17794327 | Loss_G: 0.00361153 | D(x): -0.00704190 | D(G(z)): -0.00864531 / -0.00864533 | MSE: 0.00273524 | KLD: 0.00001176 | WGAN_F(G): 0.00086453 | WGAN_F(D): -0.00864533 | WGAN_R(D): 0.00704190 | WGAN_A(D): 0.17954671


 17%|█▋        | 600/3472 [03:52<14:49,  3.23it/s]

Loss_D: 0.10536683 | Loss_G: 0.00518783 | D(x): -0.00455053 | D(G(z)): -0.00700978 / -0.00700956 | MSE: 0.00447424 | KLD: 0.00001263 | WGAN_F(G): 0.00070096 | WGAN_F(D): -0.00700956 | WGAN_R(D): 0.00455053 | WGAN_A(D): 0.10782586


 20%|██        | 700/3472 [04:23<14:16,  3.24it/s]

Loss_D: 0.00326395 | Loss_G: 0.00309915 | D(x): -0.00871344 | D(G(z)): -0.01051914 / -0.01051904 | MSE: 0.00203537 | KLD: 0.00001188 | WGAN_F(G): 0.00105190 | WGAN_F(D): -0.01051904 | WGAN_R(D): 0.00871344 | WGAN_A(D): 0.00506955


 23%|██▎       | 800/3472 [04:54<13:45,  3.24it/s]

Loss_D: 0.02265310 | Loss_G: 0.00500074 | D(x): -0.00429392 | D(G(z)): -0.00740372 / -0.00740364 | MSE: 0.00424714 | KLD: 0.00001324 | WGAN_F(G): 0.00074036 | WGAN_F(D): -0.00740364 | WGAN_R(D): 0.00429392 | WGAN_A(D): 0.02576282


 26%|██▌       | 900/3472 [05:27<13:26,  3.19it/s]

Loss_D: 0.01740765 | Loss_G: 0.00328792 | D(x): 0.00267292 | D(G(z)): 0.00017949 / 0.00017941 | MSE: 0.00329465 | KLD: 0.00001121 | WGAN_F(G): -0.00001794 | WGAN_F(D): 0.00017941 | WGAN_R(D): -0.00267292 | WGAN_A(D): 0.01990117


 29%|██▉       | 1000/3472 [05:58<12:42,  3.24it/s]

Loss_D: 0.01829929 | Loss_G: 0.00572035 | D(x): 0.00000378 | D(G(z)): -0.00330666 / -0.00330668 | MSE: 0.00537677 | KLD: 0.00001291 | WGAN_F(G): 0.00033067 | WGAN_F(D): -0.00330668 | WGAN_R(D): -0.00000378 | WGAN_A(D): 0.02160975


 32%|███▏      | 1100/3472 [06:29<12:13,  3.23it/s]

Loss_D: 0.02885067 | Loss_G: 0.00136784 | D(x): -0.00035925 | D(G(z)): -0.00172110 / -0.00172101 | MSE: 0.00118627 | KLD: 0.00000947 | WGAN_F(G): 0.00017210 | WGAN_F(D): -0.00172101 | WGAN_R(D): 0.00035925 | WGAN_A(D): 0.03021244


 35%|███▍      | 1200/3472 [07:00<11:52,  3.19it/s]

Loss_D: 0.01186235 | Loss_G: 0.00519455 | D(x): 0.00048600 | D(G(z)): -0.00245960 / -0.00245937 | MSE: 0.00493589 | KLD: 0.00001272 | WGAN_F(G): 0.00024594 | WGAN_F(D): -0.00245937 | WGAN_R(D): -0.00048600 | WGAN_A(D): 0.01480772


 37%|███▋      | 1300/3472 [07:32<11:15,  3.22it/s]

Loss_D: 0.00703790 | Loss_G: 0.00193879 | D(x): 0.00140477 | D(G(z)): -0.00010486 / -0.00010485 | MSE: 0.00191614 | KLD: 0.00001216 | WGAN_F(G): 0.00001049 | WGAN_F(D): -0.00010485 | WGAN_R(D): -0.00140477 | WGAN_A(D): 0.00854752


 40%|████      | 1400/3472 [08:02<10:45,  3.21it/s]

Loss_D: 0.04071800 | Loss_G: 0.00430467 | D(x): 0.00517391 | D(G(z)): 0.00188279 / 0.00188272 | MSE: 0.00448030 | KLD: 0.00001264 | WGAN_F(G): -0.00018827 | WGAN_F(D): 0.00188272 | WGAN_R(D): -0.00517391 | WGAN_A(D): 0.04400920


 43%|████▎     | 1500/3472 [08:34<10:14,  3.21it/s]

Loss_D: 0.03734224 | Loss_G: 0.00267973 | D(x): 0.00640244 | D(G(z)): 0.00382764 / 0.00382766 | MSE: 0.00304860 | KLD: 0.00001389 | WGAN_F(G): -0.00038277 | WGAN_F(D): 0.00382766 | WGAN_R(D): -0.00640244 | WGAN_A(D): 0.03991703


 46%|████▌     | 1600/3472 [09:04<09:44,  3.20it/s]

Loss_D: 0.00637902 | Loss_G: 0.00140540 | D(x): 0.00798287 | D(G(z)): 0.00645625 / 0.00645616 | MSE: 0.00203689 | KLD: 0.00001413 | WGAN_F(G): -0.00064562 | WGAN_F(D): 0.00645616 | WGAN_R(D): -0.00798287 | WGAN_A(D): 0.00790573


 49%|████▉     | 1700/3472 [09:36<09:16,  3.18it/s]

Loss_D: 0.01748156 | Loss_G: 0.00333506 | D(x): 0.00845416 | D(G(z)): 0.00515166 / 0.00515174 | MSE: 0.00383690 | KLD: 0.00001333 | WGAN_F(G): -0.00051517 | WGAN_F(D): 0.00515174 | WGAN_R(D): -0.00845416 | WGAN_A(D): 0.02078398


 52%|█████▏    | 1800/3472 [10:07<08:46,  3.18it/s]

Loss_D: 0.02822577 | Loss_G: 0.00072233 | D(x): 0.00739532 | D(G(z)): 0.00603169 / 0.00603163 | MSE: 0.00131363 | KLD: 0.00001187 | WGAN_F(G): -0.00060316 | WGAN_F(D): 0.00603163 | WGAN_R(D): -0.00739532 | WGAN_A(D): 0.02958946


 55%|█████▍    | 1900/3472 [10:38<08:15,  3.17it/s]

Loss_D: 0.00624681 | Loss_G: 0.00190393 | D(x): 0.00960845 | D(G(z)): 0.00698490 / 0.00698503 | MSE: 0.00259025 | KLD: 0.00001219 | WGAN_F(G): -0.00069850 | WGAN_F(D): 0.00698503 | WGAN_R(D): -0.00960845 | WGAN_A(D): 0.00887023


 58%|█████▊    | 2000/3472 [11:09<07:43,  3.17it/s]

Loss_D: 0.01904126 | Loss_G: 0.00144547 | D(x): 0.00591287 | D(G(z)): 0.00384131 / 0.00384128 | MSE: 0.00181679 | KLD: 0.00001281 | WGAN_F(G): -0.00038413 | WGAN_F(D): 0.00384128 | WGAN_R(D): -0.00591287 | WGAN_A(D): 0.02111285


 60%|██████    | 2100/3472 [11:40<07:09,  3.20it/s]

Loss_D: 0.02405037 | Loss_G: 0.00376421 | D(x): 0.00930608 | D(G(z)): 0.00564807 / 0.00564809 | MSE: 0.00431727 | KLD: 0.00001176 | WGAN_F(G): -0.00056481 | WGAN_F(D): 0.00564809 | WGAN_R(D): -0.00930608 | WGAN_A(D): 0.02770836


 63%|██████▎   | 2200/3472 [12:11<06:36,  3.21it/s]

Loss_D: 0.01523497 | Loss_G: 0.00286786 | D(x): 0.00904536 | D(G(z)): 0.00613711 / 0.00613714 | MSE: 0.00346908 | KLD: 0.00001250 | WGAN_F(G): -0.00061371 | WGAN_F(D): 0.00613714 | WGAN_R(D): -0.00904536 | WGAN_A(D): 0.01814318


 66%|██████▌   | 2300/3472 [12:42<06:07,  3.19it/s]

Loss_D: 0.02549177 | Loss_G: 0.00247224 | D(x): 0.00618219 | D(G(z)): 0.00382088 / 0.00382089 | MSE: 0.00284206 | KLD: 0.00001227 | WGAN_F(G): -0.00038209 | WGAN_F(D): 0.00382089 | WGAN_R(D): -0.00618219 | WGAN_A(D): 0.02785307


 69%|██████▉   | 2400/3472 [13:13<05:36,  3.19it/s]

Loss_D: 0.01196991 | Loss_G: 0.00417518 | D(x): 0.00424047 | D(G(z)): 0.00093941 / 0.00093943 | MSE: 0.00425637 | KLD: 0.00001275 | WGAN_F(G): -0.00009394 | WGAN_F(D): 0.00093943 | WGAN_R(D): -0.00424047 | WGAN_A(D): 0.01527096


 72%|███████▏  | 2500/3472 [13:44<05:05,  3.18it/s]

Loss_D: 0.00680956 | Loss_G: 0.00489796 | D(x): 0.00543607 | D(G(z)): 0.00169227 / 0.00169227 | MSE: 0.00505540 | KLD: 0.00001179 | WGAN_F(G): -0.00016923 | WGAN_F(D): 0.00169227 | WGAN_R(D): -0.00543607 | WGAN_A(D): 0.01055337


 75%|███████▍  | 2600/3472 [14:16<04:30,  3.22it/s]

Loss_D: 0.00359157 | Loss_G: 0.00408825 | D(x): 0.00775151 | D(G(z)): 0.00451842 / 0.00451832 | MSE: 0.00452956 | KLD: 0.00001053 | WGAN_F(G): -0.00045183 | WGAN_F(D): 0.00451832 | WGAN_R(D): -0.00775151 | WGAN_A(D): 0.00682475


 78%|███████▊  | 2700/3472 [14:47<04:00,  3.22it/s]

Loss_D: -0.00037056 | Loss_G: 0.00323599 | D(x): 0.00588384 | D(G(z)): 0.00338734 / 0.00338740 | MSE: 0.00356417 | KLD: 0.00001056 | WGAN_F(G): -0.00033874 | WGAN_F(D): 0.00338740 | WGAN_R(D): -0.00588384 | WGAN_A(D): 0.00212589


 81%|████████  | 2800/3472 [15:18<03:32,  3.16it/s]

Loss_D: 0.00961222 | Loss_G: 0.00238491 | D(x): 0.00547094 | D(G(z)): 0.00330648 / 0.00330663 | MSE: 0.00270595 | KLD: 0.00000962 | WGAN_F(G): -0.00033066 | WGAN_F(D): 0.00330663 | WGAN_R(D): -0.00547094 | WGAN_A(D): 0.01177654


 84%|████████▎ | 2900/3472 [15:50<03:02,  3.14it/s]

Loss_D: -0.00023033 | Loss_G: 0.00505923 | D(x): 0.01161445 | D(G(z)): 0.00813605 / 0.00813581 | MSE: 0.00586256 | KLD: 0.00001026 | WGAN_F(G): -0.00081358 | WGAN_F(D): 0.00813581 | WGAN_R(D): -0.01161445 | WGAN_A(D): 0.00324831


 86%|████████▋ | 3000/3472 [16:21<02:28,  3.17it/s]

Loss_D: -0.00066302 | Loss_G: 0.00808441 | D(x): 0.01204837 | D(G(z)): 0.00746851 / 0.00746837 | MSE: 0.00882051 | KLD: 0.00001073 | WGAN_F(G): -0.00074684 | WGAN_F(D): 0.00746837 | WGAN_R(D): -0.01204837 | WGAN_A(D): 0.00391698


 89%|████████▉ | 3100/3472 [16:52<01:55,  3.21it/s]

Loss_D: 0.02601753 | Loss_G: 0.00161393 | D(x): 0.01119963 | D(G(z)): 0.00843883 / 0.00843886 | MSE: 0.00244419 | KLD: 0.00001363 | WGAN_F(G): -0.00084389 | WGAN_F(D): 0.00843886 | WGAN_R(D): -0.01119963 | WGAN_A(D): 0.02877829


 92%|█████████▏| 3200/3472 [17:23<01:25,  3.19it/s]

Loss_D: 0.02180595 | Loss_G: 0.00148547 | D(x): 0.01028687 | D(G(z)): 0.00806926 / 0.00806929 | MSE: 0.00227811 | KLD: 0.00001429 | WGAN_F(G): -0.00080693 | WGAN_F(D): 0.00806929 | WGAN_R(D): -0.01028687 | WGAN_A(D): 0.02402353


 95%|█████████▌| 3300/3472 [17:54<00:53,  3.21it/s]

Loss_D: 0.00391066 | Loss_G: 0.00196430 | D(x): 0.01150701 | D(G(z)): 0.00933607 / 0.00933610 | MSE: 0.00288512 | KLD: 0.00001279 | WGAN_F(G): -0.00093361 | WGAN_F(D): 0.00933610 | WGAN_R(D): -0.01150701 | WGAN_A(D): 0.00608157


 98%|█████████▊| 3400/3472 [18:25<00:23,  3.13it/s]

Loss_D: 0.00819048 | Loss_G: 0.00147069 | D(x): 0.01349594 | D(G(z)): 0.01083410 / 0.01083406 | MSE: 0.00254234 | KLD: 0.00001175 | WGAN_F(G): -0.00108341 | WGAN_F(D): 0.01083406 | WGAN_R(D): -0.01349594 | WGAN_A(D): 0.01085236


100%|██████████| 3472/3472 [18:48<00:00,  3.08it/s]


Saving Model and Optimizer weights.....


In [4]:
!ls ./checkpoint



__notebook_source__.ipynb
