In [None]:
import os
import yaml
import logging
from datetime import datetime

import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, random_split, Dataset
import torchvision
import torchvision.transforms as transforms
from torchvision import models

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from skimage.color import rgb2lab, lab2rgb


In [None]:

config = {
    'dataset': {
        'name_classification': 'cifar10',
        'name_colorization': 'cifar100', 
        'batch_size': 128,
        'num_workers': 0 
    },
    'training_classification': {
        'epochs': 20,
        'learning_rate': 1e-4,
        'save_interval': 5,  
        'checkpoint_dir': './checkpoints'
    },
    'training_colorization': {
        'epochs': 200,
        'learning_rate': 1e-4,
        'save_interval': 10, 
        'checkpoint_dir': './checkpoints_colorization'
    }
}


os.makedirs(config['training_classification']['checkpoint_dir'], exist_ok=True)
os.makedirs(config['training_colorization']['checkpoint_dir'], exist_ok=True)


In [None]:

transform_train_class = transforms.Compose([
    transforms.RandomHorizontalFlip(),  # Data augmentation
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),  # Mean for CIFAR10
                         (0.2023, 0.1994, 0.2010))  # Std for CIFAR10
])

transform_test_class = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010))
])


In [None]:
transform_train_color = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010))
])

transform_test_color = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010))
])


In [None]:
class ColorizationDataset(Dataset):
    def __init__(self, dataset, transform=None):

        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, _ = self.dataset[idx]
        
        img = img.numpy().transpose((1, 2, 0))
        img = img * np.array([0.2023, 0.1994, 0.2010]) + np.array([0.4914, 0.4822, 0.4465])
        img = np.clip(img, 0, 1)
        
        lab = rgb2lab(img)
        L = lab[:, :, 0] / 100.0 
        ab = lab[:, :, 1:] / 128.0 
        
        L = torch.tensor(L).unsqueeze(0).float()      #  [1, H, W]
        ab = torch.tensor(ab).permute(2, 0, 1).float()  #  [2, H, W]
        
        return L, ab


In [None]:
def get_classification_dataloaders(config):
    dataset_name = config['dataset']['name_classification'].lower()
    batch_size = config['dataset']['batch_size']
    num_workers = config['dataset']['num_workers']
    
    if dataset_name == 'cifar10':
        train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train_class)
        test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test_class)
        num_classes = 10
    elif dataset_name == 'cifar100':
        train_set = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train_class)
        test_set = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test_class)
        num_classes = 100
    else:
        raise ValueError("Unsupported dataset. Choose 'cifar10' or 'cifar100'.")
    
    total_train = len(train_set)
    val_size = int(0.2 * total_train)
    train_size = total_train - val_size
    train_dataset, val_dataset = random_split(train_set, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    return train_loader, val_loader, test_loader, num_classes

def get_colorization_dataloaders(config):
    dataset_name = config['dataset']['name_colorization'].lower()
    batch_size = config['dataset']['batch_size']
    num_workers = config['dataset']['num_workers']
    
    if dataset_name == 'cifar10':
        full_train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train_color)
        test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test_color)
    elif dataset_name == 'cifar100':
        full_train_set = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train_color)
        test_set = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test_color)
    else:
        raise ValueError("Unsupported dataset. Choose 'cifar10' or 'cifar100'.")
    

    total_train = len(full_train_set)
    val_size = int(0.2 * total_train)
    train_size = total_train - val_size
    train_dataset, val_dataset = random_split(full_train_set, [train_size, val_size])
    

    train_color_dataset = ColorizationDataset(train_dataset, transform=None)
    val_color_dataset = ColorizationDataset(val_dataset, transform=None)
    test_color_dataset = ColorizationDataset(test_set, transform=None)

    train_loader = DataLoader(train_color_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_color_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(test_color_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    return train_loader, val_loader, test_loader


In [None]:

train_loader_class, val_loader_class, test_loader_class, num_classes = get_classification_dataloaders(config)
print(f"Number of classes (Classification): {num_classes}")
print(f"Training samples: {len(train_loader_class.dataset)}")
print(f"Validation samples: {len(val_loader_class.dataset)}")
print(f"Testing samples: {len(test_loader_class.dataset)}")


In [None]:

train_loader_color, val_loader_color, test_loader_color = get_colorization_dataloaders(config)
print(f"Colorization - Training samples: {len(train_loader_color.dataset)}")
print(f"Colorization - Validation samples: {len(val_loader_color.dataset)}")
print(f"Colorization - Testing samples: {len(test_loader_color.dataset)}")



In [None]:
def get_classification_model(num_classes, pretrained=True):
    model = models.resnet50(pretrained=pretrained)
    
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
    
    return model


In [None]:
classification_model = get_classification_model(num_classes=num_classes, pretrained=True)
classification_model = classification_model.to(device)
print(classification_model)



In [None]:
class ColorizationResNet(nn.Module):
    def __init__(self, pretrained=True):
        super(ColorizationResNet, self).__init__()
        self.encoder = models.resnet50(pretrained=pretrained)
        
        original_conv1 = self.encoder.conv1
        self.encoder.conv1 = nn.Conv2d(1, original_conv1.out_channels, 
                                       kernel_size=original_conv1.kernel_size, 
                                       stride=original_conv1.stride, 
                                       padding=original_conv1.padding, 
                                       bias=original_conv1.bias)
        
        if pretrained:
            with torch.no_grad():
                self.encoder.conv1.weight = nn.Parameter(original_conv1.weight.mean(dim=1, keepdim=True))
        
        self.encoder.fc = nn.Identity()
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(2048, 1024, kernel_size=4, stride=2, padding=1),  # [Batch,1024,2,2]
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1),   # [Batch,512,4,4]
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),    # [Batch,256,8,8]
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),    # [Batch,128,16,16]
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),     # [Batch,64,32,32]
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(64, 2, kernel_size=3, padding=1),                          # [Batch,2,32,32]
            nn.Tanh()  
        )
        
    def forward(self, x):
        # Encoder
        x = self.encoder.conv1(x)    # [Batch, 64, 16, 16]
        x = self.encoder.bn1(x)
        x = self.encoder.relu(x)
        x = self.encoder.maxpool(x)  # [Batch, 64, 8, 8]
        
        x = self.encoder.layer1(x)   # [Batch, 256, 8, 8]
        x = self.encoder.layer2(x)   # [Batch, 512, 4, 4]
        x = self.encoder.layer3(x)   # [Batch, 1024, 2, 2]
        x = self.encoder.layer4(x)   # [Batch, 2048, 1, 1]
        
        # Decoder
        x = self.decoder(x)          # [Batch,2,32,32]
        return x


In [None]:
colorization_model = ColorizationResNet(pretrained=True)
colorization_model = colorization_model.to(device)
print(colorization_model)


In [None]:
def train_classification_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, labels in dataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
        
  
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss / len(dataloader.dataset)
    accuracy = 100 * correct / total
    return epoch_loss, accuracy

def validate_classification(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * inputs.size(0)
            
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss / len(dataloader.dataset)
    accuracy = 100 * correct / total
    return epoch_loss, accuracy


In [None]:
def train_colorization_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    
    for inputs, targets in dataloader:
        inputs = inputs.to(device)    # Shape: [batch_size, 1, 32, 32]
        targets = targets.to(device)  # Shape: [batch_size, 2, 32, 32]
        
        outputs = model(inputs)       # Shape: [batch_size, 2, 32, 32]
        
        if outputs.shape != targets.shape:
            print(f"Mismatch in shapes:")
            print(f"Outputs shape: {outputs.shape}")
            print(f"Targets shape: {targets.shape}")
            raise ValueError("Shape mismatch between outputs and targets.")
        
        loss = criterion(outputs, targets)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
    
    epoch_loss = running_loss / len(dataloader.dataset)
    return epoch_loss

def validate_colorization(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            running_loss += loss.item() * inputs.size(0)
    
    epoch_loss = running_loss / len(dataloader.dataset)
    return epoch_loss


In [None]:

classification_model = get_classification_model(num_classes=num_classes, pretrained=True)
classification_model = classification_model.to(device)
print(classification_model)


In [None]:
criterion_class = nn.CrossEntropyLoss()
optimizer_class = optim.Adam(classification_model.parameters(), lr=config['training_classification']['learning_rate'])


In [None]:
num_epochs_class = config['training_classification']['epochs']
save_interval_class = config['training_classification']['save_interval']
checkpoint_dir_class = config['training_classification']['checkpoint_dir']

for epoch in range(1, num_epochs_class + 1):
    train_loss, train_acc = train_classification_one_epoch(classification_model, train_loader_class, criterion_class, optimizer_class, device)
    val_loss, val_acc = validate_classification(classification_model, val_loader_class, criterion_class, device)
    
    print(f"Epoch [{epoch}/{num_epochs_class}]")
    print(f"Training Loss: {train_loss:.4f}, Training Accuracy: {train_acc:.2f}%")
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.2f}%\n")
    
    if epoch % save_interval_class == 0:
        checkpoint_path_class = os.path.join(checkpoint_dir_class, f"resnet50_classification_epoch_{epoch}.pth")
        torch.save(classification_model.state_dict(), checkpoint_path_class)
        print(f"Checkpoint saved at {checkpoint_path_class}\n")


In [None]:
def transfer_weights_class_to_color(model_class, model_color):
    model_class_dict = model_class.state_dict()
    model_color_dict = model_color.state_dict()
    
    pretrained_dict = {k: v for k, v in model_class_dict.items() if k in model_color_dict and 
                       not k.startswith('encoder.conv1') and not k.startswith('encoder.fc')}
    
    model_color_dict.update(pretrained_dict)
    model_color.load_state_dict(model_color_dict)
    
    print("Transferred pretrained weights to colorization model (excluding encoder.conv1 and encoder.fc layers).")


In [None]:
colorization_model = ColorizationResNet(pretrained=True)
colorization_model = colorization_model.to(device)
print(colorization_model)


In [None]:

classification_checkpoint = os.path.join(checkpoint_dir_class, f"resnet50_classification_epoch_{num_epochs_class}.pth")

if os.path.exists(classification_checkpoint):
    classification_model.load_state_dict(torch.load(classification_checkpoint, map_location=device))
    transfer_weights_class_to_color(classification_model, colorization_model)
else:
    print(f"Checkpoint not found at {classification_checkpoint}. Please ensure the classification model is trained and checkpointed.")


In [None]:
criterion_color = nn.MSELoss()
optimizer_color = optim.Adam(colorization_model.decoder.parameters(), lr=config['training_colorization']['learning_rate'])


In [None]:
for param in colorization_model.encoder.parameters():
    param.requires_grad = False


In [None]:
num_epochs_color = config['training_colorization']['epochs']
save_interval_color = config['training_colorization']['save_interval']
checkpoint_dir_color = config['training_colorization']['checkpoint_dir']

for epoch in range(1, num_epochs_color + 1):
    train_loss = train_colorization_one_epoch(colorization_model, train_loader_color, criterion_color, optimizer_color, device)
    val_loss = validate_colorization(colorization_model, val_loader_color, criterion_color, device)
    
    print(f"Epoch [{epoch}/{num_epochs_color}]")
    print(f"Training Loss: {train_loss:.4f}")
    print(f"Validation Loss: {val_loss:.4f}\n")
    
    # Save checkpoint
    if epoch % save_interval_color == 0:
        checkpoint_path_color = os.path.join(checkpoint_dir_color, f"resnet50_colorization_epoch_{epoch}.pth")
        torch.save(colorization_model.state_dict(), checkpoint_path_color)
        print(f"Checkpoint saved at {checkpoint_path_color}\n")


In [None]:
def lab_to_rgb(L, ab):

    L = L.squeeze().cpu().numpy() * 100.0  # Denormalize L
    ab = ab.squeeze().cpu().numpy() * 128.0  # Denormalize ab
    
    lab = np.zeros((L.shape[0], L.shape[1], 3))
    lab[:, :, 0] = L
    lab[:, :, 1:] = ab.transpose((1, 2, 0))
    
    rgb = lab2rgb(lab)
    rgb = np.clip(rgb, 0, 1)
    return rgb


In [None]:
dataiter = iter(test_loader_color)
inputs, targets = next(dataiter) 

num_samples = 5
inputs_sample = inputs[:num_samples].to(device)
targets_sample = targets[:num_samples].to(device)

colorization_model.eval()
with torch.no_grad():
    outputs = colorization_model(inputs_sample)

fig, axs = plt.subplots(num_samples, 3, figsize=(12, 4 * num_samples))
for i in range(num_samples):
    L = inputs_sample[i]
    rgb_grayscale = lab_to_rgb(L, torch.zeros_like(outputs[i]))
    ab_true = targets_sample[i]
    rgb_true = lab_to_rgb(L, ab_true)
    
    ab_pred = outputs[i]
    rgb_pred = lab_to_rgb(L, ab_pred)
    
    axs[i, 0].imshow(rgb_grayscale)
    axs[i, 0].set_title('Grayscale Input')
    axs[i, 0].axis('off')
    
    axs[i, 1].imshow(rgb_true)
    axs[i, 1].set_title('Ground Truth')
    axs[i, 1].axis('off')
    axs[i, 2].imshow(rgb_pred)
    axs[i, 2].set_title('Colorized Output')
    axs[i, 2].axis('off')

plt.tight_layout()
plt.show()
