# Image Colorization as pre-training task

In [31]:
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from sklearn.metrics import f1_score, accuracy_score, recall_score, precision_score,confusion_matrix
import numpy as np
import torch
import torch.nn as nn
import os
import matplotlib.pyplot as plt
import time
import torch.nn.init as init
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


## Creating the dataset

In [5]:
class CIFAR10Dataset(Dataset):
    def __init__(self, split='train', transform=None):
        if split == 'train':
            cifar_dataset = datasets.CIFAR10(root='./data', train=True, download=True)
        elif split == 'test':
            cifar_dataset = datasets.CIFAR10(root='./data', train=False, download=True)
        
        self.data = []
        self.labels = []
        self.transform = transform
        
        for i, (img, label) in enumerate(cifar_dataset):
            img_gray = rgb2gray(np.array(img)) # convert to lightness channel
        

            self.data.append(img_gray)
            self.labels.append(label)
            
        self.data = np.array(self.data)
        self.labels = np.array(self.labels)
       
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img = self.data[idx]
        label = self.labels[idx]


        if self.transform:
            img = self.transform(img)
        
        return img, label
    

In [6]:
train_transform = transforms.Compose([
    #transforms.RandomHorizontalFlip(), # isto depois da erro no type dos dados, ver pq e corrigir
    transforms.ToTensor()
])

test_transform = transforms.Compose([
    transforms.ToTensor()
])

trainset = CIFAR10Dataset(split='train', transform=train_transform)

testset = CIFAR10Dataset(split='test', transform=test_transform)


batch_size = 32

trainloader = DataLoader(trainset, batch_size=batch_size)

testloader = DataLoader(testset, batch_size=batch_size)



Files already downloaded and verified
[6 9 9 ... 9 1 1]
Files already downloaded and verified
[3 8 8 ... 5 1 7]


In [7]:
# choose a picture at random
im_minibatch, label_minibatch = next(iter(testloader))
im, label = im_minibatch[0].cpu(), label_minibatch[0].cpu()

# store image size dimensions
image_size = tuple(im.shape)
label_size = tuple(label.shape)

print(image_size)
print(label, "->" ,label_size)

(1, 32, 32)
tensor(3, dtype=torch.int32) -> ()


### Preparing the Model

In [51]:
class ClassificationNet(nn.Module):
    def initialize_parameters_random(self):
        for module in self.modules():
            print(module)
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                init.kaiming_normal_(module.weight)  # Randomly initialize weights
                if module.bias is not None:
                    init.constant_(module.bias, 0)  # Initialize biases to zeros
                    
    def __init__(self,random_initialization, input_size=128):
        super(ClassificationNet, self).__init__()
        MIDLEVEL_FEATURE_SIZE = 2048
        num_classes = 10

        ## First half: ResNet
        resnet = models.resnet18(num_classes=365) 
        # Change first conv layer to accept single-channel (grayscale) input
        resnet.conv1.weight = nn.Parameter(resnet.conv1.weight.sum(dim=1).unsqueeze(1)) 
        # Extract midlevel features from ResNet-gray
        self.midlevel_resnet = nn.Sequential(*list(resnet.children())[0:6])
        
        ## Second half: Classification
        self.classifier = nn.Sequential(
                nn.Linear(MIDLEVEL_FEATURE_SIZE, num_classes),
                nn.Softmax(dim=1)
            )
        
        if random_initialization:
            self.initialize_parameters_random()
        

    

    def forward(self, input):
        # Pass input through the encoder (modified ResNet)
        features = self.midlevel_resnet(input)
        features = features.view(features.size(0), -1)  # Flatten the features
        
        # Pass the flattened features through the classifier
        output = self.classifier(features)
        
        return output  


## Initialize the weights randomly

In [9]:
model = ClassificationNet(random_initialization=True)

criterion = nn.MSELoss()

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

model.to(device)

ClassificationNet(
  (encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_s

ClassificationNet(
  (encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_s

## Defining the training loop

In [16]:


def train_model(model, trainloader, testloader, num_epochs=20):
    start=time.time()
    losses = []
    for epoch in range(0,num_epochs):

        model.train()  # Put the network in train mode
        for i, (x_batch, y_batch) in enumerate(trainloader):
            x_batch, y_batch = x_batch.to(device).float(), y_batch.to(device).long()  # Move the data to the device that is used

            optimizer.zero_grad()  # Set all currenly stored gradients to zero 
            # convert y_batch to one hot encoding
            y_batch = F.one_hot(y_batch, num_classes=10)

            y_pred = model(x_batch)
      
            loss = criterion(y_pred.float(), y_batch.float())

            loss.backward()

            optimizer.step()

            # Compute relevant metrics

            elapsed = time.time() - start  # Keep track of how much time has elapsed

            # Show progress every 75 batches 
            if not i % 75:
                print(f'epoch: {epoch}, time: {elapsed:.3f}s, loss: {loss.item()}')

        

        model.eval()  # Put the network in eval mode
        f1_score_epoch = 0
        accuracy_epoch = 0
        recall_epoch = 0
        precision_epoch = 0

        for i, (x_batch, y_batch) in enumerate(testloader):
            x_batch, y_batch = x_batch.to(device).float(), y_batch.to(device).long()  
            # convert y_batch to one hot encoding
        
            y_batch = F.one_hot(y_batch, num_classes=10)
            y_pred = model(x_batch)

            loss = criterion(y_pred.float(), y_batch.float())

            # Convert y_batch and y_pred to their respective classes
            y_batch = torch.argmax(y_batch, dim=1)
            y_pred = torch.argmax(y_pred, dim=1)

            f1 = f1_score(y_batch.cpu(), y_pred.cpu(), average='weighted', zero_division=0)
            accuracy = accuracy_score(y_batch.cpu(), y_pred.cpu())
            recall = recall_score(y_batch.cpu(), y_pred.cpu(), average='weighted', zero_division=0)
            precision = precision_score(y_batch.cpu(), y_pred.cpu(), average='weighted', zero_division=0)

            f1_score_epoch += f1
            accuracy_epoch += accuracy
            recall_epoch += recall
            precision_epoch += precision

            # Do something with the calculated scores
            
        # Example: Print the scores for each 
        f1_score_epoch = f1_score_epoch / len(testloader)
        accuracy_epoch = accuracy_epoch / len(testloader)
        recall_epoch = recall_epoch / len(testloader)
        precision_epoch = precision_epoch / len(testloader)

        print(f"Epoch {epoch+1}: F1 Score: {f1_score_epoch:.4f}, Accuracy: {accuracy_epoch:.4f}, Recall: {recall_epoch:.4f}, Precision: {precision_epoch:.4f}")
        

        print(f'loss: {loss.item()}')
        losses.append(loss.item())

    return losses

## Training with random parameters


In [None]:
losses_random_init = train_model(model, trainloader, testloader, num_epochs=20)

### Extract metrics from the testset (randomized parameters)

In [28]:
def extract_metrics(model, testloader, losses):
    model.eval()  # Put the network in eval mode
 
    y_pred_list = []
    y_batch_list = []

    for i, (x_batch, y_batch) in enumerate(testloader):
        x_batch, y_batch = x_batch.to(device).float(), y_batch.to(device).long()  
        # convert y_batch to one hot encoding

        y_batch = F.one_hot(y_batch, num_classes=10)
        y_pred = model(x_batch)

        loss = criterion(y_pred.float(), y_batch.float())

        # Convert y_batch and y_pred to their respective classes
        y_batch = torch.argmax(y_batch, dim=1)
        y_pred = torch.argmax(y_pred, dim=1)

        y_pred_list.append(y_pred.cpu())
        y_batch_list.append(y_batch.cpu())

    y_pred_list = torch.cat(y_pred_list, dim=0)
    y_batch_list = torch.cat(y_batch_list, dim=0)
    # confusion matrix
    print("accuracy: ", accuracy_score(y_batch_list, y_pred_list))
    print("f1 score: ", f1_score(y_batch_list, y_pred_list, average='weighted'))
    print("recall: ", recall_score(y_batch_list, y_pred_list, average='weighted'))
    print("precision: ", precision_score(y_batch_list, y_pred_list, average='weighted'))
    print(confusion_matrix(y_batch_list, y_pred_list))

    #plot loss vs epoch graph
    plt.plot(losses)
    plt.title('loss vs epochs')
    plt.xlabel('epochs')
    plt.ylabel('loss')



In [None]:
extract_metrics(model, testloader, losses_random_init)

## Training model with loaded weights

In [48]:
# function for loaduing weights of a trained model
def load_weights(weights_dir):
    files = os.listdir(weights_dir)
    weight_paths = [os.path.join(weights_dir, basename) for basename in files]
    # get the latest file in the directory
    final_weight_file = os.path.basename(max(weight_paths, key=os.path.getctime))

    # first model needs to be loaded
    model = ClassificationNet(random_initialization=False).to(device)

    # fixes odd error when state_dict has prescript "model."" in keys
    state_dict = torch.load(os.path.join(weights_dir, final_weight_file), map_location=device)
    for key in list(state_dict.keys()):
        if key.startswith("model."):
            state_dict[key[6:]] = state_dict.pop(key)
        if(key not in model.state_dict()):
            state_dict.pop(key)
   
    model.load_state_dict(state_dict, strict=False)
    
    print('Loaded weights: ' + final_weight_file)
    
    return model

In [52]:
model_loaded_weights = load_weights('weights')

Loaded weights: epoch-10_accuracy-0.000.pth


In [53]:
losses_loaded_weights = train_model(model_loaded_weights, trainloader, testloader, num_epochs=20)

epoch: 0, time: 0.083s, loss: 0.09528755396604538
epoch: 0, time: 5.129s, loss: 0.09628479182720184
epoch: 0, time: 12.750s, loss: 0.09314993768930435
epoch: 0, time: 18.060s, loss: 0.09268101304769516
epoch: 0, time: 23.847s, loss: 0.0930139422416687
epoch: 0, time: 31.038s, loss: 0.0911705419421196
epoch: 0, time: 38.878s, loss: 0.09542544186115265
epoch: 0, time: 45.448s, loss: 0.09534415602684021
epoch: 0, time: 50.628s, loss: 0.09683115780353546
epoch: 0, time: 56.120s, loss: 0.09531386196613312
epoch: 0, time: 62.363s, loss: 0.0948512926697731
epoch: 0, time: 66.636s, loss: 0.09035952389240265
epoch: 0, time: 72.240s, loss: 0.09520880877971649
epoch: 0, time: 76.995s, loss: 0.0914224162697792
epoch: 0, time: 81.704s, loss: 0.08991773426532745
epoch: 0, time: 86.489s, loss: 0.09130150079727173
epoch: 0, time: 94.462s, loss: 0.09331996738910675
epoch: 0, time: 101.603s, loss: 0.09361673891544342
epoch: 0, time: 109.063s, loss: 0.09344004094600677
epoch: 0, time: 117.137s, loss: 0.0

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

Epoch 1: F1 Score: 0.0650, Accuracy: 0.0854, Recall: 0.0854, Precision: 0.0691
loss: 0.09873798489570618
epoch: 1, time: 135.677s, loss: 0.09528755396604538
epoch: 1, time: 142.063s, loss: 0.09628479182720184
epoch: 1, time: 148.839s, loss: 0.09314993768930435
epoch: 1, time: 155.229s, loss: 0.09268101304769516
epoch: 1, time: 161.553s, loss: 0.0930139422416687
epoch: 1, time: 166.123s, loss: 0.0911705419421196
epoch: 1, time: 174.308s, loss: 0.09542544186115265
epoch: 1, time: 178.746s, loss: 0.09534415602684021
epoch: 1, time: 183.485s, loss: 0.09683115780353546
epoch: 1, time: 189.001s, loss: 0.09531386196613312
epoch: 1, time: 193.715s, loss: 0.0948512926697731
epoch: 1, time: 198.442s, loss: 0.09035952389240265
epoch: 1, time: 204.465s, loss: 0.09520880877971649
epoch: 1, time: 210.235s, loss: 0.0914224162697792
epoch: 1, time: 217.582s, loss: 0.08991773426532745
epoch: 1, time: 224.402s, loss: 0.09130150079727173
epoch: 1, time: 231.409s, loss: 0.09331996738910675
epoch: 1, time:

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

Epoch 2: F1 Score: 0.0650, Accuracy: 0.0854, Recall: 0.0854, Precision: 0.0691
loss: 0.09873798489570618
epoch: 2, time: 269.867s, loss: 0.09528755396604538
epoch: 2, time: 275.342s, loss: 0.09628479182720184
epoch: 2, time: 282.075s, loss: 0.09314993768930435
epoch: 2, time: 289.906s, loss: 0.09268101304769516
epoch: 2, time: 294.934s, loss: 0.0930139422416687
epoch: 2, time: 300.147s, loss: 0.0911705419421196
epoch: 2, time: 307.408s, loss: 0.09542544186115265
epoch: 2, time: 314.111s, loss: 0.09534415602684021
epoch: 2, time: 319.217s, loss: 0.09683115780353546


KeyboardInterrupt: 