In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torchinfo import summary

import sys
import os
sys.path.append(os.path.abspath(os.path.join("..", "src")))
import dataset
from models import FSRCNN_CA
import metrics

# %load_ext autoreload
# %autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
ds = dataset.load_general100_dataset()
img_arr = [img["image"] for img in ds["train"]]

In [None]:
# Paso 1: Convertir imágenes PIL a arrays BGR (formato OpenCV)
img_sample = img_arr[0:100]
img_sample_cv2 = [dataset.pil_to_cv2(img) for img in img_sample]

# Paso 2: Augmentación con escalado y rotación usando OpenCV
args_augment_fsrcnn = [[0.9, 0.8, 0.7, 0.6], [90, 180, 270]]
augmented_images_cv2 = dataset.augment_data_cv2(img_sample_cv2, *args_augment_fsrcnn)

# Paso 3: Parámetros para generación de parches lazy
upsample_factor = 2
patch_size = 10
stride = 5
use_deconv = True

# Paso 4: DataLoaders con LazyPatchDatasetCV2
train_loader, val_loader = dataset.lazy_train_val_dataloaders_cv2(
    images=augmented_images_cv2,
    scale_factor=upsample_factor,
    patch_size=patch_size,
    stride=stride,
    use_deconv=use_deconv,
    batch_size=512,
    num_workers=6,
    seed=42,
    val_split=0.2
)

In [None]:
from torch.amp import autocast, GradScaler
import numpy as np
def train_sr_model(model, train_loader, val_loader, num_epochs=10, lr=1e-4, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    model.to(device)
    scaler = GradScaler('cuda')
    train_loss = []
    valid_loss = []
    psnr_list = []
    ssim_list = []
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
            lr_batch, hr_batch = batch
            lr_batch, hr_batch = lr_batch.to(device), hr_batch.to(device)

            optimizer.zero_grad()
            with autocast('cuda'):
                sr_batch = model(lr_batch)
                loss = criterion(sr_batch, hr_batch)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()

        avg_train_loss = running_loss / len(train_loader)
        print(f"[Epoch {epoch+1}] Train Loss: {avg_train_loss:.6f}")
        train_loss.append(avg_train_loss)

        # VALIDACIÓN
        model.eval()
        val_loss = 0.0
        psnr_total, ssim_total = 0.0, 0.0
        with torch.no_grad():
            for val_batch in val_loader:
                lr_val, hr_val = val_batch
                lr_val, hr_val = lr_val.to(device), hr_val.to(device)
                with autocast('cuda'):
                    sr_val = model(lr_val)
                    loss = criterion(sr_val, hr_val)
                val_loss += loss.item()

                for i in range(sr_val.shape[0]):
                    sr_img = sr_val[i].squeeze().detach().cpu().numpy()
                    hr_img = hr_val[i].squeeze().detach().cpu().numpy()

                    sr_img = np.clip(sr_img, 0.0, 1.0)
                    hr_img = np.clip(hr_img, 0.0, 1.0)

                    psnr = metrics.psnr(hr_img, sr_img, data_range=1.0)
                    ssim = metrics.ssim(hr_img, sr_img, data_range=1.0)

                    psnr_total += psnr
                    ssim_total += ssim
        avg_psnr = psnr_total / len(val_loader)
        avg_ssim = ssim_total / len(val_loader)
        avg_val_loss = val_loss / len(val_loader)
        print(f"[Epoch {epoch+1}] Val Loss: {avg_val_loss:.6f}, PSNR: {avg_psnr:.4f}, SSIM: {avg_ssim:.4f}")
        valid_loss.append(avg_val_loss)
        psnr_list.append(avg_psnr)
        ssim_list.append(avg_ssim)

    return train_loss, valid_loss, psnr_list, ssim_list

In [6]:
model = FSRCNN_CA(scale=2)
upsample_factor = 2
patch_size = 10
summary(model, input_size=(1, 1, patch_size, patch_size))  # también 10x10

Layer (type:depth-idx)                                  Output Shape              Param #
FSRCNN_CA                                               [1, 1, 20, 20]            --
├─Sequential: 1-1                                       [1, 56, 10, 10]           --
│    └─Conv2d: 2-1                                      [1, 56, 10, 10]           1,456
│    └─PReLU: 2-2                                       [1, 56, 10, 10]           56
├─Sequential: 1-2                                       [1, 12, 10, 10]           --
│    └─Conv2d: 2-3                                      [1, 12, 10, 10]           684
│    └─PReLU: 2-4                                       [1, 12, 10, 10]           12
├─RIR: 1-3                                              [1, 12, 10, 10]           --
│    └─Sequential: 2-5                                  [1, 12, 10, 10]           --
│    │    └─RCAB: 3-1                                   [1, 12, 10, 10]           2,653
│    │    └─RCAB: 3-2                                

In [None]:
train_loss, val_loss, psnr_list, ssim_list= train_sr_model(model, train_loader, val_loader, num_epochs=10, lr=1e-4)
# torch.save(model.state_dict(), "fsrcnn_model.pth")
# HACER SAFE

In [None]:
## plot the train loss and validation loss evolution
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 5))
plt.plot(train_loss, label='Train', color='blue')
plt.plot(val_loss, label='Validation', color='orange')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('MCE Loss')
plt.legend()
plt.grid()
plt.show()

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(psnr_list, label='PSNR', color='green')
plt.plot(ssim_list, label='SSIM', color='red')
plt.title('PSNR and SSIM Evolution')
plt.xlabel('Epochs')
plt.ylabel('Metric Value')
plt.legend()
plt.grid()
plt.show()
