In [1]:
import os
from PIL import Image
import numpy as np
import torch
from torch.utils.data import Dataset
from skimage.color import rgb2lab, lab2rgb
from torch import nn
import torchvision.transforms as T
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

from torch.utils.data import random_split
import torch.optim as optim
from torchsummary import summary

In [2]:
torch.cuda.empty_cache()

In [3]:
def lab_to_rgb(L, ab):
    """
    L is (1,H,W), ab is (2,H,W), both torch tensors.
    L is in [0,1], ab is in [-1,1].
    We convert them back to LAB then RGB.
    """
    L = L[0].cpu().numpy()          # (H,W)
    ab = ab.cpu().numpy().transpose(1,2,0)  # (H,W,2)

    # Undo normalization
    L = L * 100
    ab = ab * 128

    lab = np.concatenate([L[..., np.newaxis], ab], axis=2)  # (H,W,3)
    rgb = np.clip(lab2rgb(lab.astype(np.float64)), 0, 1)
    return rgb

In [4]:
class RGB2LabDataset(Dataset):
    def __init__(self, image_dir, image_size=256, extensions=('.jpg','.jpeg','.png','.bmp','.webp')):
        self.image_paths = [
            os.path.join(image_dir, f)
            for f in os.listdir(image_dir)
            if f.lower().endswith(extensions)
        ]
        if len(self.image_paths) == 0:
            raise RuntimeError(f"No images found in {image_dir}")
        self.transform = T.Compose([
            T.Resize((image_size, image_size)),
            T.ToTensor(),
        ])

    def __len__(self): return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        x = self.transform(img)
        x_np = x.permute(1,2,0).numpy().astype(np.float32)
        lab = rgb2lab(x_np).astype("float32")

        L  = lab[...,0] / 100.0
        ab = lab[...,1:] / 128.0

        L  = torch.from_numpy(L).unsqueeze(0)
        ab = torch.from_numpy(ab).permute(2,0,1)
        return L, ab

In [17]:
class ResNet_UNet_AE(nn.Module):
    def __init__(self):
        super().__init__()

        # -----------------------------------
        # RESNET34 BACKBONE (FROZEN)
        # -----------------------------------
        resnet = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
        self.resnet_conv = nn.Sequential(*list(resnet.children())[:-2])  # → (B,512,H/32,W/32)

        for p in self.resnet_conv.parameters():
            p.requires_grad = False

        # -----------------------------------
        # UNET ENCODER
        # -----------------------------------
        self.enc1 = self.block(1, 32)      # H
        self.enc2 = self.block(32, 64)     # H/2
        self.enc3 = self.block(64, 128)    # H/4
        self.enc4 = self.block(128, 256)   # H/8
        self.enc5 = self.block(256, 512)   # H/16
        self.pool = nn.MaxPool2d(2)

        # -----------------------------------
        # BOTTLENECK FUSION
        # (ResNet bottleneck + UNet bottleneck)
        # -----------------------------------
        self.fuse = nn.Conv2d(512 + 512, 512, kernel_size=1)

        # -----------------------------------
        # UNET DECODER (mirrors encoder)
        # -----------------------------------
        self.up5 = self.up_block(512, 512)    # H/16, matches e5
        self.up4 = self.up_block(512, 256)     # H/8, matches e4
        self.up3 = self.up_block(256, 128)     # H/4, matches e3
        self.up2 = self.up_block(128, 64)      # H/2, matches e2
        self.up1 = self.up_block(64, 32)       # H,   matches e1

        # FINAL OUTPUT LAYER
        self.out_conv = nn.Conv2d(32, 1, kernel_size=1)


    # -----------------------------
    # CONV BLOCK
    # -----------------------------
    def block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )


    # -----------------------------
    # UPSAMPLING BLOCK
    # -----------------------------
    def up_block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )


    # -----------------------------
    # FORWARD PASS
    # -----------------------------
    def forward(self, x):

        # -----------------------------
        # UNET ENCODER
        # -----------------------------
        e1 = self.enc1(x)               # H
        e2 = self.enc2(self.pool(e1))   # H/2
        e3 = self.enc3(self.pool(e2))   # H/4
        e4 = self.enc4(self.pool(e3))   # H/8
        e5 = self.enc5(self.pool(e4))   # H/16
        e6 = self.pool(e5)              # H/32  (UNet bottleneck)

        # -----------------------------
        # RESNET ENCODER (frozen)
        # -----------------------------
        x_res = x.repeat(1, 3, 1, 1)    # convert 1-channel → fake RGB
        res = self.resnet_conv(x_res)   # (B,512,H/32,W/32)

        # -----------------------------
        # BOTTLENECK FUSION
        # -----------------------------
        fused = torch.cat([e6, res], dim=1)  # (B,1024,H/32,W/32)
        bottleneck = self.fuse(fused)        # (B,512,H/32,W/32)

        # -----------------------------
        # DECODER WITH SKIPS
        # -----------------------------
        d5 = self.up5(bottleneck)    # H/16
        d5 = d5 + e5                 # skip

        d4 = self.up4(d5)            # H/8
        d4 = d4 + e4

        d3 = self.up3(d4)            # H/4
        d3 = d3 + e3

        d2 = self.up2(d3)            # H/2
        d2 = d2 + e2

        d1 = self.up1(d2)            # H
        d1 = d1 + e1

        # final conv
        out = self.out_conv(d1)      # (B,1,H,W)

        return out


In [18]:
def train(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for L, ab in loader:
        L = L.to(device)
        ab = ab.to(device)
        optimizer.zero_grad()
        outputs = model(L)
        loss = criterion(outputs, ab)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(loader)

def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for L, ab in loader:
            L = L.to(device)
            ab = ab.to(device)
            outputs = model(L)
            loss = criterion(outputs, ab)
            running_loss += loss.item()
    return running_loss / len(loader)

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [20]:
dataset = RGB2LabDataset(".", image_size=256)

# Split dataset into training and evaluation sets
train_size = int(0.8 * len(dataset))
eval_size = len(dataset) - train_size
train_dataset, eval_dataset = random_split(dataset, [train_size, eval_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=32, shuffle=False)

In [21]:
from torchvision import models

resnet = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
resnet_conv = nn.Sequential(*list(resnet.children())[:-2])

for p in resnet_conv.parameters():
    p.requires_grad = False

In [22]:
model = ResNet_UNet_AE().to(device)
criterion = nn.SmoothL1Loss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
summary(model, input_size=(1, 256, 256), device=str(device))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 256, 256]             320
       BatchNorm2d-2         [-1, 32, 256, 256]              64
              ReLU-3         [-1, 32, 256, 256]               0
            Conv2d-4         [-1, 32, 256, 256]           9,248
       BatchNorm2d-5         [-1, 32, 256, 256]              64
              ReLU-6         [-1, 32, 256, 256]               0
         MaxPool2d-7         [-1, 32, 128, 128]               0
            Conv2d-8         [-1, 64, 128, 128]          18,496
       BatchNorm2d-9         [-1, 64, 128, 128]             128
             ReLU-10         [-1, 64, 128, 128]               0
           Conv2d-11         [-1, 64, 128, 128]          36,928
      BatchNorm2d-12         [-1, 64, 128, 128]             128
             ReLU-13         [-1, 64, 128, 128]               0
        MaxPool2d-14           [-1, 64,

In [1]:
num_epochs = 10
for epoch in range(num_epochs):
    train_loss = train(model, train_loader, criterion, optimizer, device)
    eval_loss = evaluate(model, eval_loader, criterion, device)
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Eval Loss: {eval_loss:.4f}")

NameError: name 'train' is not defined

In [None]:
import matplotlib.pyplot as plt

def visualize_batches(model, loader, device, num_batches=3, num_show_per_batch=5):
    model.eval()
    fig_rows = num_batches * num_show_per_batch
    plt.figure(figsize=(12, fig_rows * 3))

    row = 0
    with torch.no_grad():
        for b_idx, (L_batch, ab_batch) in enumerate(loader):
            if b_idx >= num_batches:
                break

            L_batch, ab_batch = L_batch.to(device), ab_batch.to(device)
            pred_ab_batch = model(L_batch)

            # how many images to show from this batch
            n = min(num_show_per_batch, L_batch.size(0))
            for i in range(n):
                L = L_batch[i]
                ab_gt = ab_batch[i]
                ab_pred = pred_ab_batch[i]

                rgb_gt = lab_to_rgb(L, ab_gt)
                rgb_pred = lab_to_rgb(L, ab_pred)
                gray = L[0].cpu().numpy()

                plt.subplot(fig_rows, 3, row*3 + 1)
                plt.imshow(gray, cmap='gray')
                plt.title(f"Batch {b_idx} – L")
                plt.axis('off')

                plt.subplot(fig_rows, 3, row*3 + 2)
                plt.imshow(rgb_gt)
                plt.title("Ground Truth")
                plt.axis('off')

                plt.subplot(fig_rows, 3, row*3 + 3)
                plt.imshow(rgb_pred)
                plt.title("Prediction")
                plt.axis('off')

                row += 1

    plt.tight_layout(pad=1.5)
    plt.show()

# ==== Example call ====
visualize_batches(model, eval_loader, device, num_batches=5, num_show_per_batch=10)