In [170]:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'  # Prevents OMP conflicts

In [172]:
import torch
import random
import numpy as np

def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()


In [174]:
import torch

def check_gpu():
    if not torch.cuda.is_available():
        print("⚠️ CUDA not available - Falling back to CPU")
        return torch.device("cpu")
    
    try:
        # Get first available GPU
        device = torch.device("cuda:0")
        
        # Test communication
        test_tensor = torch.tensor([1.0]).to(device)
        if test_tensor.item() == 1.0:
            print(f"✅ GPU active: {torch.cuda.get_device_name(0)}")
            return device
    except RuntimeError as e:
        print(f"❌ GPU test failed: {str(e)}")
    
    print("⚠️ Falling back to CPU")
    return torch.device("cpu")

device = check_gpu()
print("Using device:", device)

✅ GPU active: NVIDIA GeForce RTX 4060 Laptop GPU
Using device: cuda:0


In [176]:
import torch
torch.cuda.empty_cache()  # Try freeing memory

In [178]:
import torch
import sys
import subprocess

print("="*50)
print(f"Python: {sys.version}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA Version: {torch.version.cuda if hasattr(torch.version, 'cuda') else 'N/A'}")
print(f"CUDA Available: {torch.cuda.is_available()}")
print(f"Device Count: {torch.cuda.device_count()}")
print("-"*50)

try:
    nvidia_smi = subprocess.check_output('nvidia-smi', shell=True).decode()
    print("nvidia-smi output:\n", nvidia_smi)
except Exception as e:
    print(f"nvidia-smi error: {str(e)}")

print("="*50)

Python: 3.12.7 | packaged by Anaconda, Inc. | (main, Oct  4 2024, 13:17:27) [MSC v.1929 64 bit (AMD64)]
PyTorch: 2.7.0+cu118
CUDA Version: 11.8
CUDA Available: True
Device Count: 1
--------------------------------------------------
nvidia-smi output:
 Thu May 29 12:13:28 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.97                 Driver Version: 555.97         CUDA Version: 12.5     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4060 ...  WDDM  |   00000000:01:00.0 Off |                  N/A |
| N/A   37C    P8              1W /  115W |     795MiB /   8188MiB |    

In [180]:
import torch
print("CUDA available:", torch.cuda.is_available())
print("Device count:", torch.cuda.device_count())
print("Current device:", torch.cuda.current_device())
print("Device name:", torch.cuda.get_device_name(torch.cuda.current_device()))


CUDA available: True
Device count: 1
Current device: 0
Device name: NVIDIA GeForce RTX 4060 Laptop GPU


In [182]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [184]:
#To be used when defining the model 
# model.to(device)
# tensor.to(device)


In [186]:
%pip install torch
%pip install kaggle
%pip install torchvision
%pip install scikit-image
%pip install numpy

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


# Dataset Preparation

In [188]:
#!kaggle datasets download -d requiemonk/sentinel12-image-pairs-segregated-by-terrain

In [189]:
# %%capture
# !unzip sentinel12-image-pairs-segregated-by-terrain.zip
# !rm -rf sentinel12-image-pairs-segregated-by-terrain.zip

In [190]:
import os

opt = []
sar = []
root_dir = '../v_2'

# Verify root directory exists
if not os.path.exists(root_dir):
    raise FileNotFoundError(f"Root directory {root_dir} not found!")

for category in os.listdir(root_dir):
    # Skip hidden files/folders
    if category.startswith('.') or category == '.ipynb_checkpoints':
        continue
        
    category_path = os.path.join(root_dir, category)
    
    # Skip non-directories
    if not os.path.isdir(category_path):
        continue
        
    # Find s1 and s2 subdirectories
    subdirs = [d for d in os.listdir(category_path) 
              if os.path.isdir(os.path.join(category_path, d))]
    
    # Process each subdirectory
    for subdir in subdirs:
        subdir_path = os.path.join(category_path, subdir)
        
        # Collect SAR images
        if subdir == 's1':
            sar.extend([
                os.path.join(subdir_path, f) 
                for f in os.listdir(subdir_path) 
                if f.lower().endswith('.png')
            ])
            
        # Collect Optical images
        elif subdir == 's2':
            opt.extend([
                os.path.join(subdir_path, f) 
                for f in os.listdir(subdir_path) 
                if f.lower().endswith('.png')
            ])

# Final sorting and validation
opt = sorted(opt)
sar = sorted(sar)
print(f"Found {len(opt)} optical images, {len(sar)} SAR images")

Found 16000 optical images, 16000 SAR images


In [191]:
# Check first few paths
print("Sample SAR paths:")
print(sar[:2])
print("\nSample Optical paths:")
print(opt[:2])

# Verify all paths exist
all_paths = sar + opt
missing = [p for p in all_paths if not os.path.exists(p)]
print(f"\nMissing files: {len(missing)}")

Sample SAR paths:
['../v_2\\agri\\s1\\ROIs1868_summer_s1_59_p10.png', '../v_2\\agri\\s1\\ROIs1868_summer_s1_59_p100.png']

Sample Optical paths:
['../v_2\\agri\\s2\\ROIs1868_summer_s2_59_p10.png', '../v_2\\agri\\s2\\ROIs1868_summer_s2_59_p100.png']

Missing files: 0


# **Implementing Colorization Model**

In [193]:
import cv2
import numpy as np
import os
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
import matplotlib.pyplot as plt



## Preparing dataset for colorization model

In [195]:
import cv2
def rgb_to_lab_cv2(pil_img):
    # Fix 1: Correct variable names and conversion flags
    img_rgb = np.array(pil_img.convert("RGB")).astype("float32") / 255.0
    img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)  # Fixed conversion flag
    img_lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2Lab)   # Fixed conversion flag
    
    # Fix 2: Proper tensor conversion
    img_lab = torch.from_numpy(img_lab.transpose(2, 0, 1)).float()
    
    # Fix 3: Correct channel indexing
    L = (img_lab[0:1, ...] / 50.0) - 1.0    # [0,100] -> [-1,1]
    ab = (img_lab[1:3, ...] - 128.0) / 128.0  # Normalize ab channels
    
    return L, ab

In [196]:
img = Image.open(opt[0]).resize((256, 256))
L, ab = rgb_to_lab_cv2(img)
print("L:", L.shape, "ab:", ab.shape)


L: torch.Size([1, 256, 256]) ab: torch.Size([2, 256, 256])


In [197]:
def create_patches(img_tensor, patch_size=224):
    """Split tensor into patches (img_tensor: [C, H, W])"""
    patches = []
    c, h, w = img_tensor.shape
    
    for i in range(0, h, patch_size):
        for j in range(0, w, patch_size):
            if i + patch_size <= h and j + patch_size <= w:
                patch = img_tensor[:, i:i+patch_size, j:j+patch_size]
                patches.append(patch)
    return torch.stack(patches) if patches else None

In [198]:
class ColorizationDataset(Dataset):
    def __init__(self, color_paths, transform=None, img_size=224):
        self.color_paths = color_paths
        self.transform = transform
        self.img_size = img_size
        
    def __len__(self):
        return len(self.color_paths)
    
    def __getitem__(self, idx):
        # Fix 1: Variable name correction
        pil_img = Image.open(self.color_paths[idx]).convert("RGB")
        
        # Apply transformations before conversion to Lab
        if self.transform:
            pil_img = self.transform(pil_img)
        else:
            pil_img = pil_img.resize((self.img_size, self.img_size))
            
        L, ab = rgb_to_lab_cv2(pil_img)
        return L, ab

# Add this above dataset creation
from torchvision import transforms

train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2)
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224))
])

In [199]:
dataset = ColorizationDataset(opt[:1000])

# Split dataset into training and validation
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

val_size = int(0.2 * len(train_dataset))
train_size = len(train_dataset) - val_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

# Data loaders
# Update your data loaders:
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, 
                          num_workers=0, pin_memory=True)  # num_workers=0
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False,
                        num_workers=0, pin_memory=True)     # num_workers=0
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False,
                         num_workers=0, drop_last=True)      # num_workers=0


In [200]:
# Load a batch and print the shapes of the patches
for L_patches, ab_patches in train_loader:
    print(f"L_patches shape: {L_patches.shape}")
    print(f"ab_patches shape: {ab_patches.shape}")
    print(L_patches[0])
    break

L_patches shape: torch.Size([16, 1, 224, 224])
ab_patches shape: torch.Size([16, 2, 224, 224])
tensor([[[ 0.2850,  0.1689,  0.1709,  ...,  0.8491,  0.7821,  0.6760],
         [ 0.3840,  0.4141,  0.4686,  ...,  0.9592,  0.9485,  0.8821],
         [ 0.4254,  0.2477,  0.2405,  ...,  0.7600,  0.8180,  0.8152],
         ...,
         [-0.8616, -0.7687, -0.7786,  ...,  0.7795,  0.6974,  0.5947],
         [-0.8596, -0.8247, -0.8203,  ...,  0.5887,  0.7018,  0.7974],
         [-0.7771, -0.7858, -0.7433,  ..., -0.0098,  0.0870,  0.2286]]])


## Implementing the Encoder

In [202]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models import ResNet50_Weights, DenseNet121_Weights

class EnsembleEncoder(nn.Module):
    def __init__(self):
        super(EnsembleEncoder, self).__init__()

        # Load pre-trained ResNet50 and DenseNet121
        self.resnet50 = models.resnet50(weights=ResNet50_Weights.DEFAULT)
        self.densenet121 = models.densenet121(weights=DenseNet121_Weights.DEFAULT)

        self.resnet50 = nn.Sequential(*list(self.resnet50.children())[:-2])
        # self.densenet121 = nn.Sequential(*list(self.densenet121.children())[:-1])
        self.densenet121.classifier = nn.Identity()


        # Custom layers for fusion
        self.conv1x1_resnet50 = nn.ModuleList([
            nn.Conv2d(256, 128, kernel_size=1),
            nn.Conv2d(512, 256, kernel_size=1),
            nn.Conv2d(1024, 512, kernel_size=1),
            nn.Conv2d(2048, 1024, kernel_size=1)
        ])

        self.conv1x1_densenet121 = nn.ModuleList([
            nn.Conv2d(256, 128, kernel_size=1),
            nn.Conv2d(512, 256, kernel_size=1),
            nn.Conv2d(1024, 512, kernel_size=1),
            nn.Conv2d(1024, 1024, kernel_size=1)
        ])

        self.fusion_blocks = nn.ModuleList([
            self.fusion_block(128, 128),
            self.fusion_block(256, 256),
            self.fusion_block(512, 512),
            self.fusion_block(1024, 1024)
        ])

    # Fusion block
    def fusion_block(self, in_channels_resnet, in_channels_densenet):
        return nn.Sequential(
            nn.Conv2d(in_channels_resnet + in_channels_densenet, in_channels_resnet, kernel_size=1),
            nn.BatchNorm2d(in_channels_resnet),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Forward pass through ResNet50
        resnet_features = []
        resnet_input = x
        for i, layer in enumerate(self.resnet50.children()):
            resnet_input = layer(resnet_input)
            if i in [4, 5, 6, 7]:  # Extract features after specific layers
                resnet_features.append(self.conv1x1_resnet50[i-4](resnet_input))

        # Forward pass through DenseNet121
        densenet_features = []
        idx = 0
        densenet_input = x
        for i, layer in enumerate(self.densenet121.features):
            densenet_input = layer(densenet_input)
            if i in [ 4, 6, 8, 11]:
                densenet_features.append(self.conv1x1_densenet121[idx](densenet_input))
                idx += 1


        fused_features = []
        for i in range(4):
            fused = torch.cat((resnet_features[i], densenet_features[i]), dim=1)
            fused = self.fusion_blocks[i](fused)
            fused_features.append(fused)

        return fused_features

In [203]:
# # Test the fixed encoder
# encoder = EnsembleEncoder()
# dummy_input = torch.randn(1, 1, 224, 224)

# # Should output 3 feature maps with proper shapes
# features = encoder(dummy_input)
# print("\nFinal Feature Shapes:")
# for i, f in enumerate(features):
#     print(f"Level {i}: {f.shape}")

## Implementing the Decoder

In [205]:
import torch
import torch.nn as nn

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()

        # Decoder block 1: Takes input from Fusion Block 4
        self.decode1 = nn.Sequential(
            nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)  # 7x7 -> 14x14
        )

        # Decoder block 2: Takes input from Decoder Block 1 + Fusion Block 3 (512 + 512 channels)
        self.decode2 = nn.Sequential(
            nn.Conv2d(512 + 512, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)  # 14x14 -> 28x28
        )

        # Decoder block 3: Takes input from Decoder Block 2 + Fusion Block 2 (256 + 256 channels)
        self.decode3 = nn.Sequential(
            nn.Conv2d(256 + 256, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)  # 28x28 -> 56x56
        )

        # Decoder block 4: Takes input from Decoder Block 3 + Fusion Block 1 (128 + 128 channels)
        self.decode4 = nn.Sequential(
            nn.Conv2d(128 + 128, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)  # 56x56 -> 112x112
        )

        # Final decoder block: Reduce to 2 channels (ab channels)
        self.decode5 = nn.Sequential(
            nn.Conv2d(64, 2, kernel_size=3, stride=1, padding=1),            
            nn.BatchNorm2d(2),            
            nn.Tanh(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)  # 112x112 -> 224x224
        )

    def forward(self, features_7x7, features_14x14, features_28x28, features_56x56):
        x = self.decode1(features_7x7)
        x = torch.cat([x, features_14x14], dim=1)
        x = self.decode2(x)

        x = torch.cat([x, features_28x28], dim=1)
        x = self.decode3(x)

        x = torch.cat([x, features_56x56], dim=1)
        x = self.decode4(x)

        output = self.decode5(x)

        return output

## Checking our model

In [207]:
import torch
import torch.nn as nn

class ColorizationModel(nn.Module):
    def __init__(self, encoder, decoder):
        super(ColorizationModel, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, x):
        features_56x56, features_28x28, features_14x14, features_7x7 = self.encoder(x)

        output = self.decoder(features_7x7, features_14x14, features_28x28, features_56x56)

        return output

encoder = EnsembleEncoder().to(device)
decoder = Decoder().to(device)

model = ColorizationModel(encoder, decoder)

# input data
L_patches = torch.randn(1, 3, 224, 224).to(device)

output = model(L_patches)

print("Output shape:", output.shape)  # output shape should be [1, 2, 224, 224]

Output shape: torch.Size([1, 2, 224, 224])


## Training the model

In [None]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize encoder and decoder
encoder = EnsembleEncoder().to(device)
decoder = Decoder().to(device)

# Freeze the encoder parameters as they are pre-trained
for param in encoder.parameters():
    param.requires_grad = False

# Loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(decoder.parameters(), lr=0.001)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)

# Training loop
num_epochs = 20
best_val_loss = float('inf')

for epoch in range(num_epochs):
    # Training phase
    encoder.eval()
    decoder.train()
    running_loss = 0.0

    train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} (Training)")
    for i, (L_batch, ab_batch) in enumerate(train_bar):
        L, ab = L_batch.to(device), ab_batch.to(device)
        L = L.repeat(1, 3, 1, 1)

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass
        features_56x56, features_28x28, features_14x14, features_7x7 = encoder(L)
        output = decoder(features_7x7, features_14x14, features_28x28, features_56x56)

        # Compute loss
        loss = criterion(output, ab)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Accumulate running loss
        running_loss += loss.item()

        # Update progress bar
        train_bar.set_postfix(loss=f"{running_loss/(i+1):.4f}")

    # Validation phase
    decoder.eval()
    val_loss = 0.0

    val_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} (Validation)")
    with torch.no_grad():
        for i, (L_batch, ab_batch) in enumerate(val_bar):
            L, ab = L_batch.to(device), ab_batch.to(device)
            L = L.repeat(1, 3, 1, 1)  

            # Forward pass
            features_56x56, features_28x28, features_14x14, features_7x7 = encoder(L)
            output = decoder(features_7x7, features_14x14, features_28x28, features_56x56)

            # Compute validation loss
            loss = criterion(output, ab)
            val_loss += loss.item()

            val_bar.set_postfix(loss=f"{val_loss/(i+1):.4f}")

    # Calculate average losses
    avg_train_loss = running_loss / len(train_loader)
    avg_val_loss = val_loss / len(val_loader)

    print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")

    # Save the best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(decoder.state_dict(), 'model_1.pth')
        print(f"Model saved with validation loss: {best_val_loss:.4f}")

    # Step the scheduler
    scheduler.step(avg_val_loss)

print("Training complete.")


Epoch 1/20 (Training): 100%|██████████████████████████████████████████████| 40/40 [00:20<00:00,  1.98it/s, loss=0.8593]
Epoch 1/20 (Validation): 100%|████████████████████████████████████████████| 10/10 [00:04<00:00,  2.15it/s, loss=0.9242]


Epoch 1/20, Training Loss: 0.8593, Validation Loss: 0.9242
Model saved with validation loss: 0.9242


Epoch 2/20 (Training): 100%|██████████████████████████████████████████████| 40/40 [00:19<00:00,  2.01it/s, loss=0.7602]
Epoch 2/20 (Validation): 100%|████████████████████████████████████████████| 10/10 [00:04<00:00,  2.04it/s, loss=0.7754]


Epoch 2/20, Training Loss: 0.7602, Validation Loss: 0.7754
Model saved with validation loss: 0.7754


Epoch 3/20 (Training): 100%|██████████████████████████████████████████████| 40/40 [00:20<00:00,  1.94it/s, loss=0.6969]
Epoch 3/20 (Validation): 100%|████████████████████████████████████████████| 10/10 [00:04<00:00,  2.31it/s, loss=0.7475]


Epoch 3/20, Training Loss: 0.6969, Validation Loss: 0.7475
Model saved with validation loss: 0.7475


Epoch 4/20 (Training): 100%|██████████████████████████████████████████████| 40/40 [00:20<00:00,  1.98it/s, loss=0.6398]
Epoch 4/20 (Validation): 100%|████████████████████████████████████████████| 10/10 [00:04<00:00,  2.02it/s, loss=0.5258]


Epoch 4/20, Training Loss: 0.6398, Validation Loss: 0.5258
Model saved with validation loss: 0.5258


Epoch 5/20 (Training): 100%|██████████████████████████████████████████████| 40/40 [00:20<00:00,  1.92it/s, loss=0.5912]
Epoch 5/20 (Validation): 100%|████████████████████████████████████████████| 10/10 [00:05<00:00,  1.98it/s, loss=0.5705]


Epoch 5/20, Training Loss: 0.5912, Validation Loss: 0.5705


Epoch 6/20 (Training): 100%|██████████████████████████████████████████████| 40/40 [00:20<00:00,  1.94it/s, loss=0.5396]
Epoch 6/20 (Validation): 100%|████████████████████████████████████████████| 10/10 [00:04<00:00,  2.01it/s, loss=0.5202]


Epoch 6/20, Training Loss: 0.5396, Validation Loss: 0.5202
Model saved with validation loss: 0.5202


Epoch 7/20 (Training): 100%|██████████████████████████████████████████████| 40/40 [00:21<00:00,  1.90it/s, loss=0.4989]
Epoch 7/20 (Validation): 100%|████████████████████████████████████████████| 10/10 [00:05<00:00,  2.00it/s, loss=0.4004]


Epoch 7/20, Training Loss: 0.4989, Validation Loss: 0.4004
Model saved with validation loss: 0.4004


Epoch 8/20 (Training): 100%|██████████████████████████████████████████████| 40/40 [00:21<00:00,  1.90it/s, loss=0.4564]
Epoch 8/20 (Validation): 100%|████████████████████████████████████████████| 10/10 [00:04<00:00,  2.02it/s, loss=0.3989]


Epoch 8/20, Training Loss: 0.4564, Validation Loss: 0.3989
Model saved with validation loss: 0.3989


Epoch 9/20 (Training): 100%|██████████████████████████████████████████████| 40/40 [00:20<00:00,  1.93it/s, loss=0.4328]
Epoch 9/20 (Validation): 100%|████████████████████████████████████████████| 10/10 [00:05<00:00,  1.96it/s, loss=0.3395]


Epoch 9/20, Training Loss: 0.4328, Validation Loss: 0.3395
Model saved with validation loss: 0.3395


Epoch 10/20 (Training): 100%|█████████████████████████████████████████████| 40/40 [00:21<00:00,  1.90it/s, loss=0.3964]
Epoch 10/20 (Validation): 100%|███████████████████████████████████████████| 10/10 [00:05<00:00,  1.96it/s, loss=0.3355]


Epoch 10/20, Training Loss: 0.3964, Validation Loss: 0.3355
Model saved with validation loss: 0.3355


Epoch 11/20 (Training): 100%|█████████████████████████████████████████████| 40/40 [00:20<00:00,  1.97it/s, loss=0.3753]
Epoch 11/20 (Validation): 100%|███████████████████████████████████████████| 10/10 [00:05<00:00,  1.96it/s, loss=0.3331]


Epoch 11/20, Training Loss: 0.3753, Validation Loss: 0.3331
Model saved with validation loss: 0.3331


Epoch 12/20 (Training): 100%|█████████████████████████████████████████████| 40/40 [00:20<00:00,  1.91it/s, loss=0.3578]
Epoch 12/20 (Validation): 100%|███████████████████████████████████████████| 10/10 [00:05<00:00,  1.98it/s, loss=0.5091]


Epoch 12/20, Training Loss: 0.3578, Validation Loss: 0.5091


Epoch 13/20 (Training): 100%|█████████████████████████████████████████████| 40/40 [00:20<00:00,  1.92it/s, loss=0.3308]
Epoch 13/20 (Validation): 100%|███████████████████████████████████████████| 10/10 [00:04<00:00,  2.06it/s, loss=0.2376]


Epoch 13/20, Training Loss: 0.3308, Validation Loss: 0.2376
Model saved with validation loss: 0.2376


Epoch 14/20 (Training): 100%|█████████████████████████████████████████████| 40/40 [00:20<00:00,  1.92it/s, loss=0.3139]
Epoch 14/20 (Validation): 100%|███████████████████████████████████████████| 10/10 [00:05<00:00,  1.98it/s, loss=0.2680]


Epoch 14/20, Training Loss: 0.3139, Validation Loss: 0.2680


Epoch 15/20 (Training): 100%|█████████████████████████████████████████████| 40/40 [00:20<00:00,  1.91it/s, loss=0.2881]
Epoch 15/20 (Validation): 100%|███████████████████████████████████████████| 10/10 [00:05<00:00,  1.96it/s, loss=0.2029]


Epoch 15/20, Training Loss: 0.2881, Validation Loss: 0.2029
Model saved with validation loss: 0.2029


Epoch 16/20 (Training): 100%|█████████████████████████████████████████████| 40/40 [00:20<00:00,  1.94it/s, loss=0.3316]
Epoch 16/20 (Validation): 100%|███████████████████████████████████████████| 10/10 [00:05<00:00,  1.98it/s, loss=0.2753]


Epoch 16/20, Training Loss: 0.3316, Validation Loss: 0.2753


Epoch 17/20 (Training): 100%|█████████████████████████████████████████████| 40/40 [00:21<00:00,  1.90it/s, loss=0.2919]
Epoch 17/20 (Validation): 100%|███████████████████████████████████████████| 10/10 [00:05<00:00,  1.99it/s, loss=0.2165]


Epoch 17/20, Training Loss: 0.2919, Validation Loss: 0.2165


Epoch 18/20 (Training): 100%|█████████████████████████████████████████████| 40/40 [00:20<00:00,  1.94it/s, loss=0.2564]
Epoch 18/20 (Validation): 100%|███████████████████████████████████████████| 10/10 [00:05<00:00,  1.98it/s, loss=0.2039]


Epoch 18/20, Training Loss: 0.2564, Validation Loss: 0.2039


Epoch 19/20 (Training):  32%|██████████████▋                              | 13/40 [00:06<00:14,  1.93it/s, loss=0.2425]

In [None]:
torch.cuda.empty_cache()

## Load the model

In [None]:
decoder = Decoder().to(device)
decoder.load_state_dict(torch.load('model_1.pth', map_location=device))

## Inference

In [None]:
# Get a batch from the test loader
dataiter = iter(test_loader)
L_batch, ab_batch = next(dataiter)
L_batch, ab_batch = next(dataiter)
L_batch, ab_batch = L_batch.to(device), ab_batch.to(device)
L_batch = L_batch.repeat(1, 3, 1, 1)

encoder.eval()
decoder.eval()
with torch.no_grad():
    features_56x56, features_28x28, features_14x14, features_7x7 = encoder(L_batch)

    predicted_ab = decoder(features_7x7, features_14x14, features_28x28, features_56x56)


In [None]:
L_batch = L_batch[:, 0, :, :]
L_batch = L_batch.unsqueeze(1)

L_batch = (L_batch + 1) * 0.5 * 100
predicted_ab = ((predicted_ab + 1) * 0.5 * (127 + 128)) - 128
ab_batch = ((ab_batch + 1) * 0.5 * (127 + 128)) - 128

# Combine L and ab channels
predicted_lab = torch.cat([L_batch, predicted_ab], dim=1)
real_lab = torch.cat([L_batch, ab_batch], dim=1)


predicted_lab = predicted_lab.cpu().numpy()
real_lab = real_lab.cpu().numpy()

In [None]:
import numpy as np
from skimage.color import lab2rgb
import matplotlib.pyplot as plt
import cv2


# Iterate over the batch
for i in range(8):

    lab_image = predicted_lab[i]
    real_img = real_lab[i]

    # Transpose to (height, width, 3) for skimage
    lab_image = lab_image.transpose(1, 2, 0).astype(np.float64)
    real_img = real_img.transpose(1, 2, 0).astype(np.float64)

    rgb_image = lab2rgb(lab_image)
    real_rgb = lab2rgb(real_img)

    plt.subplot(1, 2, 1)
    plt.imshow(real_rgb)
    plt.axis('off')
    plt.title('Real Color Image')
    plt.subplot(1, 2, 2)
    plt.imshow(rgb_image)
    plt.title('Predicted Color Image')
    plt.axis('off')
    plt.show()

In [None]:
# real_rgb = to_rgb_safe(L, ab_real)
# plt.imshow(real_rgb)
# plt.title("Ground Truth Check")
# plt.axis("off")
# plt.show()


In [None]:
# # Manual dummy LAB: mid-gray with blue tint
# lab = np.zeros((224, 224, 3), dtype=np.float32)
# lab[:, :, 0] = 50         # L: mid-brightness
# lab[:, :, 1] = 0          # a: neutral
# lab[:, :, 2] = -50        # b: blue tint

# rgb = lab2rgb(lab)

# plt.imshow(rgb)
# plt.title("Manual LAB to RGB Test")
# plt.axis("off")
# plt.show()


In [None]:
# import numpy as np
# from skimage.color import lab2rgb
# import matplotlib.pyplot as plt
# import torch

# def to_rgb_safe(L_tensor, ab_tensor):
#     # Remove batch dim
#     L = L_tensor.squeeze(0).cpu().numpy()
#     ab = ab_tensor.squeeze(0).cpu().numpy()

#     # Denormalize properly
#     L = (L + 1) * 50                # [0, 100]
#     ab = ab * 127.5                 # [-128, 127]

#     # Stack to (H, W, 3)
#     lab = np.zeros((224, 224, 3), dtype=np.float32)
#     lab[:, :, 0] = L[0]            # L channel
#     lab[:, :, 1] = ab[0]           # a channel
#     lab[:, :, 2] = ab[1]           # b channel

#     # LAB to RGB
#     rgb = lab2rgb(lab)
#     return rgb

# # Choose a sample
# i = 0
# L = L_batch[i].unsqueeze(0)
# ab_real = ab_batch[i].unsqueeze(0)
# ab_pred = predicted_ab[i].unsqueeze(0)

# # Convert
# real_rgb = to_rgb_safe(L, ab_real)
# pred_rgb = to_rgb_safe(L, ab_pred)

# # Plot
# plt.figure(figsize=(8, 4))
# plt.subplot(1, 2, 1)
# plt.imshow(real_rgb)
# plt.title("Real Color Image")
# plt.axis("off")

# plt.subplot(1, 2, 2)
# plt.imshow(pred_rgb)
# plt.title("Predicted Color Image")
# plt.axis("off")

# plt.show()


## Evaulting the Model

In [None]:
def prediction(model, test_loader):
    encoder.eval()
    model.eval()
    original_images = []
    predicted_images = []

    for L_batch, ab_batch in tqdm(test_loader):
        L_batch, ab_batch = L_batch.to(device), ab_batch.to(device)
        input = L_batch.repeat(1, 3, 1, 1)

        with torch.no_grad():
            features_56x56, features_28x28, features_14x14, features_7x7 = encoder(input)

            predicted_ab = model(features_7x7, features_14x14, features_28x28, features_56x56)            

        L_batch = (L_batch + 1) * 0.5 * 100        
        predicted_ab = ((predicted_ab + 1) * 0.5 * (127 + 128)) - 128
        ab_batch = ((ab_batch + 1) * 0.5 * (127 + 128)) - 128

        # Combine L and ab channels
        predicted_lab = torch.cat([L_batch, predicted_ab], dim=1)
        actual_lab = torch.cat([L_batch, ab_batch], dim=1)

        predicted_lab = predicted_lab.cpu().numpy()
        actual_lab = actual_lab.cpu().numpy()

        predicted_images.extend(predicted_lab)
        original_images.extend(actual_lab)

    return original_images, predicted_images

original_images, predicted_images = prediction(decoder, test_loader)

In [None]:
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.color import lab2rgb
import numpy as np

def evaluate_model(original_images, predicted_images):

    total_ssim = 0.0
    total_psnr = 0.0
    total_samples = 0
    for original_img, predicted_img in zip(original_images, predicted_images):
        original_img = lab2rgb(original_img.transpose(1, 2, 0))
        predicted_img = lab2rgb(predicted_img.transpose(1, 2, 0))

        ssim_value = ssim(original_img, predicted_img, multichannel=True, channel_axis=2, data_range=1.0)
        psnr_value = psnr(original_img, predicted_img, data_range=1.0)
        total_ssim += ssim_value
        total_psnr += psnr_value
        total_samples += 1

    average_ssim = total_ssim / total_samples
    average_psnr = total_psnr / total_samples

    return average_ssim, average_psnr

ssim_value, psnr_value = evaluate_model(original_images, predicted_images)
print(f"Average SSIM: {ssim_value:.4f}")
print(f"Average PSNR: {psnr_value:.4f}")



# Results

We are getting quite good results but after observing carefully we saw that the model is learning green color more than the other colors.