In [1]:
# 1) Imports & CUDA Setup
import os, glob, time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import STL10
import torchvision.transforms as T
from PIL import Image
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler

cudnn.benchmark = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)


Using device: cuda


In [2]:
# 2) On-Disk Preprocessing to 64×64 (run once)
preproc_root = './data/preprocessed_64'
resize64 = T.Resize((64,64))

for split in ['train', 'test']:
    dst = os.path.join(preproc_root, split)
    os.makedirs(dst, exist_ok=True)
    ds = STL10(root='./data', split=split, download=True)
    print(f"Preprocessing {split} ({len(ds)} images)…")
    for idx, (img, _) in enumerate(tqdm(ds, desc=f'{split}→64×64')):
        out_path = os.path.join(dst, f'{idx:04d}.png')
        if not os.path.exists(out_path):
            resize64(img).save(out_path)


Preprocessing train (5000 images)…


train→64×64: 100%|██████████| 5000/5000 [00:05<00:00, 846.19it/s]


Preprocessing test (8000 images)…


test→64×64: 100%|██████████| 8000/8000 [00:09<00:00, 856.04it/s]


In [3]:
# 3) Disk-Based Dataset (64×64)
class Preproc64Dataset(Dataset):
    def __init__(self, folder):
        self.paths = sorted(glob.glob(f"{folder}/*.png"))
        self.to_tensor = T.ToTensor()
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert('RGB')
        tgt = self.to_tensor(img)                                    # [3,64,64]
        gray = T.functional.rgb_to_grayscale(img, 1)                  # PIL gray
        inp  = self.to_tensor(gray)                                  # [1,64,64]
        return inp, tgt

train_ds = Preproc64Dataset(f'{preproc_root}/train')
test_ds  = Preproc64Dataset(f'{preproc_root}/test')
print("Train:", len(train_ds), "Test:", len(test_ds))


Train: 5000 Test: 8000


In [4]:
# 4) DataLoaders
train_loader = DataLoader(
    train_ds, batch_size=16, shuffle=True,
    num_workers=2, pin_memory=True, persistent_workers=True
)
test_loader = DataLoader(
    test_ds, batch_size=16, shuffle=False,
    num_workers=2, pin_memory=True, persistent_workers=True
)


In [5]:
# 5) TinyColorizer Model
class TinyColorizer(nn.Module):
    def __init__(self):
        super().__init__()
        enc, ch = [], 1
        for oc in [16,32,64,128]:
            enc += [nn.Conv2d(ch,oc,3,padding=1), nn.ReLU(True), nn.BatchNorm2d(oc), nn.MaxPool2d(2)]
            ch = oc
        self.encoder = nn.Sequential(*enc)
        dec = []
        for oc in [64,32,16]:
            dec += [nn.ConvTranspose2d(ch,oc,3,2,1,1), nn.ReLU(True), nn.BatchNorm2d(oc)]
            ch = oc
        dec += [
            nn.ConvTranspose2d(ch,16,3,2,1,1), nn.ReLU(True), nn.BatchNorm2d(16),
            nn.Conv2d(16,3,3,padding=1), nn.Sigmoid()
        ]
        self.decoder = nn.Sequential(*dec)
    def forward(self,x):
        return self.decoder(self.encoder(x))

model = TinyColorizer().to(device)
print(model)


TinyColorizer(
  (encoder): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU(inplace=True)
    (6): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (12): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): BatchNorm2d(128

In [None]:
# 6) Training Loop (5 epochs demo)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scaler    = GradScaler()
num_epochs, best_val, patience, wait = 5, float('inf'), 3, 0

for ep in range(1, num_epochs+1):
    start = time.time()
    model.train()
    running = 0.0
    for inp, tgt in train_loader:
        inp, tgt = inp.to(device), tgt.to(device)
        optimizer.zero_grad()
        with autocast():
            out = model(inp)
            loss = criterion(out, tgt)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        running += loss.item() * inp.size(0)
    train_loss = running / len(train_loader.dataset)

    model.eval()
    vrun = 0.0
    with torch.no_grad():
        for inp, tgt in test_loader:
            inp, tgt = inp.to(device), tgt.to(device)
            with autocast():
                vrun += criterion(model(inp), tgt).item() * inp.size(0)
    val_loss = vrun / len(test_loader.dataset)

    print(f"Epoch {ep}/{num_epochs}  Train: {train_loss:.4f}  Val: {val_loss:.4f}  Time: {time.time()-start:.1f}s")

    if val_loss < best_val:
        best_val = val_loss
        torch.save(model.state_dict(),'best_colorizer64.pt')
        wait = 0
    else:
        wait += 1
        if wait >= patience:
            print("Early stopping.")
            break

print("Training complete.")


  scaler    = GradScaler()


In [None]:
# 7) Quick Visual Check on 10 Samples
import matplotlib.pyplot as plt
model.eval()
samples = [test_ds[i] for i in range(0, len(test_ds), len(test_ds)//10)]
fig, axes = plt.subplots(10,3, figsize=(6,20))
for i,(inp,tgt) in enumerate(samples):
    with torch.no_grad():
        out = model(inp.unsqueeze(0).to(device))
    out_img = out[0].cpu().permute(1,2,0).numpy()
    tgt_img = tgt.permute(1,2,0).numpy()
    gray_img= inp.squeeze(0).permute(1,2,0).numpy()
    axes[i,0].imshow(gray_img, cmap='gray'); axes[i,0].set_title('Input'); axes[i,0].axis('off')
    axes[i,1].imshow(out_img);               axes[i,1].set_title('Predicted');axes[i,1].axis('off')
    axes[i,2].imshow(tgt_img);               axes[i,2].set_title('GroundTruth');axes[i,2].axis('off')
plt.tight_layout()
