Code write by kushwanth

Experiments for better training of swin ir

In [None]:
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from tqdm import tqdm

# ---------- Dataset (no cropping) ----------
class DIV2KDataset(Dataset):
    SUPPORTED_EXTS = ('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff')

    def __init__(self, lr_dir, hr_dir):
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        lr_files = {f for f in os.listdir(lr_dir) if f.lower().endswith(self.SUPPORTED_EXTS)}
        hr_files = {f for f in os.listdir(hr_dir) if f.lower().endswith(self.SUPPORTED_EXTS)}
        self.files = sorted(lr_files & hr_files)
        if not self.files:
            raise RuntimeError(f"No matching images in {lr_dir} and {hr_dir}")
        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        fname = self.files[idx]
        lr = Image.open(os.path.join(self.lr_dir, fname)).convert("RGB")  # assumed 256x256
        hr = Image.open(os.path.join(self.hr_dir, fname)).convert("RGB")  # assumed 1024x1024
        return self.to_tensor(lr), self.to_tensor(hr)


# ---------- SwinIR Model ----------
class PatchEmbed(nn.Module):
    def __init__(self, in_chans=3, embed_dim=96, patch_size=1):
        super().__init__()
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
    def forward(self, x):
        return self.proj(x)


def window_partition(x, window_size):
    B,C,H,W = x.shape
    x = x.view(B, C, H//window_size, window_size, W//window_size, window_size)
    windows = x.permute(0,2,4,3,5,1).contiguous().view(-1, window_size*window_size, C)
    return windows


def window_reverse(windows, window_size, H, W):
    B = int(windows.shape[0] / (H*W/window_size/window_size))
    x = windows.view(B, H//window_size, W//window_size, window_size, window_size, -1)
    x = x.permute(0,5,1,3,2,4).contiguous().view(B, -1, H, W)
    return x

class WindowAttention(nn.Module):
    def __init__(self, dim, num_heads, window_size):
        super().__init__()
        self.heads = num_heads
        self.scale = (dim // num_heads) ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.heads, C//self.heads).permute(2,0,3,1,4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2,-1)) * self.scale
        attn = attn.softmax(dim=-1)
        out = (attn @ v).transpose(1,2).reshape(B_, N, C)
        return self.proj(out)

class SwinTransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, window_size=8, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, num_heads, window_size)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim*mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(dim*mlp_ratio), dim)
        )
        self.window_size = window_size

    def forward(self, x):
        B,C,H,W = x.shape
        shortcut = x
        x = x.permute(0,2,3,1).reshape(B, H*W, C)
        x = self.norm1(x).reshape(B, H, W, C).permute(0,3,1,2)
        x_windows = window_partition(x, self.window_size)
        attn_windows = self.attn(x_windows)
        x = window_reverse(attn_windows, self.window_size, H, W)
        x = x + shortcut
        x = x.permute(0,2,3,1).reshape(B, H*W, C)
        x = self.norm2(x)
        x = self.mlp(x).reshape(B, H, W, C).permute(0,3,1,2)
        return x

class RSTB(nn.Module):
    def __init__(self, dim, depth, num_heads, window_size):
        super().__init__()
        layers = [SwinTransformerBlock(dim, num_heads, window_size) for _ in range(depth)]
        self.blocks = nn.Sequential(*layers)
        self.conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
    def forward(self, x):
        return self.conv(self.blocks(x)) + x

class Upsampler(nn.Sequential):
    def __init__(self, scale, n_feats):
        m = []
        if (scale & (scale-1)) == 0:
            for _ in range(int(math.log2(scale))):
                m += [nn.Conv2d(n_feats, n_feats*4, 3,1,1), nn.PixelShuffle(2)]
        elif scale == 3:
            m += [nn.Conv2d(n_feats, n_feats*9, 3,1,1), nn.PixelShuffle(3)]
        else:
            raise ValueError(f"Scale {scale} not supported")
        m += [nn.Conv2d(n_feats, 3, 3,1,1)]
        super().__init__(*m)

class SwinIR(nn.Module):
    def __init__(self, embed_dim=96, depths=[6,6,6,6], num_heads=[6,6,6,6], window_size=8, scale=4):
        super().__init__()
        self.shallow = PatchEmbed(3, embed_dim, 1)
        self.body = nn.Sequential(*[
            RSTB(embed_dim, depths[i], num_heads[i], window_size)
            for i in range(len(depths))
        ])
        self.upsample = Upsampler(scale, embed_dim)

    def forward(self, x):
        x0 = self.shallow(x)
        xb = self.body(x0)
        return self.upsample(x0 + xb)

# ---------- Losses ----------
class CharbonnierLoss(nn.Module):
    def __init__(self, eps=1e-3):
        super().__init__()
        self.eps = eps
    def forward(self, sr, hr):
        diff = sr - hr
        return torch.mean(torch.sqrt(diff*diff + self.eps**2))

class PerceptualLoss(nn.Module):
    def __init__(self, layer_ids=[3,8,15], weights=[1.0,1.0,1.0]):
        super().__init__()
        vgg = models.vgg19(pretrained=True).features.eval()
        for p in vgg.parameters(): p.requires_grad=False
        self.slices = nn.ModuleList()
        prev = 0
        for lid in layer_ids:
            self.slices.append(nn.Sequential(*vgg[prev:lid]))
            prev = lid
        self.weights = weights
    def forward(self, sr, hr):
        loss = 0
        for w, slice in zip(self.weights, self.slices):
            loss += w * F.l1_loss(slice(sr), slice(hr))
        return loss

# ---------- Training ----------
def train(model, loader, optimizer, device,
          loss_type='l1', gan=None, perceptual=None,
          epochs=25, save_path='best_swinir.pth'):
    best_psnr = 0.0
    scaler = torch.cuda.amp.GradScaler()
    criterion = {
        'l1': nn.L1Loss(),
        'mse': nn.MSELoss(),
        'charb': CharbonnierLoss()
    }.get(loss_type, nn.L1Loss())

    for epoch in range(epochs):
        model.train()
        loop = tqdm(loader, desc=f'Epoch {epoch+1}/{epochs}')
        for lr, hr in loop:
            lr, hr = lr.to(device), hr.to(device)
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                sr = model(lr)
                loss = criterion(sr, hr)
                if gan and perceptual:
                    d_real = gan.discriminator(hr)
                    d_fake = gan.discriminator(sr.detach())
                    adv_loss = gan.adversarial_loss(d_fake, True)
                    p_loss = perceptual(sr, hr)
                    loss = loss + gan.lambda_gan*adv_loss + gan.lambda_p*p_loss
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            loop.set_postfix(loss=loss.item())

        # validation PSNR
        model.eval()
        psnr_sum, cnt = 0.0, 0
        with torch.no_grad():
            for lr, hr in loader:
                lr, hr = lr.to(device), hr.to(device)
                sr = model(lr)
                mse = F.mse_loss(sr, hr)
                psnr = 20*torch.log10(1.0/torch.sqrt(mse))
                psnr_sum += psnr.item(); cnt+=1
                if cnt >= 10: break
        avg_psnr = psnr_sum / cnt
        print(f'Epoch {epoch+1} | PSNR: {avg_psnr:.2f}')
        if avg_psnr > best_psnr:
            best_psnr = avg_psnr
            torch.save(model.state_dict(), save_path)

# ---------- Main ----------
if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dataset = DIV2KDataset('/kaggle/input/paintings/resized_dataset/resized_dataset', '/kaggle/input/paintings/1024data')
    loader  = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2)
    
    model = SwinIR(embed_dim=96,
                   depths=[6,6,6,6],
                   num_heads=[6,6,6,6],
                   window_size=8, scale=4).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

    # Classical SR with L1 pixel loss
    train(model, loader, optimizer, device, loss_type='l1', epochs=10)
    scheduler.step()


In [None]:
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.checkpoint as checkpoint
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm

# ---------- Dataset (fixed resizing) ----------
class DIV2KDataset(Dataset):
    SUPPORTED_EXTS = ('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff')

    def __init__(self, lr_dir, hr_dir, lr_size=256, hr_size=1024):
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.lr_size = lr_size
        self.hr_size = hr_size
        lr_files = {f for f in os.listdir(lr_dir) if f.lower().endswith(self.SUPPORTED_EXTS)}
        hr_files = {f for f in os.listdir(hr_dir) if f.lower().endswith(self.SUPPORTED_EXTS)}
        self.files = sorted(lr_files & hr_files)
        if not self.files:
            raise RuntimeError(f"No matching images in {lr_dir} and {hr_dir}")
        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        fname = self.files[idx]
        lr = Image.open(os.path.join(self.lr_dir, fname)).convert("RGB")
        hr = Image.open(os.path.join(self.hr_dir, fname)).convert("RGB")
        lr = lr.resize((self.lr_size, self.lr_size), Image.BICUBIC)
        hr = hr.resize((self.hr_size, self.hr_size), Image.BICUBIC)
        return self.to_tensor(lr), self.to_tensor(hr)

# ---------- Utility Functions ----------
def window_partition(x, window_size):
    B, C, H, W = x.shape
    x = x.view(B, C, H//window_size, window_size, W//window_size, window_size)
    windows = x.permute(0,2,4,3,5,1).contiguous().view(-1, window_size*window_size, C)
    return windows

def window_reverse(windows, window_size, H, W):
    B = int(windows.shape[0] / (H*W/window_size/window_size))
    x = windows.view(B, H//window_size, W//window_size, window_size, window_size, -1)
    x = x.permute(0,5,1,3,2,4).contiguous().view(B, -1, H, W)
    return x

# ---------- Model Components ----------
class PatchEmbed(nn.Module):
    def __init__(self, in_chans=3, embed_dim=64, patch_size=1):
        super().__init__()
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
    def forward(self, x): return self.proj(x)

class WindowAttention(nn.Module):
    def __init__(self, dim, num_heads, window_size):
        super().__init__()
        self.heads = num_heads
        self.scale = (dim // num_heads) ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.proj = nn.Linear(dim, dim)
    def forward(self, x):
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.heads, C//self.heads).permute(2,0,3,1,4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2,-1)) * self.scale
        attn = attn.softmax(dim=-1)
        out = (attn @ v).transpose(1,2).reshape(B_, N, C)
        return self.proj(out)

class SwinTransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, window_size=8, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, num_heads, window_size)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim*mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(dim*mlp_ratio), dim)
        )
        self.window_size = window_size

    def forward(self, x):
        B,C,H,W = x.shape
        shortcut = x
        x = x.permute(0,2,3,1).reshape(B, H*W, C)
        x = self.norm1(x).reshape(B, H, W, C).permute(0,3,1,2)
        x_windows = window_partition(x, self.window_size)
        attn_windows = self.attn(x_windows)
        x = window_reverse(attn_windows, self.window_size, H, W)
        x = x + shortcut
        x = x.permute(0,2,3,1).reshape(B, H*W, C)
        x = self.norm2(x)
        x = self.mlp(x).reshape(B, H, W, C).permute(0,3,1,2)
        return x

class RSTB(nn.Module):
    def __init__(self, dim, depth, num_heads, window_size):
        super().__init__()
        layers = [SwinTransformerBlock(dim, num_heads, window_size) for _ in range(depth)]
        self.blocks = nn.Sequential(*layers)
        self.conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
    def forward(self, x):
        def run(x_in):
            out = self.blocks(x_in)
            out = self.conv(out)
            return out + x_in
        return checkpoint.checkpoint(run, x)

class Upsampler(nn.Sequential):
    def __init__(self, scale, n_feats):
        m = []
        if (scale & (scale-1)) == 0:
            for _ in range(int(math.log2(scale))):
                m += [nn.Conv2d(n_feats, n_feats*4, 3,1,1), nn.PixelShuffle(2)]
        elif scale == 3:
            m += [nn.Conv2d(n_feats, n_feats*9, 3,1,1), nn.PixelShuffle(3)]
        else:
            raise ValueError(f"Scale {scale} not supported")
        m += [nn.Conv2d(n_feats, 3, 3,1,1)]
        super().__init__(*m)

class SwinIR(nn.Module):
    def __init__(self, embed_dim=64, depths=[4,4], num_heads=[4,4], window_size=8, scale=4):
        super().__init__()
        self.shallow = PatchEmbed(3, embed_dim, 1)
        self.body = nn.Sequential(*[
            RSTB(embed_dim, depths[i], num_heads[i], window_size)
            for i in range(len(depths))
        ])
        self.upsample = Upsampler(scale, embed_dim)
    def forward(self, x):
        x0 = self.shallow(x)
        xb = self.body(x0)
        return self.upsample(x0 + xb)

# ---------- PSNR Calculation ----------
def calc_psnr(sr, hr):
    mse = F.mse_loss(sr, hr)
    return 20 * torch.log10(1.0 / torch.sqrt(mse)) if mse > 0 else torch.tensor(100.)

# ---------- Training Loop (with resume) ----------
def train(model, loader, optimizer, device, epochs=25, save_path='best_swinir.pth'):
    best_psnr = 0.0
    scaler = torch.cuda.amp.GradScaler()

    # Resume from checkpoint if available
    if os.path.exists(save_path):
        print(f"Loading checkpoint '{save_path}'...")
        state_dict = torch.load(save_path, map_location=device)
        model.load_state_dict(state_dict)
        model.eval()
        # compute its PSNR
        psnr_sum, cnt = 0.0, 0
        with torch.no_grad():
            for lr, hr in loader:
                lr, hr = lr.to(device), hr.to(device)
                sr = model(lr)
                psnr_sum += calc_psnr(sr, hr).item()
                cnt += 1
                if cnt >= 10: break
        best_psnr = psnr_sum / cnt
        print(f"Resumed best PSNR: {best_psnr:.2f}")

    for epoch in range(epochs):
        model.train()
        loop = tqdm(loader, desc=f"Epoch {epoch+1}/{epochs}")
        for lr, hr in loop:
            lr, hr = lr.to(device), hr.to(device)
            optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast():
                sr = model(lr)
                loss = F.l1_loss(sr, hr)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            loop.set_postfix(loss=loss.item())

        # Evaluate PSNR
        model.eval()
        psnr_sum, cnt = 0.0, 0
        with torch.no_grad():
            for lr, hr in loader:
                lr, hr = lr.to(device), hr.to(device)
                sr = model(lr)
                psnr_sum += calc_psnr(sr, hr).item()
                cnt += 1
                if cnt >= 10: break
        avg_psnr = psnr_sum / cnt
        print(f"Epoch {epoch+1} | PSNR: {avg_psnr:.2f}")

        # Save if improved
        if avg_psnr > best_psnr:
            best_psnr = avg_psnr
            torch.save(model.state_dict(), save_path)
            print(f"Saved new best PSNR: {best_psnr:.2f}")
        torch.cuda.empty_cache()

# ---------- Main ----------
if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dataset = DIV2KDataset(
        '/kaggle/input/paintings/resized_dataset/resized_dataset',
        '/kaggle/input/paintings/1024data'
    )
    loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2)
    model = SwinIR(
        embed_dim=64, depths=[4,4], num_heads=[4,4], window_size=8, scale=4
    ).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

    train(model, loader, optimizer, device, epochs=20)
    scheduler.step()


In [None]:
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF
from PIL import Image
from tqdm import tqdm

# ---------- Dataset (random aligned crop for patch training) ----------
class DIV2KDataset(Dataset):
    SUPPORTED_EXTS = ('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff')

    def __init__(self, lr_dir, hr_dir, hr_patch=256, scale=2):
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.scale = scale
        self.hr_patch = hr_patch
        self.lr_patch = hr_patch // scale

        # only keep matched files
        lr_files = {f for f in os.listdir(lr_dir) if f.lower().endswith(self.SUPPORTED_EXTS)}
        hr_files = {f for f in os.listdir(hr_dir) if f.lower().endswith(self.SUPPORTED_EXTS)}
        self.files = sorted(lr_files & hr_files)
        if not self.files:
            raise RuntimeError(f"No matching images in {lr_dir} and {hr_dir}")

        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        fname = self.files[idx]
        hr = Image.open(os.path.join(self.hr_dir, fname)).convert('RGB')
        lr = Image.open(os.path.join(self.lr_dir, fname)).convert('RGB')

        # random crop on HR
        i, j, h, w = transforms.RandomCrop.get_params(hr, (self.hr_patch, self.hr_patch))
        hr_crop = TF.crop(hr, i, j, h, w)

        # aligned crop on LR
        lr_i, lr_j = i // self.scale, j // self.scale
        lr_crop = TF.crop(lr, lr_i, lr_j, self.lr_patch, self.lr_patch)

        return self.to_tensor(lr_crop), self.to_tensor(hr_crop)

# ---------- Utility Functions ----------
def window_partition(x, window_size):
    B, C, H, W = x.shape
    x = x.view(B, C, H//window_size, window_size, W//window_size, window_size)
    windows = x.permute(0,2,4,3,5,1).contiguous().view(-1, window_size*window_size, C)
    return windows

def window_reverse(windows, window_size, H, W):
    B = int(windows.shape[0] / (H*W/window_size/window_size))
    x = windows.view(B, H//window_size, W//window_size, window_size, window_size, -1)
    x = x.permute(0,5,1,3,2,4).contiguous().view(B, -1, H, W)
    return x

# ---------- Model Components ----------
class PatchEmbed(nn.Module):
    def __init__(self, in_chans=3, embed_dim=32, patch_size=1):
        super().__init__()
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
    def forward(self, x): return self.proj(x)

class WindowAttention(nn.Module):
    def __init__(self, dim, num_heads, window_size):
        super().__init__()
        self.heads = num_heads
        self.scale = (dim // num_heads) ** -0.5
        self.qkv = nn.Linear(dim, dim*3, bias=True)
        self.proj = nn.Linear(dim, dim)
    def forward(self, x):
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.heads, C//self.heads).permute(2,0,3,1,4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2,-1)) * self.scale
        attn = attn.softmax(dim=-1)
        out = (attn @ v).transpose(1,2).reshape(B_, N, C)
        return self.proj(out)

class SwinTransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, window_size=8, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, num_heads, window_size)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim*mlp_ratio)), nn.GELU(), nn.Linear(int(dim*mlp_ratio), dim)
        )
        self.window_size = window_size
    def forward(self, x):
        B,C,H,W = x.shape
        shortcut = x
        x = x.permute(0,2,3,1).reshape(B, H*W, C)
        x = self.norm1(x).reshape(B, H, W, C).permute(0,3,1,2)
        x_windows = window_partition(x, self.window_size)
        attn_windows = self.attn(x_windows)
        x = window_reverse(attn_windows, self.window_size, H, W)
        x = x + shortcut
        x = x.permute(0,2,3,1).reshape(B, H*W, C)
        x = self.norm2(x)
        x = self.mlp(x).reshape(B, H, W, C).permute(0,3,1,2)
        return x

class RSTB(nn.Module):
    def __init__(self, dim, depth, num_heads, window_size):
        super().__init__()
        layers = [SwinTransformerBlock(dim, num_heads, window_size) for _ in range(depth)]
        self.blocks = nn.Sequential(*layers)
        self.conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
    def forward(self, x):
        out = self.blocks(x)
        out = self.conv(out)
        return out + x

class Upsampler(nn.Sequential):
    def __init__(self, scale, n_feats):
        m = []
        if scale == 2:
            m += [nn.Conv2d(n_feats, n_feats*4, 3,1,1), nn.PixelShuffle(2)]
        else:
            raise ValueError(f"Scale {scale} not supported")
        m += [nn.Conv2d(n_feats, 3, 3,1,1)]
        super().__init__(*m)

class SwinIR(nn.Module):
    def __init__(self, embed_dim=32, depths=[2,2], num_heads=[2,2], window_size=8, scale=2):
        super().__init__()
        self.shallow = PatchEmbed(3, embed_dim, 1)
        self.body = nn.Sequential(*[
            RSTB(embed_dim, depths[i], num_heads[i], window_size)
            for i in range(len(depths))
        ])
        self.upsample = Upsampler(scale, embed_dim)
    def forward(self, x):
        x0 = self.shallow(x)
        xb = self.body(x0)
        return self.upsample(x0 + xb)

# ---------- PSNR Calculation ----------
def calc_psnr(sr, hr):
    mse = F.mse_loss(sr, hr)
    return 20 * torch.log10(1.0 / torch.sqrt(mse)) if mse > 0 else torch.tensor(100.)

# ---------- Training Loop (with resume) ----------
def train(model, loader, optimizer, device, epochs=25, save_path='best_swinir_x2.pth'):
    best_psnr = 0.0
    scaler = torch.cuda.amp.GradScaler()

    if os.path.exists(save_path):
        state = torch.load(save_path, map_location=device)
        model.load_state_dict(state)
        model.eval()
        psnr_sum, cnt = 0.0, 0
        with torch.no_grad():
            for lr, hr in loader:
                lr = lr.to(device, non_blocking=True)
                hr = hr.to(device, non_blocking=True)
                sr = model(lr)
                psnr_sum += calc_psnr(sr, hr).item()
                cnt += 1
                if cnt >= 10: break
        best_psnr = psnr_sum / cnt
        print(f"Resumed best PSNR: {best_psnr:.2f}")

    for epoch in range(epochs):
        model.train()
        loop = tqdm(loader, desc=f"Epoch {epoch+1}/{epochs}")
        for lr, hr in loop:
            lr = lr.to(device, non_blocking=True)
            hr = hr.to(device, non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast():
                sr = model(lr)
                loss = F.l1_loss(sr, hr)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            loop.set_postfix(loss=loss.item())

        model.eval()
        psnr_sum, cnt = 0.0, 0
        with torch.no_grad():
            for lr, hr in loader:
                lr = lr.to(device, non_blocking=True)
                hr = hr.to(device, non_blocking=True)
                sr = model(lr)
                psnr_sum += calc_psnr(sr, hr).item()
                cnt += 1
                if cnt >= 10: break
        avg_psnr = psnr_sum / cnt
        print(f"Epoch {epoch+1} | PSNR: {avg_psnr:.2f}")
        if avg_psnr > best_psnr:
            best_psnr = avg_psnr
            torch.save(model.state_dict(), save_path)
            print(f"Saved new best PSNR: {best_psnr:.2f}")
        torch.cuda.empty_cache()

# ---------- Main ----------
if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dataset = DIV2KDataset(
        '/kaggle/input/paintings/resized_dataset/resized_dataset',
        '/kaggle/input/paintings/512_data',
        hr_patch=256, scale=2
    )
    loader = DataLoader(
        dataset,
        batch_size=2,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        prefetch_factor=2,
        persistent_workers=True
    )
    model = SwinIR(
        embed_dim=32,
        depths=[2,2],
        num_heads=[2,2],
        window_size=8,
        scale=2
    ).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)

    train(model, loader, optimizer, device, epochs=20)
    scheduler.step()


In [None]:
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF
from PIL import Image
from tqdm import tqdm



# ---------- Dataset (random aligned crop for patch training) ----------
class DIV2KDataset(Dataset):
    SUPPORTED_EXTS = ('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff')

    def _init_(self, lr_dir, hr_dir, hr_patch=256, scale=2):
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.scale = scale
        self.hr_patch = hr_patch
        self.lr_patch = hr_patch // scale

        # only keep matched files
        lr_files = {f for f in os.listdir(lr_dir) if f.lower().endswith(self.SUPPORTED_EXTS)}
        hr_files = {f for f in os.listdir(hr_dir) if f.lower().endswith(self.SUPPORTED_EXTS)}
        self.files = sorted(lr_files & hr_files)
        if not self.files:
            raise RuntimeError(f"No matching images in {lr_dir} and {hr_dir}")

        self.to_tensor = transforms.ToTensor()

    def _len_(self):
        return len(self.files)

    def _getitem_(self, idx):
        fname = self.files[idx]
        hr = Image.open(os.path.join(self.hr_dir, fname)).convert('RGB')
        lr = Image.open(os.path.join(self.lr_dir, fname)).convert('RGB')

        # random crop on HR
        i, j, h, w = transforms.RandomCrop.get_params(hr, (self.hr_patch, self.hr_patch))
        hr_crop = TF.crop(hr, i, j, h, w)

        # aligned crop on LR
        lr_i, lr_j = i // self.scale, j // self.scale
        lr_crop = TF.crop(lr, lr_i, lr_j, self.lr_patch, self.lr_patch)

        return self.to_tensor(lr_crop), self.to_tensor(hr_crop)

# ---------- Utility Functions ----------
def window_partition(x, window_size):
    B, C, H, W = x.shape
    x = x.view(B, C, H//window_size, window_size, W//window_size, window_size)
    windows = x.permute(0,2,4,3,5,1).contiguous().view(-1, window_size*window_size, C)
    return windows

def window_reverse(windows, window_size, H, W):
    B = int(windows.shape[0] / (H*W/window_size/window_size))
    x = windows.view(B, H//window_size, W//window_size, window_size, window_size, -1)
    x = x.permute(0,5,1,3,2,4).contiguous().view(B, -1, H, W)
    return x

# ---------- Model Components ----------
class PatchEmbed(nn.Module):
    def _init_(self, in_chans=3, embed_dim=32, patch_size=1):
        super()._init_()
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
    def forward(self, x): return self.proj(x)

class WindowAttention(nn.Module):
    def _init_(self, dim, num_heads, window_size):
        super()._init_()
        self.heads = num_heads
        self.scale = (dim // num_heads) ** -0.5
        self.qkv = nn.Linear(dim, dim*3, bias=True)
        self.proj = nn.Linear(dim, dim)
    def forward(self, x):
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.heads, C//self.heads).permute(2,0,3,1,4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2,-1)) * self.scale
        attn = attn.softmax(dim=-1)
        out = (attn @ v).transpose(1,2).reshape(B_, N, C)
        return self.proj(out)

class SwinTransformerBlock(nn.Module):
    def _init_(self, dim, num_heads, window_size=8, mlp_ratio=4.0):
        super()._init_()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, num_heads, window_size)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim*mlp_ratio)), nn.GELU(), nn.Linear(int(dim*mlp_ratio), dim)
        )
        self.window_size = window_size
    def forward(self, x):
        B,C,H,W = x.shape
        shortcut = x
        x = x.permute(0,2,3,1).reshape(B, H*W, C)
        x = self.norm1(x).reshape(B, H, W, C).permute(0,3,1,2)
        x_windows = window_partition(x, self.window_size)
        attn_windows = self.attn(x_windows)
        x = window_reverse(attn_windows, self.window_size, H, W)
        x = x + shortcut
        x = x.permute(0,2,3,1).reshape(B, H*W, C)
        x = self.norm2(x)
        x = self.mlp(x).reshape(B, H, W, C).permute(0,3,1,2)
        return x

class RSTB(nn.Module):
    def _init_(self, dim, depth, num_heads, window_size):
        super()._init_()
        layers = [SwinTransformerBlock(dim, num_heads, window_size) for _ in range(depth)]
        self.blocks = nn.Sequential(*layers)
        self.conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
    def forward(self, x):
        out = self.blocks(x)
        out = self.conv(out)
        return out + x

class Upsampler(nn.Sequential):
    def _init_(self, scale, n_feats):
        m = []
        if scale == 2:
            m += [nn.Conv2d(n_feats, n_feats*4, 3,1,1), nn.PixelShuffle(2)]
        else:
            raise ValueError(f"Scale {scale} not supported")
        m += [nn.Conv2d(n_feats, 3, 3,1,1)]
        super()._init_(*m)

class SwinIR(nn.Module):
    def _init_(self, embed_dim=32, depths=[2,2], num_heads=[2,2], window_size=8, scale=2):
        super()._init_()
        self.shallow = PatchEmbed(3, embed_dim, 1)
        self.body = nn.Sequential(*[
            RSTB(embed_dim, depths[i], num_heads[i], window_size)
            for i in range(len(depths))
        ])
        self.upsample = Upsampler(scale, embed_dim)
    def forward(self, x):
        x0 = self.shallow(x)
        xb = self.body(x0)
        return self.upsample(x0 + xb)

# ---------- PSNR Calculation ----------
def calc_psnr(sr, hr):
    mse = F.mse_loss(sr, hr)
    return 20 * torch.log10(1.0 / torch.sqrt(mse)) if mse > 0 else torch.tensor(100.)

# ---------- Training Loop (with resume) ----------
def train(model, loader, optimizer, device, epochs=25, save_path='best_swinir_x2.pth'):
    best_psnr = 0.0
    scaler = torch.cuda.amp.GradScaler()

    if os.path.exists(RESUME_CHECKPOINT):
        state = torch.load(RESUME_CHECKPOINT, map_location=device)
        model.load_state_dict(state)
        model.eval()
        psnr_sum, cnt = 0.0, 0
        with torch.no_grad():
            for lr, hr in loader:
                lr = lr.to(device, non_blocking=True)
                hr = hr.to(device, non_blocking=True)
                sr = model(lr)
                psnr_sum += calc_psnr(sr, hr).item()
                cnt += 1
                if cnt >= 10: break
        best_psnr = psnr_sum / cnt
        print(f"Resumed best PSNR: {best_psnr:.2f}")

    for epoch in range(epochs):
        model.train()
        loop = tqdm(loader, desc=f"Epoch {epoch+1}/{epochs}")
        for lr, hr in loop:
            lr = lr.to(device, non_blocking=True)
            hr = hr.to(device, non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast():
                sr = model(lr)
                loss = F.l1_loss(sr, hr)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            loop.set_postfix(loss=loss.item())

        model.eval()
        psnr_sum, cnt = 0.0, 0
        with torch.no_grad():
            for lr, hr in loader:
                lr = lr.to(device, non_blocking=True)
                hr = hr.to(device, non_blocking=True)
                sr = model(lr)
                psnr_sum += calc_psnr(sr, hr).item()
                cnt += 1
                if cnt >= 10: break
        avg_psnr = psnr_sum / cnt
        print(f"Epoch {epoch+1} | PSNR: {avg_psnr:.2f}")
        if avg_psnr > best_psnr:
            best_psnr = avg_psnr
            torch.save(model.state_dict(), save_path)
            print(f"Saved new best PSNR: {best_psnr:.2f}")
        torch.cuda.empty_cache()

# ---------- Main ----------
if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dataset = DIV2KDataset(
        '/kaggle/input/paintings/resized_dataset/resized_dataset',
        '/kaggle/input/paintings/512_data',
        hr_patch=256, scale=2
    )
    loader = DataLoader(
        dataset,
        batch_size=2,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        prefetch_factor=2,
        persistent_workers=True
    )
    model = SwinIR(
        embed_dim=128,
        depths=[6,6,6,6],
        num_heads=[8,8,8,8],
        window_size=16,
        scale=2
    ).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)

    train(model, loader, optimizer, device, epochs=20)
    scheduler.step()

In [None]:
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF
from PIL import Image
from tqdm import tqdm



RESUME_CHECKPOINT ='/kaggle/input/ipaths/best_swinir_x2.pth'
BEST_MODEL_PATH  =' /kaggle/working/best_swinir_x2.pth'
# ---------- Dataset (random aligned crop for patch training) ----------
class DIV2KDataset(Dataset):
    SUPPORTED_EXTS = ('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff')

    def __init__(self, lr_dir, hr_dir, hr_patch=256, scale=2):
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.scale = scale
        self.hr_patch = hr_patch
        self.lr_patch = hr_patch // scale

        # only keep matched files
        lr_files = {f for f in os.listdir(lr_dir) if f.lower().endswith(self.SUPPORTED_EXTS)}
        hr_files = {f for f in os.listdir(hr_dir) if f.lower().endswith(self.SUPPORTED_EXTS)}
        self.files = sorted(lr_files & hr_files)
        if not self.files:
            raise RuntimeError(f"No matching images in {lr_dir} and {hr_dir}")

        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        fname = self.files[idx]
        hr = Image.open(os.path.join(self.hr_dir, fname)).convert('RGB')
        lr = Image.open(os.path.join(self.lr_dir, fname)).convert('RGB')

        # random crop on HR
        i, j, h, w = transforms.RandomCrop.get_params(hr, (self.hr_patch, self.hr_patch))
        hr_crop = TF.crop(hr, i, j, h, w)

        # aligned crop on LR
        lr_i, lr_j = i // self.scale, j // self.scale
        lr_crop = TF.crop(lr, lr_i, lr_j, self.lr_patch, self.lr_patch)

        return self.to_tensor(lr_crop), self.to_tensor(hr_crop)

# ---------- Utility Functions ----------
def window_partition(x, window_size):
    B, C, H, W = x.shape
    x = x.view(B, C, H//window_size, window_size, W//window_size, window_size)
    windows = x.permute(0,2,4,3,5,1).contiguous().view(-1, window_size*window_size, C)
    return windows

def window_reverse(windows, window_size, H, W):
    B = int(windows.shape[0] / (H*W/window_size/window_size))
    x = windows.view(B, H//window_size, W//window_size, window_size, window_size, -1)
    x = x.permute(0,5,1,3,2,4).contiguous().view(B, -1, H, W)
    return x

# ---------- Model Components ----------
class PatchEmbed(nn.Module):
    def __init__(self, in_chans=3, embed_dim=32, patch_size=1):
        super().__init__()
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
    def forward(self, x): return self.proj(x)

class WindowAttention(nn.Module):
    def __init__(self, dim, num_heads, window_size):
        super().__init__()
        self.heads = num_heads
        self.scale = (dim // num_heads) ** -0.5
        self.qkv = nn.Linear(dim, dim*3, bias=True)
        self.proj = nn.Linear(dim, dim)
    def forward(self, x):
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.heads, C//self.heads).permute(2,0,3,1,4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2,-1)) * self.scale
        attn = attn.softmax(dim=-1)
        out = (attn @ v).transpose(1,2).reshape(B_, N, C)
        return self.proj(out)

class SwinTransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, window_size=8, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, num_heads, window_size)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim*mlp_ratio)), nn.GELU(), nn.Linear(int(dim*mlp_ratio), dim)
        )
        self.window_size = window_size
    def forward(self, x):
        B,C,H,W = x.shape
        shortcut = x
        x = x.permute(0,2,3,1).reshape(B, H*W, C)
        x = self.norm1(x).reshape(B, H, W, C).permute(0,3,1,2)
        x_windows = window_partition(x, self.window_size)
        attn_windows = self.attn(x_windows)
        x = window_reverse(attn_windows, self.window_size, H, W)
        x = x + shortcut
        x = x.permute(0,2,3,1).reshape(B, H*W, C)
        x = self.norm2(x)
        x = self.mlp(x).reshape(B, H, W, C).permute(0,3,1,2)
        return x

class RSTB(nn.Module):
    def __init__(self, dim, depth, num_heads, window_size):
        super().__init__()
        layers = [SwinTransformerBlock(dim, num_heads, window_size) for _ in range(depth)]
        self.blocks = nn.Sequential(*layers)
        self.conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
    def forward(self, x):
        out = self.blocks(x)
        out = self.conv(out)
        return out + x

class Upsampler(nn.Sequential):
    def __init__(self, scale, n_feats):
        m = []
        if scale == 2:
            m += [nn.Conv2d(n_feats, n_feats*4, 3,1,1), nn.PixelShuffle(2)]
        else:
            raise ValueError(f"Scale {scale} not supported")
        m += [nn.Conv2d(n_feats, 3, 3,1,1)]
        super().__init__(*m)

class SwinIR(nn.Module):
    def __init__(self, embed_dim=32, depths=[2,2], num_heads=[2,2], window_size=8, scale=2):
        super().__init__()
        self.shallow = PatchEmbed(3, embed_dim, 1)
        self.body = nn.Sequential(*[
            RSTB(embed_dim, depths[i], num_heads[i], window_size)
            for i in range(len(depths))
        ])
        self.upsample = Upsampler(scale, embed_dim)
    def forward(self, x):
        x0 = self.shallow(x)
        xb = self.body(x0)
        return self.upsample(x0 + xb)

# ---------- PSNR Calculation ----------
def calc_psnr(sr, hr):
    mse = F.mse_loss(sr, hr)
    return 20 * torch.log10(1.0 / torch.sqrt(mse)) if mse > 0 else torch.tensor(100.)

# ---------- Training Loop (with resume) ----------
def train(model, loader, optimizer, device, epochs=25, save_path='best_swinir_x2.pth'):
    best_psnr = 0.0
    scaler = torch.cuda.amp.GradScaler()

    if os.path.exists(RESUME_CHECKPOINT):
        state = torch.load(RESUME_CHECKPOINT, map_location=device)
        model.load_state_dict(state)
        model.eval()
        psnr_sum, cnt = 0.0, 0
        with torch.no_grad():
            for lr, hr in loader:
                lr = lr.to(device, non_blocking=True)
                hr = hr.to(device, non_blocking=True)
                sr = model(lr)
                psnr_sum += calc_psnr(sr, hr).item()
                cnt += 1
                if cnt >= 10: break
        best_psnr = psnr_sum / cnt
        print(f"Resumed best PSNR: {best_psnr:.2f}")

    for epoch in range(epochs):
        model.train()
        loop = tqdm(loader, desc=f"Epoch {epoch+1}/{epochs}")
        for lr, hr in loop:
            lr = lr.to(device, non_blocking=True)
            hr = hr.to(device, non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast():
                sr = model(lr)
                loss = F.l1_loss(sr, hr)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            loop.set_postfix(loss=loss.item())

        model.eval()
        psnr_sum, cnt = 0.0, 0
        with torch.no_grad():
            for lr, hr in loader:
                lr = lr.to(device, non_blocking=True)
                hr = hr.to(device, non_blocking=True)
                sr = model(lr)
                psnr_sum += calc_psnr(sr, hr).item()
                cnt += 1
                if cnt >= 10: break
        avg_psnr = psnr_sum / cnt
        print(f"Epoch {epoch+1} | PSNR: {avg_psnr:.2f}")
        if avg_psnr > best_psnr:
            best_psnr = avg_psnr
            torch.save(model.state_dict(), save_path)
            print(f"Saved new best PSNR: {best_psnr:.2f}")
        torch.cuda.empty_cache()

# ---------- Main ----------
if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dataset = DIV2KDataset(
        '/kaggle/input/paintings/resized_dataset/resized_dataset',
        '/kaggle/input/paintings/512_data',
        hr_patch=256, scale=2
    )
    loader = DataLoader(
        dataset,
        batch_size=2,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        prefetch_factor=2,
        persistent_workers=True
    )
    model = SwinIR(
        embed_dim=128,
        depths=[6,6,6,6],
        num_heads=[8,8,8,8],
        window_size=16,
        scale=2
    ).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)

    train(model, loader, optimizer, device, epochs=20)
    scheduler.step()


In [2]:
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF
from PIL import Image
from tqdm import tqdm

# ---------- Checkpoint paths ----------
RESUME_CHECKPOINT = '/kaggle/input/ipaths/best_swinir_x2.pth'
BEST_MODEL_PATH  = '/kaggle/working/best_swinir_x2.pth'
BATCH_SIZE       = 2       # match your original batch size
LR               = 2e-4    # learning rate for optimizer
EPOCHS           = 20      # number of epochs

# ---------- Dataset (aligned crop + augment) ----------
class DIV2KDataset(Dataset):
    SUPPORTED_EXTS = ('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff')

    def __init__(self, lr_dir, hr_dir, hr_patch=256, scale=2):
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.scale = scale
        self.hr_patch = hr_patch
        self.lr_patch = hr_patch // scale

        lr_files = {f for f in os.listdir(lr_dir) if f.lower().endswith(self.SUPPORTED_EXTS)}
        hr_files = {f for f in os.listdir(hr_dir) if f.lower().endswith(self.SUPPORTED_EXTS)}
        self.files = sorted(lr_files & hr_files)
        if not self.files:
            raise RuntimeError(f"No matching images in {lr_dir} and {hr_dir}")

        self.to_tensor = transforms.ToTensor()
        self.augment = transforms.Compose([
            transforms.RandomHorizontalFlip(0.5),
            transforms.RandomVerticalFlip(0.5),
            transforms.RandomRotation(90)
        ])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        fname = self.files[idx]
        hr = Image.open(os.path.join(self.hr_dir, fname)).convert('RGB')
        lr = Image.open(os.path.join(self.lr_dir, fname)).convert('RGB')

        i, j, _, _ = transforms.RandomCrop.get_params(hr, (self.hr_patch, self.hr_patch))
        hr_crop = TF.crop(hr, i, j, self.hr_patch, self.hr_patch)
        lr_crop = TF.crop(lr, i // self.scale, j // self.scale, self.lr_patch, self.lr_patch)

        # same augmentation on both
        hr_crop = self.augment(hr_crop)
        lr_crop = self.augment(lr_crop)

        return self.to_tensor(lr_crop), self.to_tensor(hr_crop)

# ---------- SwinIR Model (global residual skip) ----------
def window_partition(x, window_size):
    B,C,H,W = x.shape
    x = x.view(B,C,H//window_size,window_size,W//window_size,window_size)
    return x.permute(0,2,4,3,5,1).reshape(-1, window_size*window_size, C)

def window_reverse(windows, window_size, H, W):
    B = windows.shape[0] // (H*W//window_size//window_size)
    x = windows.view(B, H//window_size, W//window_size, window_size, window_size, -1)
    return x.permute(0,5,1,3,2,4).reshape(B, -1, H, W)

class PatchEmbed(nn.Module):
    def __init__(self, in_chans=3, embed_dim=32, patch_size=1):
        super().__init__()
        self.proj = nn.Conv2d(in_chans, embed_dim, patch_size, patch_size)
    def forward(self, x): return self.proj(x)

class WindowAttention(nn.Module):
    def __init__(self, dim, num_heads, window_size):
        super().__init__()
        self.heads = num_heads
        self.scale = (dim // num_heads) ** -0.5
        self.qkv = nn.Linear(dim, dim*3, bias=True)
        self.proj = nn.Linear(dim, dim)
    def forward(self, x):
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.heads, C//self.heads).permute(2,0,3,1,4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2,-1)) * self.scale
        attn = attn.softmax(dim=-1)
        out = (attn @ v).transpose(1,2).reshape(B_, N, C)
        return self.proj(out)

class SwinTransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, window_size=8, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, num_heads, window_size)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim*mlp_ratio)), nn.GELU(), nn.Linear(int(dim*mlp_ratio), dim)
        )
        self.window_size = window_size
    def forward(self, x):
        B,C,H,W = x.shape
        shortcut = x
        x = x.permute(0,2,3,1).reshape(B, H*W, C)
        x = self.norm1(x).reshape(B, H, W, C).permute(0,3,1,2)
        x_windows = window_partition(x, self.window_size)
        attn_windows = self.attn(x_windows)
        x = window_reverse(attn_windows, self.window_size, H, W)
        x = x + shortcut
        x = x.permute(0,2,3,1).reshape(B, H*W, C)
        x = self.norm2(x)
        x = self.mlp(x).reshape(B, H, W, C).permute(0,3,1,2)
        return x

class RSTB(nn.Module):
    def __init__(self, dim, depth, num_heads, window_size):
        super().__init__()
        layers = [SwinTransformerBlock(dim, num_heads, window_size) for _ in range(depth)]
        self.blocks = nn.Sequential(*layers)
        self.conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
    def forward(self, x):
        out = self.blocks(x)
        out = self.conv(out)
        return out + x

class Upsampler(nn.Sequential):
    def __init__(self, scale, n_feats):
        m = []
        if scale == 2:
            m += [nn.Conv2d(n_feats, n_feats*4, 3,1,1), nn.PixelShuffle(2)]
        else:
            raise ValueError(f"Scale {scale} not supported")
        m += [nn.Conv2d(n_feats, 3, 3,1,1)]
        super().__init__(*m)


class SwinIR(nn.Module):
    def __init__(self, embed_dim=32, depths=[2,2], num_heads=[2,2], window_size=8, scale=2):
        super().__init__()
        self.scale = scale
        self.shallow = PatchEmbed(3, embed_dim, 1)
        self.body = nn.Sequential(*[
            RSTB(embed_dim, depths[i], num_heads[i], window_size)
            for i in range(len(depths))
        ])
        self.upsample = Upsampler(scale, embed_dim)

    def forward(self, x):
        x0 = self.shallow(x)
        xb = self.body(x0)
        feat = x0 + xb
        sr = self.upsample(feat)
        base = F.interpolate(x, scale_factor=self.scale, mode='bicubic', align_corners=False)
        return sr + base

# ---------- PSNR Calc ----------
def calc_psnr(sr, hr):
    mse = F.mse_loss(sr, hr)
    return 20 * torch.log10(1.0 / torch.sqrt(mse)) if mse > 0 else torch.tensor(100.)

# ---------- Training Loop ----------
def train(model, loader, optimizer, scheduler, device, epochs=20):
    best_psnr = 0.0
    scaler = torch.amp.GradScaler()

    # resume
    if os.path.exists(RESUME_CHECKPOINT):
        print(f"Loading pretrained weights from {RESUME_CHECKPOINT}")
        model.load_state_dict(torch.load(RESUME_CHECKPOINT, map_location=device))
        model.eval()
        psnr_sum = 0.0
        with torch.no_grad():
            for i, (lr, hr) in enumerate(loader):
                if i >= 10: break
                lr, hr = lr.to(device), hr.to(device)
                psnr_sum += calc_psnr(model(lr), hr).item()
        best_psnr = psnr_sum / 10
        print(f"Resumed best PSNR: {best_psnr:.2f}")

    for epoch in range(epochs):
        model.train()
        for lr, hr in tqdm(loader, desc=f"Epoch {epoch+1}/{epochs}"):
            lr, hr = lr.to(device), hr.to(device)
            optimizer.zero_grad(set_to_none=True)
            with torch.amp.autocast(device_type='cuda'):
                sr = model(lr)
                loss = F.mse_loss(sr, hr)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

        # validate
        model.eval()
        psnr_sum = 0.0
        with torch.no_grad():
            for i, (lr, hr) in enumerate(loader):
                if i >= 10: break
                lr, hr = lr.to(device), hr.to(device)
                psnr_sum += calc_psnr(model(lr), hr).item()
        avg_psnr = psnr_sum / 10
        print(f"Epoch {epoch+1} | PSNR: {avg_psnr:.2f}")

        if avg_psnr > best_psnr:
            best_psnr = avg_psnr
            torch.save(model.state_dict(), BEST_MODEL_PATH)
            print(f"Saved new best PSNR: {best_psnr:.2f}")

# ---------- Main ----------
if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dataset = DIV2KDataset(
        '/kaggle/input/paintings/resized_dataset/resized_dataset',
        '/kaggle/input/paintings/512_data', hr_patch=256, scale=2
    )
    loader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True)
    model = SwinIR(embed_dim=128, depths=[4,4], num_heads=[8,8,8,8], window_size=16, scale=2).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
    total_steps = 20 * len(loader)
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=LR, total_steps=total_steps, pct_start=0.05, anneal_strategy='cos'
    )

    train(model, loader, optimizer, scheduler, device, epochs=20)


Loading pretrained weights from /kaggle/input/ipaths/best_swinir_x2.pth


  model.load_state_dict(torch.load(RESUME_CHECKPOINT, map_location=device))


Resumed best PSNR: 4.11


Epoch 1/20: 100%|██████████| 1981/1981 [23:57<00:00,  1.38it/s]


Epoch 1 | PSNR: 13.15
Saved new best PSNR: 13.15


Epoch 2/20: 100%|██████████| 1981/1981 [23:56<00:00,  1.38it/s]


Epoch 2 | PSNR: 13.12


Epoch 3/20: 100%|██████████| 1981/1981 [23:56<00:00,  1.38it/s]


Epoch 3 | PSNR: 12.47


Epoch 4/20: 100%|██████████| 1981/1981 [23:56<00:00,  1.38it/s]


Epoch 4 | PSNR: 12.60


Epoch 5/20: 100%|██████████| 1981/1981 [23:56<00:00,  1.38it/s]


Epoch 5 | PSNR: 12.03


Epoch 6/20:  33%|███▎      | 650/1981 [07:51<16:06,  1.38it/s]


KeyboardInterrupt: 

In [3]:
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import torchvision.transforms.functional as TF
from PIL import Image
from tqdm import tqdm

# ---------- Checkpoint & Hyperparameters ----------
#RESUME_CHECKPOINT = '/kaggle/input/ipaths/best_swinir_x2.pth'
RESUME_CHECKPOINT = False
BEST_MODEL_PATH  = '/kaggle/working/best_swinir_x2.pth'
BATCH_SIZE       = 4
LR               = 2e-4
EPOCHS           = 50
VAL_SPLIT        = 0.1  # fraction for validation

# ---------- Dataset (aligned random crop, no augment) ----------
class DIV2KDataset(Dataset):
    SUPPORTED_EXTS = ('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff')

    def __init__(self, lr_dir, hr_dir, hr_patch=256, scale=2):
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.scale = scale
        self.hr_patch = hr_patch
        self.lr_patch = hr_patch // scale

        lr_files = [f for f in os.listdir(lr_dir) if f.lower().endswith(self.SUPPORTED_EXTS)]
        hr_files = [f for f in os.listdir(hr_dir) if f.lower().endswith(self.SUPPORTED_EXTS)]
        self.files = sorted(set(lr_files) & set(hr_files))
        if not self.files:
            raise RuntimeError(f"No matching images in {lr_dir} and {hr_dir}")

        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        fname = self.files[idx]
        hr = Image.open(os.path.join(self.hr_dir, fname)).convert('RGB')
        lr = Image.open(os.path.join(self.lr_dir, fname)).convert('RGB')

        i, j, _, _ = transforms.RandomCrop.get_params(hr, (self.hr_patch, self.hr_patch))
        hr_crop = TF.crop(hr, i, j, self.hr_patch, self.hr_patch)
        lr_crop = TF.crop(lr, i // self.scale, j // self.scale, self.lr_patch, self.lr_patch)

        return self.to_tensor(lr_crop), self.to_tensor(hr_crop)

# ---------- SwinIR Model (global residual skip) ----------
def window_partition(x, window_size):
    B,C,H,W = x.shape
    x = x.view(B,C,H//window_size,window_size,W//window_size,window_size)
    return x.permute(0,2,4,3,5,1).reshape(-1, window_size*window_size, C)

def window_reverse(windows, window_size, H, W):
    B = windows.shape[0] // (H*W//window_size//window_size)
    x = windows.view(B, H//window_size, W//window_size, window_size, window_size, -1)
    return x.permute(0,5,1,3,2,4).reshape(B, -1, H, W)

class PatchEmbed(nn.Module):
    def __init__(self, in_chans=3, embed_dim=32, patch_size=1):
        super().__init__()
        self.proj = nn.Conv2d(in_chans, embed_dim, patch_size, patch_size)
    def forward(self, x): return self.proj(x)

class WindowAttention(nn.Module):
    def __init__(self, dim, num_heads, window_size):
        super().__init__()
        self.heads = num_heads
        self.scale = (dim // num_heads) ** -0.5
        self.qkv = nn.Linear(dim, dim*3, bias=True)
        self.proj = nn.Linear(dim, dim)
    def forward(self, x):
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.heads, C//self.heads).permute(2,0,3,1,4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2,-1)) * self.scale
        attn = attn.softmax(dim=-1)
        out = (attn @ v).transpose(1,2).reshape(B_, N, C)
        return self.proj(out)

class SwinTransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, window_size=8, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, num_heads, window_size)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim*mlp_ratio)), nn.GELU(), nn.Linear(int(dim*mlp_ratio), dim)
        )
        self.window_size = window_size
    def forward(self, x):
        B,C,H,W = x.shape
        shortcut = x
        x = x.permute(0,2,3,1).reshape(B, H*W, C)
        x = self.norm1(x).reshape(B, H, W, C).permute(0,3,1,2)
        x_windows = window_partition(x, self.window_size)
        attn_windows = self.attn(x_windows)
        x = window_reverse(attn_windows, self.window_size, H, W)
        x = x + shortcut
        x = x.permute(0,2,3,1).reshape(B, H*W, C)
        x = self.norm2(x)
        x = self.mlp(x).reshape(B, H, W, C).permute(0,3,1,2)
        return x

class RSTB(nn.Module):
    def __init__(self, dim, depth, num_heads, window_size):
        super().__init__()
        layers = [SwinTransformerBlock(dim, num_heads, window_size) for _ in range(depth)]
        self.blocks = nn.Sequential(*layers)
        self.conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
    def forward(self, x):
        out = self.blocks(x)
        out = self.conv(out)
        return out + x

class Upsampler(nn.Sequential):
    def __init__(self, scale, n_feats):
        m = []
        if scale == 2:
            m += [nn.Conv2d(n_feats, n_feats*4, 3,1,1), nn.PixelShuffle(2)]
        else:
            raise ValueError(f"Scale {scale} not supported")
        m += [nn.Conv2d(n_feats, 3, 3,1,1)]
        super().__init__(*m)


class SwinIR(nn.Module):
    def __init__(self, embed_dim=32, depths=[2,2], num_heads=[2,2], window_size=8, scale=2):
        super().__init__()
        self.scale = scale
        self.shallow = PatchEmbed(3, embed_dim, 1)
        self.body = nn.Sequential(*[
            RSTB(embed_dim, depths[i], num_heads[i], window_size)
            for i in range(len(depths))
        ])
        self.upsample = Upsampler(scale, embed_dim)

    def forward(self, x):
        x0 = self.shallow(x)
        xb = self.body(x0)
        feat = x0 + xb
        sr = self.upsample(feat)
        base = F.interpolate(x, scale_factor=self.scale, mode='bicubic', align_corners=False)
        return sr + base

# ---------- Luma PSNR Calculation ----------
def calc_psnr_y(sr, hr):
    # sr, hr: [B,3,H,W] in [0,1]
    # convert to luma channel
    coeff = torch.tensor([0.299, 0.587, 0.114], device=sr.device).view(1,3,1,1)
    sr_y = (sr * coeff).sum(dim=1, keepdim=True)
    hr_y = (hr * coeff).sum(dim=1, keepdim=True)
    mse = F.mse_loss(sr_y, hr_y)
    return 20 * torch.log10(1.0 / torch.sqrt(mse)) if mse > 0 else torch.tensor(100.)

# ---------- Training Loop with Validation ----------
def train(model, train_loader, val_loader, optimizer, scheduler, device, epochs=EPOCHS):
    best_psnr = 0.0
    scaler = torch.cuda.amp.GradScaler()

    # resume
    if RESUME_CHECKPOINT and os.path.exists(RESUME_CHECKPOINT):
        print(f"Loading pretrained weights from {RESUME_CHECKPOINT}")
        model.load_state_dict(torch.load(RESUME_CHECKPOINT, map_location=device))
        model.eval()
        psnr_sum = 0.0
        with torch.no_grad():
            for lr, hr in val_loader:
                lr, hr = lr.to(device), hr.to(device)
                psnr_sum += calc_psnr_y(model(lr), hr).item()
        best_psnr = psnr_sum / len(val_loader)
        print(f"Resumed best PSNR on Y: {best_psnr:.2f}")

    for epoch in range(epochs):
        # train
        model.train()
        for lr, hr in tqdm(train_loader, desc=f"Train Epoch {epoch+1}/{epochs}"):
            lr, hr = lr.to(device), hr.to(device)
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                sr = model(lr)
                loss = F.mse_loss(sr, hr)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

        # validate
        model.eval()
        psnr_sum = 0.0
        with torch.no_grad():
            for lr, hr in val_loader:
                lr, hr = lr.to(device), hr.to(device)
                psnr_sum += calc_psnr_y(model(lr), hr).item()
        avg_psnr = psnr_sum / len(val_loader)
        print(f"Val Epoch {epoch+1} PSNR (Y): {avg_psnr:.2f}")

        if avg_psnr > best_psnr:
            best_psnr = avg_psnr
            torch.save(model.state_dict(), BEST_MODEL_PATH)
            print(f"Saved new best PSNR: {best_psnr:.2f}")

# ---------- Main ----------
if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dataset = DIV2KDataset(
        '/kaggle/input/paintings/resized_dataset/resized_dataset',
        '/kaggle/input/paintings/512_data',
        hr_patch=256, scale=2
    )
    # train/val split
    val_size = int(len(dataset) * VAL_SPLIT)
    train_size = len(dataset) - val_size
    train_ds, val_ds = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
    val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

    model = SwinIR(embed_dim=96, depths=[4,4], num_heads=[4,4], window_size=4, scale=2).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
    total_steps = EPOCHS * len(train_loader)
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=LR,
        total_steps=total_steps,
        pct_start=0.05,
        anneal_strategy='cos'
    )

    train(model, train_loader, val_loader, optimizer, scheduler, device)


  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
Train Epoch 1/50: 100%|██████████| 892/892 [03:51<00:00,  3.85it/s]


Val Epoch 1 PSNR (Y): 26.35
Saved new best PSNR: 26.35


Train Epoch 2/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 2 PSNR (Y): 25.72


Train Epoch 3/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 3 PSNR (Y): 25.54


Train Epoch 4/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 4 PSNR (Y): 25.77


Train Epoch 5/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 5 PSNR (Y): 25.60


Train Epoch 6/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 6 PSNR (Y): 25.49


Train Epoch 7/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 7 PSNR (Y): 26.04


Train Epoch 8/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 8 PSNR (Y): 25.27


Train Epoch 9/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 9 PSNR (Y): 25.66


Train Epoch 10/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 10 PSNR (Y): 25.74


Train Epoch 11/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 11 PSNR (Y): 25.96


Train Epoch 12/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 12 PSNR (Y): 25.67


Train Epoch 13/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 13 PSNR (Y): 25.59


Train Epoch 14/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 14 PSNR (Y): 25.18


Train Epoch 15/50: 100%|██████████| 892/892 [03:50<00:00,  3.88it/s]


Val Epoch 15 PSNR (Y): 26.28


Train Epoch 16/50: 100%|██████████| 892/892 [03:50<00:00,  3.88it/s]


Val Epoch 16 PSNR (Y): 25.29


Train Epoch 17/50: 100%|██████████| 892/892 [03:50<00:00,  3.88it/s]


Val Epoch 17 PSNR (Y): 25.80


Train Epoch 18/50: 100%|██████████| 892/892 [03:50<00:00,  3.88it/s]


Val Epoch 18 PSNR (Y): 25.74


Train Epoch 19/50: 100%|██████████| 892/892 [03:50<00:00,  3.88it/s]


Val Epoch 19 PSNR (Y): 25.93


Train Epoch 20/50: 100%|██████████| 892/892 [03:50<00:00,  3.88it/s]


Val Epoch 20 PSNR (Y): 26.38
Saved new best PSNR: 26.38


Train Epoch 21/50: 100%|██████████| 892/892 [03:50<00:00,  3.88it/s]


Val Epoch 21 PSNR (Y): 25.97


Train Epoch 22/50: 100%|██████████| 892/892 [03:50<00:00,  3.88it/s]


Val Epoch 22 PSNR (Y): 25.64


Train Epoch 23/50: 100%|██████████| 892/892 [03:50<00:00,  3.88it/s]


Val Epoch 23 PSNR (Y): 25.85


Train Epoch 24/50: 100%|██████████| 892/892 [03:50<00:00,  3.88it/s]


Val Epoch 24 PSNR (Y): 26.10


Train Epoch 25/50: 100%|██████████| 892/892 [03:50<00:00,  3.88it/s]


Val Epoch 25 PSNR (Y): 25.64


Train Epoch 26/50: 100%|██████████| 892/892 [03:50<00:00,  3.88it/s]


Val Epoch 26 PSNR (Y): 26.18


Train Epoch 27/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 27 PSNR (Y): 25.51


Train Epoch 28/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 28 PSNR (Y): 26.72
Saved new best PSNR: 26.72


Train Epoch 29/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 29 PSNR (Y): 26.54


Train Epoch 30/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 30 PSNR (Y): 26.67


Train Epoch 31/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 31 PSNR (Y): 26.85
Saved new best PSNR: 26.85


Train Epoch 32/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 32 PSNR (Y): 26.94
Saved new best PSNR: 26.94


Train Epoch 33/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 33 PSNR (Y): 26.98
Saved new best PSNR: 26.98


Train Epoch 34/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 34 PSNR (Y): 26.95


Train Epoch 35/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 35 PSNR (Y): 26.87


Train Epoch 36/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 36 PSNR (Y): 26.83


Train Epoch 37/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 37 PSNR (Y): 27.19
Saved new best PSNR: 27.19


Train Epoch 38/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 38 PSNR (Y): 26.95


Train Epoch 39/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 39 PSNR (Y): 27.07


Train Epoch 40/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 40 PSNR (Y): 27.15


Train Epoch 41/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 41 PSNR (Y): 27.00


Train Epoch 42/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 42 PSNR (Y): 27.11


Train Epoch 43/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 43 PSNR (Y): 27.36
Saved new best PSNR: 27.36


Train Epoch 44/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 44 PSNR (Y): 26.93


Train Epoch 45/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 45 PSNR (Y): 27.12


Train Epoch 46/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 46 PSNR (Y): 27.12


Train Epoch 47/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 47 PSNR (Y): 27.15


Train Epoch 48/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 48 PSNR (Y): 27.17


Train Epoch 49/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 49 PSNR (Y): 27.11


Train Epoch 50/50: 100%|██████████| 892/892 [03:50<00:00,  3.87it/s]


Val Epoch 50 PSNR (Y): 27.01
