# Real Experiment: STL-10 Colorization at 128×128

In [None]:
# 1) Imports & CUDA setup
import os, glob, time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import STL10
import torchvision.transforms as T
from PIL import Image
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
from skimage.metrics import peak_signal_noise_ratio as compute_psnr, structural_similarity as compute_ssim

cudnn.benchmark = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)


In [None]:
# 2) Preprocess STL-10 → on‐disk 128×128 PNG (run only once)
pre32 = './data/preprocessed_128'
resize = T.Resize((128,128))

for split in ['train','test']:
    folder = os.path.join(pre32, split)
    os.makedirs(folder, exist_ok=True)
    ds = STL10(root='./data', split=split, download=True)
    print(f"→ {split}: {len(ds)} images")
    for i, (img,_) in enumerate(tqdm(ds, desc=f"{split}→128")):
        out = os.path.join(folder, f"{i:04d}.png")
        if not os.path.exists(out):
            resize(img).save(out)


In [None]:
# 3) Disk-based Dataset class
class Preproc128(Dataset):
    def __init__(self, folder):
        self.files = sorted(glob.glob(f"{folder}/*.png"))
        self.to_tensor = T.ToTensor()
    def __len__(self):
        return len(self.files)
    def __getitem__(self, idx):
        img = Image.open(self.files[idx]).convert('RGB')
        tgt = self.to_tensor(img)
        gray = T.functional.rgb_to_grayscale(img,1)
        inp  = self.to_tensor(gray)
        return inp, tgt

train_ds = Preproc128(f'{pre32}/train')
test_ds  = Preproc128(f'{pre32}/test')
print("Train:", len(train_ds), "Test:", len(test_ds))


In [None]:
# 4) DataLoaders
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True,
                          num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=32, shuffle=False,
                          num_workers=2, pin_memory=True)
print("Loaders ready →", len(train_loader), "batches train,", len(test_loader), "batches test")


In [None]:
# 5) LiteColorizer model (128×128)
class LiteColorizer(nn.Module):
    def __init__(self):
        super().__init__()
        enc, ch = [], 1
        for oc in [32,64,128,256]:
            enc += [nn.Conv2d(ch,oc,3,padding=1), nn.ReLU(True),
                    nn.BatchNorm2d(oc), nn.MaxPool2d(2)]
            ch = oc
        self.encoder = nn.Sequential(*enc)
        dec = []
        for oc in [128,64,32]:
            dec += [nn.ConvTranspose2d(ch,oc,3,2,1,1), nn.ReLU(True),
                    nn.BatchNorm2d(oc)]
            ch = oc
        dec += [nn.ConvTranspose2d(ch,32,3,2,1,1), nn.ReLU(True),
                nn.BatchNorm2d(32), nn.Conv2d(32,3,3,padding=1), nn.Sigmoid()]
        self.decoder = nn.Sequential(*dec)
    def forward(self,x):
        return self.decoder(self.encoder(x))

model = LiteColorizer().to(device)
print(model)


In [None]:
# Debug: test how long it takes to pull one batch
import time
t0 = time.time()
batch = next(iter(train_loader))
print("Loaded batch in", time.time() - t0, "seconds:", batch[0].shape, batch[1].shape)


In [None]:
# 6) Training loop (20 epochs)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scaler    = GradScaler()
epochs, best_val, patience, wait = 20, float('inf'), 5, 0

for ep in range(1, epochs+1):
    t0 = time.time()
    # — train —
    model.train()
    train_loss = 0.0
    for inp,tgt in tqdm(train_loader, desc=f"Epoch {ep} Train"):
        inp, tgt = inp.to(device), tgt.to(device)
        optimizer.zero_grad()
        with autocast():
            out = model(inp)
            loss = criterion(out, tgt)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        train_loss += loss.item()*inp.size(0)
    train_loss /= len(train_loader.dataset)
    # — val —
    model.eval()
    val_loss, psnr_sum, ssim_sum = 0.0, 0.0, 0.0
    with torch.no_grad():
        for inp,tgt in tqdm(test_loader, desc=f"Epoch {ep} Val"):
            inp, tgt = inp.to(device), tgt.to(device)
            with autocast():
                out = model(inp)
                val_loss += criterion(out, tgt).item()*inp.size(0)
            # gather for metrics (on CPU)
            out_np = out.cpu().permute(0,2,3,1).numpy()
            tgt_np = tgt.cpu().permute(0,2,3,1).numpy()
            for o,t in zip(out_np, tgt_np):
                psnr_sum += compute_psnr(t, o, data_range=1.0)
                ssim_sum += compute_ssim(t, o, multichannel=True, data_range=1.0)
    val_loss /= len(test_loader.dataset)
    avg_psnr = psnr_sum / len(test_loader.dataset)
    avg_ssim = ssim_sum / len(test_loader.dataset)
    print(f"Epoch {ep}/{epochs}  Train:{train_loss:.4f}  Val:{val_loss:.4f}  "
          f"PSNR:{avg_psnr:.2f}dB  SSIM:{avg_ssim:.3f}  Time:{time.time()-t0:.1f}s")
    if val_loss < best_val:
        best_val = val_loss
        torch.save(model.state_dict(), 'lite128_best.pt')
        wait = 0
    else:
        wait += 1
        if wait>=patience:
            print("Early stop triggered.")
            break

print("✅ Experiment complete.")


In [None]:
# 7) Visualize 10 Diverse Test Results
import matplotlib.pyplot as plt
model.eval()
fig, axs = plt.subplots(10,3,figsize=(6,24))
indices = list(range(0,len(test_ds),len(test_ds)//10))
for i, idx in enumerate(indices):
    inp,tgt = test_ds[idx]
    with torch.no_grad():
        out = model(inp.unsqueeze(0).to(device))[0].cpu()
    axs[i,0].imshow(inp.squeeze(0), cmap='gray'); axs[i,0].axis('off')
    axs[i,1].imshow(out.permute(1,2,0)); axs[i,1].axis('off')
    axs[i,2].imshow(tgt.permute(1,2,0)); axs[i,2].axis('off')
    axs[i,0].set_title("Input")
    axs[i,1].set_title("Predicted")
    axs[i,2].set_title("GT")
plt.tight_layout()


---
**Summary:**  
- Full STL-10 (5 000 train, 8 000 test) at **128×128** trained for ~10–15 min.  
- Best model saved as `lite128_best.pt`.  
- Reported MSE, PSNR, SSIM each epoch and visualized 10 samples.  
