In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import os
from torchmetrics.functional import structural_similarity_index_measure as ssim

In [8]:
#Augmentations
import numpy as np
import cv2
import random


def radial_vignette(img, strength=0.2):
    """
    Apply radial vignette (simulate limb-darkening).
    strength: 0.0 (none) to 0.5 (strong)
    """
    h, w = img.shape
    y, x = np.ogrid[:h, :w]
    cy, cx = h / 2, w / 2
    r = np.sqrt((x - cx)**2 + (y - cy)**2)
    r = r / r.max()  # normalize radius [0,1]
    
    mask = 1 - strength * (r**2)  
    vignette = img.astype(np.float32) * mask
    vignette = np.clip(vignette, 0, 255)
    return vignette.astype(np.uint8)

def add_poisson_noise(img, scale_low=5.0, scale_high=100.0):
    """
    Add Poisson noise to simulate photon noise.
    """
    if img.dtype != np.float32:
        img = img.astype(np.float32) / 255.0
    scale = random.uniform(scale_low, scale_high)
    noisy = np.random.poisson(img * scale) / float(scale)
    noisy = np.clip(noisy, 0.0, 1.0)
    return (noisy * 255).astype(np.uint8)

def random_rotation(img, angle_range=20):
    """
    Rotate image by random angle within ±angle_range.
    """
    h, w = img.shape
    angle = random.uniform(-angle_range, angle_range)
    M = cv2.getRotationMatrix2D((w//2, h//2), angle, 1.0)
    rotated = cv2.warpAffine(img, M, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
    return rotated


def random_crop_resize(img, crop_scale=0.8, out_size=(256, 256)):
    """
    Random crop followed by resize.
    crop_scale: fraction of original image kept (0.5–1.0).
    """
    h, w = img.shape
    ch, cw = int(h * crop_scale), int(w * crop_scale)
    y = random.randint(0, h - ch)
    x = random.randint(0, w - cw)
    crop = img[y:y+ch, x:x+cw]
    resized = cv2.resize(crop, out_size, interpolation=cv2.INTER_LINEAR)
    return resized


def gaussian_blur(img, sigma_range=(0.5, 2.0)):
    """
    Apply Gaussian blur with random sigma.
    """
    sigma = random.uniform(*sigma_range)
    ksize = int(2 * round(3*sigma) + 1) 
    blurred = cv2.GaussianBlur(img, (ksize, ksize), sigmaX=sigma)
    return blurred


def brightness_contrast_jitter(img, brightness=0.2, contrast=0.3):
    """
    Random brightness/contrast adjustment.
    brightness: fraction (e.g. 0.2 → ±20%)
    contrast: fraction (e.g. 0.3 → ±30%)
    """
    b = random.uniform(-brightness, brightness) * 255
    c = 1.0 + random.uniform(-contrast, contrast)
    jittered = img.astype(np.float32) * c + b
    jittered = np.clip(jittered, 0, 255)
    return jittered.astype(np.uint8)

def random_noise(img, noise_level=0.10):
    """
    Add random Gaussian noise.
    noise_level: fraction of max pixel value (0.0-0.5)
    """
    if img.dtype != np.float32:
        img = img.astype(np.float32) / 255.0
    noise = np.random.normal(0, noise_level, img.shape)
    noisy = img + noise
    noisy = np.clip(noisy, 0.0, 1.0)
    return (noisy * 255).astype(np.uint8)

In [12]:
!curl -L https://raw.githubusercontent.com/JingyunLiang/SwinIR/main/models/network_swinir.py -o swinir_model.py



  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100 36832  100 36832    0     0  69592      0 --:--:-- --:--:-- --:--:-- 69889


In [14]:
from swinir_model import SwinIR
model = SwinIR(
    upscale=4,
    in_chans=3,
    img_size=64,
    window_size=8,
    img_range=1.0,
    depths=[6,6,6,6],
    embed_dim=60,
    num_heads=[6,6,6,6],
    mlp_ratio=2,
    upsampler='pixelshuffle'
)


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [15]:
import torch
from torch.utils.data import Dataset, DataLoader
import cv2
import os
import random
import numpy as np

class SolarDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, augment=True, img_size=(256,256),max_images=None):
        self.lr_files = sorted([os.path.join(lr_dir, f) for f in os.listdir(lr_dir)])
        self.hr_files = sorted([os.path.join(hr_dir, f) for f in os.listdir(hr_dir)])
        if max_images is not None:
            self.lr_files = self.lr_files[:max_images]
            self.hr_files = self.hr_files[:max_images]
        self.augment = augment
        self.img_size = img_size

    def __len__(self):
        return len(self.lr_files)

    def __getitem__(self, idx):
        # Load images
        lr = cv2.imread(self.lr_files[idx], cv2.IMREAD_GRAYSCALE)
        hr = cv2.imread(self.hr_files[idx], cv2.IMREAD_GRAYSCALE)

        # Resize to same size
        lr = cv2.resize(lr, self.img_size)
        hr = cv2.resize(hr, self.img_size)

        # Apply augmentations to LR only
        if self.augment:
            lr = radial_vignette(lr, strength=random.uniform(0.05,0.2))
            lr = add_poisson_noise(lr)
            lr = random_rotation(lr, angle_range=15)
            lr = gaussian_blur(lr, sigma_range=(0.5,1.5))
            lr = brightness_contrast_jitter(lr, brightness=0.15, contrast=0.2)
            lr = random_noise(lr, noise_level=0.05)
        
        # Normalize and convert to tensor [C,H,W]
        lr = torch.tensor(lr, dtype=torch.float32).unsqueeze(0)/255.0
        hr = torch.tensor(hr, dtype=torch.float32).unsqueeze(0)/255.0

        return lr, hr


In [16]:
max_images = None  # Set to None to use all images
batch_size = 2

train_dataset = SolarDataset("new_dataset/training/low_res", 
                             "new_dataset/training/high_res", 
                             augment=True, max_images=max_images)

val_dataset = SolarDataset("new_dataset/validation/low_res", 
                           "new_dataset/validation/high_res", 
                           augment=False, max_images=max_images)

test_dataset = SolarDataset("new_dataset/testing/low_res", 
                            "new_dataset/testing/high_res", 
                            augment=False, max_images=max_images)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [6]:
device='cuda'
device

'cuda'

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tqdm import tqdm
from pytorch_msssim import ssim

# -----------------------------------
# DEVICE SETUP
# -----------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on:", device)

# -----------------------------------
# LOSS FUNCTION
# -----------------------------------
criterion = nn.L1Loss()

# -----------------------------------
# MODEL & OPTIMIZER
# -----------------------------------
 # custom file or your model class

model = SwinIR(
    upscale=4,
    in_chans=3,
    img_size=64,
    window_size=8,
    img_range=1.0,
    depths=[6,6,6,6],
    embed_dim=60,
    num_heads=[6,6,6,6],
    mlp_ratio=2,
    upsampler='pixelshuffle'
)


optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-6)

# -----------------------------------
# CHECKPOINT DIRECTORY
# -----------------------------------
checkpoint_dir = "checkpoints_swin_l1"
os.makedirs(checkpoint_dir, exist_ok=True)
best_val_ssim = -1.0
epochs = 50

# -----------------------------------
# HELPER METRIC FUNCTIONS
# -----------------------------------
def psnr(pred, target):
    mse = torch.mean((pred - target) ** 2)
    if mse == 0:
        return 100
    return 20 * torch.log10(1.0 / torch.sqrt(mse))

def pixelwise_error(pred, target):
    return torch.mean(torch.abs(pred - target))

# -----------------------------------
# TRAINING LOOP
# -----------------------------------
scaler = torch.cuda.amp.GradScaler()  # mixed precision for speed
for epoch in range(epochs):
    model.train()
    train_loss = 0.0

    for lr_imgs, hr_imgs in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]"):
        lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)
        optimizer.zero_grad()

        with torch.cuda.amp.autocast():  # mixed precision
            outputs = model(lr_imgs)
            loss = criterion(outputs, hr_imgs)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item() * lr_imgs.size(0)

    train_loss /= len(train_loader.dataset)

    # -----------------------------------
    # VALIDATION
    # -----------------------------------
    model.eval()
    val_loss, pixel_err, psnr_val, ssim_val = 0.0, 0.0, 0.0, 0.0

    with torch.no_grad():
        for lr_imgs, hr_imgs in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val]"):
            lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)
            outputs = model(lr_imgs)

            val_loss += criterion(outputs, hr_imgs).item() * lr_imgs.size(0)
            pixel_err += pixelwise_error(outputs, hr_imgs).item() * lr_imgs.size(0)
            psnr_val += psnr(outputs, hr_imgs).item() * lr_imgs.size(0)
            ssim_val += ssim(outputs, hr_imgs, data_range=1.0, size_average=True).item() * lr_imgs.size(0)

    n_val = len(val_loader.dataset)
    val_loss /= n_val
    pixel_err /= n_val
    psnr_val /= n_val
    ssim_val /= n_val

    print(f"Epoch {epoch+1}/{epochs} | "
          f"Train Loss: {train_loss:.6f} | "
          f"Val Loss: {val_loss:.6f} | "
          f"Pixel Err: {pixel_err:.6f} | "
          f"PSNR: {psnr_val:.2f} dB | "
          f"SSIM: {ssim_val:.4f}")

    # -----------------------------------
    # CHECKPOINT SAVING
    # -----------------------------------
    if ssim_val > best_val_ssim:
        best_val_ssim = ssim_val
        ckpt_path = os.path.join(checkpoint_dir, "best_swin_sr.pth")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_ssim': best_val_ssim
        }, ckpt_path)
        print(f"✅ Best model updated (SSIM={ssim_val:.4f}) at epoch {epoch+1}")


In [None]:

import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torchmetrics.functional import structural_similarity_index_measure as ssim
import os

# -----------------------------
# Metrics
# -----------------------------
def pixelwise_error(pred, target):
    return torch.mean(torch.abs(pred - target)).item()

def psnr(pred, target, max_val=1.0):
    mse = torch.mean((pred - target) ** 2)
    if mse == 0:
        return float('inf')
    return 20 * torch.log10(max_val / torch.sqrt(mse)).item()

def ssim_metric(pred, target):
    return ssim(pred, target, data_range=1.0).item()

# -----------------------------
# Load model
# -----------------------------
model = SwinSRNet(sw_model_name="swin_tiny_patch4_window7_224", pretrained=False, upscale=4).to(device)
model.load_state_dict(torch.load("checkpoints_swin_l1/best_swin_sr.pth", map_location=device))
model.eval()

os.makedirs("reconstructed_images_swin", exist_ok=True)

# -----------------------------
# Testing and Visualization
# -----------------------------
psnr_list = []
ssim_list = []
pixel_err_list = []

count = 0
for lr_imgs, hr_imgs in tqdm(test_loader, desc="Testing"):
    lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)
    with torch.no_grad():
        outputs = model(lr_imgs)

    for i in range(lr_imgs.size(0)):
        pred_img = outputs[i].detach().cpu().float()
        true_img = hr_imgs[i].detach().cpu().float()

        pred_img_ssim = pred_img.unsqueeze(0)
        true_img_ssim = true_img.unsqueeze(0)

        psnr_list.append(psnr(pred_img, true_img))
        ssim_list.append(ssim_metric(pred_img_ssim, true_img_ssim))
        pixel_err_list.append(pixelwise_error(pred_img, true_img))

        recon_img_path = f"reconstructed_images_swin/recon_{count}.png"
        plt.imsave(recon_img_path, pred_img.squeeze(), cmap='gray')

        if count < 5:
            fig, axes = plt.subplots(1,2, figsize=(6,3))
            axes[0].imshow(true_img.squeeze(), cmap='gray')
            axes[0].set_title("Original")
            axes[0].axis('off')
            axes[1].imshow(pred_img.squeeze(), cmap='gray')
            axes[1].set_title("Reconstructed")
            axes[1].axis('off')
            plt.show()

        count += 1

print(f"Average PSNR: {np.mean(psnr_list):.2f} dB")
print(f"Average SSIM: {np.mean(ssim_list):.4f}")
print(f"Average Pixel Error: {np.mean(pixel_err_list):.6f}")
print("Reconstructed images saved in 'reconstructed_images_swin/' folder")


Testing:   0%|          | 0/104 [00:07<?, ?it/s]


RuntimeError: Predictions and targets are expected to have the same shape, but got torch.Size([1, 8, 1, 256, 256]) and torch.Size([1, 1, 256, 256]).

In [None]:
# train_swin_solar_sr.py
import os
import random
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from pytorch_msssim import ssim as ms_ssim

# ---------------------------
# Reproducibility / device
# ---------------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on:", device)

# ---------------------------
# Augmentations (your functions)
# ---------------------------
def radial_vignette(img, strength=0.2):
    h, w = img.shape
    y, x = np.ogrid[:h, :w]
    cy, cx = h / 2, w / 2
    r = np.sqrt((x - cx)**2 + (y - cy)**2)
    r = r / r.max()
    mask = 1 - strength * (r**2)
    vignette = img.astype(np.float32) * mask
    vignette = np.clip(vignette, 0, 255)
    return vignette.astype(np.uint8)

def add_poisson_noise(img, scale_low=5.0, scale_high=100.0):
    if img.dtype != np.float32:
        img_f = img.astype(np.float32) / 255.0
    else:
        img_f = img
    scale = random.uniform(scale_low, scale_high)
    noisy = np.random.poisson(img_f * scale) / float(scale)
    noisy = np.clip(noisy, 0.0, 1.0)
    return (noisy * 255).astype(np.uint8)

def random_rotation(img, angle_range=20):
    h, w = img.shape
    angle = random.uniform(-angle_range, angle_range)
    M = cv2.getRotationMatrix2D((w//2, h//2), angle, 1.0)
    rotated = cv2.warpAffine(img, M, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
    return rotated

def random_crop_resize(img, crop_scale=0.8, out_size=(256, 256)):
    h, w = img.shape
    ch, cw = int(h * crop_scale), int(w * crop_scale)
    if ch == 0 or cw == 0:
        return cv2.resize(img, out_size, interpolation=cv2.INTER_LINEAR)
    y = random.randint(0, h - ch)
    x = random.randint(0, w - cw)
    crop = img[y:y+ch, x:x+cw]
    resized = cv2.resize(crop, out_size, interpolation=cv2.INTER_LINEAR)
    return resized

def gaussian_blur(img, sigma_range=(0.5, 2.0)):
    sigma = random.uniform(*sigma_range)
    ksize = int(2 * round(3*sigma) + 1)
    if ksize % 2 == 0:
        ksize += 1
    blurred = cv2.GaussianBlur(img, (ksize, ksize), sigmaX=sigma)
    return blurred

def brightness_contrast_jitter(img, brightness=0.2, contrast=0.3):
    b = random.uniform(-brightness, brightness) * 255
    c = 1.0 + random.uniform(-contrast, contrast)
    jittered = img.astype(np.float32) * c + b
    jittered = np.clip(jittered, 0, 255)
    return jittered.astype(np.uint8)

def random_noise(img, noise_level=0.10):
    if img.dtype != np.float32:
        img_f = img.astype(np.float32) / 255.0
    else:
        img_f = img
    noise = np.random.normal(0, noise_level, img_f.shape).astype(np.float32)
    noisy = img_f + noise
    noisy = np.clip(noisy, 0.0, 1.0)
    return (noisy * 255).astype(np.uint8)

# ---------------------------
# Model: try official SwinIR else fallback timm-based simple head
# ---------------------------
try:
    # If network_swinir.py is present and defines SwinIR
    from swinir_model import SwinIR    # user-provided network_swinir.py saved as swinir_model.py
    print("Using SwinIR from swinir_model.py")
    def build_model(in_chans=1, upscale=4):
        # instantiate a typical SwinIR-like config for medium model (you can change)
        return SwinIR(
            upscale=upscale,
            in_chans=in_chans,
            img_size=64,
            window_size=8,
            img_range=1.0,
            depths=[6,6,6,6],
            embed_dim=60,
            num_heads=[6,6,6,6],
            mlp_ratio=2,
            upsampler='pixelshuffle'
        )
except Exception as e:
    print("Official SwinIR import failed or not found. Falling back to a timm-based Swin-T backbone + simple reconstruction head.")
    print("Import error was:", e)
    # fallback model using timm
    try:
        from timm import create_model
    except Exception:
        raise RuntimeError("timm not installed. Install it with `pip install timm` or provide network_swinir.py in the working directory.")
    class SwinFallbackSR(nn.Module):
        def __init__(self, sw_model_name="swin_tiny_patch4_window7_224", pretrained=True, in_chans=1, upscale=4):
            super().__init__()
            self.in_chans = in_chans
            self.upscale = upscale
            # create backbone that returns intermediate feature maps
            self.backbone = create_model(sw_model_name, pretrained=pretrained, features_only=True, in_chans=in_chans)
            # get channels of the last feature map
            try:
                backbone_out = self.backbone.feature_info[-1]['num_chs']
            except Exception:
                # fallback guess
                backbone_out = 768
            # reconstruction: upsample then conv to output channels
            self.reconstruct = nn.Sequential(
                nn.Conv2d(backbone_out, 256, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Upsample(scale_factor=upscale, mode='bilinear', align_corners=False),
                nn.Conv2d(256, in_chans, kernel_size=3, padding=1),
                nn.Sigmoid()   # ensure output in [0,1]
            )
        def forward(self, x):
            feats = self.backbone(x)[-1]    # last feature map
            out = self.reconstruct(feats)
            return out
    def build_model(in_chans=1, upscale=4):
        return SwinFallbackSR(in_chans=in_chans, upscale=upscale)

# ---------------------------
# Dataset
# ---------------------------
class SolarDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, augment=True, img_size=(256,256), max_images=None):
        self.lr_files = sorted([os.path.join(lr_dir, f) for f in os.listdir(lr_dir) if f.lower().endswith(('.png','.jpg','.jpeg','.bmp','.tif'))])
        self.hr_files = sorted([os.path.join(hr_dir, f) for f in os.listdir(hr_dir) if f.lower().endswith(('.png','.jpg','.jpeg','.bmp','.tif'))])
        if max_images is not None:
            self.lr_files = self.lr_files[:max_images]
            self.hr_files = self.hr_files[:max_images]
        assert len(self.lr_files) == len(self.hr_files), "LR and HR file counts differ!"
        self.augment = augment
        self.img_size = img_size

    def __len__(self):
        return len(self.lr_files)

    def __getitem__(self, idx):
        lr = cv2.imread(self.lr_files[idx], cv2.IMREAD_GRAYSCALE)
        hr = cv2.imread(self.hr_files[idx], cv2.IMREAD_GRAYSCALE)
        if lr is None or hr is None:
            raise RuntimeError(f"Failed to load image pair: {self.lr_files[idx]}, {self.hr_files[idx]}")
        lr = cv2.resize(lr, self.img_size, interpolation=cv2.INTER_LINEAR)
        hr = cv2.resize(hr, self.img_size, interpolation=cv2.INTER_LINEAR)

        if self.augment:
            # random order of augmentation
            if random.random() < 0.9:
                lr = radial_vignette(lr, strength=random.uniform(0.03,0.18))
            if random.random() < 0.8:
                lr = add_poisson_noise(lr)
            if random.random() < 0.6:
                lr = random_rotation(lr, angle_range=12)
            if random.random() < 0.7:
                lr = gaussian_blur(lr, sigma_range=(0.3,1.2))
            if random.random() < 0.7:
                lr = brightness_contrast_jitter(lr, brightness=0.12, contrast=0.18)
            if random.random() < 0.5:
                lr = random_noise(lr, noise_level=0.04)

        # convert to tensor CxHxW in float32 [0,1]
        lr_t = torch.tensor(lr, dtype=torch.float32).unsqueeze(0) / 255.0
        hr_t = torch.tensor(hr, dtype=torch.float32).unsqueeze(0) / 255.0
        return lr_t, hr_t

# ---------------------------
# Config and DataLoaders
# ---------------------------
max_images = None
batch_size = 8
img_size = (256,256)
upscale = 4

train_dataset = SolarDataset("new_dataset/training/low_res",
                             "new_dataset/training/high_res",
                             augment=True, img_size=img_size, max_images=max_images)

val_dataset = SolarDataset("new_dataset/validation/low_res",
                           "new_dataset/validation/high_res",
                           augment=False, img_size=img_size, max_images=max_images)

test_dataset = SolarDataset("new_dataset/testing/low_res",
                            "new_dataset/testing/high_res",
                            augment=False, img_size=img_size, max_images=max_images)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2, pin_memory=True)

# ---------------------------
# Build model, loss, optimizer
# ---------------------------
in_chans = 1
try:
    model = build_model(in_chans=in_chans, upscale=upscale)
except TypeError:
    # fallback if build_model signature uses different args
    model = build_model(in_chans, upscale)
# ensure final activation range match training targets: many implementations expect output in [0,1]
# If using official SwinIR that outputs raw range, we'll clamp later.
model = model.to(device)

criterion = nn.L1Loss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-6)

# ---------------------------
# Metrics
# ---------------------------
def psnr_torch(pred, target):
    # pred, target in [0,1]
    mse = torch.mean((pred - target) ** 2)
    if mse.item() == 0:
        return torch.tensor(100.0, device=pred.device)
    return 20.0 * torch.log10(1.0 / torch.sqrt(mse))

def pixelwise_error(pred, target):
    return torch.mean(torch.abs(pred - target))

# ---------------------------
# Training loop (no mixed precision)
# ---------------------------
checkpoint_dir = "checkpoints_swin_l1"
os.makedirs(checkpoint_dir, exist_ok=True)
best_val_ssim = -1.0
epochs = 50

for epoch in range(epochs):
    model.train()
    train_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]")
    for lr_imgs, hr_imgs in pbar:
        lr_imgs = lr_imgs.to(device)   # float32
        hr_imgs = hr_imgs.to(device)
        optimizer.zero_grad()
        outputs = model(lr_imgs)
        # If model returns outside [0,1], ensure we map to [0,1] for loss with HR in [0,1]
        # If using official SwinIR with img_range=1.0 it should already be in [0,1]; else clamp:
        outputs = torch.clamp(outputs, 0.0, 1.0)

        loss = criterion(outputs, hr_imgs)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * lr_imgs.size(0)
        pbar.set_postfix({"loss": f"{loss.item():.6f}"})

    train_loss /= len(train_loader.dataset)

    # Validation
    model.eval()
    val_loss = 0.0
    pixel_err = 0.0
    psnr_val = 0.0
    ssim_val = 0.0

    with torch.no_grad():
        for lr_imgs, hr_imgs in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val]"):
            lr_imgs = lr_imgs.to(device)
            hr_imgs = hr_imgs.to(device)
            outputs = model(lr_imgs)
            outputs = torch.clamp(outputs, 0.0, 1.0)

            val_loss += criterion(outputs, hr_imgs).item() * lr_imgs.size(0)
            pixel_err += pixelwise_error(outputs, hr_imgs).item() * lr_imgs.size(0)
            psnr_val += psnr_torch(outputs, hr_imgs).item() * lr_imgs.size(0)
            # ms_ssim expects images in [0,1] and shape N,C,H,W
            ssim_val += ms_ssim(outputs, hr_imgs, data_range=1.0, size_average=True).item() * lr_imgs.size(0)

    n_val = len(val_loader.dataset)
    val_loss /= n_val
    pixel_err /= n_val
    psnr_val /= n_val
    ssim_val /= n_val

    print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f} | Pixel Err: {pixel_err:.6f} | PSNR: {psnr_val:.2f} dB | SSIM: {ssim_val:.4f}")

    # Save best + latest checkpoints
    ckpt_latest = os.path.join(checkpoint_dir, "latest_swin_sr.pth")
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_loss': val_loss,
        'ssim': ssim_val
    }, ckpt_latest)

    if ssim_val > best_val_ssim:
        best_val_ssim = ssim_val
        ckpt_best = os.path.join(checkpoint_dir, "best_swin_sr.pth")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'ssim': ssim_val
        }, ckpt_best)
        print(f"✅ Best model updated (SSIM={ssim_val:.4f}) at epoch {epoch+1}")

# Optionally: run test inference using test_loader and save outputs
model.eval()
out_dir = "sr_results"
os.makedirs(out_dir, exist_ok=True)
with torch.no_grad():
    for i, (lr_imgs, hr_imgs) in enumerate(tqdm(test_loader, desc="Testing")):
        lr_imgs = lr_imgs.to(device)
        outputs = torch.clamp(model(lr_imgs), 0.0, 1.0)
        # save first example in batch
        out_np = (outputs[0,0].cpu().numpy() * 255.0).astype(np.uint8)
        cv2.imwrite(os.path.join(out_dir, f"sr_{i:04d}.png"), out_np)
print("Training finished. Best SSIM:", best_val_ssim)


Running on: cuda


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Using SwinIR from swinir_model.py


Epoch 1/50 [Train]:   0%|          | 0/776 [00:00<?, ?it/s]