In [1]:
# =========================
# 1️⃣ Imports
# =========================
import os, cv2, random, torch, numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from pytorch_msssim import ssim
from math import log10
from tqdm import tqdm
from PIL import Image


In [2]:
# =========================
# 2️⃣ Augmentation Functions
# =========================
def radial_vignette(img, strength=0.2):
    h, w = img.shape[:2]
    y, x = np.ogrid[:h, :w]
    cy, cx = h / 2, w / 2
    r = np.sqrt((x - cx)**2 + (y - cy)**2)
    r = r / r.max()
    mask = 1 - strength * (r**2)
    vignette = img.astype(np.float32) * mask[..., None]
    vignette = np.clip(vignette, 0, 255)
    return vignette.astype(np.uint8)

def add_poisson_noise(img, scale_low=5.0, scale_high=100.0):
    img_f = img.astype(np.float32) / 255.0
    scale = random.uniform(scale_low, scale_high)
    noisy = np.random.poisson(img_f * scale) / float(scale)
    noisy = np.clip(noisy, 0.0, 1.0)
    return (noisy * 255).astype(np.uint8)

def random_rotation(img, angle_range=20):
    h, w = img.shape[:2]
    angle = random.uniform(-angle_range, angle_range)
    M = cv2.getRotationMatrix2D((w//2, h//2), angle, 1.0)
    rotated = cv2.warpAffine(img, M, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
    return rotated

def random_crop_resize(img, crop_scale=0.8, out_size=(256, 256)):
    h, w = img.shape[:2]
    ch, cw = int(h * crop_scale), int(w * crop_scale)
    y = random.randint(0, h - ch)
    x = random.randint(0, w - cw)
    crop = img[y:y+ch, x:x+cw]
    resized = cv2.resize(crop, out_size, interpolation=cv2.INTER_LINEAR)
    return resized

def gaussian_blur(img, sigma_range=(0.5, 2.0)):
    sigma = random.uniform(*sigma_range)
    ksize = int(2 * round(3*sigma) + 1)
    if ksize % 2 == 0:
        ksize += 1
    return cv2.GaussianBlur(img, (ksize, ksize), sigmaX=sigma)

def brightness_contrast_jitter(img, brightness=0.2, contrast=0.3):
    b = random.uniform(-brightness, brightness) * 255
    c = 1.0 + random.uniform(-contrast, contrast)
    jittered = img.astype(np.float32) * c + b
    jittered = np.clip(jittered, 0, 255)
    return jittered.astype(np.uint8)

def random_noise(img, noise_level=0.10):
    img_f = img.astype(np.float32) / 255.0
    noise = np.random.normal(0, noise_level, img_f.shape).astype(np.float32)
    noisy = np.clip(img_f + noise, 0.0, 1.0)
    return (noisy * 255).astype(np.uint8)


In [3]:
class PairedSolarDataset(Dataset):
    def __init__(self, lr_dir, hr_dir):
        self.lr_paths = sorted([os.path.join(lr_dir, f) for f in os.listdir(lr_dir) if f.endswith(('.png', '.jpg'))])
        self.hr_paths = sorted([os.path.join(hr_dir, f) for f in os.listdir(hr_dir) if f.endswith(('.png', '.jpg'))])
        self.transform = transforms.ToTensor()

    def __getitem__(self, idx):
        lr = cv2.imread(self.lr_paths[idx])
        hr = cv2.imread(self.hr_paths[idx])
        lr = cv2.cvtColor(lr, cv2.COLOR_BGR2RGB)
        hr = cv2.cvtColor(hr, cv2.COLOR_BGR2RGB)
        lr = cv2.resize(lr, (64, 64))
        hr = cv2.resize(hr, (256, 256))
        return self.transform(lr), self.transform(hr)

    def __len__(self):
        return min(len(self.lr_paths), len(self.hr_paths))


In [7]:
ckpt = torch.load("003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth", map_location=device)
ckpt

{'params_ema': OrderedDict([('conv_first.weight',
               tensor([[[[-0.0400,  0.0239, -0.0688],
                         [ 0.0345, -0.0567, -0.0223],
                         [ 0.0403,  0.2288,  0.0460]],
               
                        [[-0.1009,  0.0221, -0.0068],
                         [ 0.0635,  0.0337,  0.0825],
                         [ 0.1634, -0.0767, -0.0187]],
               
                        [[ 0.1008,  0.0605,  0.0887],
                         [-0.0641, -0.1189, -0.0769],
                         [-0.0783, -0.1061, -0.0453]]],
               
               
                       [[[-0.2702, -0.1981, -0.0328],
                         [ 0.1220,  0.3599,  0.1127],
                         [-0.0221,  0.0820, -0.1455]],
               
                        [[ 0.1062, -0.1305,  0.0597],
                         [-0.0110,  0.1761,  0.0845],
                         [-0.0729, -0.1616, -0.0334]],
               
                        [[ 0.0946,  0.

In [9]:


from swinir_model import SwinIR  # ensure swinir_model.py is in the same folder

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SwinIR(
    upscale=4,
    in_chans=3,
    img_size=64,
    window_size=8,
    img_range=1.0,
    depths=[6, 6, 6, 6, 6, 6],
    embed_dim=180,
    num_heads=[6, 6, 6, 6, 6, 6],
    mlp_ratio=2,
    upsampler='pixelshuffle'
).to(device)

ckpt = torch.load("003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth", map_location=device)
model.load_state_dict(ckpt['params_ema'], strict=False)


_IncompatibleKeys(missing_keys=['upsample.0.weight', 'upsample.0.bias', 'upsample.2.weight', 'upsample.2.bias'], unexpected_keys=['conv_up1.weight', 'conv_up1.bias', 'conv_up2.weight', 'conv_up2.bias', 'conv_hr.weight', 'conv_hr.bias'])

In [10]:
def psnr(pred, gt):
    mse = torch.mean((pred - gt) ** 2)
    if mse == 0:
        return 100
    return 10 * log10(1 / mse.item())


In [11]:
train_dataset = PairedSolarDataset(
    "new_dataset/training/low_res",
    "new_dataset/training/high_res"
)
val_dataset = PairedSolarDataset(
    "new_dataset/validation",
    "new_dataset/validation"
)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)


In [12]:
criterion = nn.L1Loss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
best_ssim = -1.0

for epoch in range(10):
    model.train()
    total_loss = 0
    for lr_imgs, hr_imgs in tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]"):
        lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)
        optimizer.zero_grad()
        sr = model(lr_imgs)
        loss = criterion(sr, hr_imgs)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Train Loss: {total_loss / len(train_loader):.4f}")

    model.eval()
    avg_ssim, avg_psnr = 0, 0
    with torch.no_grad():
        for lr_imgs, hr_imgs in tqdm(val_loader, desc=f"Epoch {epoch+1} [Val]"):
            lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)
            sr = model(lr_imgs)
            avg_ssim += ssim(sr, hr_imgs, data_range=1.0, size_average=True).item()
            avg_psnr += psnr(sr, hr_imgs)
    avg_ssim /= len(val_loader)
    avg_psnr /= len(val_loader)
    print(f"Val SSIM: {avg_ssim:.4f} | PSNR: {avg_psnr:.2f}")

    if avg_ssim > best_ssim:
        best_ssim = avg_ssim
        torch.save(model.state_dict(), "best_solar_swinir.pth")
        print(f"✅ Best Model Saved (SSIM={best_ssim:.4f})")


Epoch 1 [Train]:   0%|          | 0/1551 [00:04<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 12.00 MiB. GPU 0 has a total capacty of 4.00 GiB of which 0 bytes is free. Of the allocated memory 3.39 GiB is allocated by PyTorch, and 64.61 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF