In [49]:
import os
import sys
import argparse
import time
import math
import random
import numpy as np
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast

# Add the current directory to path to import local modules
sys.path.append(os.getcwd())

from model.architecture import IMDN
from data.custom_dataset import ThermalDataset

In [50]:
# Configuration - set these parameters directly
SCALE = 2
DATASET_DIR = f'/home/kronbii/repos/thermal-super-resolution/datasets/flir_thermal_x{SCALE}'  # Auto-update based on scale
PRETRAINED_MODEL_DIR = f'/home/kronbii/repos/thermal-super-resolution/checkpoints/pretrained/IMDN_x{SCALE}.pth'

# Training parameters
EPOCHS = 60
BATCH_SIZE = 16
LR = 2e-5
WEIGHT_DECAY = 1e-4

# Training strategy
GRADUAL_UNFREEZE = True
FREEZE_EPOCHS = 10

# Loss function weights
L1_WEIGHT = 1.0
GRADIENT_WEIGHT = 0.08
THERMAL_WEIGHT = 0.03

# System settings
NUM_WORKERS = 6
DEVICE = 'cuda'
MIXED_PRECISION = True

# Memory optimization settings
PATCH_SIZE = 192
GRADIENT_ACCUMULATION_STEPS = 3

# Output settings
CHECKPOINT_DIR = f'checkpoints/_x{SCALE}'  # Auto-update based on scale
LOG_INTERVAL = 50
VAL_INTERVAL = 5

# Other settings
SEED = 42

print(f"🎯 Configuration for {SCALE}x upscaling:")
print(f"   📁 Dataset: {DATASET_DIR}")
print(f"   🏋️ Pretrained: {PRETRAINED_MODEL_DIR}")
print(f"   💾 Checkpoints: {CHECKPOINT_DIR}")
print(f"   🧠 Memory optimized: Batch={BATCH_SIZE}, Patch={PATCH_SIZE}, Accumulation={GRADIENT_ACCUMULATION_STEPS}")
print()

🎯 Configuration for 2x upscaling:
   📁 Dataset: /home/kronbii/repos/thermal-super-resolution/datasets/flir_thermal_x2
   🏋️ Pretrained: /home/kronbii/repos/thermal-super-resolution/checkpoints/pretrained/IMDN_x2.pth
   💾 Checkpoints: checkpoints/_x2
   🧠 Memory optimized: Batch=16, Patch=192, Accumulation=3



In [51]:
# Replace your current set_random_seed function with this optimized version
def set_random_seed(seed=42):
    """Set random seed with performance optimizations"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # Speed optimizations
    torch.backends.cudnn.deterministic = False  # Much faster
    torch.backends.cudnn.benchmark = True       # Auto-optimize kernels
    torch.backends.cuda.matmul.allow_tf32 = True    # RTX 3070 acceleration
    torch.backends.cudnn.allow_tf32 = True          # RTX 3070 acceleration

In [52]:
# Optimize your ThermalLoss class by pre-registering kernels
class ThermalLoss(nn.Module):
    def __init__(self, l1_weight=1.0, gradient_weight=0.1, thermal_weight=0.05):
        super(ThermalLoss, self).__init__()
        self.l1_weight = l1_weight
        self.gradient_weight = gradient_weight
        self.thermal_weight = thermal_weight
        self.l1_loss = nn.L1Loss()
        
        # Pre-register Sobel kernels as buffers (major speed boost)
        sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
        sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3)
        self.register_buffer('sobel_x', sobel_x)
        self.register_buffer('sobel_y', sobel_y)
        
        # Pre-register averaging kernel
        kernel = torch.ones(1, 1, 3, 3) / 9.0
        self.register_buffer('avg_kernel', kernel)
    
    def gradient_loss(self, pred, target):
        """Calculate gradient loss using pre-registered kernels"""
        pred_grad_x = F.conv2d(pred, self.sobel_x, padding=1)
        pred_grad_y = F.conv2d(pred, self.sobel_y, padding=1)
        target_grad_x = F.conv2d(target, self.sobel_x, padding=1)
        target_grad_y = F.conv2d(target, self.sobel_y, padding=1)
        
        grad_loss = self.l1_loss(pred_grad_x, target_grad_x) + self.l1_loss(pred_grad_y, target_grad_y)
        return grad_loss
    
    def thermal_contrast_loss(self, pred, target):
        """Loss using pre-registered averaging kernel"""
        pred_mean = F.conv2d(pred, self.avg_kernel, padding=1)
        target_mean = F.conv2d(target, self.avg_kernel, padding=1)
        
        pred_var = F.conv2d((pred - pred_mean)**2, self.avg_kernel, padding=1)
        target_var = F.conv2d((target - target_mean)**2, self.avg_kernel, padding=1)
        
        contrast_loss = self.l1_loss(pred_var, target_var)
        return contrast_loss
    
    def forward(self, pred, target):
        l1 = self.l1_loss(pred, target)
        grad = self.gradient_loss(pred, target)
        thermal = self.thermal_contrast_loss(pred, target)
        
        total_loss = (self.l1_weight * l1 + 
                     self.gradient_weight * grad + 
                     self.thermal_weight * thermal)
        
        return total_loss, {'l1': l1.item(), 'gradient': grad.item(), 'thermal': thermal.item()}

In [53]:
def freeze_layers(model, freeze_backbone=True):
  """Freeze/unfreeze model layers for gradual training"""
  for name, param in model.named_parameters():
    if freeze_backbone and not any(layer in name.lower() for layer in ['upsampler', 'lr_conv', 'fea_conv']):
      param.requires_grad = False
    else:
      param.requires_grad = True

In [54]:
def calculate_psnr(img1, img2, max_val=1.0):
  """Calculate PSNR between two images"""
  mse = torch.mean((img1 - img2)**2)
  if mse == 0:
    return float('inf')
  return 20 * torch.log10(max_val / torch.sqrt(mse))

In [55]:
def validate_model(model, val_loader, criterion, device, max_batches=50):
  """Validate the model on validation set"""
  model.eval()
  total_loss = 0
  total_psnr = 0
  num_batches = 0
  
  with torch.no_grad():
    for batch_idx, (lr, hr) in enumerate(val_loader):
      if batch_idx >= max_batches:
        break
        
      lr, hr = lr.to(device), hr.to(device)
      
      # Forward pass
      with autocast():
        sr = model(lr)
        loss, loss_components = criterion(sr, hr)
      
      # Calculate PSNR
      psnr = calculate_psnr(sr, hr)
      
      total_loss += loss.item()
      total_psnr += psnr.item()
      num_batches += 1
  
  avg_loss = total_loss / num_batches
  avg_psnr = total_psnr / num_batches
  
  return avg_loss, avg_psnr

In [56]:
def save_checkpoint(model, optimizer, epoch, loss, psnr, checkpoint_dir, is_best=False):
  """Save model checkpoint"""
  checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
    'psnr': psnr,
  }
  
  # Save regular checkpoint
  checkpoint_path = os.path.join(checkpoint_dir, f'thermal_epoch_{epoch}.pth')
  torch.save(checkpoint, checkpoint_path)
  
  # Save best model
  if is_best:
    best_path = os.path.join(checkpoint_dir, 'thermal_best.pth')
    torch.save(checkpoint, best_path)
    print(f"💫 New best model saved! PSNR: {psnr:.2f}")

In [57]:
def print_model_info(model, sample_input):
  """Print model information"""
  model.eval()
  with torch.no_grad():
    output = model(sample_input)
  
  total_params = sum(p.numel() for p in model.parameters())
  trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  
  print(f"📊 Model Information:")
  print(f"   • Total parameters: {total_params:,}")
  print(f"   • Trainable parameters: {trainable_params:,}")
  print(f"   • Input shape: {sample_input.shape}")
  print(f"   • Output shape: {output.shape}")
  print()

In [58]:
# Debug dataset structure before running training
print("🔍 Debugging dataset paths...")
print(f"DATASET_DIR: {DATASET_DIR}")
print()

# Check expected paths
expected_paths = [
    os.path.join(DATASET_DIR, 'train', 'HR'),
    os.path.join(DATASET_DIR, 'train', f'LR_bicubic', f'X{SCALE}'),
    os.path.join(DATASET_DIR, 'val', 'HR'),
    os.path.join(DATASET_DIR, 'val', f'LR_bicubic', f'X{SCALE}')
]

print("📁 Expected directory structure:")
for path in expected_paths:
    exists = os.path.exists(path)
    if exists:
        file_count = len([f for f in os.listdir(path) if f.endswith('.png')])
        print(f"   ✅ {path} - {file_count} files")
    else:
        print(f"   ❌ {path} - NOT FOUND")

print()

# Show actual directory structure
if os.path.exists(DATASET_DIR):
    print("📂 Actual directory structure:")
    for root, dirs, files in os.walk(DATASET_DIR):
        level = root.replace(DATASET_DIR, '').count(os.sep)
        indent = '   ' * level
        print(f"{indent}{os.path.basename(root)}/")
        
        # Show image files count
        image_files = [f for f in files if f.endswith(('.png', '.jpg'))]
        if image_files:
            subindent = '   ' * (level + 1)
            print(f"{subindent}({len(image_files)} image files)")
        
        if level >= 3:  # Don't go too deep
            break
else:
    print(f"❌ Dataset directory doesn't exist: {DATASET_DIR}")

🔍 Debugging dataset paths...
DATASET_DIR: /home/kronbii/repos/thermal-super-resolution/datasets/flir_thermal_x2

📁 Expected directory structure:
   ✅ /home/kronbii/repos/thermal-super-resolution/datasets/flir_thermal_x2/train/HR - 10697 files
   ✅ /home/kronbii/repos/thermal-super-resolution/datasets/flir_thermal_x2/train/LR_bicubic/X2 - 10697 files
   ✅ /home/kronbii/repos/thermal-super-resolution/datasets/flir_thermal_x2/val/HR - 1189 files
   ✅ /home/kronbii/repos/thermal-super-resolution/datasets/flir_thermal_x2/val/LR_bicubic/X2 - 1189 files

📂 Actual directory structure:
flir_thermal_x2/
   val/
      HR/
         (1189 image files)
      LR_bicubic/
         X2/
            (1189 image files)


In [59]:
# Debug ThermalDataset loading issue
print("🔧 Debugging ThermalDataset loading...")

# Test dataset initialization step by step
train_hr_dir = os.path.join(DATASET_DIR, 'train', 'HR')
train_lr_dir = os.path.join(DATASET_DIR, 'train', f'LR_bicubic', f'X{SCALE}')

print(f"Train HR dir: {train_hr_dir}")
print(f"Train LR dir: {train_lr_dir}")
print(f"Directories exist: HR={os.path.exists(train_hr_dir)}, LR={os.path.exists(train_lr_dir)}")

# Test TrainOpt class
class TestTrainOpt:
    def __init__(self):
        self.scale = SCALE
        self.phase = 'train'
        self.hr_dir = train_hr_dir
        self.lr_dir = train_lr_dir
        self.ext = '.png'
        self.augment = True
        self.thermal_augment = True
        self.patch_size = PATCH_SIZE
        self.n_colors = 1  
        self.rgb_range = 1
        self.batch_size = BATCH_SIZE
        self.test_every = 1000

print("🧪 Testing dataset creation...")
try:
    test_opt = TestTrainOpt()
    print(f"   ✅ TestTrainOpt created successfully")
    print(f"   • hr_dir: {test_opt.hr_dir}")
    print(f"   • lr_dir: {test_opt.lr_dir}")
    print(f"   • scale: {test_opt.scale}")
    print(f"   • phase: {test_opt.phase}")
    
    # Try creating dataset
    test_dataset = ThermalDataset(test_opt)
    print(f"   ✅ ThermalDataset created successfully")
    print(f"   • Dataset length: {len(test_dataset)}")
    
    # Check if dataset has images_hr attribute
    if hasattr(test_dataset, 'images_hr'):
        print(f"   • images_hr length: {len(test_dataset.images_hr)}")
        if test_dataset.images_hr:
            print(f"   • Sample HR files: {test_dataset.images_hr[:3]}")
    
    if hasattr(test_dataset, 'images_lr'):
        print(f"   • images_lr length: {len(test_dataset.images_lr) if test_dataset.images_lr else 'None'}")
        if test_dataset.images_lr:
            print(f"   • Sample LR files: {test_dataset.images_lr[:3]}")
    
except Exception as e:
    print(f"   ❌ Error creating ThermalDataset: {e}")
    import traceback
    traceback.print_exc()

🔧 Debugging ThermalDataset loading...
Train HR dir: /home/kronbii/repos/thermal-super-resolution/datasets/flir_thermal_x2/train/HR
Train LR dir: /home/kronbii/repos/thermal-super-resolution/datasets/flir_thermal_x2/train/LR_bicubic/X2
Directories exist: HR=True, LR=True
🧪 Testing dataset creation...
   ✅ TestTrainOpt created successfully
   • hr_dir: /home/kronbii/repos/thermal-super-resolution/datasets/flir_thermal_x2/train/HR
   • lr_dir: /home/kronbii/repos/thermal-super-resolution/datasets/flir_thermal_x2/train/LR_bicubic/X2
   • scale: 2
   • phase: train
Loaded 10697 thermal images for training
   ✅ ThermalDataset created successfully
   • Dataset length: 10697
   • images_hr length: 10697
   • Sample HR files: ['/home/kronbii/repos/thermal-super-resolution/datasets/flir_thermal_x2/train/HR/video-24ysbPEGoEKKDvRt6-frame-000000-4C4FHWxwNaMyohLZt.png', '/home/kronbii/repos/thermal-super-resolution/datasets/flir_thermal_x2/train/HR/video-24ysbPEGoEKKDvRt6-frame-000015-ceXK8kdaSPB6oj

In [None]:
# Set random seed
set_random_seed(SEED)

# Setup device
if DEVICE == 'auto':
	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
	device = torch.device(DEVICE)

print("=" * 60)
# Check if CUDA setup was successful
if device.type == 'cuda':
	print(f"[INFO] CUDA device selected: {torch.cuda.get_device_name(device)}")
	print(f"[INFO] CUDA memory available: {torch.cuda.get_device_properties(device).total_memory / 1e9:.1f} GB")
else:
	print("[WARNING] Using CPU - training will be significantly slower")

print("[INFO] STARTING OPTIMIZATION PROCESS")
print(f"[INFO] Scale factor: {SCALE}x")
print(f"[INFO] Dataset: {DATASET_DIR}")
print(f"[INFO] Pretrained model: {PRETRAINED_MODEL_DIR}")
print("=" * 60)

# Create checkpoint directory
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Setup datasets
print("[INFO] Loading thermal dataset...")

# Training dataset
train_hr_dir = os.path.join(DATASET_DIR, 'train', 'HR')
train_lr_dir = os.path.join(DATASET_DIR, 'train', f'LR_bicubic', f'X{SCALE}')

# Create options object for ThermalDataset with FIXED repeat calculation
class TrainOpt:
	def __init__(self):
		self.scale = SCALE
		self.phase = 'train'
		self.hr_dir = train_hr_dir
		self.lr_dir = train_lr_dir
		self.ext = '.png'
		self.augment = True
		self.thermal_augment = True
		self.patch_size = PATCH_SIZE
		self.n_colors = 1
		self.rgb_range = 1
		self.batch_size = BATCH_SIZE
		self.test_every = max(1000, len(os.listdir(train_hr_dir)))  # Fix repeat calculation

train_dataset = ThermalDataset(TrainOpt())

# Validation dataset  
val_hr_dir = os.path.join(DATASET_DIR, 'val', 'HR')
val_lr_dir = os.path.join(DATASET_DIR, 'val', f'LR_bicubic', f'X{SCALE}')

# Create options object for validation dataset
class ValOpt:
	def __init__(self):
		self.scale = SCALE
		self.phase = 'val'
		self.hr_dir = val_hr_dir
		self.lr_dir = val_lr_dir
		self.ext = '.png'
		self.augment = False
		self.thermal_augment = False
		self.patch_size = PATCH_SIZE
		self.n_colors = 1
		self.rgb_range = 1
		self.batch_size = BATCH_SIZE
		self.test_every = 1000

val_dataset = ThermalDataset(ValOpt())

print(f"   • Training samples: {len(train_dataset)}")
print(f"   • Validation samples: {len(val_dataset)}")

# Quick fix if training samples still 0
if len(train_dataset) == 0:
    print("[ERROR] Training dataset length is 0, fixing repeat calculation...")
    # Manually set repeat to 1 for training dataset
    train_dataset.repeat = 1
    print(f"   • Fixed training samples: {len(train_dataset)}")

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    persistent_workers=True,    # Keeps workers alive between epochs
    prefetch_factor=3,          # More aggressive prefetching
    drop_last=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=2
)
print("=" * 60)

# Setup model
print("[INFO] Setting up IMDN model...")
model = IMDN(upscale=SCALE, in_nc=1, out_nc=1)  # Single channel for thermal

# Load pretrained weights
if os.path.exists(PRETRAINED_MODEL_DIR):
	print(f"[INFO] Loading pretrained weights from {PRETRAINED_MODEL_DIR}")
	try:
		checkpoint = torch.load(PRETRAINED_MODEL_DIR, map_location='cpu', weights_only=True)
	except:
		# Fallback for older PyTorch versions
		checkpoint = torch.load(PRETRAINED_MODEL_DIR, map_location='cpu')

	# Extract state dict if it's wrapped in a checkpoint
	if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
		state_dict = checkpoint['model_state_dict']
	else:
		state_dict = checkpoint

	# Remove 'module.' prefix if present (from DataParallel training)
	if any(key.startswith('module.') for key in state_dict.keys()):
		state_dict = {key.replace('module.', ''): value for key, value in state_dict.items()}

	# Universal weight adaptation for any scale factor
	adapted_state_dict = {}
	target_upsampler_channels = SCALE * SCALE  # For thermal: 1 channel * scale^2
	
	for name, param in state_dict.items():
		if name == 'fea_conv.weight' and param.shape[1] == 3:  
			# RGB input layer -> thermal input layer (3->1 channel)
			adapted_param = param.mean(dim=1, keepdim=True)
			adapted_state_dict[name] = adapted_param
			print(f"   • Adapted {name}: {param.shape} -> {adapted_param.shape}")
		
		elif name.startswith('upsampler.') and 'weight' in name:
			# Handle upsampler weight - adapt from RGB to thermal
			source_channels = param.shape[0]  # RGB: 3 * source_scale^2
			
			if source_channels % 3 == 0:  # Confirm it's RGB (divisible by 3)
				source_scale_sq = source_channels // 3  # Get source scale^2
				
				if source_scale_sq == target_upsampler_channels:
					# Same scale: just average RGB channels to get thermal
					param_reshaped = param.view(3, target_upsampler_channels, param.shape[1], param.shape[2], param.shape[3])
					adapted_param = param_reshaped.mean(dim=0)
				else:
					# Different scales: need to adapt the scale
					if target_upsampler_channels <= source_scale_sq:
						# Target scale is smaller: downsample
						param_reshaped = param.view(3, source_scale_sq, param.shape[1], param.shape[2], param.shape[3])
						rgb_avg = param_reshaped.mean(dim=0)  # Average RGB channels first
						adapted_param = rgb_avg[:target_upsampler_channels]
					else:
						# Target scale is larger: upsample by repeating
						param_reshaped = param.view(3, source_scale_sq, param.shape[1], param.shape[2], param.shape[3])
						rgb_avg = param_reshaped.mean(dim=0)  # Average RGB channels first
						repeat_factor = target_upsampler_channels // source_scale_sq
						remainder = target_upsampler_channels % source_scale_sq
						adapted_param = rgb_avg.repeat(repeat_factor, 1, 1, 1)
						if remainder > 0:
							adapted_param = torch.cat([adapted_param, rgb_avg[:remainder]], dim=0)
				
				adapted_state_dict[name] = adapted_param
				print(f"   • Adapted {name}: {param.shape} -> {adapted_param.shape}")
			else:
				adapted_state_dict[name] = param
		
		elif name.startswith('upsampler.') and 'bias' in name:
			# Handle upsampler bias - adapt from RGB to thermal
			source_channels = param.shape[0]  # RGB: 3 * source_scale^2
			
			if source_channels % 3 == 0:  # Confirm it's RGB (divisible by 3)
				source_scale_sq = source_channels // 3  # Get source scale^2
				
				if source_scale_sq == target_upsampler_channels:
					# Same scale: just average RGB channels to get thermal
					param_reshaped = param.view(3, target_upsampler_channels)
					adapted_param = param_reshaped.mean(dim=0)
				else:
					# Different scales: need to adapt the scale
					if target_upsampler_channels <= source_scale_sq:
						# Target scale is smaller: downsample
						param_reshaped = param.view(3, source_scale_sq)
						rgb_avg = param_reshaped.mean(dim=0)  # Average RGB channels first
						adapted_param = rgb_avg[:target_upsampler_channels]
					else:
						# Target scale is larger: upsample by repeating
						param_reshaped = param.view(3, source_scale_sq)
						rgb_avg = param_reshaped.mean(dim=0)  # Average RGB channels first
						repeat_factor = target_upsampler_channels // source_scale_sq
						remainder = target_upsampler_channels % source_scale_sq
						adapted_param = rgb_avg.repeat(repeat_factor)
						if remainder > 0:
							adapted_param = torch.cat([adapted_param, rgb_avg[:remainder]], dim=0)
				
				adapted_state_dict[name] = adapted_param
				print(f"   • Adapted {name}: {param.shape} -> {adapted_param.shape}")
			else:
				adapted_state_dict[name] = param
		
		else:
			# All other layers: keep as is
			adapted_state_dict[name] = param

	# Load adapted weights
	missing_keys, unexpected_keys = model.load_state_dict(adapted_state_dict, strict=False)
	if missing_keys:
		print(f"[ERROR] Missing keys: {missing_keys}")
	if unexpected_keys:
		print(f"[ERROR] Unexpected keys: {unexpected_keys}")
	
	print(f"[INFO] Successfully adapted pretrained model from RGB to thermal with {SCALE}x scaling")
else:
	print(f"[ERROR] Pretrained model not found at {PRETRAINED_MODEL_DIR}")
	print("[ERROR] Training from scratch (this will take much longer)")
print("=" * 60)

model = model.to(device)
model = model.to(memory_format=torch.channels_last)  # Memory layout optimization

# Print model info
sample_input = torch.randn(1, 1, 64, 64).to(device)
print_model_info(model, sample_input)

# Setup loss function
criterion = ThermalLoss(
	l1_weight=L1_WEIGHT,
	gradient_weight=GRADIENT_WEIGHT,
	thermal_weight=THERMAL_WEIGHT
).to(device)

# Setup optimizer
optimizer = optim.AdamW(
	model.parameters(),
	lr=LR,
	weight_decay=WEIGHT_DECAY,
	betas=(0.9, 0.999)
)

# Setup learning rate scheduler
scheduler = optim.lr_scheduler.CosineAnnealingLR(
	optimizer,
	T_max=EPOCHS,
	eta_min=LR * 0.01
)

# Setup mixed precision training
scaler = GradScaler() if MIXED_PRECISION and device.type == 'cuda' else None

# Gradual unfreezing setup
if GRADUAL_UNFREEZE:
	print("[INFO] Starting with frozen backbone (gradual unfreezing enabled)")
	freeze_layers(model, freeze_backbone=True)

# Memory optimization
print("[INFO] Memory optimization enabled:")
print(f"   • Batch size: {BATCH_SIZE}")
print(f"   • Patch size: {PATCH_SIZE}")
print(f"   • Gradient accumulation: {GRADIENT_ACCUMULATION_STEPS}")
print(f"   • Mixed precision: {MIXED_PRECISION}")

# Training loop with gradient accumulation
print("[INFO] Starting training...")
print()

best_psnr = 0
start_time = time.time()

for epoch in range(EPOCHS):
	# Gradual unfreezing
	if GRADUAL_UNFREEZE and epoch == FREEZE_EPOCHS:
		print("[INFO] Unfreezing backbone layers")
		freeze_layers(model, freeze_backbone=False)
		# Reduce learning rate when unfreezing
		for param_group in optimizer.param_groups:
			param_group['lr'] *= 0.5

	model.train()
	epoch_loss = 0
	epoch_l1 = 0
	epoch_gradient = 0
	epoch_thermal = 0
	
	# Gradient accumulation setup
	optimizer.zero_grad()

	for batch_idx, (lr, hr) in enumerate(train_loader):
		lr, hr = lr.to(device), hr.to(device)
		
		# Forward pass with mixed precision
		if scaler is not None:
			with autocast():
				sr = model(lr)
				loss, loss_components = criterion(sr, hr)
				loss = loss / GRADIENT_ACCUMULATION_STEPS  # Scale loss for accumulation
		
			# Backward pass
			scaler.scale(loss).backward()
		else:
			sr = model(lr)
			loss, loss_components = criterion(sr, hr)
			loss = loss / GRADIENT_ACCUMULATION_STEPS  # Scale loss for accumulation
			loss.backward()
		
		# Accumulate losses (scale back for logging)
		epoch_loss += loss.item() * GRADIENT_ACCUMULATION_STEPS
		epoch_l1 += loss_components['l1']
		epoch_gradient += loss_components['gradient']
		epoch_thermal += loss_components['thermal']
		
		# Update weights every GRADIENT_ACCUMULATION_STEPS
		if (batch_idx + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
			if scaler is not None:
				scaler.step(optimizer)
				scaler.update()
			else:
				optimizer.step()
			optimizer.zero_grad()
			
			# Clear cache periodically
			if batch_idx % 20 == 0:
				torch.cuda.empty_cache()
		
		# Logging
		if batch_idx % LOG_INTERVAL == 0:
			progress = 100.0 * batch_idx / len(train_loader)
			current_lr = optimizer.param_groups[0]['lr']
			print(f"Epoch {epoch:3d} [{batch_idx:4d}/{len(train_loader)} ({progress:5.1f}%)] "
				  f"Loss: {loss.item() * GRADIENT_ACCUMULATION_STEPS:.6f} L1: {loss_components['l1']:.6f} "
				  f"Grad: {loss_components['gradient']:.6f} Thermal: {loss_components['thermal']:.6f} "
				  f"LR: {current_lr:.2e}")

	# Update learning rate
	scheduler.step()

	# Calculate epoch averages
	avg_loss = epoch_loss / len(train_loader)
	avg_l1 = epoch_l1 / len(train_loader)
	avg_gradient = epoch_gradient / len(train_loader)
	avg_thermal = epoch_thermal / len(train_loader)

	# Validation
	val_loss, val_psnr = 0, 0
	if epoch % VAL_INTERVAL == 0:
		val_loss, val_psnr = validate_model(model, val_loader, criterion, device)
		
		# Save checkpoint if best
		is_best = val_psnr > best_psnr
		if is_best:
			best_psnr = val_psnr
		
		save_checkpoint(model, optimizer, epoch, val_loss, val_psnr, CHECKPOINT_DIR, is_best)

	# Epoch summary
	elapsed = time.time() - start_time
	print(f"Epoch {epoch:3d} Summary: Loss={avg_loss:.6f} (L1:{avg_l1:.4f}, Grad:{avg_gradient:.4f}, Thermal:{avg_thermal:.4f}) "
		  f"Val_PSNR={val_psnr:.2f}dB Best={best_psnr:.2f}dB Time={elapsed/60:.1f}min")
	print("-" * 100)

# Final summary
total_time = time.time() - start_time
print()
print("[INFO] Training completed!")
print("=" * 60)
print(f"[INFO] Best validation PSNR: {best_psnr:.2f} dB")
print(f"[INFO] Total training time: {total_time/3600:.2f} hours")
print(f"[INFO] Best model saved at: {os.path.join(CHECKPOINT_DIR, 'thermal_best.pth')}")
print()
print("[INFO] Your thermal super-resolution model is ready!")

[INFO] CUDA device selected: NVIDIA GeForce RTX 3070 Laptop GPU
[INFO] CUDA memory available: 8.2 GB
[INFO] STARTING OPTIMIZATION PROCESS
[INFO] Scale factor: 2x
[INFO] Dataset: /home/kronbii/repos/thermal-super-resolution/datasets/flir_thermal_x2
[INFO] Pretrained model: /home/kronbii/repos/thermal-super-resolution/checkpoints/pretrained/IMDN_x2.pth
[INFO] Loading thermal dataset...
Loaded 10697 thermal images for training
Loaded 1189 thermal images for testing
   • Training samples: 171152
   • Validation samples: 1189
[INFO] Setting up IMDN model...
Loaded 10697 thermal images for training
Loaded 1189 thermal images for testing
   • Training samples: 171152
   • Validation samples: 1189
[INFO] Setting up IMDN model...
[INFO] Loading pretrained weights from /home/kronbii/repos/thermal-super-resolution/checkpoints/pretrained/IMDN_x2.pth
   • Adapted fea_conv.weight: torch.Size([64, 3, 3, 3]) -> torch.Size([64, 1, 3, 3])
   • Adapted upsampler.0.weight: torch.Size([12, 64, 3, 3]) -> to

  scaler = GradScaler() if MIXED_PRECISION and device.type == 'cuda' else None
  with autocast():
  with autocast():


