Target Specification:
- OS: Ubuntu 22.04.4 LTS
- Python: 3.12
- Platform: Google Colab
- CPU: Intel Xeon E5â€‘2699 v4 1 Core / 2 Thread
- RAM: 13GB
- GPU: Nvidia T4 15GB (Optional)

Library Import

In [None]:
import os
import cv2
import time
import math
import glob
import torch
import random
import numpy as np
import pandas as pd
import torch.nn as nn
from tqdm import tqdm
from pathlib import Path
import torch.optim as optim
from google.colab import drive
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.amp import GradScaler, autocast
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision.transforms import ToTensor, ToPILImage, Compose, ColorJitter

Pre Training Initialization

In [None]:
%pip install pytorch-msssim onnxruntime-gpu
from pytorch_msssim import ssim

if not os.path.exists('/content/drive/MyDrive'):
    print("Mounting Google Drive...")
    drive.mount('/content/drive')
else:
    print("Google Drive is already mounted.")

if not os.path.exists('/content/datasets'):
    print("Copying Datasets...")
    !sudo cp -rf /content/drive/MyDrive/Thesis/datasets /content/datasets
else:
    print("datasets already exist.")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

if DEVICE == "cuda":
    # %pip install tensorrt onnxruntime-gpu
    import onnxruntime as ort
    # We have a T4 (16GB VRAM), let's use it
    BATCH_SIZE = 128
    print("CUDA (T4 GPU) detected. Using BATCH_SIZE=128.")
else:
    # %pip install onnxruntime
    import onnxruntime as ort
    # We are on CPU-only (13GB shared RAM), be more conservative
    BATCH_SIZE = 64
    print("No GPU detected. Using CPU-only with BATCH_SIZE=64.")

Training Config & Setting

In [None]:
UPSCALE_FACTOR = 2
NUM_WORKER = os.cpu_count()
EPOCHS = 150
PATCH_SIZE = 256
LEARNING_RATE = 1e-4
LOSS_ALPHA = 0.84
BASE_DIR = Path("/content/drive/MyDrive/Thesis")
DATA_DIR = Path("/content/datasets")
OUTPUTS_DIR = BASE_DIR / "outputs"
CHECKPOINTS_DIR = BASE_DIR / "checkpoints"
CHECKPOINT_FILE = CHECKPOINTS_DIR / f"best_espcn_x{UPSCALE_FACTOR}.pth"

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        print("Enabling cuDNN benchmark mode for GPU.")
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

set_seed()

Environment Preparation

In [None]:
def prepare_environment_and_datasets(patch_size):
    CHECKPOINTS_DIR.mkdir(exist_ok=True)
    OUTPUTS_DIR.mkdir(exist_ok=True)

    # Define a path for the cache file
    CACHE_FILE_TRAIN = OUTPUTS_DIR / "valid_train_paths.txt"

    # --- Handle Validation Paths (always scanned, it's fast) ---
    valid_div2k_dir = DATA_DIR / "DIV2K_valid_HR"
    validation_image_paths = glob.glob(str(valid_div2k_dir / '*.*'))

    # --- Handle Training Paths (Check cache first) ---
    valid_train_paths = []
    if CACHE_FILE_TRAIN.exists():
        print(f"Loading cached training image paths from {CACHE_FILE_TRAIN}...")
        try:
            with open(CACHE_FILE_TRAIN, 'r') as f:
                valid_train_paths = [line.strip() for line in f if line.strip()]
            print(f"Loaded {len(valid_train_paths)} paths from cache.")
        except Exception as e:
            print(f"Warning: Could not read cache file {e}. Re-building...")
            valid_train_paths = [] # Ensure list is empty to trigger re-build

    if not valid_train_paths:
        print(f"Cache file not found or was invalid. Building new image list...")
        train_div2k_dir = DATA_DIR / "DIV2K_train_HR"
        train_flickr2k_dir = DATA_DIR / "Flickr2K_HR"
        personal_dir = DATA_DIR / "Personal_HR"

        train_image_paths = glob.glob(str(train_div2k_dir / '*.*')) + \
                            glob.glob(str(train_flickr2k_dir / '*.*'))

        if personal_dir.exists():
            print("Personal dataset found. Adding to training set.")
            train_image_paths += glob.glob(str(personal_dir / '*.*'))
        else:
            print("Personal dataset not found. Proceeding without it.")

        print(f"Found {len(train_image_paths)} potential training images.")
        print(f"Verifying image dimensions against patch size {patch_size}...")

        def is_image_large_enough(image_path, min_size):
            try:
                img = cv2.imread(str(image_path))
                if img is None:
                    print(f"Warning: Failed to load {image_path}. Skipping.")
                    return False
                h, w = img.shape[:2]
                return h >= min_size and w >= min_size
            except Exception as e:
                print(f"Warning: Error reading {image_path}: {e}. Skipping.")
                return False

        valid_train_paths = [
            p for p in tqdm(train_image_paths, desc="Filtering train images")
            if is_image_large_enough(p, patch_size)
        ]

        print(f"Filtered training set: {len(valid_train_paths)} of {len(train_image_paths)} images remain.")

        # Write the new valid paths to the cache file
        try:
            print(f"Saving new cache file to {CACHE_FILE_TRAIN}...")
            with open(CACHE_FILE_TRAIN, 'w') as f:
                for path in valid_train_paths:
                    f.write(f"{path}\n")
        except Exception as e:
            print(f"Warning: Could not write cache file: {e}")

    if not valid_train_paths or not validation_image_paths:
        print("\n--- VERIFICATION FAILED: No images found. Check your './datasets/' folder structure. ---")

    return valid_train_paths, validation_image_paths

Loss Function

In [None]:
class CombinedLoss(nn.Module):
    def __init__(self, alpha=0.7):
        super(CombinedLoss, self).__init__()
        self.alpha = alpha
        self.mse_loss = nn.MSELoss()
        try:
            self.ssim_loss_fn = ssim
        except ImportError:
            raise ImportError("Please install pytorch-msssim: pip install pytorch-msssim")

    def forward(self, output, target):
        # MSE Loss
        loss_mse = self.mse_loss(output, target)

        # SSIM Loss (1 - SSIM)
        ssim_score = self.ssim_loss_fn(output, target, data_range=1.0, size_average=True)
        loss_ssim = 1 - ssim_score

        # Combined Loss
        total_loss = self.alpha * loss_mse + (1 - self.alpha) * loss_ssim

        return total_loss, loss_mse, ssim_score

def psnr(mse):
    return 10 * math.log10(1 / mse) if mse > 0 else float('inf')

Create Patches From HR Images

In [None]:
class TrainingSuperResolutionDataset(Dataset):
    def __init__(self, image_filenames, crop_size, upscale_factor):
        super(TrainingSuperResolutionDataset, self).__init__()
        self.image_filenames = image_filenames
        self.crop_size = crop_size - (crop_size % upscale_factor)
        self.upscale_factor = upscale_factor
        self.to_tensor = ToTensor()
        self.scale_factor = 1 / upscale_factor
        self.transform = Compose([
            ToPILImage(),
            ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            ToTensor()
        ])

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, index):
        try:
            hr_image = cv2.imread(self.image_filenames[index])
            if hr_image is None:
                raise IOError(f"cv2.imread failed to load image: {self.image_filenames[index]}")
            hr_image = cv2.cvtColor(hr_image, cv2.COLOR_BGR2RGB)
        except Exception as e:
            print(f"Error loading image {self.image_filenames[index]}: {e}. Skipping.")
            return self.__getitem__((index + 1) % len(self))

        h, w = hr_image.shape[:2]

        i = random.randint(0, h - self.crop_size)
        j = random.randint(0, w - self.crop_size)
        hr_patch = hr_image[i:i+self.crop_size, j:j+self.crop_size, :]

        if torch.rand(1) > 0.5:
            hr_patch = cv2.flip(hr_patch, 1)

        if torch.rand(1) > 0.5:
            angle = random.choice([cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_180, cv2.ROTATE_90_COUNTERCLOCKWISE])
            hr_patch = cv2.rotate(hr_patch, angle)

        # --- START: Recommended Change ---

        # 1. Create the CLEAN HR (target) tensor first
        hr_patch_tensor = self.transform(hr_patch)

        # 2. Create the CLEAN LR tensor from the clean HR tensor
        lr_patch_tensor = F.interpolate(
            hr_patch_tensor.unsqueeze(0),
            scale_factor=self.scale_factor,
            mode='bicubic',
            align_corners=False,
            antialias=True
        ).squeeze(0)

        # --- END: Recommended Change ---

        return lr_patch_tensor, hr_patch_tensor # (Augmented_LR, Clean_HR)

class ValidationSuperResolutionDataset(Dataset):
    def __init__(self, image_filenames, upscale_factor):
        super(ValidationSuperResolutionDataset, self).__init__()
        self.image_filenames = image_filenames
        self.upscale_factor = upscale_factor
        self.to_tensor = ToTensor()
        self.scale_factor = 1 / upscale_factor

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, index):
        try:
            hr_image = cv2.imread(self.image_filenames[index])
            if hr_image is None:
                raise IOError(f"cv2.imread failed to load image: {self.image_filenames[index]}")
            hr_image = cv2.cvtColor(hr_image, cv2.COLOR_BGR2RGB)
        except Exception as e:
            print(f"Error loading image {self.image_filenames[index]}: {e}. Skipping.")
            return self.__getitem__((index + 1) % len(self))

        h, w = hr_image.shape[:2]
        w_new, h_new = w - (w % self.upscale_factor), h - (h % self.upscale_factor)
        hr_image = hr_image[:h_new, :w_new, :]

        hr_tensor = self.to_tensor(hr_image)

        # This is the "interpolasi bikubik" method from your proposal
        lr_tensor = F.interpolate(
            hr_tensor.unsqueeze(0),
            scale_factor=self.scale_factor,
            mode='bicubic',
            align_corners=False,
            antialias=True
        ).squeeze(0)

        return lr_tensor, hr_tensor

Model Defintion

In [None]:
class ESPCN(nn.Module):
    def __init__(self, upscale_factor):
        super(ESPCN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 3 * (upscale_factor ** 2), kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
        self.relu = nn.ReLU() 

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.pixel_shuffle(self.conv3(x))
        return x

Execute

In [None]:
if __name__ == "__main__":

    print(f"Using device: {DEVICE}")

    train_paths, val_paths = prepare_environment_and_datasets(PATCH_SIZE)

    if not train_paths or not val_paths:
        raise RuntimeError("Dataset paths not found. Halting execution.")

    train_dataset = TrainingSuperResolutionDataset(train_paths, crop_size=PATCH_SIZE, upscale_factor=UPSCALE_FACTOR)
    val_dataset = ValidationSuperResolutionDataset(val_paths, upscale_factor=UPSCALE_FACTOR)

    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=min(4, NUM_WORKER),
        pin_memory=(DEVICE == "cuda"),
        persistent_workers=True,
        worker_init_fn=seed_worker
    )
    val_loader = DataLoader(
        val_dataset, batch_size=1, shuffle=False,
        num_workers=min(4, NUM_WORKER),
        pin_memory=(DEVICE == "cuda"),
        persistent_workers=True
    )

    print(f"Found {len(train_dataset)} training images and {len(val_dataset)} validation images.")

    # Model initialization and multi-GPU support
    model = ESPCN(upscale_factor=UPSCALE_FACTOR)
    if DEVICE == 'cuda' and torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs with DataParallel.")
        model = torch.nn.DataParallel(model)
    model = model.to(DEVICE)

    criterion = CombinedLoss(alpha=LOSS_ALPHA).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-6)
    scaler = GradScaler(enabled=(DEVICE == "cuda"))

    # --- Checkpoint Loading Logic ---
    start_epoch = 0
    best_psnr = 0.0
    log_file_path = OUTPUTS_DIR / 'training_log.csv'
    training_log = []

    if CHECKPOINT_FILE.exists():
        print(f"Resuming training from checkpoint: {CHECKPOINT_FILE}")
        checkpoint = torch.load(CHECKPOINT_FILE, map_location=DEVICE)
        original_state_dict = checkpoint['model_state_dict']
        new_state_dict = {}

        # Handle '_orig_mod.' (torch.compile) prefix gracefully
        needs_prefix_stripping = any(key.startswith('_orig_mod.') for key in original_state_dict.keys())
        if needs_prefix_stripping:
            print("Detected '_orig_mod.' prefix from torch.compile(). Stripping prefix...")
            for key, value in original_state_dict.items():
                if key.startswith('_orig_mod.'):
                    new_key = key[len('_orig_mod.'):]  # Remove the prefix
                    new_state_dict[new_key] = value
                else:
                    new_state_dict[key] = value
        else:
            print("No '_orig_mod.' prefix detected. Loading state dict as is.")
            new_state_dict = original_state_dict

        # Handle legacy 'module.' DataParallel prefix (optional, safe for all cases)
        needs_module_prefix_stripping = any(key.startswith('module.') for key in new_state_dict.keys())
        if needs_module_prefix_stripping:
            print("Detected 'module.' prefix from DataParallel. Stripping prefix...")
            stripped_dict = {}
            for key, value in new_state_dict.items():
                if key.startswith('module.'):
                    new_key = key[len('module.'):]
                    stripped_dict[new_key] = value
                else:
                    stripped_dict[key] = value
            new_state_dict = stripped_dict

        model.load_state_dict(new_state_dict)
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_psnr = checkpoint.get('best_psnr', 0.0)
        print(f"Resumed from epoch {start_epoch}, with best PSNR of {best_psnr:.4f} dB.")

        # Load existing log if it exists
        if log_file_path.exists():
            try:
                print(f"Loading existing training log from {log_file_path}")
                log_df = pd.read_csv(log_file_path)
                log_df = log_df[log_df['epoch'] < start_epoch]
                training_log = log_df.to_dict('records')
                print(f"Loaded {len(training_log)} previous log entries.")
            except Exception as e:
                print(f"Warning: Could not load or parse log file: {e}. Starting with an empty log.")
                training_log = []
        else:
            print("No existing log file found. Starting a new log.")
            training_log = []
    else:
        print("No checkpoint found. Starting training from scratch.")
        training_log = []
        if log_file_path.exists():
            print(f"Deleting old log file: {log_file_path}")
            log_file_path.unlink()

    # --- Now torch.compile (AFTER checkpoint loading) ---
    try:
        print(f"Compiling model with torch.compile for {DEVICE}...")
        model = torch.compile(model)
    except Exception as e:
        print(f"torch.compile() for {DEVICE} failed: {e}. Running in eager mode.")

    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"\n--- Training ESPCN Model ({num_params:,} parameters) ---")

    start_time = time.time()
    for epoch in range(start_epoch, EPOCHS):
        model.train()

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
        for lr_images, hr_images in progress_bar:
            lr_images = lr_images.to(DEVICE, non_blocking=(DEVICE == "cuda"))
            hr_images = hr_images.to(DEVICE, non_blocking=(DEVICE == "cuda"))

            optimizer.zero_grad()

            if DEVICE == "cuda":
                with autocast(device_type=DEVICE):
                    outputs = model(lr_images)
                total_loss, mse, ssim_score = criterion(outputs.float(), hr_images.float())
                scaler.scale(total_loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()
            else:
                outputs = model(lr_images)
                total_loss, mse, ssim_score = criterion(outputs.float(), hr_images.float())
                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

            progress_bar.set_postfix(
                Loss=f"{total_loss.item():.4f}",
                MSE=f"{mse.item():.4f}",
                SSIM=f"{ssim_score.item():.4f}"
            )

        # validation
        model.eval()
        val_psnr, val_ssim = 0.0, 0.0
        with torch.no_grad():
            for i, (lr_images, hr_images) in enumerate(val_loader):
                lr_images = lr_images.to(DEVICE)
                hr_images = hr_images.to(DEVICE)
                output = model(lr_images)
                mse_val = nn.functional.mse_loss(output, hr_images)
                val_psnr += psnr(mse_val.item())
                val_ssim += ssim(output, hr_images, data_range=1.0, size_average=True).item()

                if i == 0 and (epoch + 1) % 5 == 0:
                    sr_pil = ToPILImage()(output.squeeze(0).cpu())
                    output_path = OUTPUTS_DIR / f'val_epoch_{epoch+1}.png'
                    sr_pil.save(output_path)
                    print(f"\nSaved validation sample to {output_path}")

        avg_val_psnr = val_psnr / len(val_loader)
        avg_val_ssim = val_ssim / len(val_loader)

        print(f"\nEpoch {epoch+1} Summary:")
        print(f"  - Validation PSNR: {avg_val_psnr:.4f} dB | Validation SSIM: {avg_val_ssim:.4f}")
        print(f"  - Current Learning Rate: {scheduler.get_last_lr()[0]:.6f}")

        training_log.append({
            'epoch': epoch + 1,
            'val_psnr': avg_val_psnr,
            'val_ssim': avg_val_ssim,
            'learning_rate': scheduler.get_last_lr()[0]
        })

        log_df = pd.DataFrame(training_log)
        log_df.to_csv(log_file_path, index=False)
        scheduler.step()

        if avg_val_psnr > best_psnr:
            best_psnr = avg_val_psnr
            print(f"  - New best model found! PSNR: {best_psnr:.4f} dB. Saving checkpoint...")

            # --- Check if model was compiled and strip prefix BEFORE saving ---
            model_state = model.state_dict()
            needs_stripping = any(key.startswith('_orig_mod.') for key in model_state.keys())

            if needs_stripping:
                print("  - (Saving) Stripping '_orig_mod.' prefix from compiled model...")
                final_model_state = {}
                for key, value in model_state.items():
                    # Remove the '_orig_mod.' prefix
                    new_key = key[len('_orig_mod.'):] if key.startswith('_orig_mod.') else key
                    final_model_state[new_key] = value
            else:
                final_model_state = model_state
            # --- End of new block ---

            torch.save({
                'epoch': epoch,
                'model_state_dict': final_model_state, # <-- Save the clean state
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_psnr': best_psnr,
            }, CHECKPOINT_FILE)

    total_time = time.time() - start_time
    print(f"\nTraining complete in {total_time/3600:.2f} hours.")
    print(f"Best model saved to '{CHECKPOINT_FILE}' with a PSNR of {best_psnr:.4f} dB.")
    print(f"\nTraining log saved successfully to {log_file_path}")

    print("Generating training graphs...")

    # Plot PSNR vs. Epoch
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(log_df['epoch'], log_df['val_psnr'], marker='o', color='b')
    plt.title('Validation PSNR vs. Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('PSNR (dB)')
    plt.grid(True)
    plt.tight_layout()

    # Plot SSIM vs. Epoch
    plt.subplot(1, 2, 2)
    plt.plot(log_df['epoch'], log_df['val_ssim'], marker='o', color='g')
    plt.title('Validation SSIM vs. Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('SSIM')
    plt.grid(True)
    plt.tight_layout()
    plt.show()


Convert To FP16

In [None]:
print("--- Applying FP16 Quantization ---")

quant_model = ESPCN(upscale_factor=UPSCALE_FACTOR).to(DEVICE)
checkpoint_path = CHECKPOINTS_DIR / f"best_espcn_x{UPSCALE_FACTOR}.pth"
checkpoint = torch.load(CHECKPOINT_FILE, map_location=DEVICE)

# <<<--- ADD THIS BLOCK --- >>>
original_state_dict = checkpoint['model_state_dict']
new_state_dict = {}
needs_prefix_stripping = any(key.startswith('_orig_mod.') for key in original_state_dict.keys())
if needs_prefix_stripping:
    print("Detected '_orig_mod.' prefix from torch.compile(). Stripping prefix for FP16 conversion...")
    for key, value in original_state_dict.items():
        if key.startswith('_orig_mod.'):
            new_key = key[len('_orig_mod.'):] # Remove the prefix
            new_state_dict[new_key] = value
        else:
            new_state_dict[key] = value # Keep keys without the prefix as is
else:
     print("No '_orig_mod.' prefix detected. Loading state dict as is for FP16 conversion.")
     new_state_dict = original_state_dict # Use the original if no prefix found

quant_model.load_state_dict(new_state_dict) # <<<--- LOAD THE STRIPPED DICT
# <<<--- END OF ADDED BLOCK --- >>>

quant_model.eval()
# Note: .half() is usually for GPU. On CPU, it might be slow or unsupported.
# Consider keeping it FP32 for CPU-only export unless you have specific needs.
# If you keep .half(), ensure subsequent ONNX export/benchmark handles FP16.
quant_model.half()

print("Model converted to FP16.")

quantized_model_path = CHECKPOINTS_DIR / f"best_espcn_x{UPSCALE_FACTOR}_fp16.pth"
# Saving the state_dict directly is correct here
torch.save(quant_model.state_dict(), quantized_model_path)

print(f"FP16 quantized model saved to: {quantized_model_path}")

Export To ONNX

In [None]:
print("--- Exporting to ONNX ---")

EXPORT_FP16 = True

if EXPORT_FP16:
    print("Exporting FP16 model...")
    onnx_model = ESPCN(upscale_factor=UPSCALE_FACTOR)
    checkpoint_path_fp16 = CHECKPOINTS_DIR / f"best_espcn_x{UPSCALE_FACTOR}_fp16.pth"
    onnx_model.load_state_dict(torch.load(checkpoint_path_fp16))
    onnx_model.eval().half().to(DEVICE)
    dummy_input = torch.randn(1, 3, 720, 1280, device=DEVICE).half()
    onnx_model_path = OUTPUTS_DIR / f"espcn_x{UPSCALE_FACTOR}_fp16.onnx"
else:
    print("Exporting FP32 model...")
    onnx_model = ESPCN(upscale_factor=UPSCALE_FACTOR).to(DEVICE)
    checkpoint_path = CHECKPOINTS_DIR / f"best_espcn_x{UPSCALE_FACTOR}.pth"
    checkpoint = torch.load(CHECKPOINT_FILE, map_location=DEVICE)

    # <<< --- START: ADD THIS FIX --- >>>
    original_state_dict = checkpoint['model_state_dict']
    new_state_dict = {}
    needs_prefix_stripping = any(key.startswith('_orig_mod.') for key in original_state_dict.keys())

    if needs_prefix_stripping:
        print("Detected '_orig_mod.' prefix. Stripping for ONNX export...")
        for key, value in original_state_dict.items():
            if key.startswith('_orig_mod.'):
                new_key = key[len('_orig_mod.'):] # Remove the prefix
                new_state_dict[new_key] = value
            else:
                new_state_dict[key] = value
    else:
         print("No '_orig_mod.' prefix detected. Loading state dict as is for ONNX export.")
         new_state_dict = original_state_dict

    onnx_model.load_state_dict(new_state_dict) # <<< --- LOAD THE STRIPPED DICT
    # <<< --- END: ADD THIS FIX --- >>>

    onnx_model.eval()
    dummy_input = torch.randn(1, 3, 720, 1280, device=DEVICE)
    onnx_model_path = OUTPUTS_DIR / f"espcn_x{UPSCALE_FACTOR}.onnx"

torch.onnx.export(
    onnx_model,
    dummy_input,
    str(onnx_model_path),
    export_params=True,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output'],
    dynamo=True,
    dynamic_axes={
        'input': {2: 'height', 3: 'width'},
        'output': {2: 'height_out', 3: 'width_out'}
    }
)

print(f"Model successfully exported to ONNX: {onnx_model_path}")

Benchmark

In [None]:
print("\n--- Running Inference Benchmark ---")

if EXPORT_FP16:
    test_image_np = np.random.randn(1, 3, 720, 1280).astype(np.float16)
else:
    test_image_np = np.random.randn(1, 3, 720, 1280).astype(np.float32)
test_image_torch = torch.from_numpy(test_image_np).to(DEVICE)

if EXPORT_FP16:
    test_image_torch = test_image_torch.half()

if DEVICE == "cuda":
    torch.cuda.synchronize()

with torch.no_grad():
    _ = onnx_model(test_image_torch)
    if DEVICE == "cuda":
        torch.cuda.synchronize()
    start_time = time.time()
    for _ in range(100):
        _ = onnx_model(test_image_torch)
    if DEVICE == "cuda":
        torch.cuda.synchronize()
    pytorch_time = (time.time() - start_time) / 100

print(f"PyTorch Inference Time: {pytorch_time * 1000:.4f} ms per image")

providers = [
    # ==> Tier 1: Highest-Performance, Hardware-Specific Providers (NVIDIA, AMD, Intel, Apple)
    'TensorrtExecutionProvider',      # NVIDIA's top-tier for speed
    'CUDAExecutionProvider',          # Standard NVIDIA GPU provider
    'MIGraphXExecutionProvider',      # AMD's high-performance graph compiler on Linux
    'ROCmExecutionProvider',          # Standard AMD GPU provider on Linux
    'OpenVINOExecutionProvider',      # Optimized for Intel GPUs and CPUs
    'CoreMLExecutionProvider',        # For Apple M-series chips (macOS, iOS)

    # ==> Tier 2: Specialized NPU/Edge/Mobile Providers
    'QNNExecutionProvider',           # Qualcomm AI Engine Direct (Snapdragon)
    'NNAPIExecutionProvider',         # Android Neural Networks API
    'CANNExecutionProvider',          # Huawei Ascend Chips
    'RockchipNpuExecutionProvider',   # Rockchip NPUs
    'VitisAIExecutionProvider',       # Xilinx FPGAs
    'ArmNNExecutionProvider',         # Arm NN SDK
    'ACLExecutionProvider',           # Arm Compute Library

    # ==> Tier 3: General-Purpose GPU Provider (Windows)
    'DmlExecutionProvider',           # DirectX 12 for NVIDIA, AMD, Intel GPUs on Windows

    # ==> Tier 4: Optimized CPU Providers
    'DnnlExecutionProvider',          # Intel's high-performance DNNL for CPUs
    'XnnpackExecutionProvider',       # Optimized for ARM CPUs

    # ==> Tier 5: Advanced Compiler & Cloud Providers
    'TvmExecutionProvider',           # Apache TVM
    'AzureExecutionProvider',         # For running on Azure

    # ==> Tier 6: Default Fallback
    'CPUExecutionProvider',           # The universal fallback that runs on any CPU
]

session = ort.InferenceSession(str(onnx_model_path), providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
ort_inputs = {session.get_inputs()[0].name: test_image_np}
_ = session.run(None, ort_inputs)
start_time = time.time()
for _ in range(100):
    _ = session.run(None, ort_inputs)
onnx_time = (time.time() - start_time) / 100

print(f"ONNX Runtime Inference Time: {onnx_time * 1000:.4f} ms per image")

if pytorch_time > 0:
    speed_increase = (pytorch_time - onnx_time) / pytorch_time * 100

    print(f"\nResult: ONNX Runtime is ~{speed_increase:.2f}% faster than PyTorch for this model.")