# SAR Colorization V5: ColorSAR Protocol + Novel CBAM Architecture

## Key Insight
ColorSAR achieves SSIM ~0.94 because they preserve the **Ground Truth's L (Lightness)** channel during evaluation. The model only predicts colors (ab), not structure.

## Our Approach
- **Same Protocol**: Predict ab channels from SAR, reconstruct using GT's L channel
- **Novel Architecture**: ResNet34 + **CBAM Attention** (vs their ResNet50+DenseNet ensemble)
- **Fair Comparison**: Same evaluation = same SSIM potential


In [None]:
!pip install -q torch torchvision matplotlib numpy tqdm scikit-image

In [None]:
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import glob, os
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import models, transforms
from skimage import color
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

In [None]:
from google.colab import drive
import shutil

drive.mount('/content/drive')
DRIVE_BASE = '/content/drive/MyDrive/SAR image colorization/SAR dataset'
LOCAL_DIR = '/content/patches'
VIS_DIR = os.path.join(DRIVE_BASE, 'visualizations_v5')
os.makedirs(LOCAL_DIR, exist_ok=True)
os.makedirs(VIS_DIR, exist_ok=True)

ZIP = os.path.join(DRIVE_BASE, 'patches.zip')
if os.path.exists(ZIP):
    shutil.copy2(ZIP, '/content/patches.zip')
    !unzip -q /content/patches.zip -d {LOCAL_DIR}
    npz = [os.path.join(d,f) for d,_,fs in os.walk(LOCAL_DIR) for f in fs if f.endswith('.npz')]
    if npz: LOCAL_DIR = os.path.dirname(npz[0])
    print(f'Dataset: {len(npz)} files')

In [None]:
class LabDataset(Dataset):
    '''Returns: SAR tensor, ab tensor (target), L tensor (for eval), RGB tensor (for metrics)'''
    def __init__(self, root, size=128):
        self.files = glob.glob(os.path.join(root, '*.npz'))
        self.size = size
        self.resize = transforms.Resize((size, size), antialias=True)

    def __len__(self): return len(self.files)

    def __getitem__(self, idx):
        try:
            data = np.load(self.files[idx])
            sar = data.get('sar_denoised', data.get('sar_raw', data[data.files[0]]))
            rgb = data['rgb']
            if sar.ndim == 3: sar = sar[0]

            # Normalize to 0-255 uint8
            sar = ((sar * 255) if sar.max() <= 1 else sar).astype('uint8')
            rgb = ((rgb * 255) if rgb.max() <= 1 else rgb).astype('uint8')

            # Resize
            sar_t = self.resize(transforms.ToTensor()(Image.fromarray(sar)))
            rgb_t = self.resize(transforms.ToTensor()(Image.fromarray(rgb)))

            # Convert RGB to Lab
            rgb_np = rgb_t.permute(1,2,0).numpy()
            lab = color.rgb2lab(rgb_np)

            # L: 0-100 -> 0-1, ab: ~-128..128 -> -1..1
            L = torch.tensor(lab[:,:,0:1] / 50.0 - 1.0).permute(2,0,1).float()  # -1 to 1
            ab = torch.tensor(lab[:,:,1:] / 128.0).permute(2,0,1).float()  # ~-1 to 1

            # SAR: -1 to 1
            sar_norm = sar_t * 2 - 1

            return sar_norm.float(), ab, L, rgb_t
        except Exception as e:
            return self.__getitem__((idx+1) % len(self.files))

BATCH = 32
ds = LabDataset(LOCAL_DIR)
train_ds, val_ds = random_split(ds, [int(0.9*len(ds)), len(ds)-int(0.9*len(ds))])
train_dl = DataLoader(train_ds, BATCH, shuffle=True, num_workers=4)
val_dl = DataLoader(val_ds, BATCH, shuffle=False, num_workers=2)
print(f'Train: {len(train_ds)}, Val: {len(val_ds)}')

In [None]:
# CBAM Module (Our Novel Contribution)
class CBAM(nn.Module):
    def __init__(self, c, r=16, k=7):
        super().__init__()
        self.ca = nn.Sequential(
            nn.AdaptiveAvgPool2d(1), nn.Flatten(),
            nn.Linear(c, c//r), nn.ReLU(),
            nn.Linear(c//r, c), nn.Sigmoid()
        )
        self.sa = nn.Sequential(
            nn.Conv2d(2, 1, k, padding=k//2), nn.Sigmoid()
        )
    def forward(self, x):
        # Channel attention
        ca = self.ca(x).unsqueeze(-1).unsqueeze(-1)
        x = x * ca
        # Spatial attention
        sa = self.sa(torch.cat([x.mean(1,keepdim=True), x.max(1,keepdim=True)[0]], 1))
        return x * sa

class ColorizerNet(nn.Module):
    '''ResNet34 + CBAM encoder, predicts 2 channels (ab)'''
    def __init__(self):
        super().__init__()
        base = models.resnet34(weights='IMAGENET1K_V1')
        self.adapt = nn.Conv2d(1, 64, 7, 2, 3, bias=False)
        self.bn1, self.relu = base.bn1, base.relu
        self.pool = base.maxpool
        self.e1, self.e2, self.e3, self.e4 = base.layer1, base.layer2, base.layer3, base.layer4

        # CBAM at each level (NOVEL)
        self.cbam1 = CBAM(64)
        self.cbam2 = CBAM(128)
        self.cbam3 = CBAM(256)
        self.cbam4 = CBAM(512)

        # Decoder
        self.up4 = nn.Sequential(nn.Conv2d(512, 256, 3, 1, 1), nn.ReLU(), nn.Upsample(scale_factor=2))
        self.up3 = nn.Sequential(nn.Conv2d(512, 128, 3, 1, 1), nn.ReLU(), nn.Upsample(scale_factor=2))
        self.up2 = nn.Sequential(nn.Conv2d(256, 64, 3, 1, 1), nn.ReLU(), nn.Upsample(scale_factor=2))
        self.up1 = nn.Sequential(nn.Conv2d(128, 64, 3, 1, 1), nn.ReLU(), nn.Upsample(scale_factor=2))
        self.out = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(64, 32, 3, 1, 1), nn.ReLU(),
            nn.Conv2d(32, 2, 1), nn.Tanh()  # Output ab channels
        )

    def forward(self, x):
        x0 = self.relu(self.bn1(self.adapt(x)))  # 64, H/2
        x1 = self.cbam1(self.e1(self.pool(x0)))  # 64, H/4
        x2 = self.cbam2(self.e2(x1))             # 128, H/8
        x3 = self.cbam3(self.e3(x2))             # 256, H/16
        x4 = self.cbam4(self.e4(x3))             # 512, H/32

        d4 = self.up4(x4)                        # 256, H/16
        d3 = self.up3(torch.cat([d4, x3], 1))    # 128, H/8
        d2 = self.up2(torch.cat([d3, x2], 1))    # 64, H/4
        d1 = self.up1(torch.cat([d2, x1], 1))    # 64, H/2
        return self.out(d1)                       # 2, H

model = ColorizerNet().to(device)
opt = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()
print('Model ready: ResNet34 + CBAM -> ab prediction')

In [None]:
def lab_to_rgb(L, ab):
    '''Reconstruct RGB from L (GT) and ab (predicted)'''
    # L: -1..1 -> 0..100, ab: -1..1 -> -128..128
    L_np = ((L.squeeze().cpu().numpy() + 1) * 50).clip(0, 100)
    ab_np = (ab.squeeze().cpu().permute(1,2,0).numpy() * 128).clip(-128, 128)

    lab = np.zeros((*L_np.shape, 3))
    lab[:,:,0] = L_np
    lab[:,:,1:] = ab_np
    return color.lab2rgb(lab).clip(0, 1)

def evaluate(epoch):
    model.eval()
    sar, ab_gt, L_gt, rgb_gt = next(iter(val_dl))
    sar, ab_gt, L_gt = sar.to(device), ab_gt.to(device), L_gt.to(device)

    with torch.no_grad():
        ab_pred = model(sar)

    # Metrics using GT's L channel (ColorSAR protocol)
    psnrs, ssims = [], []
    fig, ax = plt.subplots(4, 3, figsize=(12, 16))

    for i in range(min(4, sar.shape[0])):
        # Reconstruct using GT's L + predicted ab
        pred_rgb = lab_to_rgb(L_gt[i], ab_pred[i:i+1])
        true_rgb = rgb_gt[i].permute(1,2,0).numpy()

        p = psnr(true_rgb, pred_rgb, data_range=1.0)
        s = ssim(true_rgb, pred_rgb, data_range=1.0, channel_axis=-1, win_size=3)
        psnrs.append(p); ssims.append(s)

        ax[i,0].imshow((sar[i,0].cpu()+1)/2, cmap='gray'); ax[i,0].set_title('SAR')
        ax[i,1].imshow(true_rgb); ax[i,1].set_title('Ground Truth')
        ax[i,2].imshow(pred_rgb); ax[i,2].set_title(f'Pred SSIM:{s:.2f}')
        for a in ax[i]: a.axis('off')

    print(f'Epoch {epoch} | PSNR: {np.mean(psnrs):.2f} | SSIM: {np.mean(ssims):.4f}')
    plt.tight_layout()
    plt.savefig(os.path.join(VIS_DIR, f'epoch_{epoch}.png'))
    plt.show()
    model.train()

# Training
EPOCHS = 50
for epoch in range(EPOCHS):
    loop = tqdm(train_dl, desc=f'Epoch {epoch}')
    for sar, ab, L, _ in loop:
        sar, ab = sar.to(device), ab.to(device)
        opt.zero_grad()
        pred = model(sar)
        loss = criterion(pred, ab)
        loss.backward()
        opt.step()
        loop.set_postfix(loss=loss.item())

    if epoch % 1 == 0: evaluate(epoch)
    if epoch % 5 == 0:
        torch.save(model.state_dict(), os.path.join(DRIVE_BASE, f'model_v5_{epoch}.pth'))

print('Training Complete!')