In [1]:
%load_ext autoreload
%autoreload 2
import os 
os.chdir("/home/jadli/Bureau/BDAI2/Satellite_Super_Resulotion0")

import torch 
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
import torch.nn.functional as F

from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

from importlib import reload
import src.utils.config
reload(src.utils.config)
from src.utils.config import CONFIG

from src.utils.data_loader import create_loaders
from src.models.models_architecture import SRCNN        
from src.utils.helper_functions import train_sr, test_sr, plot_sr_progress


device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device : {device}")

# CONFIG FROM YAML 
data_root      = CONFIG["paths"]["output_root"]
batch_size     = CONFIG["training"]["batch_size"]
num_workers    = CONFIG["training"]["num_workers"]
use_aug        = CONFIG["training"].get("use_augmentation", True)

# HYPERPARAMS FROM CONFIG 
lr              = CONFIG["training"]["lr"]
weight_decay    = CONFIG["training"]["weight_decay"]
num_epochs      = CONFIG["training"]["epochs"]
step_size       = CONFIG["training"]["scheduler_step_size"]
gamma           = CONFIG["training"]["scheduler_gamma"]


# === DataLoaders ===

train_loader, val_loader, test_loader = create_loaders(
    root=data_root,
    batch_size=batch_size,
    num_workers=4,
    use_augmentation=True
)


device : cuda

ðŸ“¦ DATA LOADED:
  Train: 64800 samples
  Val:   8100 samples
  Test:  8100 samples


In [2]:
class ResidualBlockHR(nn.Module):
    def __init__(self, channels=128, scale=0.1):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.relu  = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.scale = scale

    def forward(self, x):
        res = self.conv1(x)
        res = self.relu(res)
        res = self.conv2(res)
        return x + res * self.scale


class EDSRPro(nn.Module):
    """
    Version lourde : upsample d'abord â†’ residual blocks en HR.
    EntrÃ©e  : LR (B,3,H,W)
    Sortie  : HR (B,3,scale*H, scale*W)
    """
    def __init__(self, num_blocks=32, channels=128, scale_factor=4):
        super().__init__()
        self.scale_factor = scale_factor

        # 1) Feature extraction en LR
        self.conv_head = nn.Conv2d(3, channels, kernel_size=3, padding=1)

        # 2) Upsampling learned â†’ HR features
        up_layers = []
        if scale_factor in [2, 4]:
            for _ in range(scale_factor // 2):
                up_layers += [
                    nn.Conv2d(channels, channels * 4, 3, padding=1),
                    nn.PixelShuffle(2),
                    nn.ReLU(True),
                ]
        self.upsample = nn.Sequential(*up_layers)

        # 3) Residual blocks en HR
        self.res_blocks = nn.Sequential(
            *[ResidualBlockHR(channels=channels) for _ in range(num_blocks)]
        )

        # 4) Global skip en HR
        self.conv_tail = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

        # 5) Reconstruction finale
        self.conv_last = nn.Conv2d(channels, 3, kernel_size=3, padding=1)

    def forward(self, x):
        # x: LR
        x = self.conv_head(x)         # LR features
        x = self.upsample(x)          # HR features
        x_head_hr = x                 # pour skip global

        x_res = self.res_blocks(x)    # HR residual blocks
        x = self.conv_tail(x_res) + x_head_hr
        x = self.conv_last(x)
        return x


In [None]:
# === Hyperparams EDSR Pro ===
scale_factor = 4
num_blocks   = 32         # lourd : 32, tu peux tester 16 si OOM
channels     = 128        # 128 recommandÃ© pour Pro
batch_size   = 4           # <= 8 sinon OOM probable
num_epochs   = 10
lr           = 1e-4



# === Model ===
model = EDSRPro(num_blocks=num_blocks, channels=channels, scale_factor=scale_factor).to(device)

criterion  = nn.L1Loss()
optimizer  = optim.Adam(model.parameters(), lr=lr)
scheduler  = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

train_losses, val_losses = [], []
train_psnrs,  val_psnrs  = [], []
best_psnr = 0.0
best_model_path = CONFIG["model"]["best_EDSR_path"]
os.makedirs(os.path.dirname(best_model_path), exist_ok=True)

scaler = None  # pour AMP

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")

    train_loss, train_psnr, scaler = train_sr(
        model=model,
        train_loader=train_loader,
        loss_fn=criterion,
        optimizer=optimizer,
        device=device,
        scale_factor=scale_factor,
        model_requires_upscale=False,   # IMPORTANT pour EDSRPro
        scheduler=scheduler,
        use_amp=True,
        scaler=scaler
    )

    val_loss, val_psnr = test_sr(
        model=model,
        test_loader=val_loader,
        loss_fn=criterion,
        device=device,
        scale_factor=scale_factor,
        model_requires_upscale=False
    )

    # if val_psnr > best_psnr:
    #     best_psnr = val_psnr
    #     torch.save(model.state_dict(), best_model_path)
    #     print(f" New best model saved with Val PSNR = {best_psnr:.2f} dB")

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_psnrs.append(train_psnr)
    val_psnrs.append(val_psnr)

    print(f"Train loss: {train_loss:.6f} | Train PSNR: {train_psnr:.2f} dB")
    print(f"Val   loss: {val_loss:.6f} | Val   PSNR: {val_psnr:.2f} dB")
    print(f"  âž¤ LR: {optimizer.param_groups[0]['lr']:.8f}")

plot_sr_progress(train_losses, val_losses, train_psnrs, val_psnrs)


In [None]:
# train_losses, val_losses = [], []
# train_psnrs,  val_psnrs  = [], []

# best_psnr = 0.0
# best_model_path = CONFIG["model"]["best_EDSR_path"]
# os.makedirs(os.path.dirname(best_model_path), exist_ok=True)

# for epoch in range(num_epochs):
#     print(f"\nEpoch {epoch+1}/{num_epochs}")

    
#     train_loss, train_psnr, scaler = train_sr(
#         model=model,
#         train_loader=train_loader,
#         loss_fn=criterion,
#         optimizer=optimizer,
#         device=device,
#         scale_factor=4,                 
#         model_requires_upscale=False,   
#         scheduler=scheduler
#     )

#     val_loss, val_psnr = test_sr(
#         model=model,
#         test_loader=val_loader,
#         loss_fn=criterion,
#         device=device,
#         scale_factor=4,
#         model_requires_upscale=False
#     )

#     if val_psnr > best_psnr:
#         best_psnr = val_psnr
#         torch.save(model.state_dict(), best_model_path)
#         print(f" New best model saved with Val PSNR = {best_psnr:.2f} dB")

#     train_losses.append(train_loss)
#     val_losses.append(val_loss)
#     train_psnrs.append(train_psnr)
#     val_psnrs.append(val_psnr)

#     print(f"Train loss: {train_loss:.6f} | Train PSNR: {train_psnr:.2f} dB")
#     print(f"Val   loss: {val_loss:.6f} | Val   PSNR: {val_psnr:.2f} dB")
#     print(f"  -> LR: {optimizer.param_groups[0]['lr']:.8f}")

# plot_sr_progress(train_losses, val_losses, train_psnrs, val_psnrs)
