# Model2: Pretrained Encoder + RNN Pixel‐Wise Refinement


In [1]:
# — Karma Cache: Gri input’ları RAM’de, renkli hedefleri diskten —
import torch
from torch.utils.data import TensorDataset, DataLoader
from torchvision.datasets import STL10
import torchvision.transforms as T
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

# Transforms
resize256 = T.Resize((256,256))
to_tensor  = T.ToTensor()

# 1) Gri input’ları RAM’e cache
train_raw = STL10(root='./data', split='train', download=True)
inp_list = []
print("Caching TRAIN grayscale inputs to RAM…")
for img, _ in tqdm(train_raw, total=len(train_raw)):
    gray = T.functional.rgb_to_grayscale(resize256(img), num_output_channels=1)
    inp_list.append(to_tensor(gray))
inp_tensor = torch.stack(inp_list)       # [5000,1,256,256]
print("Grayscale cache:", inp_tensor.shape)

# 2) Disk‐based renkli targets, on-the-fly transform
class HybridTrain256(torch.utils.data.Dataset):
    def __init__(self, inp_tensor):
        self.inp = inp_tensor
        self.base = STL10(root='./data', split='train', download=False)
        self.tf_tgt = T.Compose([T.Resize((256,256)), T.ToTensor()])
    def __len__(self):
        return len(self.inp)
    def __getitem__(self, idx):
        gray = self.inp[idx]  # tensor in RAM
        img, _ = self.base[idx]
        tgt = self.tf_tgt(img)
        return gray, tgt

train_ds = HybridTrain256(inp_tensor)

# 3) Disk‐based test loader as before
test_ds = HybridTrain256(torch.stack([
    T.ToTensor()(T.functional.rgb_to_grayscale(
        T.Resize((256,256))(img),1
    )) for img,_ in STL10(root='./data', split='test', download=False)
]))
# But we need corresponding colored targets for test
test_tgt_ds = STL10(root='./data', split='test', download=False, transform=T.ToTensor())
class HybridTest256(torch.utils.data.Dataset):
    def __init__(self, gray_ds, tgt_ds):
        self.gray_ds = gray_ds
        self.tgt_ds  = tgt_ds
    def __len__(self):
        return len(self.gray_ds)
    def __getitem__(self, idx):
        return self.gray_ds[idx], self.tgt_ds[idx][0]

test_ds = HybridTest256(
    torch.stack([T.ToTensor()(T.functional.rgb_to_grayscale(
        T.Resize((256,256))(img),1
    )) for img,_ in STL10(root='./data', split='test', download=False)]),
    STL10(root='./data', split='test', download=False, transform=T.Compose([T.Resize((256,256)), T.ToTensor()]))
)

# 4) DataLoaders
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True,
                          num_workers=4, pin_memory=True)
test_loader  = DataLoader(test_ds, batch_size=16, shuffle=False,
                          num_workers=4, pin_memory=True)
print("DataLoaders ready →", len(train_loader), "train batches,", len(test_loader), "test batches")


Using device: cuda
Caching TRAIN grayscale inputs to RAM…


100%|██████████| 5000/5000 [00:03<00:00, 1277.80it/s]


Grayscale cache: torch.Size([5000, 1, 256, 256])
DataLoaders ready → 313 train batches, 500 test batches


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet18

class Model2(nn.Module):
    def __init__(self, hidden_size=128, window=5):
        super().__init__()
        # 1) Encoder
        resnet = resnet18(pretrained=True)
        self.encoder = nn.Sequential(
            resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool,
            resnet.layer1, resnet.layer2, resnet.layer3
        )
        for p in self.encoder.parameters(): p.requires_grad = False

        # 2) Coarse decoder: 6→96 upsample 4×
        self.coarse = nn.Sequential(
            nn.ConvTranspose2d(256,128,4,2,1), nn.ReLU(True),
            nn.ConvTranspose2d(128,64,4,2,1),  nn.ReLU(True),
            nn.ConvTranspose2d(64,32,4,2,1),   nn.ReLU(True),
            nn.ConvTranspose2d(32,16,4,2,1),   nn.ReLU(True),
            nn.Conv2d(16,3,3,padding=1),       nn.Sigmoid()
        )

        # 3) Project encoder global features to hidden_size
        self.h_init = nn.Linear(256, hidden_size)

        # 4) Pixel-wise RNN
        self.win = window
        inp_dim = window*window*4 + hidden_size
        self.rnn = nn.LSTMCell(inp_dim, hidden_size)
        self.out = nn.Linear(hidden_size, 3)

    def forward(self, gray):
        B = gray.size(0)
        feat = self.encoder(gray.repeat(1,3,1,1))    # [B,256,6,6]
        coarse = self.coarse(feat)                   # [B,3,96,96]

        # Prepare sliding window input
        pad = self.win//2
        inp_cat = torch.cat([gray, coarse], dim=1)   # [B,4,96,96]
        inp_pad = F.pad(inp_cat, (pad,)*4, mode='reflect')

        # Initialize hidden state from feat
        global_feat = feat.mean(dim=[2,3])           # [B,256]
        h = torch.tanh(self.h_init(global_feat))     # [B, hidden_size]
        c = torch.zeros_like(h)

        out_ref = torch.zeros_like(coarse)
        for y in range(96):
            for x in range(96):
                patch = inp_pad[:, :, y:y+self.win, x:x+self.win].reshape(B, -1)
                rnn_in = torch.cat([patch, h], dim=1)        # [B, inp_dim]
                h, c = self.rnn(rnn_in, (h, c))
                out_ref[:,:,y,x] = self.out(h).reshape(B,3)

        return coarse + out_ref

# Re-instantiate
model2 = Model2(hidden_size=128, window=5).to(device)
print(f"Fixed Model2 loaded, total params: {sum(p.numel() for p in model2.parameters()):,}")




Fixed Model2 loaded, total params: 3,696,358


In [None]:
# — Hücre 3: Model2 Eğitim Döngüsü (10 epoch) —
import time
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from skimage.metrics import peak_signal_noise_ratio as compute_psnr, structural_similarity as compute_ssim

criterion = nn.MSELoss()
optimizer = optim.Adam([p for p in model2.parameters() if p.requires_grad], lr=1e-4)
scaler    = GradScaler()
epochs, best_val, patience, wait = 10, float('inf'), 3, 0

for ep in range(1, epochs+1):
    t0 = time.time()
    # — Train —
    model2.train()
    train_loss = 0.0
    for gray, tgt in tqdm(train_loader, desc=f"Epoch {ep} Train"):
        gray, tgt = gray.to(device), tgt.to(device)
        optimizer.zero_grad()
        with autocast():
            out = model2(gray)
            loss = criterion(out, tgt)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        train_loss += loss.item() * gray.size(0)
    train_loss /= len(train_loader.dataset)

    # — Val & Metrics —
    model2.eval()
    val_loss, psnr_sum, ssim_sum = 0.0, 0.0, 0.0
    with torch.no_grad():
        for gray, tgt in tqdm(test_loader, desc=f"Epoch {ep} Val"):
            gray, tgt = gray.to(device), tgt.to(device)
            out = model2(gray)
            val_loss += criterion(out, tgt).item() * gray.size(0)
            o_np = out.cpu().permute(0,2,3,1).numpy()
            t_np = tgt.cpu().permute(0,2,3,1).numpy()
            for o, t in zip(o_np, t_np):
                psnr_sum += compute_psnr(t, o, data_range=1.0)
                ssim_sum += compute_ssim(t, o, channel_axis=2, data_range=1.0, win_size=7)
    val_loss /= len(test_loader.dataset)
    avg_psnr = psnr_sum / len(test_loader.dataset)
    avg_ssim = ssim_sum / len(test_loader.dataset)

    print(
        f"Epoch {ep}/{epochs}  "
        f"Train Loss: {train_loss:.4f}  Val Loss: {val_loss:.4f}  "
        f"PSNR: {avg_psnr:.2f} dB  SSIM: {avg_ssim:.3f}  "
        f"Time: {time.time()-t0:.1f}s"
    )

    if val_loss < best_val:
        best_val = val_loss
        torch.save(model2.state_dict(), 'model2_best.pt')
        wait = 0
    else:
        wait += 1
        if wait >= patience:
            print("Early stopping.")
            break

print("✅ Model2 training complete.")


  scaler    = GradScaler()
Epoch 1 Train:   0%|          | 0/313 [00:00<?, ?it/s]

In [None]:
# — Görselleştirme Hücresi: 10 Örnek —
import matplotlib.pyplot as plt

model2.eval()
fig, axs = plt.subplots(10, 3, figsize=(8, 24))
indices = list(range(0, len(test_ds), len(test_ds)//10))

for i, idx in enumerate(indices):
    gray, tgt = test_ds[idx]
    with torch.no_grad():
        out = model2(gray.unsqueeze(0).to(device))[0].cpu()
    # Gri girdi
    axs[i,0].imshow(gray.squeeze(0), cmap='gray')
    axs[i,0].axis('off')
    axs[i,0].set_title("Gri Input")
    # Tahmin
    axs[i,1].imshow(out.permute(1,2,0))
    axs[i,1].axis('off')
    axs[i,1].set_title("Model2 Prediction")
    # Yer Gerçek (Ground Truth)
    axs[i,2].imshow(tgt.permute(1,2,0))
    axs[i,2].axis('off')
    axs[i,2].set_title("Ground Truth")

plt.tight_layout()
plt.show()


---
**Model2 Değerlendirme:**  
- Coarse renk tahminin üzerine RNN tabanlı piksel refinmenti ekledik.  
- PSNR/SSIM iyileşti mi? Grafik yaparak gösterebilirsiniz.  
- Geliştirilebilecek noktalar: pencere boyutu, hidden size, iki yönde geçme (bidirectional LSTM), vb.  
