In [18]:
import cv2
import math
import random
import numpy as np
from pathlib import Path
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("[INFO] Using device:", DEVICE)

# =====================
# PATHS
# =====================
DATA_ROOT = Path("../data/datasets")

RAILSEM_DIR = Path("../data/datasets/RailSem19/jpgs/rs19_val")

BDD_DIR     = DATA_ROOT / "BDD100K/night_frames"

OUT_DIR = Path("../outputs/models/deblur")
OUT_DIR.mkdir(parents=True, exist_ok=True)

LOG_DIR = Path("../outputs/logs")
LOG_DIR.mkdir(parents=True, exist_ok=True)

# =====================
# TRAIN PARAMS
# =====================
IMG_SIZE = 256
BATCH_SIZE = 4
LR = 1e-4

EPOCHS_STAGE1 = 60   # RailSem
EPOCHS_STAGE2 = 40   # BDD Night


[INFO] Using device: cuda


In [19]:
def motion_blur_kernel(size, angle):
    kernel = np.zeros((size, size))
    kernel[size // 2, :] = np.ones(size)
    M = cv2.getRotationMatrix2D((size//2, size//2), angle, 1)
    kernel = cv2.warpAffine(kernel, M, (size, size))
    kernel /= kernel.sum()
    return kernel

def apply_motion_blur(img):
    k = random.choice([7, 9, 11, 15])
    angle = random.uniform(-15, 15)
    kernel = motion_blur_kernel(k, angle)
    return cv2.filter2D(img, -1, kernel)


In [20]:
class DeblurDataset(Dataset):
    def __init__(self, img_dir):
        self.images = list(img_dir.glob("*.jpg"))

        if len(self.images) == 0:
            raise RuntimeError(f"No images found in {img_dir}")

        print(f"[INFO] Loaded {len(self.images)} images from {img_dir}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = cv2.imread(str(self.images[idx]))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))

        sharp = img.astype(np.float32) / 255.0
        blur  = apply_motion_blur(img).astype(np.float32) / 255.0

        return (
            torch.from_numpy(blur).permute(2,0,1),
            torch.from_numpy(sharp).permute(2,0,1)
        )


In [21]:
class ResBlock(nn.Module):
    def __init__(self, c):
        super().__init__()
        self.conv1 = nn.Conv2d(c, c, 3, padding=1)
        self.conv2 = nn.Conv2d(c, c, 3, padding=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return x + self.conv2(self.relu(self.conv1(x)))

class DeblurNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.head = nn.Conv2d(3, 64, 3, padding=1)
        self.body = nn.Sequential(*[ResBlock(64) for _ in range(12)])
        self.tail = nn.Conv2d(64, 3, 3, padding=1)

    def forward(self, x):
        x = torch.relu(self.head(x))
        x = self.body(x)
        return torch.clamp(self.tail(x), 0, 1)


In [22]:
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    psnr_vals, ssim_vals = [], []

    for blur, sharp in loader:
        blur, sharp = blur.to(DEVICE), sharp.to(DEVICE)
        out = model(blur)

        out_np = out[0].permute(1,2,0).cpu().numpy()
        gt_np  = sharp[0].permute(1,2,0).cpu().numpy()

        psnr_vals.append(psnr(gt_np, out_np, data_range=1.0))
        ssim_vals.append(
            ssim(gt_np, out_np, channel_axis=2, data_range=1.0)
        )

    return np.mean(psnr_vals), np.mean(ssim_vals)


In [23]:
def save_checkpoint(model, optimizer, epoch, best_psnr, path):
    torch.save({
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "best_psnr": best_psnr
    }, path)

def save_epoch(model, epoch, tag):
    torch.save(
        model.state_dict(),
        OUT_DIR / f"{tag}_epoch_{epoch:03d}.pth"
    )


In [24]:
def train_stage(
    model, train_loader, val_loader,
    optimizer, start_epoch, epochs, tag
):
    best_psnr = -1

    for epoch in range(start_epoch, start_epoch + epochs):
        model.train()
        total_loss = 0

        for blur, sharp in tqdm(train_loader, desc=f"{tag} Epoch {epoch}"):
            blur, sharp = blur.to(DEVICE), sharp.to(DEVICE)

            optimizer.zero_grad()
            out = model(blur)
            loss = nn.L1Loss()(out, sharp)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        val_psnr, val_ssim = evaluate(model, val_loader)

        print(
            f"[{tag}] Epoch {epoch} | "
            f"Loss={avg_loss:.4f} | "
            f"PSNR={val_psnr:.2f} | "
            f"SSIM={val_ssim:.4f}"
        )

        save_epoch(model, epoch, tag)

        if val_psnr > best_psnr:
            best_psnr = val_psnr
            torch.save(
                model.state_dict(),
                OUT_DIR / f"best_{tag}.pth"
            )

        save_checkpoint(
            model, optimizer, epoch, best_psnr,
            OUT_DIR / f"{tag}_checkpoint.pth"
        )


In [None]:
rail_train = DeblurDataset(RAILSEM_DIR)

rail_loader = DataLoader(
    rail_train,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)
rail_val   = DeblurDataset(RAILSEM_DIR)
rail_val_loader = DataLoader(
    rail_val, batch_size=1,
    shuffle=False, num_workers=2
)

model = DeblurNet().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LR)

train_stage(
    model,
    rail_loader,
    rail_val_loader,
    optimizer,
    start_epoch=0,
    epochs=EPOCHS_STAGE1,
    tag="railsem"
)


[INFO] Loaded 8500 images from ..\data\datasets\RailSem19\jpgs\rs19_val
[INFO] Loaded 8500 images from ..\data\datasets\RailSem19\jpgs\rs19_val


railsem Epoch 0:   0%|          | 0/2125 [00:00<?, ?it/s]

In [None]:
bdd_train = DeblurDataset(BDD_DIR)
bdd_val   = DeblurDataset(BDD_DIR)

bdd_loader = DataLoader(
    bdd_train, batch_size=BATCH_SIZE,
    shuffle=True, num_workers=4, pin_memory=True
)

bdd_val_loader = DataLoader(
    bdd_val, batch_size=1,
    shuffle=False, num_workers=2
)

optimizer = optim.Adam(model.parameters(), lr=LR * 0.5)

train_stage(
    model,
    bdd_loader,
    bdd_val_loader,
    optimizer,
    start_epoch=EPOCHS_STAGE1,
    epochs=EPOCHS_STAGE2,
    tag="bdd_night"
)

torch.save(
    model.state_dict(),
    OUT_DIR / "best_model_deblur.pth"
)

print("✅ Final deblur model saved")


AssertionError: No images found in ..\data\datasets\BDD100K\night_frames

In [None]:
# ============================================================
# SINGLE CELL – COMPLETE DEBLUR TRAINING WITH PSNR & SSIM
# ============================================================

import cv2, torch, random, numpy as np
from pathlib import Path
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

# ---------------- CONFIG ----------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("[INFO] Using device:", DEVICE)

DATA_ROOT = Path("../data/datasets/RailSem19")
TRAIN_DIR = DATA_ROOT / "images_train"
VAL_DIR   = DATA_ROOT / "images_val"

OUT_DIR = Path("../outputs/models/deblur")
OUT_DIR.mkdir(parents=True, exist_ok=True)

IMG_SIZE = 256
BATCH_SIZE = 4
LR = 1e-4
EPOCHS = 50

# ---------------- MOTION BLUR ----------------
def motion_blur_kernel(size, angle):
    k = np.zeros((size, size))
    k[size // 2] = np.ones(size)
    M = cv2.getRotationMatrix2D((size//2, size//2), angle, 1)
    k = cv2.warpAffine(k, M, (size, size))
    return k / k.sum()

def apply_motion_blur(img):
    ksize = random.choice([7, 9, 11, 15])
    angle = random.uniform(-15, 15)
    return cv2.filter2D(img, -1, motion_blur_kernel(ksize, angle))

# ---------------- DATASET ----------------
class DeblurDataset(Dataset):
    def __init__(self, folder):
        self.files = list(folder.glob("*.jpg")) + list(folder.glob("*.png"))
        assert len(self.files) > 0, f"No images in {folder}"

    def __len__(self): return len(self.files)

    def __getitem__(self, idx):
        img = cv2.imread(str(self.files[idx]))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))

        sharp = img.astype(np.float32) / 255.0
        blur  = apply_motion_blur(img).astype(np.float32) / 255.0

        return (
            torch.from_numpy(blur).permute(2,0,1),
            torch.from_numpy(sharp).permute(2,0,1)
        )

# ---------------- MODEL ----------------
class ResBlock(nn.Module):
    def __init__(self, c):
        super().__init__()
        self.c1 = nn.Conv2d(c, c, 3, padding=1)
        self.c2 = nn.Conv2d(c, c, 3, padding=1)
        self.r  = nn.ReLU(inplace=True)
    def forward(self, x):
        return x + self.c2(self.r(self.c1(x)))

class DeblurNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.h = nn.Conv2d(3, 64, 3, padding=1)
        self.b = nn.Sequential(*[ResBlock(64) for _ in range(12)])
        self.t = nn.Conv2d(64, 3, 3, padding=1)
    def forward(self, x):
        return torch.clamp(self.t(self.b(torch.relu(self.h(x)))), 0, 1)

# ---------------- METRICS ----------------
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    P, S = [], []
    for b, s in loader:
        b, s = b.to(DEVICE), s.to(DEVICE)
        o = model(b)[0].permute(1,2,0).cpu().numpy()
        s = s[0].permute(1,2,0).cpu().numpy()
        P.append(psnr(s, o, data_range=1.0))
        S.append(ssim(s, o, channel_axis=2, data_range=1.0))
    return float(np.mean(P)), float(np.mean(S))

# ---------------- LOADERS ----------------
train_ds = DeblurDataset(TRAIN_DIR)
val_ds   = DeblurDataset(VAL_DIR)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2)

# ---------------- TRAIN ----------------
model = DeblurNet().to(DEVICE)
opt = optim.Adam(model.parameters(), lr=LR)
loss_fn = nn.L1Loss()
best_psnr = 0.0

for epoch in range(EPOCHS):
    model.train()
    loss_sum = 0

    for blur, sharp in tqdm(train_loader, desc=f"Epoch {epoch}"):
        blur, sharp = blur.to(DEVICE), sharp.to(DEVICE)
        opt.zero_grad()
        out = model(blur)
        loss = loss_fn(out, sharp)
        loss.backward()
        opt.step()
        loss_sum += loss.item()

    loss_avg = loss_sum / len(train_loader)
    v_psnr, v_ssim = evaluate(model, val_loader)

    print(
        f"[Epoch {epoch}] "
        f"TrainLoss={loss_avg:.4f} | "
        f"PSNR={v_psnr:.2f} dB | "
        f"SSIM={v_ssim:.4f}"
    )

    torch.save(
        {"epoch": epoch, "model": model.state_dict(),
         "loss": loss_avg, "psnr": v_psnr, "ssim": v_ssim},
        OUT_DIR / f"epoch_{epoch}.pth"
    )

    if v_psnr > best_psnr:
        best_psnr = v_psnr
        torch.save({"model": model.state_dict()},
                   OUT_DIR / "best_model_deblur.pth")

print("✅ Training complete")


[INFO] Using device: cuda


AssertionError: No images in ..\data\datasets\RailSem19\images_train

In [3]:
# ============================================================
# RAILSEM19 DEBLUR TRAINING – CORRECT STRUCTURE – ONE CELL
# ============================================================

import cv2, torch, random, numpy as np
from pathlib import Path
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

# ---------------- CONFIG ----------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("[INFO] Using device:", DEVICE)

DATA_ROOT = Path("../data/datasets/RailSem19/jpgs")
TRAIN_DIR = DATA_ROOT / "rs19"
VAL_DIR   = DATA_ROOT / "rs19_val"

OUT_DIR = Path("../outputs/models/deblur")
OUT_DIR.mkdir(parents=True, exist_ok=True)

IMG_SIZE = 256
BATCH_SIZE = 4
LR = 1e-4
EPOCHS = 40

# ---------------- MOTION BLUR ----------------
def motion_blur_kernel(size, angle):
    kernel = np.zeros((size, size))
    kernel[size // 2] = np.ones(size)
    M = cv2.getRotationMatrix2D((size//2, size//2), angle, 1)
    kernel = cv2.warpAffine(kernel, M, (size, size))
    return kernel / kernel.sum()

def apply_motion_blur(img):
    k = random.choice([9, 11, 15])
    a = random.uniform(-20, 20)
    return cv2.filter2D(img, -1, motion_blur_kernel(k, a))

# ---------------- DATASET ----------------
class RailSemDeblurDataset(Dataset):
    def __init__(self, folder):
        self.files = sorted(list(folder.glob("*.jpg")))
        assert len(self.files) > 0, f"No images found in {folder}"

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img = cv2.imread(str(self.files[idx]))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))

        sharp = img.astype(np.float32) / 255.0
        blur  = apply_motion_blur(img).astype(np.float32) / 255.0

        return (
            torch.from_numpy(blur).permute(2,0,1),
            torch.from_numpy(sharp).permute(2,0,1)
        )

# ---------------- MODEL ----------------
class ResBlock(nn.Module):
    def __init__(self, c):
        super().__init__()
        self.c1 = nn.Conv2d(c, c, 3, padding=1)
        self.c2 = nn.Conv2d(c, c, 3, padding=1)
        self.r  = nn.ReLU(inplace=True)
    def forward(self, x):
        return x + self.c2(self.r(self.c1(x)))

class DeblurNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.h = nn.Conv2d(3, 64, 3, padding=1)
        self.b = nn.Sequential(*[ResBlock(64) for _ in range(12)])
        self.t = nn.Conv2d(64, 3, 3, padding=1)
    def forward(self, x):
        return torch.clamp(self.t(self.b(torch.relu(self.h(x)))), 0, 1)

# ---------------- METRICS ----------------
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    P, S = [], []
    for b, s in loader:
        b, s = b.to(DEVICE), s.to(DEVICE)
        o = model(b)[0].permute(1,2,0).cpu().numpy()
        s = s[0].permute(1,2,0).cpu().numpy()
        P.append(psnr(s, o, data_range=1.0))
        S.append(ssim(s, o, channel_axis=2, data_range=1.0))
    return float(np.mean(P)), float(np.mean(S))

# ---------------- LOADERS ----------------
train_ds = RailSemDeblurDataset(TRAIN_DIR)
val_ds   = RailSemDeblurDataset(VAL_DIR)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2)

print(f"[INFO] Train images: {len(train_ds)}")
print(f"[INFO] Val images:   {len(val_ds)}")

# ---------------- TRAIN ----------------
model = DeblurNet().to(DEVICE)
opt = optim.Adam(model.parameters(), lr=LR)
loss_fn = nn.L1Loss()
best_psnr = 0.0

for epoch in range(EPOCHS):
    model.train()
    loss_sum = 0

    for blur, sharp in tqdm(train_loader, desc=f"Epoch {epoch}"):
        blur, sharp = blur.to(DEVICE), sharp.to(DEVICE)
        opt.zero_grad()
        out = model(blur)
        loss = loss_fn(out, sharp)
        loss.backward()
        opt.step()
        loss_sum += loss.item()

    loss_avg = loss_sum / len(train_loader)
    v_psnr, v_ssim = evaluate(model, val_loader)

    print(
        f"[Epoch {epoch}] "
        f"TrainLoss={loss_avg:.4f} | "
        f"PSNR={v_psnr:.2f} dB | "
        f"SSIM={v_ssim:.4f}"
    )

    torch.save(
        {"epoch": epoch, "model": model.state_dict(),
         "loss": loss_avg, "psnr": v_psnr, "ssim": v_ssim},
        OUT_DIR / f"epoch_{epoch}.pth"
    )

    if v_psnr > best_psnr:
        best_psnr = v_psnr
        torch.save({"model": model.state_dict()},
                   OUT_DIR / "best_model_deblur.pth")

print("✅ RailSem19 Deblur training finished")


[INFO] Using device: cuda


AssertionError: No images found in ..\data\datasets\RailSem19\jpgs\rs19

In [None]:
# ============================================================
# RAILSEM19 DEBLUR TRAINING – AUTO SPLIT (YOUR STRUCTURE)
# ============================================================

import cv2, torch, random, numpy as np
from pathlib import Path
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

# ---------------- CONFIG ----------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("[INFO] Using device:", DEVICE)

DATA_DIR = Path("../data/datasets/RailSem19/jpgs/rs19_val")
assert DATA_DIR.exists(), f"Folder missing: {DATA_DIR}"

OUT_DIR = Path("../outputs/models/deblur")
OUT_DIR.mkdir(parents=True, exist_ok=True)

IMG_SIZE = 256
BATCH_SIZE = 4
LR = 1e-4
EPOCHS = 40
TRAIN_SPLIT = 0.9

# ---------------- MOTION BLUR ----------------
def motion_blur_kernel(size, angle):
    k = np.zeros((size, size))
    k[size//2] = np.ones(size)
    M = cv2.getRotationMatrix2D((size//2, size//2), angle, 1)
    k = cv2.warpAffine(k, M, (size, size))
    return k / k.sum()

def apply_motion_blur(img):
    k = random.choice([9, 11, 15])
    a = random.uniform(-20, 20)
    return cv2.filter2D(img, -1, motion_blur_kernel(k, a))

# ---------------- DATASET ----------------
class RailSemDeblurDataset(Dataset):
    def __init__(self, folder):
        self.files = sorted(list(folder.glob("*.jpg")))
        assert len(self.files) > 0, f"No images found in {folder}"

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img = cv2.imread(str(self.files[idx]))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))

        sharp = img.astype(np.float32) / 255.0
        blur  = apply_motion_blur(img).astype(np.float32) / 255.0

        return (
            torch.from_numpy(blur).permute(2,0,1),
            torch.from_numpy(sharp).permute(2,0,1)
        )

# ---------------- MODEL ----------------
class ResBlock(nn.Module):
    def __init__(self, c):
        super().__init__()
        self.c1 = nn.Conv2d(c, c, 3, padding=1)
        self.c2 = nn.Conv2d(c, c, 3, padding=1)
        self.r  = nn.ReLU(inplace=True)
    def forward(self, x):
        return x + self.c2(self.r(self.c1(x)))

class DeblurNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.h = nn.Conv2d(3, 64, 3, padding=1)
        self.b = nn.Sequential(*[ResBlock(64) for _ in range(12)])
        self.t = nn.Conv2d(64, 3, 3, padding=1)
    def forward(self, x):
        return torch.clamp(self.t(self.b(torch.relu(self.h(x)))), 0, 1)

# ---------------- METRICS ----------------
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    P, S = [], []
    for b, s in loader:
        b, s = b.to(DEVICE), s.to(DEVICE)
        o = model(b)[0].permute(1,2,0).cpu().numpy()
        s = s[0].permute(1,2,0).cpu().numpy()
        P.append(psnr(s, o, data_range=1.0))
        S.append(ssim(s, o, channel_axis=2, data_range=1.0))
    return float(np.mean(P)), float(np.mean(S))

# ---------------- SPLIT ----------------
full_ds = RailSemDeblurDataset(DATA_DIR)
n_train = int(len(full_ds) * TRAIN_SPLIT)
n_val   = len(full_ds) - n_train
train_ds, val_ds = random_split(full_ds, [n_train, n_val])

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2)

print(f"[INFO] Total images: {len(full_ds)}")
print(f"[INFO] Train images: {len(train_ds)}")
print(f"[INFO] Val images:   {len(val_ds)}")

# ---------------- TRAIN ----------------
model = DeblurNet().to(DEVICE)
opt = optim.Adam(model.parameters(), lr=LR)
loss_fn = nn.L1Loss()
best_psnr = 0.0

for epoch in range(EPOCHS):
    model.train()
    loss_sum = 0

    for blur, sharp in tqdm(train_loader, desc=f"Epoch {epoch}"):
        blur, sharp = blur.to(DEVICE), sharp.to(DEVICE)
        opt.zero_grad()
        out = model(blur)
        loss = loss_fn(out, sharp)
        loss.backward()
        opt.step()
        loss_sum += loss.item()

    loss_avg = loss_sum / len(train_loader)
    v_psnr, v_ssim = evaluate(model, val_loader)

    print(
        f"[Epoch {epoch}] "
        f"TrainLoss={loss_avg:.4f} | "
        f"PSNR={v_psnr:.2f} dB | "
        f"SSIM={v_ssim:.4f}"
    )

    torch.save(
        {"epoch": epoch, "model": model.state_dict(),
         "loss": loss_avg, "psnr": v_psnr, "ssim": v_ssim},
        OUT_DIR / f"epoch_{epoch}.pth"
    )

    if v_psnr > best_psnr:
        best_psnr = v_psnr
        torch.save({"model": model.state_dict()},
                   OUT_DIR / "best_model_deblur.pth")

print("✅ RailSem19 Deblur training finished")


[INFO] Using device: cuda
[INFO] Total images: 8500
[INFO] Train images: 7650
[INFO] Val images:   850


Epoch 0:   0%|          | 0/1913 [00:00<?, ?it/s]

In [4]:
# ============================================================
# RAILSEM19 DEBLUR TRAINING – FIXED & MORE STABLE VERSION
# ============================================================

import cv2
import cv2
import torch
import random
import numpy as np
from pathlib import Path
from tqdm import tqdm
import torch.nn as nn               # ← fixed
import torch.optim as optim         # ← use this instead of aliasing to nn
from torch.utils.data import Dataset, DataLoader, random_split
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

# ────────────────────────────────────────────────
#  CONFIG
# ────────────────────────────────────────────────

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using device: {DEVICE}")

DATA_DIR = Path("../data/datasets/RailSem19/jpgs/rs19_val")
assert DATA_DIR.exists(), f"Folder not found: {DATA_DIR}"

OUT_DIR = Path("../outputs/models/deblur")
OUT_DIR.mkdir(parents=True, exist_ok=True)

IMG_SIZE    = 256
BATCH_SIZE  = 1          # start small → increase later if stable
LR          = 1e-4
EPOCHS      = 100
TRAIN_SPLIT = 0.9

NUM_WORKERS    = 0       # 0 = safest, increase to 2–4 later if needed
PIN_MEMORY     = torch.cuda.is_available()

# ────────────────────────────────────────────────
#  MOTION BLUR KERNEL
# ────────────────────────────────────────────────

def motion_blur_kernel(size, angle):
    k = np.zeros((size, size), dtype=np.float32)
    k[size//2, :] = 1.0
    M = cv2.getRotationMatrix2D((size//2, size//2), angle, 1.0)
    k = cv2.warpAffine(k, M, (size, size), flags=cv2.INTER_LINEAR)
    return k / k.sum()

def apply_motion_blur(img):
    try:
        kernel_size = random.choice([9, 11, 13, 15])
        angle = random.uniform(-35, 35)
        kernel = motion_blur_kernel(kernel_size, angle)
        blurred = cv2.filter2D(img, -1, kernel)
        return blurred
    except Exception as e:
        print(f"Blur failed: {e}")
        return img  # fallback to sharp if blur crashes

# ────────────────────────────────────────────────
#  DATASET
# ────────────────────────────────────────────────

class RailSemDeblurDataset(Dataset):
    def __init__(self, folder: Path):
        self.files = sorted(folder.glob("*.jpg"))
        if not self.files:
            raise ValueError(f"No .jpg files found in {folder}")
        print(f"[DATASET] Found {len(self.files)} images")

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        path = self.files[idx]
        img = cv2.imread(str(path))
        if img is None:
            raise IOError(f"Cannot read image: {path}")

        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_AREA)

        sharp = img.astype(np.float32) / 255.0
        blur  = apply_motion_blur(img).astype(np.float32) / 255.0

        return (
            torch.from_numpy(blur ).permute(2, 0, 1),
            torch.from_numpy(sharp).permute(2, 0, 1)
        )

# ────────────────────────────────────────────────
#  MODEL
# ────────────────────────────────────────────────

class ResBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=True)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, bias=True)
        self.relu  = nn.ReLU(inplace=True)

    def forward(self, x):
        return x + self.conv2(self.relu(self.conv1(x)))

class DeblurNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.head   = nn.Conv2d(3,   64, 3, padding=1)
        self.body   = nn.Sequential(*[ResBlock(64) for _ in range(12)])
        self.tail   = nn.Conv2d(64,  3,  3, padding=1)

    def forward(self, x):
        feat = torch.relu(self.head(x))
        feat = self.body(feat)
        out  = self.tail(feat)
        return torch.clamp(out, 0.0, 1.0)

# ────────────────────────────────────────────────
#  EVALUATION
# ────────────────────────────────────────────────

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    psnrs, ssims = [], []
    for blur, sharp in loader:
        blur  = blur.to(DEVICE)
        sharp = sharp.to(DEVICE)
        pred = model(blur)

        # take first image in batch for metric
        p = pred[0].permute(1,2,0).cpu().numpy()
        s = sharp[0].permute(1,2,0).cpu().numpy()

        psnrs.append(psnr(s, p, data_range=1.0))
        ssims.append(ssim(s, p, channel_axis=2, data_range=1.0))

    return float(np.mean(psnrs)), float(np.mean(ssims))

# ────────────────────────────────────────────────
#  MAIN
# ────────────────────────────────────────────────

if __name__ == "__main__":
    # ── Dataset + split ────────────────────────────────
    full_dataset = RailSemDeblurDataset(DATA_DIR)

    n_total = len(full_dataset)
    n_train = int(n_total * TRAIN_SPLIT)
    n_val   = n_total - n_train

    train_ds, val_ds = random_split(full_dataset, [n_train, n_val])

    train_loader = DataLoader(
        train_ds,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        drop_last=True,
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=1,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
    )

    print(f"[INFO] Total: {n_total}  |  Train: {len(train_ds)}  |  Val: {len(val_ds)}")
    print(f"[INFO] Batch size: {BATCH_SIZE}  |  Workers: {NUM_WORKERS}  |  Pin memory: {PIN_MEMORY}")

    # ── Quick test: can we load one batch? ─────────────
    print("\nTesting first batch ... ", end="")
    try:
        test_blur, test_sharp = next(iter(train_loader))
        print("OK")
        print(f"   shapes → blur: {test_blur.shape}   sharp: {test_sharp.shape}")
    except Exception as e:
        print("FAILED")
        print(e)
        exit(1)

    # ── Model + Optimizer + Loss ───────────────────────
    model = DeblurNet().to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    loss_fn = nn.L1Loss()

    best_psnr = -1.0
    best_path = OUT_DIR / "best_model_deblur.pth"

    # ── Training loop ──────────────────────────────────
    for epoch in range(1, EPOCHS + 1):
        model.train()
        total_loss = 0.0
        batch_count = 0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}")
        for blur, sharp in pbar:
            blur  = blur.to(DEVICE)
            sharp = sharp.to(DEVICE)

            optimizer.zero_grad()
            pred = model(blur)
            loss = loss_fn(pred, sharp)

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            batch_count += 1

            pbar.set_postfix(loss=f"{loss.item():.4f}")

            # occasional print
            if batch_count % 100 == 0:
                print(f"  [batch {batch_count:4d}] loss={loss.item():.4f}")

        avg_loss = total_loss / batch_count

        # ── Validation ────────────────────────────────
        val_psnr, val_ssim = evaluate(model, val_loader)

        print(f"[Epoch {epoch:2d}] "
              f"Loss={avg_loss:.4f} | "
              f"PSNR={val_psnr:.2f} dB | "
              f"SSIM={val_ssim:.4f}")

        # Save every epoch
        torch.save(
            {
                "epoch": epoch,
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "loss": avg_loss,
                "val_psnr": val_psnr,
                "val_ssim": val_ssim,
            },
            OUT_DIR / f"epoch_{epoch:03d}.pth"
        )

        # Save best model
        if val_psnr > best_psnr:
            best_psnr = val_psnr
            torch.save(model.state_dict(), best_path)
            print(f"  → New best model saved (PSNR {val_psnr:.2f})")

    print("\n" + "═" * 60)
    print("Training finished. Best PSNR:", f"{best_psnr:.2f}")
    print("Best model saved at:", best_path)

[INFO] Using device: cuda
[DATASET] Found 8500 images
[INFO] Total: 8500  |  Train: 7650  |  Val: 850
[INFO] Batch size: 1  |  Workers: 0  |  Pin memory: True

Testing first batch ... OK
   shapes → blur: torch.Size([1, 3, 256, 256])   sharp: torch.Size([1, 3, 256, 256])


Epoch 1/100:   1%|▏         | 101/7650 [00:10<12:18, 10.22it/s, loss=0.0449]

  [batch  100] loss=0.0433


Epoch 1/100:   3%|▎         | 200/7650 [00:19<10:37, 11.69it/s, loss=0.0360]

  [batch  200] loss=0.0400


Epoch 1/100:   4%|▍         | 301/7650 [00:31<13:32,  9.04it/s, loss=0.0291]

  [batch  300] loss=0.0459


Epoch 1/100:   5%|▌         | 401/7650 [00:40<10:40, 11.32it/s, loss=0.0319]

  [batch  400] loss=0.0457


Epoch 1/100:   7%|▋         | 500/7650 [00:50<11:17, 10.56it/s, loss=0.0415]

  [batch  500] loss=0.0331


Epoch 1/100:   8%|▊         | 600/7650 [00:59<10:39, 11.03it/s, loss=0.0319]

  [batch  600] loss=0.0325


Epoch 1/100:   9%|▉         | 702/7650 [01:09<10:52, 10.65it/s, loss=0.0344]

  [batch  700] loss=0.0314


Epoch 1/100:  10%|█         | 800/7650 [01:19<10:20, 11.05it/s, loss=0.0211]

  [batch  800] loss=0.0486


Epoch 1/100:  12%|█▏        | 901/7650 [01:29<10:43, 10.49it/s, loss=0.0418]

  [batch  900] loss=0.0323


Epoch 1/100:  13%|█▎        | 1000/7650 [01:38<09:54, 11.19it/s, loss=0.0481]

  [batch 1000] loss=0.0476


Epoch 1/100:  14%|█▍        | 1100/7650 [01:48<10:46, 10.13it/s, loss=0.0774]

  [batch 1100] loss=0.0390


Epoch 1/100:  16%|█▌        | 1201/7650 [01:58<10:14, 10.49it/s, loss=0.0297]

  [batch 1200] loss=0.0656


Epoch 1/100:  17%|█▋        | 1302/7650 [02:08<10:49,  9.78it/s, loss=0.0358]

  [batch 1300] loss=0.0258


Epoch 1/100:  18%|█▊        | 1401/7650 [02:18<09:53, 10.52it/s, loss=0.0198]

  [batch 1400] loss=0.0187


Epoch 1/100:  20%|█▉        | 1501/7650 [02:27<12:38,  8.10it/s, loss=0.0488]

  [batch 1500] loss=0.0317


Epoch 1/100:  21%|██        | 1601/7650 [02:38<10:31,  9.57it/s, loss=0.0298]

  [batch 1600] loss=0.0722


Epoch 1/100:  22%|██▏       | 1701/7650 [02:47<09:19, 10.64it/s, loss=0.0406]

  [batch 1700] loss=0.0458


Epoch 1/100:  24%|██▎       | 1801/7650 [02:56<08:36, 11.33it/s, loss=0.0292]

  [batch 1800] loss=0.0333


Epoch 1/100:  25%|██▍       | 1902/7650 [03:07<08:32, 11.23it/s, loss=0.0290]

  [batch 1900] loss=0.0257


Epoch 1/100:  26%|██▌       | 1999/7650 [03:16<08:17, 11.36it/s, loss=0.0719]

  [batch 2000] loss=0.0719


Epoch 1/100:  27%|██▋       | 2102/7650 [03:25<08:13, 11.24it/s, loss=0.0404]

  [batch 2100] loss=0.0348


Epoch 1/100:  29%|██▉       | 2200/7650 [03:35<09:37,  9.44it/s, loss=0.0322]

  [batch 2200] loss=0.0370


Epoch 1/100:  30%|███       | 2300/7650 [03:45<08:08, 10.96it/s, loss=0.0189]

  [batch 2300] loss=0.0430


Epoch 1/100:  31%|███▏      | 2401/7650 [03:54<11:00,  7.95it/s, loss=0.0448]

  [batch 2400] loss=0.0349


Epoch 1/100:  33%|███▎      | 2502/7650 [04:05<07:44, 11.08it/s, loss=0.0300]

  [batch 2500] loss=0.0307


Epoch 1/100:  34%|███▍      | 2601/7650 [04:14<07:31, 11.19it/s, loss=0.0392]

  [batch 2600] loss=0.0155


Epoch 1/100:  35%|███▌      | 2700/7650 [04:24<07:15, 11.36it/s, loss=0.0342]

  [batch 2700] loss=0.0150


Epoch 1/100:  37%|███▋      | 2800/7650 [04:34<07:37, 10.59it/s, loss=0.0284]

  [batch 2800] loss=0.0475


Epoch 1/100:  38%|███▊      | 2899/7650 [04:44<07:16, 10.87it/s, loss=0.0417]

  [batch 2900] loss=0.0417


Epoch 1/100:  39%|███▉      | 3001/7650 [04:54<06:55, 11.20it/s, loss=0.0238]

  [batch 3000] loss=0.0124


Epoch 1/100:  41%|████      | 3102/7650 [05:04<07:06, 10.65it/s, loss=0.0388]

  [batch 3100] loss=0.0312


Epoch 1/100:  42%|████▏     | 3200/7650 [05:14<06:53, 10.76it/s, loss=0.0199]

  [batch 3200] loss=0.0336


Epoch 1/100:  43%|████▎     | 3302/7650 [05:25<07:57,  9.11it/s, loss=0.0193]

  [batch 3300] loss=0.0434


Epoch 1/100:  44%|████▍     | 3400/7650 [05:37<07:56,  8.92it/s, loss=0.0279]

  [batch 3400] loss=0.0196


Epoch 1/100:  46%|████▌     | 3501/7650 [05:48<07:23,  9.36it/s, loss=0.0514]

  [batch 3500] loss=0.0462


Epoch 1/100:  47%|████▋     | 3601/7650 [05:59<07:21,  9.17it/s, loss=0.0257]

  [batch 3600] loss=0.0305


Epoch 1/100:  48%|████▊     | 3701/7650 [06:10<06:26, 10.22it/s, loss=0.0252]

  [batch 3700] loss=0.0253


Epoch 1/100:  50%|████▉     | 3800/7650 [06:21<05:52, 10.93it/s, loss=0.0370]

  [batch 3800] loss=0.0093


Epoch 1/100:  51%|█████     | 3901/7650 [06:34<07:16,  8.58it/s, loss=0.0190]

  [batch 3900] loss=0.0387


Epoch 1/100:  52%|█████▏    | 4001/7650 [06:45<07:00,  8.67it/s, loss=0.0326]

  [batch 4000] loss=0.0295


Epoch 1/100:  54%|█████▎    | 4100/7650 [06:56<06:48,  8.69it/s, loss=0.0464]

  [batch 4100] loss=0.0464


Epoch 1/100:  55%|█████▍    | 4201/7650 [07:09<08:01,  7.17it/s, loss=0.0315]

  [batch 4200] loss=0.0237


Epoch 1/100:  56%|█████▌    | 4301/7650 [07:20<05:54,  9.46it/s, loss=0.0338]

  [batch 4300] loss=0.0263


Epoch 1/100:  58%|█████▊    | 4400/7650 [07:32<05:17, 10.22it/s, loss=0.0387]

  [batch 4400] loss=0.0367


Epoch 1/100:  59%|█████▉    | 4502/7650 [07:42<05:11, 10.10it/s, loss=0.0383]

  [batch 4500] loss=0.0354


Epoch 1/100:  60%|██████    | 4601/7650 [07:52<04:40, 10.88it/s, loss=0.0308]

  [batch 4600] loss=0.0278


Epoch 1/100:  61%|██████▏   | 4702/7650 [08:02<04:38, 10.59it/s, loss=0.0427]

  [batch 4700] loss=0.0485


Epoch 1/100:  63%|██████▎   | 4801/7650 [08:12<05:12,  9.12it/s, loss=0.0395]

  [batch 4800] loss=0.0441


Epoch 1/100:  64%|██████▍   | 4901/7650 [08:23<05:16,  8.68it/s, loss=0.0489]

  [batch 4900] loss=0.0250


Epoch 1/100:  65%|██████▌   | 5001/7650 [08:33<04:15, 10.36it/s, loss=0.0311]

  [batch 5000] loss=0.0384


Epoch 1/100:  67%|██████▋   | 5101/7650 [08:44<04:26,  9.56it/s, loss=0.0227]

  [batch 5100] loss=0.0308


Epoch 1/100:  68%|██████▊   | 5202/7650 [08:55<03:40, 11.09it/s, loss=0.0349]

  [batch 5200] loss=0.0291


Epoch 1/100:  69%|██████▉   | 5301/7650 [09:05<03:48, 10.27it/s, loss=0.0324]

  [batch 5300] loss=0.0335


Epoch 1/100:  71%|███████   | 5400/7650 [09:16<03:30, 10.69it/s, loss=0.0386]

  [batch 5400] loss=0.0108


Epoch 1/100:  72%|███████▏  | 5500/7650 [09:25<03:09, 11.36it/s, loss=0.0273]

  [batch 5500] loss=0.0508


Epoch 1/100:  73%|███████▎  | 5602/7650 [09:36<03:27,  9.87it/s, loss=0.0337]

  [batch 5600] loss=0.0373


Epoch 1/100:  75%|███████▍  | 5701/7650 [09:47<03:02, 10.68it/s, loss=0.0385]

  [batch 5700] loss=0.0335


Epoch 1/100:  76%|███████▌  | 5801/7650 [09:57<03:09,  9.75it/s, loss=0.0227]

  [batch 5800] loss=0.0231


Epoch 1/100:  77%|███████▋  | 5901/7650 [10:09<03:00,  9.68it/s, loss=0.0263]

  [batch 5900] loss=0.0351


Epoch 1/100:  78%|███████▊  | 6001/7650 [10:19<03:11,  8.60it/s, loss=0.0286]

  [batch 6000] loss=0.0509


Epoch 1/100:  80%|███████▉  | 6102/7650 [10:30<02:29, 10.35it/s, loss=0.0151]

  [batch 6100] loss=0.0354


Epoch 1/100:  81%|████████  | 6201/7650 [10:41<02:34,  9.36it/s, loss=0.0383]

  [batch 6200] loss=0.0295


Epoch 1/100:  82%|████████▏ | 6300/7650 [10:51<02:36,  8.63it/s, loss=0.0294]

  [batch 6300] loss=0.0384


Epoch 1/100:  84%|████████▎ | 6400/7650 [11:02<02:18,  9.01it/s, loss=0.0106]

  [batch 6400] loss=0.0106


Epoch 1/100:  85%|████████▍ | 6502/7650 [11:12<01:48, 10.60it/s, loss=0.0210]

  [batch 6500] loss=0.0145


Epoch 1/100:  86%|████████▋ | 6599/7650 [11:22<01:41, 10.36it/s, loss=0.0257]

  [batch 6600] loss=0.0257


Epoch 1/100:  88%|████████▊ | 6700/7650 [11:33<01:44,  9.06it/s, loss=0.0308]

  [batch 6700] loss=0.0366


Epoch 1/100:  89%|████████▉ | 6801/7650 [11:44<01:21, 10.36it/s, loss=0.0227]

  [batch 6800] loss=0.0309


Epoch 1/100:  90%|█████████ | 6901/7650 [11:54<01:10, 10.58it/s, loss=0.0405]

  [batch 6900] loss=0.0361


Epoch 1/100:  92%|█████████▏| 7000/7650 [12:04<01:00, 10.79it/s, loss=0.0245]

  [batch 7000] loss=0.0381


Epoch 1/100:  93%|█████████▎| 7102/7650 [12:14<00:55,  9.84it/s, loss=0.0472]

  [batch 7100] loss=0.0312


Epoch 1/100:  94%|█████████▍| 7201/7650 [12:24<00:53,  8.40it/s, loss=0.0274]

  [batch 7200] loss=0.0419


Epoch 1/100:  95%|█████████▌| 7301/7650 [12:34<00:34, 10.14it/s, loss=0.0349]

  [batch 7300] loss=0.0454


Epoch 1/100:  97%|█████████▋| 7402/7650 [12:45<00:24, 10.22it/s, loss=0.0290]

  [batch 7400] loss=0.0335


Epoch 1/100:  98%|█████████▊| 7501/7650 [12:56<00:15,  9.32it/s, loss=0.0225]

  [batch 7500] loss=0.0226


Epoch 1/100:  99%|█████████▉| 7600/7650 [13:06<00:05,  9.48it/s, loss=0.0164]

  [batch 7600] loss=0.0244


Epoch 1/100: 100%|██████████| 7650/7650 [13:11<00:00,  9.67it/s, loss=0.0265]


[Epoch  1] Loss=0.0365 | PSNR=25.95 dB | SSIM=0.7306
  → New best model saved (PSNR 25.95)


Epoch 2/100:   1%|▏         | 101/7650 [00:09<12:09, 10.34it/s, loss=0.0118]

  [batch  100] loss=0.0293


Epoch 2/100:   3%|▎         | 201/7650 [00:18<10:48, 11.49it/s, loss=0.0321]

  [batch  200] loss=0.0068


Epoch 2/100:   4%|▍         | 301/7650 [00:28<14:12,  8.62it/s, loss=0.0246]

  [batch  300] loss=0.0343


Epoch 2/100:   5%|▌         | 402/7650 [00:38<10:52, 11.11it/s, loss=0.0272]

  [batch  400] loss=0.0414


Epoch 2/100:   7%|▋         | 500/7650 [00:47<10:22, 11.49it/s, loss=0.0721]

  [batch  500] loss=0.0404


Epoch 2/100:   8%|▊         | 599/7650 [00:57<12:28,  9.41it/s, loss=0.0204]

  [batch  600] loss=0.0204


Epoch 2/100:   9%|▉         | 701/7650 [01:07<09:54, 11.69it/s, loss=0.0230]

  [batch  700] loss=0.0319


Epoch 2/100:  10%|█         | 801/7650 [01:16<10:22, 11.00it/s, loss=0.0243]

  [batch  800] loss=0.0065


Epoch 2/100:  12%|█▏        | 899/7650 [01:26<11:23,  9.88it/s, loss=0.0391]

  [batch  900] loss=0.0391


Epoch 2/100:  13%|█▎        | 1000/7650 [01:37<11:58,  9.26it/s, loss=0.0317]

  [batch 1000] loss=0.0520


Epoch 2/100:  14%|█▍        | 1100/7650 [01:46<09:37, 11.34it/s, loss=0.0314]

  [batch 1100] loss=0.0328


Epoch 2/100:  16%|█▌        | 1201/7650 [01:56<10:38, 10.10it/s, loss=0.0313]

  [batch 1200] loss=0.0446


Epoch 2/100:  17%|█▋        | 1301/7650 [02:06<10:51,  9.75it/s, loss=0.0242]

  [batch 1300] loss=0.0178


Epoch 2/100:  18%|█▊        | 1401/7650 [02:16<12:01,  8.66it/s, loss=0.0268]

  [batch 1400] loss=0.0415


Epoch 2/100:  20%|█▉        | 1501/7650 [02:26<13:23,  7.65it/s, loss=0.0447]

  [batch 1500] loss=0.0733


Epoch 2/100:  21%|██        | 1600/7650 [02:37<10:18,  9.78it/s, loss=0.0426]

  [batch 1600] loss=0.0258


Epoch 2/100:  22%|██▏       | 1700/7650 [02:46<08:50, 11.22it/s, loss=0.0452]

  [batch 1700] loss=0.0407


Epoch 2/100:  24%|██▎       | 1801/7650 [02:55<08:33, 11.39it/s, loss=0.0174]

  [batch 1800] loss=0.0234


Epoch 2/100:  25%|██▍       | 1900/7650 [03:05<08:25, 11.37it/s, loss=0.0510]

  [batch 1900] loss=0.0196


Epoch 2/100:  26%|██▌       | 2002/7650 [03:15<08:52, 10.61it/s, loss=0.0367]

  [batch 2000] loss=0.0382


Epoch 2/100:  27%|██▋       | 2102/7650 [03:25<08:04, 11.45it/s, loss=0.0439]

  [batch 2100] loss=0.0282


Epoch 2/100:  29%|██▉       | 2202/7650 [03:35<08:58, 10.12it/s, loss=0.0207]

  [batch 2200] loss=0.0503


Epoch 2/100:  30%|███       | 2300/7650 [03:45<08:25, 10.58it/s, loss=0.0372]

  [batch 2300] loss=0.0261


Epoch 2/100:  31%|███▏      | 2400/7650 [03:55<08:40, 10.08it/s, loss=0.0226]

  [batch 2400] loss=0.0232


Epoch 2/100:  33%|███▎      | 2502/7650 [04:04<07:37, 11.25it/s, loss=0.0468]

  [batch 2500] loss=0.0568


Epoch 2/100:  34%|███▍      | 2602/7650 [04:14<08:27,  9.95it/s, loss=0.0223]

  [batch 2600] loss=0.0554


Epoch 2/100:  35%|███▌      | 2700/7650 [04:24<07:38, 10.79it/s, loss=0.0150]

  [batch 2700] loss=0.0407


Epoch 2/100:  37%|███▋      | 2800/7650 [04:34<07:50, 10.31it/s, loss=0.0132]

  [batch 2800] loss=0.0435


Epoch 2/100:  38%|███▊      | 2902/7650 [04:44<07:30, 10.54it/s, loss=0.0390]

  [batch 2900] loss=0.0502


Epoch 2/100:  39%|███▉      | 3002/7650 [04:54<07:29, 10.34it/s, loss=0.0392]

  [batch 3000] loss=0.0327


Epoch 2/100:  41%|████      | 3100/7650 [05:04<07:57,  9.53it/s, loss=0.0386]

  [batch 3100] loss=0.0206


Epoch 2/100:  42%|████▏     | 3200/7650 [05:13<07:06, 10.44it/s, loss=0.0332]

  [batch 3200] loss=0.0237


Epoch 2/100:  43%|████▎     | 3301/7650 [05:23<06:23, 11.34it/s, loss=0.0377]

  [batch 3300] loss=0.0195


Epoch 2/100:  44%|████▍     | 3401/7650 [05:33<06:44, 10.50it/s, loss=0.0307]

  [batch 3400] loss=0.0234


Epoch 2/100:  46%|████▌     | 3501/7650 [05:43<06:24, 10.78it/s, loss=0.0233]

  [batch 3500] loss=0.0507


Epoch 2/100:  47%|████▋     | 3602/7650 [05:52<06:22, 10.57it/s, loss=0.0227]

  [batch 3600] loss=0.0339


Epoch 2/100:  48%|████▊     | 3701/7650 [06:02<06:11, 10.63it/s, loss=0.0347]

  [batch 3700] loss=0.0373


Epoch 2/100:  50%|████▉     | 3801/7650 [06:11<05:52, 10.91it/s, loss=0.0257]

  [batch 3800] loss=0.0150


Epoch 2/100:  51%|█████     | 3901/7650 [06:20<05:29, 11.36it/s, loss=0.0147]

  [batch 3900] loss=0.0221


Epoch 2/100:  52%|█████▏    | 4001/7650 [06:31<05:57, 10.22it/s, loss=0.0355]

  [batch 4000] loss=0.0376


Epoch 2/100:  54%|█████▎    | 4102/7650 [06:41<05:06, 11.57it/s, loss=0.0139]

  [batch 4100] loss=0.0496


Epoch 2/100:  55%|█████▍    | 4201/7650 [06:50<05:11, 11.09it/s, loss=0.0395]

  [batch 4200] loss=0.0054


Epoch 2/100:  56%|█████▌    | 4302/7650 [07:01<04:49, 11.58it/s, loss=0.0302]

  [batch 4300] loss=0.0333


Epoch 2/100:  58%|█████▊    | 4402/7650 [07:10<04:57, 10.91it/s, loss=0.0334]

  [batch 4400] loss=0.0242


Epoch 2/100:  59%|█████▉    | 4502/7650 [07:20<04:48, 10.89it/s, loss=0.0251]

  [batch 4500] loss=0.0329


Epoch 2/100:  60%|██████    | 4601/7650 [07:30<04:59, 10.18it/s, loss=0.0264]

  [batch 4600] loss=0.0242


Epoch 2/100:  61%|██████▏   | 4702/7650 [07:39<04:22, 11.22it/s, loss=0.0355]

  [batch 4700] loss=0.0406


Epoch 2/100:  63%|██████▎   | 4801/7650 [07:49<04:34, 10.37it/s, loss=0.0423]

  [batch 4800] loss=0.0329


Epoch 2/100:  64%|██████▍   | 4901/7650 [07:58<04:43,  9.71it/s, loss=0.0229]

  [batch 4900] loss=0.0240


Epoch 2/100:  65%|██████▌   | 5000/7650 [08:08<04:24, 10.00it/s, loss=0.0064]

  [batch 5000] loss=0.0136


Epoch 2/100:  67%|██████▋   | 5102/7650 [08:18<04:00, 10.61it/s, loss=0.0215]

  [batch 5100] loss=0.0387


Epoch 2/100:  68%|██████▊   | 5202/7650 [08:27<04:14,  9.61it/s, loss=0.0398]

  [batch 5200] loss=0.0232


Epoch 2/100:  69%|██████▉   | 5300/7650 [08:37<03:32, 11.06it/s, loss=0.0137]

  [batch 5300] loss=0.0444


Epoch 2/100:  71%|███████   | 5401/7650 [08:47<03:26, 10.90it/s, loss=0.0253]

  [batch 5400] loss=0.0380


Epoch 2/100:  72%|███████▏  | 5501/7650 [08:56<03:21, 10.65it/s, loss=0.0569]

  [batch 5500] loss=0.0455


Epoch 2/100:  73%|███████▎  | 5601/7650 [09:06<03:23, 10.07it/s, loss=0.0375]

  [batch 5600] loss=0.0320


Epoch 2/100:  75%|███████▍  | 5701/7650 [09:16<02:57, 11.00it/s, loss=0.0281]

  [batch 5700] loss=0.0368


Epoch 2/100:  76%|███████▌  | 5802/7650 [09:25<02:52, 10.68it/s, loss=0.0203]

  [batch 5800] loss=0.0357


Epoch 2/100:  77%|███████▋  | 5902/7650 [09:35<02:36, 11.19it/s, loss=0.0353]

  [batch 5900] loss=0.0586


Epoch 2/100:  78%|███████▊  | 6001/7650 [09:44<02:36, 10.57it/s, loss=0.0273]

  [batch 6000] loss=0.0216


Epoch 2/100:  80%|███████▉  | 6102/7650 [09:54<02:19, 11.13it/s, loss=0.0484]

  [batch 6100] loss=0.0418


Epoch 2/100:  81%|████████  | 6202/7650 [10:03<02:03, 11.77it/s, loss=0.0097]

  [batch 6200] loss=0.0288


Epoch 2/100:  82%|████████▏ | 6302/7650 [10:13<01:57, 11.47it/s, loss=0.0219]

  [batch 6300] loss=0.0212


Epoch 2/100:  84%|████████▎ | 6401/7650 [10:22<02:12,  9.40it/s, loss=0.0223]

  [batch 6400] loss=0.0174


Epoch 2/100:  85%|████████▍ | 6500/7650 [10:33<02:13,  8.61it/s, loss=0.0232]

  [batch 6500] loss=0.0442


Epoch 2/100:  86%|████████▋ | 6601/7650 [10:44<01:41, 10.33it/s, loss=0.0201]

  [batch 6600] loss=0.0307


Epoch 2/100:  88%|████████▊ | 6702/7650 [10:54<01:25, 11.06it/s, loss=0.0218]

  [batch 6700] loss=0.0250


Epoch 2/100:  89%|████████▉ | 6801/7650 [11:04<01:33,  9.13it/s, loss=0.0117]

  [batch 6800] loss=0.0224


Epoch 2/100:  90%|█████████ | 6899/7650 [11:14<01:10, 10.69it/s, loss=0.0454]

  [batch 6900] loss=0.0454


Epoch 2/100:  92%|█████████▏| 7000/7650 [11:25<01:06,  9.83it/s, loss=0.0270]

  [batch 7000] loss=0.0204


Epoch 2/100:  93%|█████████▎| 7101/7650 [11:35<00:56,  9.74it/s, loss=0.0296]

  [batch 7100] loss=0.0185


Epoch 2/100:  94%|█████████▍| 7202/7650 [11:45<00:46,  9.73it/s, loss=0.0263]

  [batch 7200] loss=0.0384


Epoch 2/100:  95%|█████████▌| 7301/7650 [11:55<00:35,  9.96it/s, loss=0.0228]

  [batch 7300] loss=0.0194


Epoch 2/100:  97%|█████████▋| 7401/7650 [12:06<00:24, 10.20it/s, loss=0.0460]

  [batch 7400] loss=0.0256


Epoch 2/100:  98%|█████████▊| 7501/7650 [12:16<00:14, 10.09it/s, loss=0.0376]

  [batch 7500] loss=0.0162


Epoch 2/100:  99%|█████████▉| 7601/7650 [12:26<00:04, 10.38it/s, loss=0.0160]

  [batch 7600] loss=0.0258


Epoch 2/100: 100%|██████████| 7650/7650 [12:31<00:00, 10.18it/s, loss=0.0132]


[Epoch  2] Loss=0.0311 | PSNR=26.91 dB | SSIM=0.7765
  → New best model saved (PSNR 26.91)


Epoch 3/100:   1%|▏         | 102/7650 [00:09<12:25, 10.13it/s, loss=0.0246]

  [batch  100] loss=0.0490


Epoch 3/100:   3%|▎         | 201/7650 [00:19<13:11,  9.41it/s, loss=0.0175]

  [batch  200] loss=0.0249


Epoch 3/100:   4%|▍         | 302/7650 [00:29<10:14, 11.95it/s, loss=0.0290]

  [batch  300] loss=0.0300


Epoch 3/100:   5%|▌         | 401/7650 [00:38<12:13,  9.88it/s, loss=0.0404]

  [batch  400] loss=0.0245


Epoch 3/100:   7%|▋         | 502/7650 [00:47<11:04, 10.76it/s, loss=0.0310]

  [batch  500] loss=0.0205


Epoch 3/100:   8%|▊         | 601/7650 [00:57<10:09, 11.56it/s, loss=0.0228]

  [batch  600] loss=0.0209


Epoch 3/100:   9%|▉         | 701/7650 [01:06<11:13, 10.32it/s, loss=0.0270]

  [batch  700] loss=0.0204


Epoch 3/100:  10%|█         | 802/7650 [01:16<10:48, 10.57it/s, loss=0.0348]

  [batch  800] loss=0.0475


Epoch 3/100:  12%|█▏        | 902/7650 [01:26<09:56, 11.31it/s, loss=0.0130]

  [batch  900] loss=0.0241


Epoch 3/100:  13%|█▎        | 1000/7650 [01:36<11:10,  9.91it/s, loss=0.0122]

  [batch 1000] loss=0.0292


Epoch 3/100:  14%|█▍        | 1102/7650 [01:46<09:18, 11.73it/s, loss=0.0238]

  [batch 1100] loss=0.0373


Epoch 3/100:  16%|█▌        | 1200/7650 [01:55<09:24, 11.42it/s, loss=0.0302]

  [batch 1200] loss=0.0265


Epoch 3/100:  17%|█▋        | 1300/7650 [02:06<14:19,  7.39it/s, loss=0.0400]

  [batch 1300] loss=0.0380


Epoch 3/100:  18%|█▊        | 1401/7650 [02:17<09:23, 11.09it/s, loss=0.0428]

  [batch 1400] loss=0.0295


Epoch 3/100:  20%|█▉        | 1500/7650 [02:29<12:07,  8.46it/s, loss=0.0349]

  [batch 1500] loss=0.0130


Epoch 3/100:  20%|██        | 1561/7650 [02:35<10:07, 10.03it/s, loss=0.0277]


KeyboardInterrupt: 

In [5]:
# ============================================================
# RAILWAY DEBLUR TRAINING PIPELINE (SYNTHETIC MOTION BLUR)
# ============================================================

import cv2
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# =====================
# CONFIG
# =====================
DATA_ROOT = Path("../data/datasets/RailSem19/JPEGImages")
SAVE_DIR  = Path("../outputs/models/deblur")
SAVE_DIR.mkdir(parents=True, exist_ok=True)

IMG_SIZE = 256
BATCH_SIZE = 4
EPOCHS = 30
LR = 1e-4

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("[INFO] Device:", DEVICE)

# =====================
# SYNTHETIC BLUR
# =====================
def apply_motion_blur(img, max_kernel=25):
    k = np.random.randint(7, max_kernel)
    kernel = np.zeros((k, k))
    kernel[k // 2, :] = np.ones(k)
    kernel /= k
    return cv2.filter2D(img, -1, kernel)

def add_night_noise(img):
    noise = np.random.normal(0, 8, img.shape)
    noisy = img + noise
    return np.clip(noisy, 0, 255).astype(np.uint8)

# =====================
# DATASET
# =====================
class RailBlurDataset(Dataset):
    def __init__(self, root):
        self.images = list(root.glob("*.jpg"))
        assert len(self.images) > 0, "No images found!"

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = cv2.imread(str(self.images[idx]))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))

        sharp = img.copy()
        blur  = apply_motion_blur(img)
        blur  = add_night_noise(blur)

        sharp = torch.from_numpy(sharp / 255.0).permute(2,0,1).float()
        blur  = torch.from_numpy(blur  / 255.0).permute(2,0,1).float()

        return blur, sharp

dataset = RailBlurDataset(DATA_ROOT)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

print(f"[INFO] Training samples: {len(dataset)}")

# =====================
# MODEL
# =====================
class ResBlock(nn.Module):
    def __init__(self, c):
        super().__init__()
        self.conv1 = nn.Conv2d(c, c, 3, padding=1)
        self.conv2 = nn.Conv2d(c, c, 3, padding=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return x + self.conv2(self.relu(self.conv1(x)))

class DeblurGenerator(nn.Module):
    def __init__(self):
        super().__init__()
        self.head = nn.Conv2d(3, 64, 7, padding=3)
        self.body = nn.Sequential(*[ResBlock(64) for _ in range(9)])
        self.tail = nn.Conv2d(64, 3, 7, padding=3)

    def forward(self, x):
        x = torch.relu(self.head(x))
        x = self.body(x)
        return torch.sigmoid(self.tail(x))

model = DeblurGenerator().to(DEVICE)

# =====================
# TRAINING SETUP
# =====================
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=LR)

scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE.type == "cuda"))

# =====================
# TRAIN LOOP
# =====================
print("[INFO] Starting training...")

best_loss = 1e9

for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0

    for blur, sharp in tqdm(loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        blur  = blur.to(DEVICE)
        sharp = sharp.to(DEVICE)

        optimizer.zero_grad()

        with torch.cuda.amp.autocast(enabled=(DEVICE.type == "cuda")):
            pred = model(blur)
            loss = criterion(pred, sharp)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        epoch_loss += loss.item()

    epoch_loss /= len(loader)
    print(f"[Epoch {epoch+1}] L1 Loss: {epoch_loss:.4f}")

    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(
            {"model_state": model.state_dict()},
            SAVE_DIR / "best_model_deblur.pth"
        )
        print("✅ Saved best model")

print("🎉 TRAINING COMPLETE")

[INFO] Device: cuda


AssertionError: No images found!

In [1]:
# ============================================================
# RAILWAY DEBLUR TRAINING (RailSem19 rs19_val)
# ============================================================

import cv2
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import random

# =====================
# CONFIG
# =====================
DATA_DIR = Path("../data/datasets/RailSem19/jpgs/rs19_val")
SAVE_DIR = Path("../outputs/models/deblur")
SAVE_DIR.mkdir(parents=True, exist_ok=True)

IMG_SIZE = 256
BATCH_SIZE = 4
EPOCHS = 30
LR = 1e-4

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("[INFO] Device:", DEVICE)

# =====================
# SYNTHETIC BLUR (PROPER)
# =====================
def motion_blur(img):
    k = random.choice([9, 11, 15, 21])
    angle = random.uniform(-20, 20)  # train motion variation

    kernel = np.zeros((k, k))
    kernel[k // 2, :] = np.ones(k)
    kernel /= k

    M = cv2.getRotationMatrix2D((k//2, k//2), angle, 1)
    kernel = cv2.warpAffine(kernel, M, (k, k))

    return cv2.filter2D(img, -1, kernel)

def night_noise(img):
    noise = np.random.normal(0, 10, img.shape)
    noisy = img + noise
    return np.clip(noisy, 0, 255).astype(np.uint8)

# =====================
# DATASET (FIXED)
# =====================
class RailSemBlurDataset(Dataset):
    def __init__(self, root):
        self.images = list(root.rglob("*.jpg"))
        assert len(self.images) > 0, "❌ No images found"
        print(f"[INFO] Loaded {len(self.images)} images")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = cv2.imread(str(self.images[idx]))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))

        sharp = img.copy()
        blur = motion_blur(img)
        blur = night_noise(blur)

        sharp = torch.from_numpy(sharp / 255.0).permute(2,0,1).float()
        blur  = torch.from_numpy(blur  / 255.0).permute(2,0,1).float()

        return blur, sharp

dataset = RailSemBlurDataset(DATA_DIR)

loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,   # 🔥 CRITICAL FIX
    pin_memory=True
)


# =====================
# MODEL
# =====================
class ResBlock(nn.Module):
    def __init__(self, c):
        super().__init__()
        self.conv1 = nn.Conv2d(c, c, 3, padding=1)
        self.conv2 = nn.Conv2d(c, c, 3, padding=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return x + self.conv2(self.relu(self.conv1(x)))

class DeblurGenerator(nn.Module):
    def __init__(self):
        super().__init__()
        self.head = nn.Conv2d(3, 64, 7, padding=3)
        self.body = nn.Sequential(*[ResBlock(64) for _ in range(9)])
        self.tail = nn.Conv2d(64, 3, 7, padding=3)

    def forward(self, x):
        x = torch.relu(self.head(x))
        x = self.body(x)
        return torch.sigmoid(self.tail(x))

model = DeblurGenerator().to(DEVICE)

# =====================
# TRAINING
# =====================
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=LR)
scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE.type == "cuda"))

best_loss = float("inf")
print("[INFO] Training started")

for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0

    for blur, sharp in tqdm(loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        blur = blur.to(DEVICE)
        sharp = sharp.to(DEVICE)

        optimizer.zero_grad()

        with torch.cuda.amp.autocast(enabled=(DEVICE.type == "cuda")):
            pred = model(blur)
            loss = criterion(pred, sharp)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        epoch_loss += loss.item()

    epoch_loss /= len(loader)
    print(f"[Epoch {epoch+1}] L1 Loss: {epoch_loss:.4f}")

    torch.save(
        {
            "epoch": epoch,
            "loss": epoch_loss,
            "model_state": model.state_dict()
        },
        SAVE_DIR / f"epoch_{epoch+1}.pth"
    )

    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(
            {"model_state": model.state_dict()},
            SAVE_DIR / "best_model_deblur.pth"
        )
        print("✅ Best model updated")

print("🎉 TRAINING COMPLETE")


[INFO] Device: cuda
[INFO] Loaded 8500 images


  scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE.type == "cuda"))


[INFO] Training started


  with torch.cuda.amp.autocast(enabled=(DEVICE.type == "cuda")):
Epoch 1/30: 100%|██████████| 2125/2125 [10:16<00:00,  3.45it/s]


[Epoch 1] L1 Loss: 0.0505
✅ Best model updated


Epoch 2/30: 100%|██████████| 2125/2125 [10:00<00:00,  3.54it/s]


[Epoch 2] L1 Loss: 0.0460
✅ Best model updated


Epoch 3/30: 100%|██████████| 2125/2125 [09:44<00:00,  3.63it/s]


[Epoch 3] L1 Loss: 0.0448
✅ Best model updated


Epoch 4/30: 100%|██████████| 2125/2125 [11:11<00:00,  3.16it/s]


[Epoch 4] L1 Loss: 0.0438
✅ Best model updated


Epoch 5/30: 100%|██████████| 2125/2125 [08:47<00:00,  4.03it/s]


[Epoch 5] L1 Loss: 0.0433
✅ Best model updated


Epoch 6/30: 100%|██████████| 2125/2125 [08:36<00:00,  4.11it/s]


[Epoch 6] L1 Loss: 0.0426
✅ Best model updated


Epoch 7/30: 100%|██████████| 2125/2125 [10:04<00:00,  3.51it/s]


[Epoch 7] L1 Loss: 0.0422
✅ Best model updated


Epoch 8/30: 100%|██████████| 2125/2125 [08:57<00:00,  3.95it/s]


[Epoch 8] L1 Loss: 0.0418
✅ Best model updated


Epoch 9/30: 100%|██████████| 2125/2125 [08:39<00:00,  4.09it/s]


[Epoch 9] L1 Loss: 0.0414
✅ Best model updated


Epoch 10/30: 100%|██████████| 2125/2125 [09:42<00:00,  3.65it/s]


[Epoch 10] L1 Loss: 0.0411
✅ Best model updated


Epoch 11/30: 100%|██████████| 2125/2125 [09:02<00:00,  3.92it/s]


[Epoch 11] L1 Loss: 0.0409
✅ Best model updated


Epoch 12/30: 100%|██████████| 2125/2125 [08:31<00:00,  4.16it/s]


[Epoch 12] L1 Loss: 0.0406
✅ Best model updated


Epoch 13/30: 100%|██████████| 2125/2125 [08:22<00:00,  4.23it/s]


[Epoch 13] L1 Loss: 0.0404
✅ Best model updated


Epoch 14/30: 100%|██████████| 2125/2125 [08:26<00:00,  4.20it/s]


[Epoch 14] L1 Loss: 0.0400
✅ Best model updated


Epoch 15/30: 100%|██████████| 2125/2125 [08:21<00:00,  4.23it/s]


[Epoch 15] L1 Loss: 0.0400
✅ Best model updated


Epoch 16/30: 100%|██████████| 2125/2125 [08:21<00:00,  4.24it/s]


[Epoch 16] L1 Loss: 0.0396
✅ Best model updated


Epoch 17/30: 100%|██████████| 2125/2125 [08:21<00:00,  4.24it/s]


[Epoch 17] L1 Loss: 0.0395
✅ Best model updated


Epoch 18/30: 100%|██████████| 2125/2125 [08:21<00:00,  4.24it/s]


[Epoch 18] L1 Loss: 0.0393
✅ Best model updated


Epoch 19/30: 100%|██████████| 2125/2125 [09:10<00:00,  3.86it/s]


[Epoch 19] L1 Loss: 0.0392
✅ Best model updated


Epoch 20/30: 100%|██████████| 2125/2125 [09:46<00:00,  3.62it/s]


[Epoch 20] L1 Loss: 0.0390
✅ Best model updated


Epoch 21/30: 100%|██████████| 2125/2125 [12:31<00:00,  2.83it/s]


[Epoch 21] L1 Loss: 0.0390
✅ Best model updated


Epoch 22/30: 100%|██████████| 2125/2125 [12:41<00:00,  2.79it/s]


[Epoch 22] L1 Loss: 0.0389
✅ Best model updated


Epoch 23/30: 100%|██████████| 2125/2125 [11:00<00:00,  3.22it/s]


[Epoch 23] L1 Loss: 0.0388
✅ Best model updated


Epoch 24/30: 100%|██████████| 2125/2125 [10:39<00:00,  3.32it/s]


[Epoch 24] L1 Loss: 0.0386
✅ Best model updated


Epoch 25/30: 100%|██████████| 2125/2125 [10:32<00:00,  3.36it/s]


[Epoch 25] L1 Loss: 0.0385
✅ Best model updated


Epoch 26/30: 100%|██████████| 2125/2125 [14:03<00:00,  2.52it/s]


[Epoch 26] L1 Loss: 0.0384
✅ Best model updated


Epoch 27/30: 100%|██████████| 2125/2125 [14:18<00:00,  2.47it/s]


[Epoch 27] L1 Loss: 0.0383
✅ Best model updated


Epoch 28/30:  26%|██▌       | 553/2125 [04:05<11:38,  2.25it/s]


KeyboardInterrupt: 