In [1]:
device="cuda"

In [2]:
### MODEL
import torch
import torch.nn as nn
import torch.nn.functional as F

class RLFB(nn.Module):
    def __init__(self, channels=64):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=True)
        self.act1 = nn.SiLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=True)
        self.act2 = nn.SiLU()
        self.conv3 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=True)
        self.act3 = nn.SiLU()
        # 1x1 conv to reduce/restore channels and connect to ESA per paper figure
        self.conv_reduce = nn.Conv2d(channels, channels, kernel_size=1, padding=0, bias=True)

    def forward(self, x):
        out = self.conv1(x); out = self.act1(out)
        out = self.conv2(out); out = self.act2(out)
        out = self.conv3(out); out = self.act3(out)
        out = self.conv_reduce(out)
        return x + out

class ESA(nn.Module):
    def __init__(self, channels=64, reduction=4):
        super().__init__()
        mid = channels // reduction
        self.conv1 = nn.Conv2d(channels, mid, kernel_size=1, padding=0)
        # downsample block: large receptive pooling as paper: use maxpool(7, stride=3) as described
        self.pool = nn.MaxPool2d(kernel_size=7, stride=3, padding=3)  # padding to keep shapes reasonable
        self.conv2 = nn.Conv2d(mid, mid, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(mid, mid, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(mid, channels, kernel_size=1, padding=0)
        self.sig = nn.Sigmoid()
        # small bilinear upsampling used in paper (interpolation)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        # x: (B,C,H,W)
        c1 = self.conv1(x)          # reduce channels
        p = self.pool(c1)          # spatial reduce
        p = self.relu(self.conv2(p))
        p = self.relu(self.conv3(p))
        # upsample back to original spatial (use bilinear)
        p_up = F.interpolate(p, size=(x.shape[2], x.shape[3]), mode='bilinear', align_corners=False)
        p_up = self.conv4(p_up)
        att = self.sig(p_up)
        return x * att

class MESRGenerator(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, num_features=64, num_rlfb=12, scale=4):
        """
        in_channels: typically 1 for magnetograms (paper uses single-channel)
        out_channels: 1
        num_features: 64 (paper uses 64)
        num_rlfb: 12 stacked RLFB blocks as paper
        scale: 4 (paper performs 4x super-resolution)
        """
        super().__init__()
        self.scale = scale
        self.shallow = nn.Conv2d(in_channels, num_features, kernel_size=3, padding=1)
        self.rlfb_blocks = nn.Sequential(*[RLFB(num_features) for _ in range(num_rlfb)])
        self.esa = ESA(num_features)
        self.fuse = nn.Conv2d(num_features, num_features, kernel_size=1, padding=0)
        ups = []
        s = scale
        while s > 1:
            ups += [nn.Conv2d(num_features, num_features * 4, kernel_size=3, padding=1),
                    nn.PixelShuffle(2),
                    nn.ReLU(inplace=True)]
            s //= 2
        self.upsampler = nn.Sequential(*ups)
        # final reconstruction conv
        self.recon = nn.Conv2d(num_features, out_channels, kernel_size=3, padding=1)

    def forward(self, x):
        # x: LR image tensor (B, C, H, W) where H,W are LR spatial dims
        feat_shallow = self.shallow(x)           # shallow features
        feat = self.rlfb_blocks(feat_shallow)    # deep RLFB stack
        feat = self.esa(feat)                    # ESA attention
        feat = self.fuse(feat + feat_shallow)    # residual-style fusion (paper connects features)
        out = self.upsampler(feat)               # upsample to HR
        out = self.recon(out)
        return out

class SRGAN_D2_Discriminator(nn.Module):
    def __init__(self, in_channels=1):
        super().__init__()
        # Following the conv/layer pattern in user's class with similar channel counts/strides
        self.conv1 = nn.Conv2d(in_channels, 64, 3, stride=1, padding=1)  # act LeakyReLU
        self.lrelu = nn.LeakyReLU(0.2, inplace=True)

        self.conv2 = nn.Conv2d(64, 64, 3, stride=2, padding=1, bias=True)
        self.bn1 = nn.BatchNorm2d(64)

        self.conv3 = nn.Conv2d(64, 128, 3, stride=1, padding=1, bias=True)
        self.bn2 = nn.BatchNorm2d(128)

        self.conv4 = nn.Conv2d(128, 128, 3, stride=2, padding=1, bias=True)
        self.bn3 = nn.BatchNorm2d(128)

        self.conv5 = nn.Conv2d(128, 256, 3, stride=1, padding=1, bias=True)
        self.bn4 = nn.BatchNorm2d(256)

        self.conv6 = nn.Conv2d(256, 256, 3, stride=2, padding=1, bias=True)
        self.bn5 = nn.BatchNorm2d(256)

        self.conv7 = nn.Conv2d(256, 512, 3, stride=1, padding=1, bias=True)
        self.bn6 = nn.BatchNorm2d(512)

        self.conv8 = nn.Conv2d(512, 512, 3, stride=2, padding=1, bias=True)
        self.bn7 = nn.BatchNorm2d(512)

        # Flatten + dense head
        # We'll use AdaptiveAvgPool2d to produce fixed-size before linear head to match original flatten behavior
        self.pool = nn.AdaptiveAvgPool2d(1)  # output size 1x1 per channel
        self.fc1 = nn.Linear(512, 1024)
        self.fc2 = nn.Linear(1024, 1)

    def forward(self, x):
        # x expected (B, C, H, W)
        x = self.lrelu(self.conv1(x))
        x = self.lrelu(self.conv2(x))
        x = self.bn1(x)
        x = self.lrelu(self.conv3(x))
        x = self.bn2(x)
        x = self.lrelu(self.conv4(x))
        x = self.bn3(x)
        x = self.lrelu(self.conv5(x))
        x = self.bn4(x)
        x = self.lrelu(self.conv6(x))
        x = self.bn5(x)
        x = self.lrelu(self.conv7(x))
        x = self.bn6(x)
        x = self.lrelu(self.conv8(x))
        x = self.bn7(x)

        x = self.pool(x)         # (B, 512, 1, 1)
        x = x.view(x.size(0), -1)  # (B, 512)
        x = self.lrelu(self.fc1(x))
        logits = self.fc2(x)     # raw logits (B,1)
        prob = torch.sigmoid(logits)
        return prob, logits


In [3]:
### Pre-Processing
import cv2
import numpy as np
import pywt
import torch
from torchvision import transforms

def preprocess_solar_image(img_path, apply_clahe=False, wavelet='db2', levels=2):
    # ----------------------
    # 1. Load & normalize
    # ----------------------
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE).astype(np.float32)
    img /= img.max() + 1e-8

    # ----------------------
    # 2. Dynamic range control
    # ----------------------
    if apply_clahe:
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        img_eq = clahe.apply((img*255).astype(np.uint8)) / 255.0
    else:
        img_eq = np.log1p(img) / np.log(2.0)  # log scaling for high dynamic range

    features = [img_eq]

    # ----------------------
    # 3. Multiscale decomposition (Laplacian pyramid)
    # ----------------------
    current = img_eq.copy()
    laplacian_levels = []
    for _ in range(levels):
        down = cv2.pyrDown(current)
        up = cv2.pyrUp(down, dstsize=current.shape[::-1])
        lap = current - up
        laplacian_levels.append(cv2.resize(lap, img_eq.shape[::-1]))
        # laplacian_levels.append(lap)
        current = down
    features.extend(laplacian_levels)

    # ----------------------
    # 4. Frequency / edge filters
    # ----------------------

    # Laplacian high-pass
    lap = cv2.Laplacian(img_eq, cv2.CV_32F, ksize=3)

    # Sobel edges
    sobelx = cv2.Sobel(img_eq, cv2.CV_32F, 1, 0, ksize=3)
    sobely = cv2.Sobel(img_eq, cv2.CV_32F, 0, 1, ksize=3)
    sobel_mag = np.sqrt(sobelx**2 + sobely**2)

    # DoG (Difference of Gaussians)
    g1 = cv2.GaussianBlur(img_eq, (3,3), 0.5)
    g2 = cv2.GaussianBlur(img_eq, (3,3), 1.5)
    dog = g1 - g2

    # Optional: Gabor filter for orientation-sensitive structures
    gabor_kernels = []
    for theta in [0, np.pi/4, np.pi/2, 3*np.pi/4]:
        kern = cv2.getGaborKernel((7,7), 2.0, theta, 5.0, 0.5, 0, ktype=cv2.CV_32F)
        gabor_resp = cv2.filter2D(img_eq, cv2.CV_32F, kern)
        gabor_kernels.append(gabor_resp)

    features.extend([lap, sobel_mag, dog] + gabor_kernels)

    # ----------------------
    # 5. Concatenate all features
    # ----------------------
    stacked = np.stack(features, axis=0)  # shape: [C, H, W]
    tensor = torch.tensor(stacked, dtype=torch.float32)

    return tensor  # ready for model input


In [4]:
### LOSS FUNCTION

import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from tqdm import tqdm
from pytorch_msssim import ms_ssim

def calc_psnr(pred, target, max_val=1.0):
    mse = F.mse_loss(pred, target)
    return 10 * torch.log10(max_val ** 2 / (mse + 1e-8))

def tv_loss(img):
    """Total Variation loss"""
    dh = torch.mean(torch.abs(img[:, :, 1:, :] - img[:, :, :-1, :]))
    dw = torch.mean(torch.abs(img[:, :, :, 1:] - img[:, :, :, :-1]))
    return dh + dw

class HaarDWT(nn.Module):
    def __init__(self):
        super().__init__()
        # fixed filters
        ll = torch.tensor([[1, 1], [1, 1]], dtype=torch.float32) / 2
        lh = torch.tensor([[1, 1], [-1, -1]], dtype=torch.float32) / 2
        hl = torch.tensor([[1, -1], [1, -1]], dtype=torch.float32) / 2
        hh = torch.tensor([[1, -1], [-1, 1]], dtype=torch.float32) / 2
        self.register_buffer("filters", torch.stack([ll, lh, hl, hh], dim=0))

    def forward(self, x):
        B, C, H, W = x.shape
        filt = self.filters.to(x.device).unsqueeze(1)  # (4,1,2,2)
        filt = filt.repeat(1, C, 1, 1)                 # (4,C,2,2)
        y = F.conv2d(x, filt, stride=2, padding=0, groups=C)  # (B, 4*C, H/2, W/2)
        y = y.view(B, 4, C, H // 2, W // 2)
        return y[:, 1], y[:, 2], y[:, 3]  # LH, HL, HH

class VGGPerceptual(nn.Module):
    def __init__(self, layer_index=16):
        super().__init__()
        vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features
        self.slice = nn.Sequential(*[vgg[i] for i in range(layer_index + 1)])
        for p in self.slice.parameters():
            p.requires_grad = False

    def forward(self, x):
        return self.slice(x)



class TotalLoss(nn.Module):
    def __init__(self, device, alpha=1.0, beta=0.5, gamma=0.2, delta=0.01, eps_tv=1e-6):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.delta = delta
        self.eps_tv = eps_tv
        self.l1 = nn.L1Loss()
        self.vgg = VGGPerceptual(layer_index=16).to(device)
        self.wavelet = HaarDWT().to(device)

    def forward(self, sr, hr):
        # L1
        l1_loss = self.l1(sr, hr)

        # MS-SSIM (use [0,1] range)
        ms_loss = 1 - ms_ssim(sr, hr, data_range=1.0, size_average=True)

        # Wavelet L2 on high-frequency bands
        sr_LH, sr_HL, sr_HH = self.wavelet(sr)
        hr_LH, hr_HL, hr_HH = self.wavelet(hr)
        wave_l2 = F.mse_loss(sr_LH, hr_LH) + F.mse_loss(sr_HL, hr_HL) + F.mse_loss(sr_HH, hr_HH)

        # Perceptual L2 (VGG features)
        if sr.shape[1] == 1:
            sr = sr.repeat(1, 3, 1, 1)
            hr = hr.repeat(1, 3, 1, 1)
        feat_sr = self.vgg(sr)
        feat_hr = self.vgg(hr)
        perceptual = F.mse_loss(feat_sr, feat_hr)

        # TV
        tv = tv_loss(sr)

        total = (self.alpha * l1_loss +
                 self.beta * ms_loss +
                 self.gamma * wave_l2 +
                 self.delta * perceptual +
                 self.eps_tv * tv)
        return total, {
            "L1": l1_loss.item(),
            "MS-SSIM": ms_loss.item(),
            "Wavelet": wave_l2.item(),
            "Perceptual": perceptual.item(),
            "TV": tv.item(),
            "Total": total.item(),
        }


In [5]:
class SolarSRDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, apply_clahe=False):
        self.lr_files = sorted([os.path.join(lr_dir, f) for f in os.listdir(lr_dir) if f.endswith(('.png', '.jpg', '.tif'))])
        self.hr_files = sorted([os.path.join(hr_dir, f) for f in os.listdir(hr_dir) if f.endswith(('.png', '.jpg', '.tif'))])
        assert len(self.lr_files) == len(self.hr_files), "LR and HR folders must have same number of images"
        self.apply_clahe = apply_clahe

    def __len__(self):
        return len(self.lr_files)

    def __getitem__(self, idx):
        lr_tensor = preprocess_solar_image(self.lr_files[idx], apply_clahe=self.apply_clahe)  # (10,H,W)
        hr_tensor = preprocess_solar_image(self.hr_files[idx], apply_clahe=self.apply_clahe)  # (10,H,W)
        # the HR target is grayscale single-channel — take only the equalized version
        hr_tensor = hr_tensor[0:1, :, :]   # keep only first channel as GT
        return lr_tensor, hr_tensor


In [6]:
def train_sr(
    train_lr_dir, train_hr_dir,
    valid_lr_dir, valid_hr_dir,
    epochs=50, batch_size=4,
    lr_gen=1e-4, lr_disc=1e-4,
    adv_weight=1e-3,
    save_dir="checkpoints",
    device="cuda" if torch.cuda.is_available() else "cpu"
):
    os.makedirs(save_dir, exist_ok=True)

    # --- Create datasets & loaders ---
    train_ds = SolarSRDataset(train_lr_dir, train_hr_dir)
    val_ds = SolarSRDataset(valid_lr_dir, valid_hr_dir)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=1, shuffle=False)

    # --- Model setup ---
    in_ch = 10  # from your preprocessing (10 stacked feature maps)
    gen = MESRGenerator(in_channels=in_ch, out_channels=1, num_features=64, num_rlfb=12, scale=4).to(device)
    disc = SRGAN_D2_Discriminator(in_channels=1).to(device)

    opt_g = torch.optim.Adam(gen.parameters(), lr=lr_gen, betas=(0.9, 0.999))
    opt_d = torch.optim.Adam(disc.parameters(), lr=lr_disc, betas=(0.9, 0.999))

    total_loss_fn = TotalLoss(device).to(device)
    bce_loss = nn.BCEWithLogitsLoss()

    best_psnr = 0.0

    # =========================
    #      Training Loop
    # =========================
    for epoch in range(1, epochs + 1):
        gen.train(); disc.train()
        loop = tqdm(train_loader, desc=f"Epoch [{epoch}/{epochs}]")
        avg_g, avg_d = 0, 0

        for lr_imgs, hr_imgs in loop:
            lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)

            # --- Train Discriminator ---
            opt_d.zero_grad()
            with torch.no_grad():
                sr_fake = gen(lr_imgs)
            real_prob, real_logit = disc(hr_imgs)
            fake_prob, fake_logit = disc(sr_fake.detach())
            real_lbl = torch.ones_like(real_logit)
            fake_lbl = torch.zeros_like(fake_logit)
            d_loss = (bce_loss(real_logit, real_lbl) + bce_loss(fake_logit, fake_lbl)) * 0.5
            d_loss.backward()
            opt_d.step()

            # --- Train Generator ---
            opt_g.zero_grad()
            sr = gen(lr_imgs)
            fake_prob, fake_logit = disc(sr)
            adv_loss = bce_loss(fake_logit, torch.ones_like(fake_logit))
            total_loss, comps = total_loss_fn(sr, hr_imgs)
            g_loss = total_loss + adv_weight * adv_loss
            g_loss.backward()
            opt_g.step()

            avg_g += g_loss.item()
            avg_d += d_loss.item()
            loop.set_postfix({
                "G_total": f"{g_loss.item():.4f}",
                "D_loss": f"{d_loss.item():.4f}",
                "L1": f"{comps['L1']:.4f}",
                "MS-SSIM": f"{comps['MS-SSIM']:.4f}"
            })

        # --- Validation ---
        gen.eval()
        psnr_total, ssim_total = 0, 0
        with torch.no_grad():
            for lr_imgs, hr_imgs in val_loader:
                lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)
                sr = gen(lr_imgs)
                psnr_total += calc_psnr(sr, hr_imgs).item()
                ssim_total += ms_ssim(sr, hr_imgs, data_range=1.0, size_average=True).item()

        psnr_mean = psnr_total / len(val_loader)
        ssim_mean = ssim_total / len(val_loader)
        print(f"Epoch {epoch}: PSNR={psnr_mean:.3f} dB | SSIM={ssim_mean:.4f}")

        if psnr_mean > best_psnr:
            best_psnr = psnr_mean
            torch.save(gen.state_dict(), os.path.join(save_dir, "best_generator.pth"))
            torch.save(disc.state_dict(), os.path.join(save_dir, "best_discriminator.pth"))
            print(f"✅ Best model saved (PSNR={psnr_mean:.2f})")

    print("Training finished ✅")

In [None]:
train_sr(
        train_lr_dir="/kaggle/input/solaresss/new_dataset/training/low_res",
        train_hr_dir="/kaggle/input/solaresss/new_dataset/training/high_res",
        valid_lr_dir="/kaggle/input/solaresss/new_dataset/validation/low_res",
        valid_hr_dir="/kaggle/input/solaresss/new_dataset/validation/high_res",
        epochs=20,
        batch_size=8,
    )


Epoch [1/20]:  12%|█▏        | 93/776 [03:58<30:03,  2.64s/it, G_total=0.2980, D_loss=0.0876, L1=0.0771, MS-SSIM=0.2437] 