In [3]:
#mount drive
from google.colab import drive
drive.mount('/content/drive')
!ls

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
drive  sample_data


In [4]:
# move into project directory
repo_name = "Image-Colorization"
%cd /content/drive/MyDrive/Personal-Projects/$repo_name
!ls

/content/drive/MyDrive/Personal-Projects/Image-Colorization
common	     datautils	  Index.ipynb  output			 README.md
config.yaml  experiments  index.py     preprocess_imagenette.py  requirements.txt
data	     Index_bc.py  models       project-structure.md	 run.yaml


In [5]:
# set up environment
# comment out if not required
'''
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install matplotlib numpy pandas pyyaml opencv-python
'''

'\n!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118\n!pip install matplotlib numpy pandas pyyaml opencv-python\n'

In [6]:
# this cell is for downloading data.
# as of yet data is not hosted and is available in the private data folder
#!tar -xvzf data/imagenette2-320.tgz
#!unzip -qq data/imagenette2-320-processed.zip -d data/


In [7]:
# setup some imports
#custom imports
from common.transforms import ToTensor
from datautils.datareader import DataReader
from datautils.datareader import ImagenetteReader
from datautils.dataset import CustomDataset
from datautils.dataset import CustomImagenetDataset
from common.utils import get_exp_params, init_config, get_config, save2config, get_saved_model, get_modelinfo, get_model_data
from models.unet import UNet
from models.conv_net import ConvNet
from models.custom_models import get_model

#py imports
import random
import numpy as np
import os
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
from experiments.experiments import Experiment
from common.visualization import Visualization
from experiments.test_model import ModelTester

In [8]:
# initialize directories and config data
init_config()
config = get_config()
print('Config parameters\n')
print(config)

Config parameters

{'X_key': 'L', 'data_dir': '/content/drive/MyDrive/Personal-Projects/Image-Colorization/data', 'device': 'cuda', 'output_dir': '/content/drive/MyDrive/Personal-Projects/Image-Colorization/output', 'root_dir': '/content/drive/MyDrive/Personal-Projects/Image-Colorization', 'use_gpu': True, 'y_key': 'AB'}


In [9]:
# read experiment parameters
exp_params = get_exp_params()
print('Experiment parameters\n')
print(exp_params)

Experiment parameters

{'transform': {'resize_dim': 256, 'crop_dim': 224}, 'train': {'shuffle_data': True, 'batch_size': 32, 'val_split_method': 'fixed-split', 'k': 3, 'val_percentage': 20, 'loss': 'l1', 'epoch_interval': 20, 'num_epochs': 1500}, 'model': {'name': 'conv_net', 'optimizer': 'Adam', 'lr': 0.0001, 'weight_decay': 1e-07, 'amsgrad': True, 'momentum': 0.8, 'build_on_pretrained': False, 'pretrained_filename': '/models/checkpoints/last_model.pt'}, 'dataset': {'name': 'imagenette', 'size': 'subset'}}


In [10]:
#initialize randomness seed
seed = 123
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

In [11]:
# read and load custom data

#save X_key and y_key
save2config('X_key', 'L')
save2config('y_key', 'AB')

if exp_params['dataset']['name'] == 'custom':
    #preprocess data or load preprocessed data (custom)
    dr = DataReader()
    ds = dr.get_split_data()
    Ltr, ABtr, ftr_len = ds['Ltr'], ds['ABtr'], ds['ftr_len']
    Lte, ABte, te_len = ds['Lte'], ds['ABte'], ds['te_len']
    print('Shape of X and y:', ds['Ltr'].shape, ds['ABtr'].shape)

    #transform data
    composed_transforms =  transforms.Compose([
        ToTensor()
    ])
    #convert to dataset
    ftr_dataset = CustomDataset(Ltr, ABtr, ftr_len)
    te_dataset = CustomDataset(Lte, ABte, te_len)
    smlen = int(0.01 * len(ftr_dataset))
    smftr_dataset = torch.utils.data.Subset(ftr_dataset, list(range(smlen)))
    smftrte_dataset = torch.utils.data.Subset(ftr_dataset, list(range(10)))
    smtelen = int(0.1 * len(te_dataset))
    smfte_dataset = torch.utils.data.Subset(te_dataset, list(range(smtelen)))
    print('Full train dataset length:', len(ftr_dataset))
    print('Test dataset length:', len(te_dataset))
    print('Subset train dataset length:', smlen)
    print('Subset test dataset length:', smtelen, '\n')
elif exp_params['dataset']['name'] == 'imagenette':
    dr = ImagenetteReader()
    train_paths, test_paths = dr.get_data_filepaths()
    ftr_dataset = CustomImagenetDataset(train_paths)
    te_dataset = CustomImagenetDataset(test_paths)
    smlen = int(0.01 * len(ftr_dataset))
    smftr_dataset = torch.utils.data.Subset(ftr_dataset, list(range(smlen)))
    smtelen = int(0.01 * len(te_dataset))
    smfte_dataset = torch.utils.data.Subset(te_dataset, list(range(smtelen)))
    smftrte_dataset = torch.utils.data.Subset(smftr_dataset, list(range(10)))
    print('Full train dataset length:', len(ftr_dataset))
    print('Test dataset length:', len(te_dataset))
    print('Subset train dataset length:', smlen)
    print('Subset test dataset length:', smtelen, '\n')
else:
    raise SystemError('Invalid dataset name passed!')



Full train dataset length: 9469
Test dataset length: 3925
Subset train dataset length: 94
Subset test dataset length: 39 



In [None]:
# model training

if exp_params['dataset']['name'] == 'custom':
    if exp_params['dataset']['size'] == 'subset':
        #model training with small dataset (custom)
        exp = Experiment(exp_params["model"]["name"], smftr_dataset)
        model_history = exp.train()
    else:
        #model training with full dataset (custom)
        exp = Experiment(exp_params["model"]["name"], ftr_dataset)
        model_history = exp.train()
elif exp_params['dataset']['name'] == 'imagenette':
    if exp_params['dataset']['size'] == 'subset':
        #model training with small dataset (imagenette)
        composed_transforms =  transforms.Compose([
            #transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225]),
            transforms.Resize(exp_params['transform']['resize_dim']),
            transforms.CenterCrop(exp_params['transform']['crop_dim'])
        ])
        exp = Experiment(exp_params["model"]["name"], smftr_dataset, composed_transforms, 'imagenette')
        model_history = exp.train()
    else:
        #model training with full dataset (imagenette)
        composed_transforms =  transforms.Compose([
            #transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225]),
            transforms.Resize(exp_params['transform']['resize_dim']),
            transforms.CenterCrop(exp_params['transform']['crop_dim'])
        ])
        exp = Experiment(exp_params["model"]["name"], ftr_dataset, composed_transforms, 'imagenette')
        model_history = exp.train()
else:
    raise SystemError('Invalid dataset name passed!')

Running straight split
	Running Epoch 1


		Running through training set: 100%|██████████| 3/3 [00:48<00:00, 16.22s/it]
		Running through validation set: 100%|██████████| 1/1 [00:10<00:00, 10.22s/it]


	Epoch 1 Training Loss: 0.16857648209521645
	Epoch 1 Validation Loss: 0.14794088900089264
	Running Epoch 20


		Running through training set: 100%|██████████| 3/3 [00:07<00:00,  2.65s/it]
		Running through validation set: 100%|██████████| 1/1 [00:00<00:00,  1.31it/s]


	Epoch 20 Training Loss: 0.09103772436317645
	Epoch 20 Validation Loss: 0.1294366866350174
	Running Epoch 40


		Running through training set: 100%|██████████| 3/3 [00:07<00:00,  2.53s/it]
		Running through validation set: 100%|██████████| 1/1 [00:00<00:00,  1.31it/s]


	Epoch 40 Training Loss: 0.08427496450512033
	Epoch 40 Validation Loss: 0.11794447153806686
	Running Epoch 60


		Running through training set: 100%|██████████| 3/3 [00:07<00:00,  2.53s/it]
		Running through validation set: 100%|██████████| 1/1 [00:00<00:00,  1.18it/s]


	Epoch 60 Training Loss: 0.08155597040527746
	Epoch 60 Validation Loss: 0.1223636046051979
	Running Epoch 80


		Running through training set: 100%|██████████| 3/3 [00:07<00:00,  2.63s/it]
		Running through validation set: 100%|██████████| 1/1 [00:00<00:00,  1.26it/s]


	Epoch 80 Training Loss: 0.0810054893556394
	Epoch 80 Validation Loss: 0.11728950589895248
	Running Epoch 100


		Running through training set: 100%|██████████| 3/3 [00:07<00:00,  2.55s/it]
		Running through validation set: 100%|██████████| 1/1 [00:00<00:00,  1.31it/s]


	Epoch 100 Training Loss: 0.07980803125783016
	Epoch 100 Validation Loss: 0.11616071313619614
	Running Epoch 120


		Running through training set: 100%|██████████| 3/3 [00:07<00:00,  2.60s/it]
		Running through validation set: 100%|██████████| 1/1 [00:00<00:00,  1.16it/s]


	Epoch 120 Training Loss: 0.08022807420868623
	Epoch 120 Validation Loss: 0.11876237392425537
	Running Epoch 140


		Running through training set: 100%|██████████| 3/3 [00:07<00:00,  2.59s/it]
		Running through validation set: 100%|██████████| 1/1 [00:00<00:00,  1.31it/s]


	Epoch 140 Training Loss: 0.0779797779886346
	Epoch 140 Validation Loss: 0.11280305683612823
	Running Epoch 160


		Running through training set: 100%|██████████| 3/3 [00:07<00:00,  2.59s/it]
		Running through validation set: 100%|██████████| 1/1 [00:00<00:00,  1.31it/s]


	Epoch 160 Training Loss: 0.0774587808470977
	Epoch 160 Validation Loss: 0.12329990416765213
	Running Epoch 180


		Running through training set: 100%|██████████| 3/3 [00:07<00:00,  2.56s/it]
		Running through validation set: 100%|██████████| 1/1 [00:00<00:00,  1.06it/s]


	Epoch 180 Training Loss: 0.0756089891258039
	Epoch 180 Validation Loss: 0.12393628060817719
	Running Epoch 200


		Running through training set: 100%|██████████| 3/3 [00:07<00:00,  2.61s/it]
		Running through validation set: 100%|██████████| 1/1 [00:00<00:00,  1.32it/s]


	Epoch 200 Training Loss: 0.07505855983809422
	Epoch 200 Validation Loss: 0.12292800098657608
	Running Epoch 220


		Running through training set: 100%|██████████| 3/3 [00:07<00:00,  2.54s/it]
		Running through validation set: 100%|██████████| 1/1 [00:00<00:00,  1.34it/s]


	Epoch 220 Training Loss: 0.0722540452292091
	Epoch 220 Validation Loss: 0.11586832255125046
	Running Epoch 240


		Running through training set: 100%|██████████| 3/3 [00:07<00:00,  2.57s/it]
		Running through validation set: 100%|██████████| 1/1 [00:00<00:00,  1.29it/s]


	Epoch 240 Training Loss: 0.07084737561250988
	Epoch 240 Validation Loss: 0.1128871813416481
	Running Epoch 260


		Running through training set: 100%|██████████| 3/3 [00:07<00:00,  2.62s/it]
		Running through validation set: 100%|██████████| 1/1 [00:00<00:00,  1.33it/s]


	Epoch 260 Training Loss: 0.06996911568076987
	Epoch 260 Validation Loss: 0.11421284824609756
	Running Epoch 280


		Running through training set: 100%|██████████| 3/3 [00:07<00:00,  2.57s/it]
		Running through validation set: 100%|██████████| 1/1 [00:00<00:00,  1.14it/s]


	Epoch 280 Training Loss: 0.0665276650535433
	Epoch 280 Validation Loss: 0.11522675305604935
	Running Epoch 300


		Running through training set: 100%|██████████| 3/3 [00:07<00:00,  2.61s/it]
		Running through validation set: 100%|██████████| 1/1 [00:00<00:00,  1.32it/s]


	Epoch 300 Training Loss: 0.06445007732040003
	Epoch 300 Validation Loss: 0.13009312748908997
	Running Epoch 320


		Running through training set: 100%|██████████| 3/3 [00:07<00:00,  2.54s/it]
		Running through validation set: 100%|██████████| 1/1 [00:00<00:00,  1.17it/s]


	Epoch 320 Training Loss: 0.06166466796084454
	Epoch 320 Validation Loss: 0.12056007981300354
	Running Epoch 340


		Running through training set: 100%|██████████| 3/3 [00:07<00:00,  2.59s/it]
		Running through validation set: 100%|██████████| 1/1 [00:00<00:00,  1.27it/s]


	Epoch 340 Training Loss: 0.06513830235129908
	Epoch 340 Validation Loss: 0.13604629039764404
	Running Epoch 360


		Running through training set: 100%|██████████| 3/3 [00:07<00:00,  2.55s/it]
		Running through validation set: 100%|██████████| 1/1 [00:00<00:00,  1.33it/s]


	Epoch 360 Training Loss: 0.061260436318422616
	Epoch 360 Validation Loss: 0.11582630127668381
	Running Epoch 380


		Running through training set: 100%|██████████| 3/3 [00:07<00:00,  2.54s/it]
		Running through validation set: 100%|██████████| 1/1 [00:00<00:00,  1.16it/s]


	Epoch 380 Training Loss: 0.05746364122942874
	Epoch 380 Validation Loss: 0.11662029474973679
	Running Epoch 400


		Running through training set: 100%|██████████| 3/3 [00:07<00:00,  2.56s/it]
		Running through validation set: 100%|██████████| 1/1 [00:00<00:00,  1.32it/s]


	Epoch 400 Training Loss: 0.05553573468013814
	Epoch 400 Validation Loss: 0.11929696053266525
	Running Epoch 420


		Running through training set: 100%|██████████| 3/3 [00:07<00:00,  2.55s/it]
		Running through validation set: 100%|██████████| 1/1 [00:00<00:00,  1.32it/s]


	Epoch 420 Training Loss: 0.05246473201795628
	Epoch 420 Validation Loss: 0.12661117315292358
	Running Epoch 440


		Running through training set: 100%|██████████| 3/3 [00:07<00:00,  2.57s/it]
		Running through validation set: 100%|██████████| 1/1 [00:00<00:00,  1.29it/s]


	Epoch 440 Training Loss: 0.05423596305282492
	Epoch 440 Validation Loss: 0.11858334392309189
	Running Epoch 460


		Running through training set: 100%|██████████| 3/3 [00:07<00:00,  2.55s/it]
		Running through validation set: 100%|██████████| 1/1 [00:00<00:00,  1.34it/s]


	Epoch 460 Training Loss: 0.049869401282385775
	Epoch 460 Validation Loss: 0.11326243728399277
	Running Epoch 480


		Running through training set: 100%|██████████| 3/3 [00:07<00:00,  2.54s/it]
		Running through validation set: 100%|██████████| 1/1 [00:00<00:00,  1.17it/s]


	Epoch 480 Training Loss: 0.04763695460401083
	Epoch 480 Validation Loss: 0.11487895995378494
	Running Epoch 500


		Running through training set: 100%|██████████| 3/3 [00:07<00:00,  2.56s/it]
		Running through validation set: 100%|██████████| 1/1 [00:00<00:00,  1.29it/s]


	Epoch 500 Training Loss: 0.04692097086655466
	Epoch 500 Validation Loss: 0.11492665857076645
	Running Epoch 520


		Running through training set: 100%|██████████| 3/3 [00:07<00:00,  2.53s/it]
		Running through validation set: 100%|██████████| 1/1 [00:00<00:00,  1.31it/s]


	Epoch 520 Training Loss: 0.05150097667386657
	Epoch 520 Validation Loss: 0.12535089254379272


In [None]:
# test model on validation set

if exp_params['dataset']['name'] == 'custom' or exp_params['dataset']['name'] == 'imagenette':
    # get best model with custom dataset
    model, model_history, _ = get_model_data(exp_params["model"]["name"])
    #model = get_saved_model(model, '')
    model_info = get_modelinfo('')
    print("\nModel validation results")
    #visualization results
    vis = Visualization(model_info, model_history)
    vis.get_results()
else:
    raise SystemError('Invalid dataset name passed!')

In [None]:
# test fine-tuned model on training set

if exp_params['dataset']['name'] == 'custom':
    #model testing with small subset of training dataset
    model = get_model(exp_params["model"]["name"])
    model.load_state_dict(torch.load("models/checkpoints/last_model.pt", map_location = torch.device(config["device"])))
    print("\n\nTesting Saved Model on Training set subset")
    mt = ModelTester(model, smftrte_dataset)
    mt.test_and_plot(ds["RGBtr"], ABtr, "best_model", True)
elif exp_params['dataset']['name'] == 'imagenette':
    #model testing with small subset of training dataset (imagenette)
    composed_transforms =  transforms.Compose([
            #transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225]),
            transforms.Resize(exp_params['transform']['resize_dim']),
            transforms.CenterCrop(exp_params['transform']['crop_dim'])
        ])
    model = get_model(exp_params["model"]["name"])
    model.load_state_dict(torch.load("models/checkpoints/last_model.pt", map_location = torch.device(config["device"])))
    print("\n\nTesting Saved Model on subset of training set")
    mt = ModelTester(model, smftrte_dataset, composed_transforms)
    mt.test_imagenette_and_plot(True)
else:
    raise SystemError('Invalid dataset name passed!')


In [None]:
'''
# test fine-tuned model on test set
if exp_params['dataset']['name'] == 'custom':
    # model testing with small subset of training dataset
    model = get_model(exp_params["model"]["name"])
    model.load_state_dict(torch.load("models/checkpoints/last_model.pt", map_location = torch.device(config["device"])))
    if exp_params['dataset']['size'] == 'subset':
        print("\n\nTesting Saved Model subset of test set")
        mt = ModelTester(model, smfte_dataset)
        mt.test_and_plot(ds["RGBte"], ABte, "best_model", True)
    else:
        print("\n\nTesting Saved Model on full test set")
        mt = ModelTester(model, te_dataset)
        mt.test_and_plot(ds["RGBte"], ABte, "best_model", True)
elif exp_params['dataset']['name'] == 'imagenette':
    # model testing with small subset of training dataset (imagenette)
    model = get_model(exp_params["model"]["name"])
    model.load_state_dict(torch.load("models/checkpoints/last_model.pt", map_location = torch.device(config["device"])))
    if exp_params['dataset']['size'] == 'subset':
        print("\n\nTesting Saved Model on subset of test set")
        mt = ModelTester(model, smfte_dataset, composed_transforms)
        mt.test_imagenette_and_plot(True)
    else:
        print("\n\nTesting Saved Model on full test set")
        mt = ModelTester(model, te_dataset, composed_transforms)
        mt.test_imagenette_and_plot(True)
else:
    raise SystemError('Invalid dataset name passed!')
'''