# Deep Colorization
### Deep learning final project for conversion of gray scale images to rgb
### Contributors: Bhumi Bhanushali, Avinash Hemaeshwara Raju, Kathan Nilesh Mehta, Atulay Ravishankar

### Download Dataset

In [0]:
# !wget -N images.cocodataset.org/zips/train2017.zip
# !wget -N images.cocodataset.org/zips/val2017.zip
# !wget -N images.cocodataset.org/zips/test2017.zip
# !pip install tensorboard
# !tensorboard --logdir=runs

In [0]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


### Import Modules

In [0]:
import os
import time
import numpy as np 
import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
import cv2
from torchvision.utils import save_image

### Configuration

In [0]:
class Configuration:
    model_file_name = 'checkpoint.pt'
    load_model_to_train = True
    load_model_to_test = False
    device = "cuda" if torch.cuda.is_available() else "cpu"
    point_batches = 500

### Hyper Parameters

In [0]:
class HyperParameters:
    epochs = 50
    batch_size_train = 32
    batch_size_val = 16
    learning_rate = 0.001
    num_workers = 16
    learning_rate_decay = 0.5

In [0]:
config = Configuration()
hparams = HyperParameters()
print('Device:',config.device)

Device: cuda


### Custom Dataloader

In [0]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, process_type):
        self.root_dir = root_dir
        self.files = [f for f in os.listdir(root_dir)]
        self.process_type = process_type
        print('File[0]:',self.files[0],'| Total Files:', len(self.files), '| Process:',self.process_type,)

    def __len__(self):
         return len(self.files)

    def __getitem__(self, index):
        try:
            #*** Read the image from file ***
            self.rgb_img = cv2.imread(os.path.join(self.root_dir,self.files[index])).astype(np.float32) 
            self.rgb_img /= 255.0 
            
            #*** Resize the color image to pass to encoder ***
            rgb_encoder_img = cv2.resize(self.rgb_img, (224, 224))
            
            #*** Resize the color image to pass to decoder ***
            rgb_inception_img = cv2.resize(self.rgb_img, (300, 300))
            
            ''' Encoder Images '''
            #*** Convert the encoder color image to normalized lab space ***
            self.lab_encoder_img = cv2.cvtColor(rgb_encoder_img,cv2.COLOR_BGR2Lab) 
            
            #*** Splitting the lab images into l-channel, a-channel, b-channel ***
            l_encoder_img, a_encoder_img, b_encoder_img = self.lab_encoder_img[:,:,0],self.lab_encoder_img[:,:,1],self.lab_encoder_img[:,:,2]
            
            #*** Normalizing l-channel between [-1,1] ***
            l_encoder_img = l_encoder_img/50.0 - 1.0
            
            #*** Repeat the l-channel to 3 dimensions ***
            l_encoder_img = torchvision.transforms.ToTensor()(l_encoder_img)
            l_encoder_img = l_encoder_img.expand(3,-1,-1)
            
            #*** Normalize a and b channels and concatenate ***
            a_encoder_img = (a_encoder_img/128.0)
            b_encoder_img = (b_encoder_img/128.0)
            a_encoder_img = torch.stack([torch.Tensor(a_encoder_img)])
            b_encoder_img = torch.stack([torch.Tensor(b_encoder_img)])
            ab_encoder_img = torch.cat([a_encoder_img, b_encoder_img], dim=0)
            
            ''' Inception Images '''
            #*** Convert the inception color image to lab space ***
            self.lab_inception_img = cv2.cvtColor(rgb_inception_img,cv2.COLOR_BGR2Lab)
            
            #*** Extract the l-channel of inception lab image *** 
            l_inception_img = self.lab_inception_img[:,:,0]/50.0 - 1.0
             
            #*** Convert the inception l-image to torch Tensor and stack it in 3 channels ***
            l_inception_img = torchvision.transforms.ToTensor()(l_inception_img)
            l_inception_img = l_inception_img.expand(3,-1,-1)
            
            ''' return images to data-loader '''
            rgb_encoder_img = torchvision.transforms.ToTensor()(rgb_encoder_img)
            return l_encoder_img, ab_encoder_img, l_inception_img, rgb_encoder_img, self.files[index]
        
        except Exception as e:
            print('Exception at ',self.files[index], e)
            return torch.tensor(-1), torch.tensor(-1), torch.tensor(-1), torch.tensor(-1), 'Error'

    def show_rgb(self, index):
        self.__getitem__(index)
        print("RGB image size:", self.rgb_img.shape)        
        cv2.imshow(self.rgb_img)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

    def show_lab_encoder(self, index):
        self.__getitem__(index)
        print("Encoder Lab image size:", self.lab_encoder_img.shape)
        cv2.imshow(self.lab_encoder_img)
        c2.waitKey(0)
        cv2.destroyAllWindows()

    def show_lab_inception(self, index):
        self.__getitem__(index)
        print("Inception Lab image size:", self.lab_inception_img.shape)
        cv2.imshow(self.lab_inception_img)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
    
    def show_other_images(self, index):
        a,b,c,d,_ = self.__getitem__(index)
        print("Encoder l channel image size:",a.shape)
        cv2.imshow((a.detach().numpy().transpose(1,2,0)))
        cv2.waitKey(0)
        cv2.destroyAllWindows()
        print("Encoder ab channel image size:",b.shape)
        cv2.imshow((b.detach().numpy().transpose(1,2,0)[:,:,0]))
        cv2.waitKey(0)
        cv2.destroyAllWindows()
        cv2.imshow((b.detach().numpy().transpose(1,2,0)[:,:,1]))
        cv2.waitKey(0)
        cv2.destroyAllWindows()
        print("Inception l channel image size:",c.shape)
        cv2.imshow(c.detach().numpy().transpose(1,2,0))
        cv2.waitKey(0)
        cv2.destroyAllWindows()
        print("Color resized image size:",d.shape)
        cv2.imshow(d.detach().numpy().transpose(1,2,0))
        cv2.waitKey(0)
        cv2.destroyAllWindows()

In [0]:
!unzip "/content/drive/My Drive/My ImageNet/train.zip" 

Archive:  /content/drive/My Drive/My ImageNet/train.zip
   creating: train/
  inflating: train/n01491361_4551.JPEG  
 extracting: train/n01491361_4555.JPEG  
  inflating: train/n01491361_4572.JPEG  
  inflating: train/n01491361_458.JPEG  
  inflating: train/n01491361_4586.JPEG  
  inflating: train/n01491361_4599.JPEG  
  inflating: train/n01491361_4615.JPEG  
  inflating: train/n01491361_462.JPEG  
  inflating: train/n01491361_4629.JPEG  
  inflating: train/n01491361_463.JPEG  
  inflating: train/n01491361_4634.JPEG  
  inflating: train/n01491361_4646.JPEG  
  inflating: train/n01491361_4658.JPEG  
  inflating: train/n01491361_4677.JPEG  
  inflating: train/n01491361_468.JPEG  
  inflating: train/n01491361_47.JPEG  
  inflating: train/n01491361_470.JPEG  
  inflating: train/n01491361_4701.JPEG  
  inflating: train/n01491361_4703.JPEG  
  inflating: train/n01491361_4704.JPEG  
 extracting: train/n01491361_4707.JPEG  
  inflating: train/n01491361_4716.JPEG  
  inflating: train/n01491361_

In [0]:
!unzip "/content/drive/My Drive/My ImageNet/val.zip"

Archive:  /content/drive/My Drive/My ImageNet/val.zip
   creating: val/
  inflating: val/ILSVRC2012_val_00000001.JPEG  
  inflating: val/ILSVRC2012_val_00000002.JPEG  
  inflating: val/ILSVRC2012_val_00000003.JPEG  
  inflating: val/ILSVRC2012_val_00000004.JPEG  
  inflating: val/ILSVRC2012_val_00000005.JPEG  
  inflating: val/ILSVRC2012_val_00000006.JPEG  
 extracting: val/ILSVRC2012_val_00000007.JPEG  
  inflating: val/ILSVRC2012_val_00000008.JPEG  
  inflating: val/ILSVRC2012_val_00000009.JPEG  
 extracting: val/ILSVRC2012_val_00000010.JPEG  
  inflating: val/ILSVRC2012_val_00000011.JPEG  
 extracting: val/ILSVRC2012_val_00000012.JPEG  
  inflating: val/ILSVRC2012_val_00000013.JPEG  
  inflating: val/ILSVRC2012_val_00000014.JPEG  
  inflating: val/ILSVRC2012_val_00000015.JPEG  
  inflating: val/ILSVRC2012_val_00000016.JPEG  
 extracting: val/ILSVRC2012_val_00000017.JPEG  
  inflating: val/ILSVRC2012_val_00000018.JPEG  
  inflating: val/ILSVRC2012_val_00000019.JPEG  
  inflating: val

In [0]:
# train_dataset.show_rgb(0)
# train_dataset.show_lab_encoder(0)
# train_dataset.show_lab_inception(0)
# train_dataset.show_other_images(0)

### Encoder

In [0]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder,self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=2, padding=1), 
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
            
        )

    def forward(self, x):
        self.model = self.model.float()
        return self.model(x.float())

### Fusion Layer

In [0]:
class FusionLayer(nn.Module):
    def __init__(self):
        super(FusionLayer,self).__init__()

    def forward(self, inputs, mask=None):
        ip, emb = inputs
        emb = torch.stack([torch.stack([emb],dim=2)],dim=3)
        emb = emb.repeat(1,1,ip.shape[2],ip.shape[3])
        fusion = torch.cat((ip,emb),1)
        return fusion

### Decoder

In [0]:
class Decoder(nn.Module):
    def __init__(self, input_depth):
        super(Decoder,self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels=input_depth, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2.0),


            nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2.0),


            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2.0),

            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels=2, out_channels=2, kernel_size=1, stride=1, padding=0),     
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

### Network Definition

In [0]:
class Colorization(nn.Module):
    def __init__(self, depth_after_fusion):
        super(Colorization,self).__init__()
        self.encoder = Encoder()
        self.fusion = FusionLayer()
        self.after_fusion = nn.Conv2d(in_channels=1256, out_channels=depth_after_fusion,kernel_size=1, stride=1,padding=0)
        self.bnorm = nn.BatchNorm2d(256)
        self.decoder = Decoder(depth_after_fusion)

    def forward(self, img_l, img_emb):
        img_enc = self.encoder(img_l)
        fusion = self.fusion([img_enc, img_emb])
        fusion = self.after_fusion(fusion)
        fusion = self.bnorm(fusion)
        return self.decoder(fusion)

def init_weights(m):
    if type(m) == nn.Conv2d or type(m) == nn.Linear:
        torch.nn.init.xavier_normal_(m.weight.data)

### Architecture Pipeline

In [0]:
if config.load_model_to_train or config.load_model_to_test:
    checkpoint = torch.load("/content/drive/My Drive/My ImageNet/Models/checkpoint6.pt",map_location=torch.device(config.device))
    model = checkpoint['model']
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(config.device) 
    optimizer = checkpoint['optimizer']
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    print('Loaded pretrain model | Previous train loss:',checkpoint['train_loss'])
else:
    model = Colorization(256).to(config.device) 
    model.apply(init_weights)
    optimizer = torch.optim.Adam(model.parameters(),lr=hparams.learning_rate, weight_decay=1e-6)

inception_model = models.inception_v3(pretrained=True).float().to(config.device)
inception_model = inception_model.float()
inception_model.eval()
loss_criterion = torch.nn.MSELoss(reduction='mean').to(config.device)
milestone_list  = list(range(0,hparams.epochs,2))
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, verbose=True)
writer = SummaryWriter()

Loaded pretrain model | Previous train loss: 0.38778666923015787 | Previous validation loss:


Downloading: "https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth" to /root/.cache/torch/checkpoints/inception_v3_google-1a9a5a14.pth
100%|██████████| 104M/104M [00:06<00:00, 16.0MB/s]


In [0]:
print(model)

Colorization(
  (encoder): Encoder(
    (model): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ReLU(inplace=True)
      (9): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (10): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (11): ReLU(inplace=True)
      (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (13): BatchNorm2d(256, eps=1e-05, momentum=0.1, affi

### Data Loaders

In [0]:
if not config.load_model_to_test:
    train_dataset = CustomDataset('/content/train','train')
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=hparams.batch_size_train, shuffle=True, num_workers=hparams.num_workers)
    

    validataion_dataset = CustomDataset('/content/val','validation')
    validation_dataloader = torch.utils.data.DataLoader(validataion_dataset, batch_size=hparams.batch_size_val, shuffle=False, num_workers=hparams.num_workers)

File[0]: n02927161_12575.JPEG | Total Files: 65400 | Process: train
File[0]: ILSVRC2012_val_00002199.JPEG | Total Files: 5000 | Process: validation


In [0]:
if not config.load_model_to_test:
    print('Train:',len(train_dataloader), '| Total Images:',len(train_dataloader)*hparams.batch_size_train)
    print('Valid:',len(validation_dataloader), '| Total Images:',len(validation_dataloader)*hparams.batch_size_val)

Train: 2044 | Total Images: 65408
Valid: 313 | Total Images: 5008


### Training & Validation Pipeline

In [0]:
if not config.load_model_to_test:
    for epoch in range(20,hparams.epochs):
        print('Starting epoch:',epoch+1)

        #*** Training step ***
        loop_start = time.time()
        avg_loss = 0.0
        batch_loss = 0.0
        main_start = time.time()
        model.train()

        for idx,(img_l_encoder, img_ab_encoder, img_l_inception, img_rgb, file_name) in enumerate(train_dataloader):
            #*** Skip bad data ***
            if not img_l_encoder.ndim:
                continue

            #*** Move data to GPU if available ***
            img_l_encoder = img_l_encoder.to(config.device)
            img_ab_encoder = img_ab_encoder.to(config.device)
            img_l_inception = img_l_inception.to(config.device)

            #*** Initialize Optimizer ***
            optimizer.zero_grad()

            #*** Forward Propagation ***
            img_embs = inception_model(img_l_inception.float())
            output_ab = model(img_l_encoder,img_embs)

            #*** Back propogation ***
            loss = loss_criterion(output_ab, img_ab_encoder.float())
            loss.backward()

            #*** Weight Update ****
            optimizer.step()

            #*** Loss Calculation ***
            avg_loss += loss.item()
            batch_loss += loss.item()

            #*** Print stats after every point_batches ***
            if idx%config.point_batches==0: 
                loop_end = time.time()   
                print('Batch:',idx, '| Processing time for',config.point_batches,':',loop_end-loop_start,'s | Batch Loss:', (batch_loss/config.point_batches)*100)
                loop_start = time.time()
                batch_loss = 0.0

        #*** Print Training Data Stats ***
        train_loss = avg_loss/len(train_dataloader)*hparams.batch_size_train
        writer.add_scalar('Loss/train', train_loss, epoch)
        print('Training Loss:',train_loss,'| Processed in ',time.time()-main_start,'s')

        #*** Save the Model to disk ***
        model_file_name = "/content/drive/My Drive/My ImageNet/Models/checkpoint"+str(epoch)+".pt"
        checkpoint = {'model': model,'model_state_dict': model.state_dict(), 'optimizer' : optimizer,'optimizer_state_dict' : optimizer.state_dict(), 'train_loss':train_loss}
        torch.save(checkpoint, model_file_name)
        print("Model saved at:", model_file_name)

        #*** Validation Step ***       
        avg_loss = 0.0
        loop_start = time.time()
        #*** Intialize Model to Eval Mode for validation ***
        model.eval()
        for idx,(img_l_encoder, img_ab_encoder, img_l_inception, img_rgb, file_name) in enumerate(validation_dataloader):
            #*** Skip bad data ***
            if not img_l_encoder.ndim:
                continue

            #*** Move data to GPU if available ***
            img_l_encoder = img_l_encoder.to(config.device)
            img_ab_encoder = img_ab_encoder.to(config.device)
            img_l_inception = img_l_inception.to(config.device)

            #*** Forward Propagation ***
            img_embs = inception_model(img_l_inception.float())
            output_ab = model(img_l_encoder,img_embs)

            #*** Loss Calculation ***
            loss = loss_criterion(output_ab, img_ab_encoder.float())
            avg_loss += loss.item()

        val_loss = avg_loss/len(validation_dataloader)*hparams.batch_size_val
        writer.add_scalar('Loss/validation', val_loss, epoch)
        print('Validation Loss:', val_loss,'| Processed in ',time.time()-loop_start,'s')

        #*** Reduce Learning Rate ***
        scheduler.step(val_loss)

        

Starting epoch: 21
Batch: 0 | Processing time for 500 : 5.968760967254639 s | Batch Loss: 0.002352140471339226
Batch: 500 | Processing time for 500 : 433.007691860199 s | Batch Loss: 1.081930439174175
Batch: 1000 | Processing time for 500 : 431.5642263889313 s | Batch Loss: 1.0676004709675908
Batch: 1500 | Processing time for 500 : 432.11999130249023 s | Batch Loss: 1.0768110217526554
Batch: 2000 | Processing time for 500 : 431.66616773605347 s | Batch Loss: 1.077400539536029
Training Loss: 0.3441285025437751 | Processed in  1771.7097947597504 s


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Model saved at: /content/drive/My Drive/My ImageNet/Models/checkpoint20.pt
Validation Loss: 0.19106617960305258 | Processed in  59.02648138999939 s
Starting epoch: 22
Batch: 0 | Processing time for 500 : 7.690591335296631 s | Batch Loss: 0.002081815153360367
Batch: 500 | Processing time for 500 : 443.29405879974365 s | Batch Loss: 1.0774534903466702
Batch: 1000 | Processing time for 500 : 442.49887585639954 s | Batch Loss: 1.082331484090537
Batch: 1500 | Processing time for 500 : 442.6472897529602 s | Batch Loss: 1.070736100897193
Batch: 2000 | Processing time for 500 : 442.5085611343384 s | Batch Loss: 1.069081163033843
Training Loss: 0.3440801344972301 | Processed in  1816.9682426452637 s
Model saved at: /content/drive/My Drive/My ImageNet/Models/checkpoint21.pt
Validation Loss: 0.191645985998856 | Processed in  59.20410656929016 s
Starting epoch: 23
Batch: 0 | Processing time for 500 : 6.0818867683410645 s | Batch Loss: 0.0015373071655631065
Batch: 500 | Processing time for 500 : 44

### Inference

In [0]:
!unzip "/content/drive/My Drive/My ImageNet/test.zip"

Archive:  /content/drive/My Drive/My ImageNet/test.zip
   creating: test/
  inflating: test/ILSVRC2012_test_00000001.JPEG  
  inflating: test/ILSVRC2012_test_00000002.JPEG  
  inflating: test/ILSVRC2012_test_00000003.JPEG  
  inflating: test/ILSVRC2012_test_00000004.JPEG  
  inflating: test/ILSVRC2012_test_00000005.JPEG  
 extracting: test/ILSVRC2012_test_00000006.JPEG  
  inflating: test/ILSVRC2012_test_00000007.JPEG  
  inflating: test/ILSVRC2012_test_00000008.JPEG  
  inflating: test/ILSVRC2012_test_00000009.JPEG  
  inflating: test/ILSVRC2012_test_00000010.JPEG  
 extracting: test/ILSVRC2012_test_00000011.JPEG  
 extracting: test/ILSVRC2012_test_00000012.JPEG  
  inflating: test/ILSVRC2012_test_00000013.JPEG  
 extracting: test/ILSVRC2012_test_00000014.JPEG  
  inflating: test/ILSVRC2012_test_00000015.JPEG  
 extracting: test/ILSVRC2012_test_00000016.JPEG  
  inflating: test/ILSVRC2012_test_00000017.JPEG  
 extracting: test/ILSVRC2012_test_00000018.JPEG  
 extracting: test/ILSVRC20

In [0]:
test_dataset = CustomDataset('/content/test','test')
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)
print('Test: ',len(test_dataloader), '| Total Image:',len(test_dataloader))

File[0]: ILSVRC2012_test_00004379.JPEG | Total Files: 7000 | Process: test
Test:  7000 | Total Image: 7000


##### Convert Tensor Image -> Numpy Image -> Color  Image -> Tensor Image

In [0]:
def concatente_and_colorize(im_lab, img_ab):
    # Assumption is that im_lab is of size [1,3,224,224]
    #print(im_lab.size(),img_ab.size())
    np_img = im_lab[0].cpu().detach().numpy().transpose(1,2,0)
    lab = np.empty([*np_img.shape[0:2], 3],dtype=np.float32)
    lab[:, :, 0] = np.squeeze(((np_img + 1) * 50))
    lab[:, :, 1:] = img_ab[0].cpu().detach().numpy().transpose(1,2,0) * 127
    np_img = cv2.cvtColor(lab,cv2.COLOR_Lab2RGB) 
    color_im = torch.stack([torchvision.transforms.ToTensor()(np_img)],dim=0)
    return color_im

In [0]:
#*** Inference Step ***
avg_loss = 0.0
loop_start = time.time()
batch_start = time.time()
batch_loss = 0.0

for idx,(img_l_encoder, img_ab_encoder, img_l_inception, img_rgb, file_name) in enumerate(test_dataloader):
        #*** Skip bad data ***
        if not img_l_encoder.ndim:
            continue
            
        #*** Move data to GPU if available ***
        img_l_encoder = img_l_encoder.to(config.device)
        img_ab_encoder = img_ab_encoder.to(config.device)
        img_l_inception = img_l_inception.to(config.device)
        
        #*** Intialize Model to Eval Mode ***
        model.eval()
        
        #*** Forward Propagation ***
        img_embs = inception_model(img_l_inception.float())
        output_ab = model(img_l_encoder,img_embs)
        
        #*** Adding l channel to ab channels ***
        color_img = concatente_and_colorize(torch.stack([img_l_encoder[:,0,:,:]],dim=1),output_ab)
        color_img_jpg = color_img[0].detach().numpy().transpose(1,2,0)
        # cv2.imshow(color_img_jpg)
        # cv2.waitKey(0)
        # cv2.destroyAllWindows()
        # cv2.imwrite('outputs/'+file_name[0],color_img_jpg*255)
        save_image(color_img[0],'/content/drive/My Drive/My ImageNet/Outputs/'+file_name[0])

#       #*** Printing to Tensor Board ***
        grid = torchvision.utils.make_grid(color_img)
        writer.add_image('Output Lab Images', grid, 0)
        
        #*** Loss Calculation ***
        loss = loss_criterion(output_ab, img_ab_encoder.float())
        avg_loss += loss.item()
        batch_loss += loss.item()

        if idx%config.point_batches==0: 
            batch_end = time.time()   
            print('Batch:',idx, '| Processing time for',config.point_batches,':',str(batch_end-batch_start)+'s', '| Batch Loss:', batch_loss/config.point_batches)
            batch_start = time.time()
            batch_loss = 0.0
        
test_loss = avg_loss/len(test_dataloader)
print('Test Loss:',avg_loss/len(test_dataloader),'| Processed in ',str(time.time()-loop_start)+'s')
writer.close() 


Batch: 0 | Processing time for 500 : 0.09049606323242188s | Batch Loss: 3.226979821920395e-05
Batch: 500 | Processing time for 500 : 41.32481288909912s | Batch Loss: 0.011836117892235052
Batch: 1000 | Processing time for 500 : 42.09878492355347s | Batch Loss: 0.011495277063688263
Batch: 1500 | Processing time for 500 : 40.68585395812988s | Batch Loss: 0.012319104019930818
Batch: 2000 | Processing time for 500 : 41.07590675354004s | Batch Loss: 0.011777173503127415
Batch: 2500 | Processing time for 500 : 41.33838677406311s | Batch Loss: 0.012586484990002646
Batch: 3000 | Processing time for 500 : 40.94072437286377s | Batch Loss: 0.011471823458210566
Batch: 3500 | Processing time for 500 : 41.4523561000824s | Batch Loss: 0.012225633652691612
Batch: 4000 | Processing time for 500 : 41.432453870773315s | Batch Loss: 0.011042956133198459
Batch: 4500 | Processing time for 500 : 41.63643002510071s | Batch Loss: 0.012260420368285849
Batch: 5000 | Processing time for 500 : 41.337135314941406s |

In [0]:
writer.close()