In [2]:
%load_ext autoreload
%autoreload 2
import os 
os.chdir("/home/jadli/Bureau/BDAI2/Satellite_Super_Resulotion0")

import torch 
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
import torch.nn.functional as F

from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

from importlib import reload
import src.utils.config
reload(src.utils.config)
from src.utils.config import CONFIG

from src.utils.data_loader import create_loaders
from src.models.models_architecture import SRCNN        
from src.utils.helper_functions import train_sr, val_sr, plot_sr_progress

best_model_path = CONFIG["model"]["best_EDSR_path"]
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device : {device}")

# CONFIG FROM YAML 
data_root      = CONFIG["paths"]["output_root"]
batch_size     = CONFIG["training"]["batch_size"]
num_workers    = CONFIG["training"]["num_workers"]
use_aug        = CONFIG["training"].get("use_augmentation", True)

# HYPERPARAMS FROM CONFIG 
lr              = CONFIG["training"]["lr"]
weight_decay    = CONFIG["training"]["weight_decay"]
num_epochs      = CONFIG["training"]["epochs"]
step_size       = CONFIG["training"]["scheduler_step_size"]
gamma           = CONFIG["training"]["scheduler_gamma"]


# LOAD DATA 

train_loader, val_loader, test_loader = create_loaders(
    root=data_root,
    batch_size=batch_size,
    num_workers=num_workers,
    use_augmentation=use_aug
)



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
device : cuda

 DATA LOADED:
  Train: 64800 samples
  Val:   8100 samples
  Test:  8100 samples


In [3]:
class ResidualBlock(nn.Module):
    def __init__(self, channels=64, scale=0.1):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.ReLU(True),
            nn.Conv2d(channels, channels, 3, padding=1)
        )
        self.scale = scale

    def forward(self, x):
        return x + self.block(x) * self.scale


class EDSR(nn.Module):
    def __init__(self, num_blocks=32, scale_factor=4):
        super().__init__()

        # 1) Initial Feature Extractor
        self.conv_head = nn.Conv2d(3, 64, kernel_size=3, padding=1)

        # 2) Residual Blocks
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(64) for _ in range(num_blocks)]
        )

        # 3) Global skip
        self.conv_tail = nn.Conv2d(64, 64, kernel_size=3, padding=1)

        # 4) Upsampling
        up_layers = []
        if scale_factor == 2 or scale_factor == 4:
            for _ in range(scale_factor // 2):
                up_layers += [
                    nn.Conv2d(64, 256, kernel_size=3, padding=1),
                    nn.PixelShuffle(2),
                    nn.ReLU(True)
                ]
        self.upsample = nn.Sequential(*up_layers)

        # 5) Reconstruction
        self.conv_last = nn.Conv2d(64, 3, kernel_size=3, padding=1)

    def forward(self, x):
        x_head = self.conv_head(x)
        x_res = self.res_blocks(x_head)
        x = self.conv_tail(x_res) + x_head   
        x = self.upsample(x)
        x = self.conv_last(x)
        return x


In [4]:
model = EDSR().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)


train_losses, val_losses = [], []
train_psnrs,  val_psnrs  = [], []

best_psnr = 0.0
best_model_path = CONFIG["model"]["best_EDSR_path"]
os.makedirs(os.path.dirname(best_model_path), exist_ok=True)

# LOAD CHECKPOINT (IF EXISTS)
start_epoch = 0

if os.path.exists(best_model_path):
    print("Loading checkpoint:", best_model_path)
    checkpoint = torch.load(best_model_path, map_location=device)

    if "model" in checkpoint:
        model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        scheduler.load_state_dict(checkpoint["scheduler"])
        best_psnr = checkpoint["best_psnr"]
        start_epoch = checkpoint["epoch"] + 1
        print(f" Resuming training from epoch {start_epoch} | Best PSNR = {best_psnr:.2f}")
    else:
        print(" Old checkpoint without optimizer/scheduler. Loading model only.")
        model.load_state_dict(checkpoint)
else:
    print(" Training from scratch")


for epoch in range(start_epoch, num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")

    # TRAIN 
    train_loss, train_psnr = train_sr(
        model=model,
        train_loader=train_loader,
        loss_fn=criterion,
        optimizer=optimizer,
        device=device,
        scale_factor=4,
        model_requires_upscale=False,   
        scheduler=scheduler
    )

    # VALIDATION 
    val_loss, val_psnr = val_sr(
        model=model,
        val_loader=val_loader,
        loss_fn=criterion,
        device=device,
        scale_factor=4,
        model_requires_upscale=False
    )

    # SAVE BEST MODEL 
    if val_psnr > best_psnr:
        best_psnr = val_psnr

        torch.save({
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
            "epoch": epoch,
            "best_psnr": best_psnr,
        }, best_model_path)

        print(f" New BEST model saved at epoch {epoch+1} with PSNR = {best_psnr:.2f}")

    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_psnrs.append(train_psnr)
    val_psnrs.append(val_psnr)

    print(f"Train loss: {train_loss:.6f} | Train PSNR: {train_psnr:.2f} dB")
    print(f"Val   loss: {val_loss:.6f} | Val   PSNR: {val_psnr:.2f} dB")
    print(f"â†’ Current LR: {optimizer.param_groups[0]['lr']:.8f}")

# PLOT
plot_sr_progress(train_losses, val_losses, train_psnrs, val_psnrs)


 Training from scratch

Epoch 1/10


                                                                                   

KeyboardInterrupt: 

In [None]:
|