In [None]:
import os
import zipfile
import cv2
import numpy as np
import matplotlib.pyplot as plt
import shutil
import time
import torch
import torch.nn as nn
import torch.optim as optim

from google.colab import drive
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from datetime import datetime
from tqdm import tqdm
from skimage import color
from skimage.color import lab2rgb, rgb2lab
from skimage.transform import resize
from skimage import io
from IPython.display import display

In [None]:
drive.mount('/content/drive/', force_remount=True)
print("Files in the current directory:")
print(os.listdir("/content/drive/MyDrive/TUDelft/Seminar_Computer_Vision/CVbyDL/DATA/rescaled"))

In [None]:
# Sometimes Google Drive takes a long time to read / know the number of files in a directory. Therefore, we get the zipped files, and unzip them per session, so that there is no problem of missing data.

# Define the path to the ZIP files and corresponding target subdirectories
zip_files = {
    "/content/drive/MyDrive/TUDelft/Seminar_Computer_Vision/CVbyDL/DATA/rescaled/test/input-test-set-rescaled.zip": "/content/unzipped_data/test/input",
    "/content/drive/MyDrive/TUDelft/Seminar_Computer_Vision/CVbyDL/DATA/rescaled/train/input-train-set-rescaled.zip": "/content/unzipped_data/train/input",
}

# Ensure base destination folders exist
if not os.path.exists("/content/unzipped_data/train"):
    os.makedirs("/content/unzipped_data/train")
if not os.path.exists("/content/unzipped_data/test"):
    os.makedirs("/content/unzipped_data/test")

# Unzip the files into specific folders stripping the top directory
for zip_path, extraction_path in tqdm(zip_files.items()):
    # Ensure each specific extraction path exists
    if not os.path.exists(extraction_path):
        os.makedirs(extraction_path)

    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        # We filter out the first level of the directory
        for file_info in zip_ref.infolist():
            # Skip directories at the root level in the zip file
            if file_info.filename.count('/') == 1 and file_info.is_dir():
                continue
            # Construct the correct path by stripping the first directory
            new_file_path = os.path.join(extraction_path, '/'.join(file_info.filename.split('/')[1:]))
            new_file_dir = os.path.dirname(new_file_path)
            if not os.path.exists(new_file_dir):
                os.makedirs(new_file_dir)
            if not file_info.is_dir():  # Avoid trying to open directories as files
                with zip_ref.open(file_info) as source, open(new_file_path, 'wb') as target:
                    target.write(source.read())

# Print the count of files in each directory for verification
for category, path in zip_files.items():
    print(f"Files in {path} directory:")
    print(len([name for name in os.listdir(path) if os.path.isfile(os.path.join(path, name))]))


In [None]:
DATASET_LOCATION = "/content/unzipped_data"
UNIQUE_ID = "original_architecture"
OUT_LOCATION = f"/content/drive/MyDrive/TUDelft/Seminar_Computer_Vision/CVbyDL/the_experiment/{UNIQUE_ID}"
print("UNIQUE_ID : ",UNIQUE_ID)

## Architecture

In [None]:
import torch
import torch.nn as nn

# If you want to reduce the number of feature maps per layer, just divide the in and out channels for Conv2d by powers of two (for the blog we have 1/8 and 1/16)

# NOTE : You need to rescale all the input images to 224x224!

# Shared Low Level Features
class SLLF(nn.Module):
    def __init__(self):

        # Output channels for each convolution layer are 64, 128, 128, 256, 256, 512. See table 1 for the details

        # For conv1, conv3, conv5, stride is 2. Therefore, it halves the height and width of the image.
        super(SLLF, self).__init__()

        # Input image = 224x224
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=2, padding=1) # 224x224 -> 112x112
        self.bn1 = nn.BatchNorm2d(64)

        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1) # 112x112 -> 112x112
        self.bn2 = nn.BatchNorm2d(128)

        self.conv3 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1) # 112x112 -> 56x56
        self.bn3 = nn.BatchNorm2d(128)

        self.conv4 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1) # 56x56 -> 56x56
        self.bn4 = nn.BatchNorm2d(256)

        self.conv5 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1) # 56x56 -> 28x28
        self.bn5 = nn.BatchNorm2d(256)

        self.conv6 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1) # 28x28 -> 28x28

        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        x = self.relu(self.bn4(self.conv4(x)))
        x = self.relu(self.bn5(self.conv5(x)))
        x = self.relu(self.conv6(x))
        return x


# Global Image Features
class GIF(nn.Module):
    def __init__(self):
        super(GIF, self).__init__()

        # Input image = 28x28
        self.conv1 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=2, padding=1) # 28x28 -> 14x14
        self.bn1 = nn.BatchNorm2d(512)

        self.conv2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1) # 14x14 -> 14x14
        self.bn2 = nn.BatchNorm2d(512)

        self.conv3 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=2, padding=1) # 14x14 -> 7x7
        self.bn3 = nn.BatchNorm2d(512)

        self.conv4 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1) # 7x7 -> 7x7

        # Input to this layer is a feature map of dimension 512x7x7 = 25088
        # To pass it into the linear layers, you need to flatten the feature map
        self.fc1 = nn.Linear(in_features=25088, out_features=1024)
        self.fc2 = nn.Linear(in_features=1024, out_features=512)
        self.fc3 = nn.Linear(in_features=512, out_features=256)

        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        x = self.relu(self.conv4(x))

        # print(f"Printing from the GLF forward function\nDimensions before flattening {x.shape}")

        # After all convolutions, flatten the input before passing them to the Fully Connected layers
        x = torch.flatten(x, 1)

        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))

        # Output to the fustion layer
        x = self.relu(self.fc3(x))

        return x


# Mid Level Features
class MLF(nn.Module):
    def __init__(self):
        super(MLF, self).__init__()

        # Input image = 28x28
        self.conv1 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1) # 28x28 -> 28x28
        self.bn1 = nn.BatchNorm2d(512)

        self.conv2 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1) # 28x28 -> 28x28

        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.conv2(x))
        return x


class ColorizationNetwork(nn.Module):
    def __init__(self):
        super(ColorizationNetwork, self).__init__()

        # Input image = 28x28
        self.conv1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=1) # 28x28 -> 28x28
        self.bn1 = nn.BatchNorm2d(128)

        self.conv2 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1) # 56x56 -> 56x56
        self.bn2 = nn.BatchNorm2d(64)

        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1) # 56x56 -> 56x56
        self.bn3 = nn.BatchNorm2d(64)

        self.conv4 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1) # 112x112 -> 112x112
        self.bn4 = nn.BatchNorm2d(32)

        self.conv5 = nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1) # 112x112 -> 112x112


        self.upsample1 = nn.Upsample(scale_factor=2, mode='nearest')
        self.upsample2 = nn.Upsample(scale_factor=2, mode='nearest')
        self.upsample3 = nn.Upsample(scale_factor=2, mode='nearest') # Used after the sigmoid layer

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()


    # Input is the output of the fusion layer!
    # Vector is of dimensions 256x28x28
    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.upsample1(x)

        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        x = self.upsample2(x)

        x = self.relu(self.bn4(self.conv4(x)))
        # Output layer
        x = self.sigmoid(self.conv5(x))

        x = self.upsample3(x)

        return x


class FusionLayer(nn.Module):
    def __init__(self):
        super(FusionLayer, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1) # 28x28 -> 28x28
        self.relu = nn.ReLU()

    def forward(self, glf, mlf):
        # Mid Out : torch.Size([2, 256, 28, 28])
        # Glob Out : torch.Size([2, 256])
        batch_size = glf.shape[0]
        glf = glf.unsqueeze(-1).unsqueeze(-1)
        glf = glf.expand(batch_size, 256, 28, 28)
        fused = torch.cat((mlf, glf), 1)

        # fused : torch.Size([2, 512, 28, 28])
        fused = self.relu(self.conv1(fused))
        return fused

### With Global Features Network

In [None]:
class FullNetworkGLF(nn.Module):
    def __init__(self):
        super(FullNetworkGLF, self).__init__()
        self.sllf = SLLF()
        self.glf = GIF()
        self.mlf = MLF()
        self.fusionLayer = FusionLayer()
        self.colorizationNetwork = ColorizationNetwork()


    def forward(self, x):
        llf = self.sllf.forward(x)
        mlf = self.mlf.forward(llf)
        glf = self.glf.forward(llf)
        fused = self.fusionLayer.forward(glf, mlf)
        predicted_colors = self.colorizationNetwork.forward(fused)

        return predicted_colors


In [None]:
model = FullNetworkGLF()

### Architecture (without Global Features Network)

In [None]:
class FullNetworkNoGLF(nn.Module):
    def __init__(self):
        super(FullNetworkNoGLF, self).__init__()
        self.sllf = SLLF()
        self.mlf = MLF()
        self.colorizationNetwork = ColorizationNetwork()


    def forward(self, x):
        llf = self.sllf.forward(x)
        mlf = self.mlf.forward(llf)
        predicted_colors = self.colorizationNetwork.forward(mlf)

        return predicted_colors


In [None]:
model = FullNetworkNoGLF()

# Dataloader

In [None]:
class FilmPicturesDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.imgs = self.make_dataset()

    def make_dataset(self):
        images = []
        valid_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff') # Doing this because it was getting confused with the .DS_Store file
        for img_name in os.listdir(self.root_dir):
            if img_name.endswith(valid_extensions):
                img_path = os.path.join(self.root_dir, img_name)
                images.append(img_path)
        return images

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img_path = self.imgs[idx]
        l, ground_truth_a_b = rgb_to_normalized_lab(img_path)
        return l, ground_truth_a_b

def convert_to_grayscale(image_path):
    '''
    Takes an RGB image, and then just converting it to grayscale
    '''
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    img = torch.tensor(img)
    # adding a channel dimension: 1 x H x W
    img = img.unsqueeze(0)  #pytorch expects channel first
    return img

def rgb_to_normalized_lab(image_path):
    '''
    Takes as input an image, and then converts it into the normalized Lab color scheme

    '''

    # (224, 224, 3) <class 'numpy.ndarray'>
    img = io.imread(image_path)

    # Just in case image is not the correct dimensions
    img = resize(img, (224, 224))

    # (224, 224, 3) <class 'numpy.ndarray'>
    img_lab = rgb2lab(img)

    # LAB range L: 0-100, a: -127-128, b: -128-127.
    img_lab[:,:,:1] = img_lab[:, :, :1] / 100.0
    img_lab[:,:,1:] = (img_lab[:, :, 1:] + 128.0) / 256.0

    # (224, 224, 3) <class 'numpy.ndarray'>
    img_lab = np.transpose(img_lab, (2,0,1)).astype(np.float32)
    img_lab = torch.from_numpy(img_lab)
    # shape (3, 224, 224), torch.Tensor

    luminance = img_lab[:1,:,:] # Use [:1] instead of [0] because [0] drops the first dimension (Luminance becomes (224,224), whereas we want it (1,224,224))
    ab = img_lab[1:,:,:]

    return luminance, ab

def lab_to_rgb(luminance, ab):
    '''
    Converts and unnormalizes the Lab image to RGB
    '''
    luminance = luminance.numpy() * 100.0
    ab = (ab.numpy() * 255.0) - 128.0

    # torch tensor of shape (batch_size, 3, 224, 224)
    luminance = luminance.transpose((1, 2, 0))
    ab = ab.transpose((1, 2, 0))

    # skimage requires the images to be of the shape (batch_size, height, width, channels)
    img_stack = np.dstack((luminance, ab))
    img_stack = img_stack.astype(np.float64)

    return lab2rgb(img_stack)

# Utils

In [None]:
def model_prediction_to_rgb(luminance, ab_pred):
    '''
    Takes as output the original luminance channel, and the predicted ab channels from the model, and joins them to convert to RGB
    Used for viewing the outputs during model inference
    '''

    # Bringing them back to the original range
    luminance = luminance.numpy() * 100.0
    ab_pred = ab_pred.numpy() * 254.0 - 127.0

    # Currently, the ordering is CxHxW
    # We need to transpose axes back to HxWxC
    luminance = luminance.transpose((1, 2, 0))
    ab_pred = ab_pred.transpose((1, 2, 0))

    img_stack = np.dstack((luminance, ab_pred))
    img_stack = img_stack.astype(np.float64)

    return  color.lab2rgb(img_stack)


# Trainer

In [None]:
class TrainingLoop:

    def __init__(self, batch_size, epochs, train_dir, val_dir, test_dir, start_epoch=0):
        '''
        Initializes the datasets according to the directories provided.
        Also creates dataloaders.
        Initializes the model based on the hyperparameters given.

        Start epoch is added because it enables to start training from a certain epoch.
        Needed as we usually could not train the model in one go on Colab GPUs
        '''


        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        print(f"Using device : {self.device}")


        self.train_dir = train_dir
        self.val_dir = val_dir
        self.test_dir = test_dir

        self.trainset = FilmPicturesDataset(self.train_dir)
        self.testset = FilmPicturesDataset(self.test_dir)
        self.trainloader = DataLoader(dataset=self.trainset, batch_size=self.batch_size, shuffle=True)
        self.testloader = DataLoader(dataset=self.testset, batch_size=self.batch_size, shuffle=False)

        self.batch_size = batch_size
        self.mse = nn.MSELoss(reduction='sum')
        self.start_epoch = start_epoch
        self.epochs = epochs
        self.optimizer = optim.Adadelta(self.net.parameters())

        self.output_dir = OUT_LOCATION

        # self.net = FullNetworkNoGLF()
        self.net = FullNetworkGLF()
        self.net.to(self.device)


    def train(self, epoch):
        '''
        Trains the model for one epoch
        '''
        epoch_loss = 0.0

        # Setting the model to train mode
        self.net.train()

        for batch_no, img in enumerate(self.trainloader):
            self.optimizer.zero_grad()

            luminance, ab = img
            luminance, ab = luminance.to(self.device), ab.to(self.device)

            ab_pred = self.net(luminance)
            loss = self.mse(ab, ab_pred)

            loss.backward()
            self.optimizer.step()

            batch_loss = loss.item()

            print(f'Epoch {epoch+1} / {self.epochs} | Batch Number : {batch_no + 1} / {len(self.trainloader)} -> Batch Loss : {batch_loss}')
            epoch_loss += batch_loss

        epoch_loss /= len(self.trainloader)

        # Save the model every 20 epochs. Required as a lot of times the runtime gets disconnected due to inactivity or running out of compute hours
        if (epoch+1)%20 == 0:
          model_folder = f"{OUT_LOCATION}/models_saved"
          os.makedirs(model_folder, exist_ok=True)
          model_path = os.path.join(model_folder, f"{UNIQUE_ID}_model_epoch{epoch}.pt")
          torch.save(self.net.state_dict(), model_path)
          print(f"Model saved to : {model_path}")

        print(f"Epoch loss: {epoch_loss}")


    def test(self, show_image=False):
        '''
        Inference on the images
        If show_image = True, you also get the predicted images as the output
        '''

        self.net.to(self.device)

        # Setting the model to evaluation mode
        self.net.eval()

        with torch.no_grad():
            for batch_no, img in enumerate(self.testloader):

                luminance, _ = img
                luminance = luminance.to(self.device)
                ab_pred= self.net(luminance)

                luminance = luminance.to(torch.device("cpu"))
                ab_pred = ab_pred.to(torch.device("cpu"))

                for i in range(luminance.shape[0]):
                    img = model_prediction_to_rgb(luminance[i], ab_pred[i])

                    img *= 255.0
                    img = img.astype(np.uint8)
                    io.imsave(os.path.join(self.output_dir, f"{batch_no}_{i}.png"), img)
                    if show_image:
                      pil_img = Image.fromarray(img)
                      display(pil_img)  # Display one image at a time in the notebook

                print(f"Batch {batch_no + 1} / {len(self.trainloader)}")

        print("Saved all photos to " + self.output_dir)

    # Trains the model for specified number of epochs, and then tests it
    def run(self):
        for epoch in range(self.start_epoch, self.epochs):
            print(f"Epoch : {epoch + 1} / {self.epochs}")
            self.train(epoch)
        self.test()


# Main

In [None]:
training_loop = TrainingLoop(batch_size=16, epochs=100, train_dir=f"{DATASET_LOCATION}/train/input", val_dir=f"{DATASET_LOCATION}/test/input", test_dir=f"{DATASET_LOCATION}/test/input")

In [None]:
training_loop.run()

# Loading from checkpoint

Used in 2 cases

1. Training more (Continuing training)
2. Testing on certain images


In [None]:
checkpoint_dir = f"{OUT_LOCATION}/models_saved/model.pt"

### Continuing training

In [None]:
# Run all the cells till main (do not run main)
start_epoch = 20 # look at the latest model , epoch + 1
end_epoch = 100

In [None]:
training_loop = TrainingLoop(batch_size=16, epochs=(end_epoch-start_epoch), train_dir=f"{DATASET_LOCATION}/train/input", val_dir=f"{DATASET_LOCATION}/test/input", test_dir=f"{DATASET_LOCATION}/test/input")

In [None]:
trainer.net.load_state_dict(torch.load(checkpoint_dir))

In [None]:
trainer.run()

### Testing the model

In [None]:
checkpoint_dir = f"{OUT_LOCATION}/models_saved/model.pt"

In [None]:
training_loop = TrainingLoop(batch_size=16, epochs=0, train_dir=f"{DATASET_LOCATION}/train/input", val_dir=f"{DATASET_LOCATION}/test/input", test_dir=f"{DATASET_LOCATION}/test/input")

In [None]:
trainer.net.load_state_dict(torch.load(checkpoint_dir))

In [None]:
trainer.test()


## Test speific images

In [None]:
training_loop = TrainingLoop(batch_size=16, epochs=10, train_dir=f"{DATASET_LOCATION}/train/input", val_dir=f"{DATASET_LOCATION}/test/Cinema", test_dir=f"{DATASET_LOCATION}/test/input")

In [None]:
trainer.net.load_state_dict(torch.load(checkpoint_dir))

In [None]:
trainer.test(show_image=True)