# Setup

In [2]:
pip install opencv-python natsort matplotlib thop torchsummary tensorboardX colorlog pytorch_msssim

Collecting opencv-python
  Using cached opencv_python-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (19 kB)
Using cached opencv_python-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (67.0 MB)
Installing collected packages: opencv-python
Successfully installed opencv-python-4.12.0.88
[0mNote: you may need to restart the kernel to use updated packages.


In [2]:
!pip install opencv-python torchvision natsort matplotlib thop torchsummary tensorboardX colorlog pytorch_msssim natsort

Collecting torchvision
  Using cached torchvision-0.23.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (6.1 kB)
Collecting natsort
  Using cached natsort-8.4.0-py3-none-any.whl.metadata (21 kB)
Collecting matplotlib
  Using cached matplotlib-3.10.6-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (11 kB)
Collecting thop
  Using cached thop-0.1.1.post2209072238-py3-none-any.whl.metadata (2.7 kB)
Collecting torchsummary
  Using cached torchsummary-1.5.1-py3-none-any.whl.metadata (296 bytes)
Collecting tensorboardX
  Using cached tensorboardx-2.6.4-py3-none-any.whl.metadata (6.2 kB)
Collecting colorlog
  Using cached colorlog-6.9.0-py3-none-any.whl.metadata (10 kB)
Collecting pytorch_msssim
  Using cached pytorch_msssim-1.0.0-py3-none-any.whl.metadata (8.0 kB)
Collecting torch==2.8.0 (from torchvision)
  Using cached torch-2.8.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (30 kB)
Collecting filelock (from torch==2.8.0->torchvision)
  Using cached filelock-3.19.1-py3

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import cv2

import natsort

print("PyTorch version:", torch.__version__)


PyTorch version: 2.8.0+cu129


In [7]:
print("Current device:", torch.cuda.current_device())
print("Device name:", torch.cuda.get_device_name(torch.cuda.current_device()))


Current device: 0
Device name: NVIDIA GeForce RTX 5090


## Config

In [8]:
import torch

# Ki·ªÉm tra GPU c√≥ kh·∫£ d·ª•ng kh√¥ng
if torch.cuda.is_available():
    print("GPU is available!")
    print("Number of GPUs:", torch.cuda.device_count())
    for i in range(torch.cuda.device_count()):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"Memory Allocated: {torch.cuda.memory_allocated(i)/1024**2:.2f} MB")
        print(f"Memory Reserved: {torch.cuda.memory_reserved(i)/1024**2:.2f} MB")
else:
    print("GPU is not available.")


GPU is available!
Number of GPUs: 1
GPU 0: NVIDIA GeForce RTX 5090
Memory Allocated: 0.00 MB
Memory Reserved: 0.00 MB


In [9]:
import multiprocessing
print("CPU cores:", multiprocessing.cpu_count())


CPU cores: 192


In [10]:
import torch

# Ki·ªÉm tra GPU
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    total_mem = torch.cuda.get_device_properties(device).total_memory
    print(f"GPU: {torch.cuda.get_device_name(device)}")
    print(f"T·ªïng b·ªô nh·ªõ GPU: {total_mem / 1e9:.2f} GB")
    
    # Th√¥ng tin hi·ªán c√≥
    reserved = torch.cuda.memory_reserved(device) / 1e9
    allocated = torch.cuda.memory_allocated(device) / 1e9
    free_mem = total_mem / 1e9 - reserved - allocated
    print(f"ƒê√£ c·∫•p ph√°t: {allocated:.2f} GB, Reserved: {reserved:.2f} GB, Free: {free_mem:.2f} GB")
else:
    print("Kh√¥ng c√≥ GPU")


GPU: NVIDIA GeForce RTX 5090
T·ªïng b·ªô nh·ªõ GPU: 33.67 GB
ƒê√£ c·∫•p ph√°t: 0.00 GB, Reserved: 0.00 GB, Free: 33.67 GB


In [None]:
# -*- coding: utf-8 -*-
# @Time    : 2018/6/11 15:54
# @Author  : zhoujun
import torch
import torch.utils.data as Data
from torchvision import transforms
from invoice_dataset import ImageData
from docunet_model import TinyDocUnet
import time
import config
from tensorboardX import SummaryWriter
from docunet_loss import DocUnetLoss_DL_batch as DocUnetLoss
import os
import shutil
import json
from collections import defaultdict
from torch.amp import GradScaler, autocast
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

from eval_scores import evaluate_batch

torch.backends.cudnn.benchmark = True

def save_checkpoint(checkpoint_path, model, optimizer, epoch, scaler=None, metrics=None):
    """Save checkpoint - simplified"""
    state = {
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
        'metrics': metrics or {}
    }
    if scaler is not None:
        state['scaler'] = scaler.state_dict()
    torch.save(state, checkpoint_path)

def load_checkpoint(checkpoint_path, model, optimizer, scaler=None):
    """Load checkpoint"""
    try:
        state = torch.load(checkpoint_path, map_location='cpu')
        model.load_state_dict(state['state_dict'])
        start_epoch = state['epoch']
        metrics = state.get('metrics', {})
        if scaler is not None and 'scaler' in state:
            scaler.load_state_dict(state['scaler'])
        print(f'Loaded checkpoint from epoch {start_epoch}')
        return start_epoch, metrics
    except Exception as e:
        print(f'Error loading checkpoint: {e}')
        return 0, {}

def validate_epoch(net, val_loader, criterion, device, use_amp=True):
    """Validation - t·ªëi ∆∞u, ch·ªâ t√≠nh to√°n metrics c·∫ßn thi·∫øt"""
    net.eval()
    total_ms_ssim = 0
    total_ad = 0
    total_val_loss = 0
    num_batches = 0
    
    with torch.no_grad():
        for images, targets in val_loader:
            try:
                images, targets = images.to(device, non_blocking=True), targets.to(device, non_blocking=True)
                
                if use_amp and torch.cuda.is_available():
                    with autocast(device_type='cuda'):
                        _, outputs = net(images)
                        val_loss = criterion(outputs, targets)
                    outputs = outputs.float()
                    val_loss = val_loss.float()
                else:
                    _, outputs = net(images)
                    val_loss = criterion(outputs, targets)
                
                targets = targets.float()
                
                if not torch.isfinite(val_loss):
                    continue
                
                total_val_loss += val_loss.item()
                
                # Evaluate batch
                ms_ssim_score, ad_score = evaluate_batch(outputs, targets)
                
                if not (np.isfinite(ms_ssim_score) and np.isfinite(ad_score)):
                    continue
                
                total_ms_ssim += ms_ssim_score
                total_ad += ad_score
                num_batches += 1
                
            except:
                continue
    
    if num_batches == 0:
        return float('inf'), 0.0, float('inf')
    
    return total_val_loss / num_batches, total_ms_ssim / num_batches, total_ad / num_batches

class TrainingMetrics:
    """Minimal metrics tracking"""
    def __init__(self, output_dir):
        self.output_dir = output_dir
        self.metrics = defaultdict(list)
        self.best_metrics = {
            'best_train_loss': float('inf'),
            'best_train_loss_epoch': 0,
            'best_val_loss': float('inf'),
            'best_val_loss_epoch': 0,
            'best_ms_ssim': 0.0,
            'best_ms_ssim_epoch': 0,
            'best_ad': float('inf'),
            'best_ad_epoch': 0,
        }
        
    def update(self, epoch, **kwargs):
        """Update metrics"""
        self.metrics['epoch'].append(epoch)
        for key, value in kwargs.items():
            if value is not None and np.isfinite(value):
                self.metrics[key].append(value)
            
        # Update best metrics
        if 'train_loss' in kwargs and kwargs['train_loss'] is not None:
            if kwargs['train_loss'] < self.best_metrics['best_train_loss']:
                self.best_metrics['best_train_loss'] = kwargs['train_loss']
                self.best_metrics['best_train_loss_epoch'] = epoch
                
        if 'val_loss' in kwargs and kwargs['val_loss'] is not None:
            if kwargs['val_loss'] < self.best_metrics['best_val_loss']:
                self.best_metrics['best_val_loss'] = kwargs['val_loss']
                self.best_metrics['best_val_loss_epoch'] = epoch
                
        if 'val_ms_ssim' in kwargs and kwargs['val_ms_ssim'] is not None:
            if kwargs['val_ms_ssim'] > self.best_metrics['best_ms_ssim']:
                self.best_metrics['best_ms_ssim'] = kwargs['val_ms_ssim']
                self.best_metrics['best_ms_ssim_epoch'] = epoch
                
        if 'val_ad' in kwargs and kwargs['val_ad'] is not None:
            if kwargs['val_ad'] < self.best_metrics['best_ad']:
                self.best_metrics['best_ad'] = kwargs['val_ad']
                self.best_metrics['best_ad_epoch'] = epoch
    
    def save_metrics(self):
        """Save metrics to JSON"""
        metrics_file = os.path.join(self.output_dir, 'training_metrics.json')
        all_metrics = {
            'training_history': dict(self.metrics),
            'best_metrics': self.best_metrics,
            'timestamp': datetime.now().isoformat()
        }
        with open(metrics_file, 'w') as f:
            json.dump(all_metrics, f, indent=2)
    
    def plot_training_curves(self):
        """T·∫°o training curves - ch·ªâ g·ªçi khi c·∫ßn"""
        try:
            fig, axes = plt.subplots(2, 2, figsize=(12, 10))
            
            # Training vs Validation loss
            if 'train_loss' in self.metrics:
                axes[0,0].plot(self.metrics['epoch'], self.metrics['train_loss'], label='Train', color='blue')
                if 'val_loss' in self.metrics:
                    val_epochs = [self.metrics['epoch'][i] for i in range(len(self.metrics['val_loss'])) 
                                 if self.metrics['val_loss'][i] is not None]
                    val_values = [v for v in self.metrics['val_loss'] if v is not None]
                    if val_values:
                        axes[0,0].plot(val_epochs, val_values, label='Val', color='red')
                axes[0,0].set_title('Loss')
                axes[0,0].set_xlabel('Epoch')
                axes[0,0].set_ylabel('Loss')
                axes[0,0].legend()
                axes[0,0].grid(True)
            
            # Learning rate
            if 'learning_rate' in self.metrics:
                axes[0,1].plot(self.metrics['epoch'], self.metrics['learning_rate'])
                axes[0,1].set_title('Learning Rate')
                axes[0,1].set_xlabel('Epoch')
                axes[0,1].set_yscale('log')
                axes[0,1].grid(True)
            
            # MS-SSIM
            if 'val_ms_ssim' in self.metrics:
                epochs = [self.metrics['epoch'][i] for i in range(len(self.metrics['val_ms_ssim'])) 
                         if self.metrics['val_ms_ssim'][i] is not None]
                values = [v for v in self.metrics['val_ms_ssim'] if v is not None]
                if values:
                    axes[1,0].plot(epochs, values, color='green')
                    axes[1,0].set_title('MS-SSIM')
                    axes[1,0].set_xlabel('Epoch')
                    axes[1,0].grid(True)
            
            # AD
            if 'val_ad' in self.metrics:
                epochs = [self.metrics['epoch'][i] for i in range(len(self.metrics['val_ad'])) 
                         if self.metrics['val_ad'][i] is not None]
                values = [v for v in self.metrics['val_ad'] if v is not None]
                if values:
                    axes[1,1].plot(epochs, values, color='orange')
                    axes[1,1].set_title('AD')
                    axes[1,1].set_xlabel('Epoch')
                    axes[1,1].grid(True)
            
            plt.tight_layout()
            plt.savefig(os.path.join(self.output_dir, 'training_curves.png'), dpi=150)
            plt.close()
        except:
            pass

def train():
    os.environ['CUDA_VISIBLE_DEVICES'] = str(config.gpu_id)
    if config.output_dir is None:
        config.output_dir = 'output'
    if config.restart_training:
        shutil.rmtree(config.output_dir, ignore_errors=True)
    if not os.path.exists(config.output_dir):
        os.mkdir(config.output_dir)

    metrics_tracker = TrainingMetrics(config.output_dir)
    
    # Minimal system info
    print("=== TRAINING START ===")
    print(f"PyTorch: {torch.__version__} | CUDA: {torch.cuda.is_available()}")
    print(f"Epochs: {config.epochs} | Batch: {config.train_batch_size} | LR: {config.lr}")
    
    # Device setup
    if config.gpu_id is not None and torch.cuda.is_available():
        device = torch.device("cuda:0")
        torch.manual_seed(config.seed)
        torch.cuda.manual_seed_all(config.seed)
    else:
        device = torch.device("cpu")
        torch.manual_seed(config.seed)

    # Data loading
    train_data = ImageData(config.trainroot, transform=transforms.ToTensor(), t_transform=transforms.ToTensor())
    train_loader = Data.DataLoader(
        dataset=train_data, 
        batch_size=config.train_batch_size, 
        shuffle=True,
        num_workers=int(config.workers),
        pin_memory=True if torch.cuda.is_available() else False,
        persistent_workers=True if int(config.workers) > 0 else False
    )
    
    test_data = ImageData(config.testroot, transform=transforms.ToTensor(), t_transform=transforms.ToTensor())
    test_loader = Data.DataLoader(
        dataset=test_data, 
        batch_size=1, 
        shuffle=False, 
        num_workers=3,
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    print(f"Train: {len(train_data)} | Val: {len(test_data)}")

    # Model setup
    writer = SummaryWriter(config.output_dir)
    net = TinyDocUnet(input_channels=3, n_classes=2).to(device)
    criterion = DocUnetLoss(reduction='mean')
    optimizer = torch.optim.AdamW(net.parameters(), lr=config.lr, weight_decay=1e-4)
    
    # Mixed precision
    use_amp = getattr(config, 'use_amp', True) and torch.cuda.is_available()
    scaler = GradScaler() if use_amp else None
    grad_clip = getattr(config, 'grad_clip', 1.0)
    accumulation_steps = getattr(config, 'accumulation_steps', 1)

    # Load checkpoint
    if config.checkpoint != '' and not config.restart_training:
        start_epoch, _ = load_checkpoint(config.checkpoint, net, optimizer, scaler)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=start_epoch
        )
    else:
        start_epoch = config.start_epoch
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=start_epoch - 1
        )

    all_step = len(train_loader)
    global_step = start_epoch * all_step
    
    try:
        training_start = time.time()
        
        for epoch in range(start_epoch, config.epochs):
            net.train()
            train_loss = 0.
            accumulated_loss = 0.
            epoch_start = time.time()
            
            optimizer.zero_grad()
            
            for i, (images, labels) in enumerate(train_loader):
                images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
                
                # Forward
                if use_amp:
                    with autocast(device_type='cuda'):
                        _, y = net(images)
                        loss = criterion(y, labels) / accumulation_steps
                else:
                    _, y = net(images)
                    loss = criterion(y, labels) / accumulation_steps
                
                # Backward
                if use_amp:
                    scaler.scale(loss).backward()
                else:
                    loss.backward()
                
                accumulated_loss += loss.item()
                train_loss += loss.item() * accumulation_steps
                
                # Update weights
                if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_loader):
                    if use_amp:
                        if grad_clip > 0:
                            scaler.unscale_(optimizer)
                            torch.nn.utils.clip_grad_norm_(net.parameters(), grad_clip)
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        if grad_clip > 0:
                            torch.nn.utils.clip_grad_norm_(net.parameters(), grad_clip)
                        optimizer.step()
                    optimizer.zero_grad()
                    accumulated_loss = 0.
                
                global_step += 1
            
            scheduler.step()
            
            # Epoch summary - ch·ªâ log m·ªói 5 epoch
            epoch_time = time.time() - epoch_start
            avg_train_loss = train_loss / len(train_loader)
            current_lr = scheduler.get_last_lr()[0]
            
            # Validation - m·ªói 5 epoch
            val_loss, val_ms_ssim, val_ad = None, None, None
            if (epoch + 1) % 5 == 0:
                val_loss, val_ms_ssim, val_ad = validate_epoch(net, test_loader, criterion, device, use_amp)
                
                # Save best models
                if val_loss < metrics_tracker.best_metrics['best_val_loss']:
                    save_checkpoint(f'{config.output_dir}/best_val_loss.pth', net, optimizer, epoch + 1, scaler,
                                  {'train_loss': avg_train_loss, 'val_loss': val_loss, 'ms_ssim': val_ms_ssim, 'ad': val_ad})
                
                if val_ms_ssim > metrics_tracker.best_metrics['best_ms_ssim']:
                    save_checkpoint(f'{config.output_dir}/best_ms_ssim.pth', net, optimizer, epoch + 1, scaler,
                                  {'train_loss': avg_train_loss, 'val_loss': val_loss, 'ms_ssim': val_ms_ssim, 'ad': val_ad})
                
                if val_ad < metrics_tracker.best_metrics['best_ad']:
                    save_checkpoint(f'{config.output_dir}/best_ad.pth', net, optimizer, epoch + 1, scaler,
                                  {'train_loss': avg_train_loss, 'val_loss': val_loss, 'ms_ssim': val_ms_ssim, 'ad': val_ad})
                
                # Print summary m·ªói 5 epoch
                print(f"\n{'='*60}")
                print(f"Epoch [{epoch+1}/{config.epochs}] - Time: {epoch_time:.1f}s")
                print(f"Train Loss: {avg_train_loss:.6f} | LR: {current_lr:.2e}")
                print(f"Val Loss: {val_loss:.6f} | MS-SSIM: {val_ms_ssim:.4f} | AD: {val_ad:.4f}")
                print(f"Best - ValLoss: {metrics_tracker.best_metrics['best_val_loss']:.6f} "
                      f"(E{metrics_tracker.best_metrics['best_val_loss_epoch']}) | "
                      f"MS-SSIM: {metrics_tracker.best_metrics['best_ms_ssim']:.4f} "
                      f"(E{metrics_tracker.best_metrics['best_ms_ssim_epoch']})")
                print(f"{'='*60}\n")
                
                # TensorBoard - minimal logging
                writer.add_scalar('Val/Loss', val_loss, epoch)
                writer.add_scalar('Val/MS-SSIM', val_ms_ssim, epoch)
                writer.add_scalar('Val/AD', val_ad, epoch)
            else:
                # Ch·ªâ print progress ng·∫Øn g·ªçn
                print(f"E{epoch+1:03d}/{config.epochs} | Loss: {avg_train_loss:.6f} | Time: {epoch_time:.1f}s", end='\r')
            
            # Update metrics
            metrics_tracker.update(
                epoch=epoch + 1,
                train_loss=avg_train_loss,
                val_loss=val_loss,
                learning_rate=current_lr,
                val_ms_ssim=val_ms_ssim,
                val_ad=val_ad
            )
            
            # TensorBoard - minimal
            writer.add_scalar('Train/loss', avg_train_loss, epoch)
            writer.add_scalar('Train/lr', current_lr, epoch)
            
            # Save checkpoint m·ªói 10 epoch
            if (epoch + 1) % 10 == 0:
                save_checkpoint(f'{config.output_dir}/checkpoint_e{epoch+1:03d}.pth', 
                              net, optimizer, epoch + 1, scaler,
                              {'train_loss': avg_train_loss, 'val_loss': val_loss})
            
            # Save metrics & plots m·ªói 10 epoch
            if (epoch + 1) % 10 == 0:
                metrics_tracker.save_metrics()
                metrics_tracker.plot_training_curves()
            
            # Memory cleanup
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        
        # Training completed
        total_time = time.time() - training_start
        print(f"\n{'='*60}")
        print("üéâ TRAINING COMPLETED!")
        print(f"Total time: {total_time/3600:.2f}h | Avg/epoch: {total_time/config.epochs:.1f}s")
        print(f"Best Val Loss: {metrics_tracker.best_metrics['best_val_loss']:.6f} (E{metrics_tracker.best_metrics['best_val_loss_epoch']})")
        print(f"Best MS-SSIM: {metrics_tracker.best_metrics['best_ms_ssim']:.4f} (E{metrics_tracker.best_metrics['best_ms_ssim_epoch']})")
        print(f"{'='*60}")
        
        metrics_tracker.save_metrics()
        metrics_tracker.plot_training_curves()
        save_checkpoint(f'{config.output_dir}/final_model.pth', net, optimizer, epoch + 1, scaler)
        writer.close()
        
    except KeyboardInterrupt:
        print("\n‚ö†Ô∏è Training interrupted")
        save_checkpoint(f'{config.output_dir}/interrupted_e{epoch+1}.pth', net, optimizer, epoch + 1, scaler)
        metrics_tracker.save_metrics()
        writer.close()
        
    except Exception as e:
        print(f"\n‚ùå Error: {e}")
        save_checkpoint(f'{config.output_dir}/error_e{epoch+1}.pth', net, optimizer, epoch + 1, scaler)
        metrics_tracker.save_metrics()
        writer.close()
        raise

if __name__ == '__main__':
    train()

=== TRAINING START ===
PyTorch: 2.8.0+cu129 | CUDA: True
Epochs: 100 | Batch: 8 | LR: 0.0001
Checking folder path: train_gen/images
Checking folder path: test_gen/images
Train: 5000 | Val: 1002


# Train

In [1]:
# -*- coding: utf-8 -*-
# @Time    : 2018/6/11 15:54
# @Author  : zhoujun
import torch
import torch.utils.data as Data
from torchvision import transforms
from invoice_dataset import ImageData
from docunet_model_c import TinyDocUnet
import time
import config
from tensorboardX import SummaryWriter
from docunet_loss import DocUnetLoss_DL_batch as DocUnetLoss
import os
import shutil
import json
from collections import defaultdict
from torch.amp import GradScaler, autocast
import logging
from colorlog import ColoredFormatter
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

from eval_scores import evaluate_batch

torch.backends.cudnn.benchmark = True
torch._inductor.config.debug = 1

def setup_logger(log_file_path: str = None):
    """Setup logger v·ªõi file v√† console output"""
    # Clear existing handlers
    logger = logging.getLogger('project')
    logger.handlers.clear()
    
    # File handler
    if log_file_path:
        file_handler = logging.FileHandler(log_file_path)
        file_formatter = logging.Formatter(
            '%(asctime)s %(levelname)-8s %(filename)s: %(message)s',
            datefmt='%Y-%m-%d %H:%M:%S'
        )
        file_handler.setFormatter(file_formatter)
        logger.addHandler(file_handler)
    
    # Console handler v·ªõi m√†u
    console_formatter = ColoredFormatter(
        "%(asctime)s %(log_color)s%(levelname)-8s %(reset)s %(filename)s: %(message)s",
        datefmt='%Y-%m-%d %H:%M:%S',
        reset=True,
        log_colors={
            'DEBUG': 'blue',
            'INFO': 'green', 
            'WARNING': 'yellow',
            'ERROR': 'red',
            'CRITICAL': 'red',
        }
    )
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(console_formatter)
    logger.addHandler(console_handler)
    
    logger.setLevel(logging.DEBUG)
    logger.info('Logger initialized successfully')
    return logger

def save_checkpoint(checkpoint_path, model, optimizer, epoch, scaler=None, metrics=None):
    """Save checkpoint v·ªõi metadata chi ti·∫øt"""
    state = {
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
        'metrics': metrics or {},
        'timestamp': datetime.now().isoformat(),
        'pytorch_version': torch.__version__
    }
    if scaler is not None:
        state['scaler'] = scaler.state_dict()
    
    torch.save(state, checkpoint_path)
    print(f'Model saved to {checkpoint_path}')

def load_checkpoint(checkpoint_path, model, optimizer, scaler=None):
    """Load checkpoint v·ªõi error handling"""
    try:
        state = torch.load(checkpoint_path, map_location='cpu')
        model.load_state_dict(state['state_dict'])
        # Kh√¥ng load optimizer state ƒë·ªÉ tr√°nh conflict
        # optimizer.load_state_dict(state['optimizer']) 
        start_epoch = state['epoch']
        metrics = state.get('metrics', {})
        
        if scaler is not None and 'scaler' in state:
            scaler.load_state_dict(state['scaler'])
        
        print(f'Model loaded from {checkpoint_path}')
        print(f'Checkpoint timestamp: {state.get("timestamp", "Unknown")}')
        return start_epoch, metrics
    except Exception as e:
        print(f'Error loading checkpoint: {e}')
        return 0, {}

def validate_epoch(net, val_loader, criterion, device, logger, use_amp=True):
    """Validation v·ªõi AMP support, validation loss v√† error handling"""
    net.eval()
    total_ms_ssim = 0
    total_ad = 0
    total_val_loss = 0
    num_batches = 0
    val_start_time = time.time()
    
    validation_errors = 0
    
    with torch.no_grad():
        for batch_idx, (images, targets) in enumerate(val_loader):
            try:
                # torch.compiler.cudagraph_mark_step_begin()
                images, targets = images.to(device, non_blocking=True), targets.to(device, non_blocking=True)
                
                # Forward pass v·ªõi mixed precision
                if use_amp and torch.cuda.is_available():
                    with autocast(device_type='cuda'):
                        _, outputs = net(images)
                        # Calculate validation loss
                        val_loss = criterion(outputs, targets)
                    # QUAN TR·ªåNG: Convert t·ª´ half precision v·ªÅ float32 cho evaluation
                    outputs = outputs.float()
                    val_loss = val_loss.float()
                else:
                    _, outputs = net(images)
                    val_loss = criterion(outputs, targets)
                
                # ƒê·∫£m b·∫£o targets c≈©ng l√† float32
                targets = targets.float()
                
                # Validate loss
                if not torch.isfinite(val_loss):
                    logger.warning(f'Invalid loss at validation batch {batch_idx}: {val_loss.item()}')
                    validation_errors += 1
                    continue
                
                # Accumulate validation loss
                total_val_loss += val_loss.item()
                
                # Evaluate batch
                ms_ssim_score, ad_score = evaluate_batch(outputs, targets)
                
                # Validate scores
                if not (np.isfinite(ms_ssim_score) and np.isfinite(ad_score)):
                    logger.warning(f'Invalid scores at batch {batch_idx}: MS-SSIM={ms_ssim_score}, AD={ad_score}')
                    validation_errors += 1
                    continue
                
                total_ms_ssim += ms_ssim_score
                total_ad += ad_score
                num_batches += 1
                
                # Log progress
                if (batch_idx + 1) % 20 == 0:
                    logger.info(f'Validation [{batch_idx + 1}/{len(val_loader)}] - '
                              f'Loss: {val_loss.item():.4f}, MS-SSIM: {ms_ssim_score:.4f}, AD: {ad_score:.4f}')
                
            except Exception as e:
                logger.error(f'Error in validation batch {batch_idx}: {str(e)}')
                validation_errors += 1
                continue
    
    if num_batches == 0:
        logger.error("No valid batches in validation!")
        return float('inf'), 0.0, float('inf')
    
    avg_val_loss = total_val_loss / num_batches
    avg_ms_ssim = total_ms_ssim / num_batches
    avg_ad = total_ad / num_batches
    val_time = time.time() - val_start_time
    
    logger.info(f'Validation completed in {val_time:.2f}s - '
                f'Avg Loss: {avg_val_loss:.4f}, Avg MS-SSIM: {avg_ms_ssim:.4f}, Avg AD: {avg_ad:.4f}')
    if validation_errors > 0:
        logger.warning(f'Validation errors: {validation_errors}/{len(val_loader)}')
    
    return avg_val_loss, avg_ms_ssim, avg_ad

class TrainingMetrics:
    """Enhanced metrics tracking v·ªõi validation loss v√† visualization"""
    def __init__(self, output_dir):
        self.output_dir = output_dir
        self.metrics = defaultdict(list)
        self.best_metrics = {
            'best_train_loss': float('inf'),
            'best_train_loss_epoch': 0,
            'best_val_loss': float('inf'),
            'best_val_loss_epoch': 0,
            'best_ms_ssim': 0.0,
            'best_ms_ssim_epoch': 0,
            'best_ad': float('inf'),
            'best_ad_epoch': 0,
        }
        self.start_time = time.time()
        
    def update(self, epoch, **kwargs):
        """Update metrics v·ªõi validation"""
        self.metrics['epoch'].append(epoch)
        for key, value in kwargs.items():
            if value is not None and np.isfinite(value):
                self.metrics[key].append(value)
            
        # Update best metrics
        if 'train_loss' in kwargs and kwargs['train_loss'] is not None:
            if kwargs['train_loss'] < self.best_metrics['best_train_loss']:
                self.best_metrics['best_train_loss'] = kwargs['train_loss']
                self.best_metrics['best_train_loss_epoch'] = epoch
                
        if 'val_loss' in kwargs and kwargs['val_loss'] is not None:
            if kwargs['val_loss'] < self.best_metrics['best_val_loss']:
                self.best_metrics['best_val_loss'] = kwargs['val_loss']
                self.best_metrics['best_val_loss_epoch'] = epoch
                
        if 'val_ms_ssim' in kwargs and kwargs['val_ms_ssim'] is not None:
            if kwargs['val_ms_ssim'] > self.best_metrics['best_ms_ssim']:
                self.best_metrics['best_ms_ssim'] = kwargs['val_ms_ssim']
                self.best_metrics['best_ms_ssim_epoch'] = epoch
                
        if 'val_ad' in kwargs and kwargs['val_ad'] is not None:
            if kwargs['val_ad'] < self.best_metrics['best_ad']:
                self.best_metrics['best_ad'] = kwargs['val_ad']
                self.best_metrics['best_ad_epoch'] = epoch
    
    def save_metrics(self):
        """Save metrics v·ªõi timestamp"""
        metrics_file = os.path.join(self.output_dir, 'training_metrics.json')
        all_metrics = {
            'training_history': dict(self.metrics),
            'best_metrics': self.best_metrics,
            'total_training_time': time.time() - self.start_time,
            'timestamp': datetime.now().isoformat()
        }
        with open(metrics_file, 'w') as f:
            json.dump(all_metrics, f, indent=2)
    
    def plot_training_curves(self):
        """T·∫°o training curves v·ªõi validation loss"""
        try:
            fig, axes = plt.subplots(2, 3, figsize=(18, 10))
            
            # Training loss
            if 'train_loss' in self.metrics:
                axes[0,0].plot(self.metrics['epoch'], self.metrics['train_loss'], label='Train Loss', color='blue')
                axes[0,0].set_title('Training Loss')
                axes[0,0].set_xlabel('Epoch')
                axes[0,0].set_ylabel('Loss')
                axes[0,0].grid(True)
                axes[0,0].legend()
            
            # Validation loss
            if 'val_loss' in self.metrics:
                epochs = [self.metrics['epoch'][i] for i in range(len(self.metrics['val_loss'])) 
                         if self.metrics['val_loss'][i] is not None]
                values = [v for v in self.metrics['val_loss'] if v is not None]
                if values:
                    axes[0,1].plot(epochs, values, label='Val Loss', color='red')
                    axes[0,1].set_title('Validation Loss')
                    axes[0,1].set_xlabel('Epoch')
                    axes[0,1].set_ylabel('Loss')
                    axes[0,1].grid(True)
                    axes[0,1].legend()
            
            # Combined losses
            if 'train_loss' in self.metrics and 'val_loss' in self.metrics:
                axes[0,2].plot(self.metrics['epoch'], self.metrics['train_loss'], label='Train Loss', color='blue')
                val_epochs = [self.metrics['epoch'][i] for i in range(len(self.metrics['val_loss'])) 
                             if self.metrics['val_loss'][i] is not None]
                val_values = [v for v in self.metrics['val_loss'] if v is not None]
                if val_values:
                    axes[0,2].plot(val_epochs, val_values, label='Val Loss', color='red')
                axes[0,2].set_title('Training vs Validation Loss')
                axes[0,2].set_xlabel('Epoch')
                axes[0,2].set_ylabel('Loss')
                axes[0,2].grid(True)
                axes[0,2].legend()
            
            # Learning rate
            if 'learning_rate' in self.metrics:
                axes[1,0].plot(self.metrics['epoch'], self.metrics['learning_rate'])
                axes[1,0].set_title('Learning Rate')
                axes[1,0].set_xlabel('Epoch')
                axes[1,0].set_ylabel('LR')
                axes[1,0].set_yscale('log')
                axes[1,0].grid(True)
            
            # Validation MS-SSIM
            if 'val_ms_ssim' in self.metrics:
                epochs = [self.metrics['epoch'][i] for i in range(len(self.metrics['val_ms_ssim'])) 
                         if self.metrics['val_ms_ssim'][i] is not None]
                values = [v for v in self.metrics['val_ms_ssim'] if v is not None]
                if values:
                    axes[1,1].plot(epochs, values, color='green')
                    axes[1,1].set_title('Validation MS-SSIM')
                    axes[1,1].set_xlabel('Epoch')
                    axes[1,1].set_ylabel('MS-SSIM')
                    axes[1,1].grid(True)
            
            # Validation AD
            if 'val_ad' in self.metrics:
                epochs = [self.metrics['epoch'][i] for i in range(len(self.metrics['val_ad'])) 
                         if self.metrics['val_ad'][i] is not None]
                values = [v for v in self.metrics['val_ad'] if v is not None]
                if values:
                    axes[1,2].plot(epochs, values, color='orange')
                    axes[1,2].set_title('Validation AD')
                    axes[1,2].set_xlabel('Epoch')
                    axes[1,2].set_ylabel('AD')
                    axes[1,2].grid(True)
            
            plt.tight_layout()
            plot_path = os.path.join(self.output_dir, 'training_curves.png')
            plt.savefig(plot_path, dpi=150, bbox_inches='tight')
            plt.close()
            
        except Exception as e:
            print(f"Error creating plots: {e}")
    
    def get_best_info(self):
        """Get formatted best metrics info v·ªõi validation loss"""
        return (f"Best Train Loss: {self.best_metrics['best_train_loss']:.4f} (epoch {self.best_metrics['best_train_loss_epoch']}), "
                f"Best Val Loss: {self.best_metrics['best_val_loss']:.4f} (epoch {self.best_metrics['best_val_loss_epoch']}), "
                f"Best MS-SSIM: {self.best_metrics['best_ms_ssim']:.4f} (epoch {self.best_metrics['best_ms_ssim_epoch']}), "
                f"Best AD: {self.best_metrics['best_ad']:.4f} (epoch {self.best_metrics['best_ad_epoch']})")


In [None]:
def train():    
    os.environ['CUDA_VISIBLE_DEVICES'] = str(config.gpu_id)
    if config.output_dir is None:
        config.output_dir = 'output'
    if config.restart_training:
        shutil.rmtree(config.output_dir, ignore_errors=True)
    if not os.path.exists(config.output_dir):
        os.mkdir(config.output_dir)

    # Initialize logging v√† metrics
    log_file = os.path.join(config.output_dir, f'train_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
    logger = setup_logger(log_file)
    metrics_tracker = TrainingMetrics(config.output_dir)
    
    # Log system info
    logger.info("=== SYSTEM INFORMATION ===")
    logger.info(f"PyTorch version: {torch.__version__}")
    logger.info(f"CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        logger.info(f"CUDA version: {torch.version.cuda}")
        logger.info(f"GPU name: {torch.cuda.get_device_name(0)}")
        logger.info(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    
    # Log training configuration
    logger.info("=== TRAINING CONFIGURATION ===")
    logger.info(f"Epochs: {config.epochs}")
    logger.info(f"Batch size: {config.train_batch_size}")
    logger.info(f"Accumulation steps: {getattr(config, 'accumulation_steps', 1)}")
    logger.info(f"Learning rate: {config.lr}")
    logger.info(f"GPU ID: {config.gpu_id}")
    logger.info(f"Mixed Precision: {getattr(config, 'use_amp', True)}")
    logger.info(f"Gradient Clipping: {getattr(config, 'grad_clip', 1.0)}")
    logger.info(f"Eval interval: {getattr(config, 'eval_interval', 5)}")
    logger.info("==============================")

    # Set seeds
    torch.manual_seed(config.seed)
    if config.gpu_id is not None and torch.cuda.is_available():
        logger.info(f'Training with GPU {config.gpu_id} and PyTorch {torch.__version__}')
        device = torch.device("cuda:0")
        torch.cuda.manual_seed(config.seed)
        torch.cuda.manual_seed_all(config.seed)
    else:
        logger.info(f'Training with CPU and PyTorch {torch.__version__}')
        device = torch.device("cpu")

    # Data loading v·ªõi error handling
    try:
        train_data = ImageData(config.trainroot, transform=transforms.ToTensor(), t_transform=transforms.ToTensor())
        train_loader = Data.DataLoader(
            dataset=train_data, 
            batch_size=config.train_batch_size, 
            shuffle=True,
            num_workers=int(config.workers),
            pin_memory=True if torch.cuda.is_available() else False,
            persistent_workers=True if int(config.workers) > 0 else False,
            prefetch_factor=2  # S·ªë batch m√† m·ªói worker s·∫Ω preload (m·∫∑c ƒë·ªãnh l√† 2)
        )
        
        test_data = ImageData(config.testroot, transform=transforms.ToTensor(), t_transform=transforms.ToTensor())
        test_loader = Data.DataLoader(
            dataset=test_data, 
            batch_size=1, 
            shuffle=False, 
            num_workers=3,
            pin_memory=True if torch.cuda.is_available() else False
        )
        
        logger.info(f"Training samples: {len(train_data)}")
        logger.info(f"Validation samples: {len(test_data)}")
        
    except Exception as e:
        logger.error(f"Error loading data: {str(e)}")
        raise

    # Model setup
    writer = SummaryWriter(config.output_dir)
    net = TinyDocUnet(input_channels=3, n_classes=2)
    net = net.to(device)

    # Compile model
    if getattr(config, 'use_compile', True):
        net = torch.compile(net, options={"triton.cudagraphs": False}  # ‚Üê T·∫Øt CUDA Graphs
    )

    # Log model info
    total_params = sum(p.numel() for p in net.parameters())
    trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
    logger.info(f"Model parameters: {total_params:,} (trainable: {trainable_params:,})")

    criterion = DocUnetLoss(reduction='mean')
    optimizer = torch.optim.AdamW(net.parameters(), lr=config.lr, weight_decay=1e-4)
    
    # Mixed precision setup
    use_amp = getattr(config, 'use_amp', True) and torch.cuda.is_available()
    scaler = GradScaler() if use_amp else None
    grad_clip = getattr(config, 'grad_clip', 1.0)
    
    if use_amp:
        logger.info("Using Automatic Mixed Precision (AMP)")
    if grad_clip > 0:
        logger.info(f"Using gradient clipping with max norm: {grad_clip}")
    
    # Gradient accumulation
    accumulation_steps = getattr(config, 'accumulation_steps', 1)
    effective_batch_size = config.train_batch_size * accumulation_steps
    logger.info(f"Effective batch size: {effective_batch_size}")

    # Load checkpoint if exists
    start_metrics = {}
    if config.checkpoint != '' and not config.restart_training:
        start_epoch, start_metrics = load_checkpoint(config.checkpoint, net, optimizer, scaler)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=start_epoch
        )
    else:
        start_epoch = config.start_epoch
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=start_epoch - 1
        )

    all_step = len(train_loader)
    global_step = start_epoch * all_step
    epoch = 0
    
    try:
        logger.info("üöÄ Starting training...")
        training_start_time = time.time()
        
        for epoch in range(start_epoch, config.epochs):
            net.train()
            train_loss = 0.
            accumulated_loss = 0.
            epoch_start_time = time.time()
            
            # Statistics tracking
            batch_times = []
            losses = []
            
            # Reset gradients
            optimizer.zero_grad()
            
            for i, (images, labels) in enumerate(train_loader):
                batch_start_time = time.time()
                
                images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
                
                # Forward pass
                if use_amp:
                    with autocast(device_type='cuda'):
                        _, y = net(images)
                        loss = criterion(y, labels)
                else:
                    _, y = net(images)
                    loss = criterion(y, labels)
                
                # Scale loss for accumulation
                loss = loss / accumulation_steps
                
                # Backward pass
                if use_amp:
                    scaler.scale(loss).backward()
                else:
                    loss.backward()
                
                accumulated_loss += loss.item()
                train_loss += loss.item() * accumulation_steps
                losses.append(loss.item() * accumulation_steps)
                
                # Update weights
                if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_loader):
                    if use_amp:
                        if grad_clip > 0:
                            scaler.unscale_(optimizer)
                            grad_norm = torch.nn.utils.clip_grad_norm_(net.parameters(), grad_clip)
                        
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        if grad_clip > 0:
                            grad_norm = torch.nn.utils.clip_grad_norm_(net.parameters(), grad_clip)
                        optimizer.step()
                    
                    optimizer.zero_grad()
                    
                    # Log accumulated loss
                    writer.add_scalar('Train/accumulated_loss', accumulated_loss, global_step // accumulation_steps)
                    
                    # Log gradient norm
                    if grad_clip > 0 and (global_step // accumulation_steps) % 50 == 0:
                        writer.add_scalar('Train/grad_norm', grad_norm, global_step // accumulation_steps)
                    
                    accumulated_loss = 0.
                
                batch_time = time.time() - batch_start_time
                batch_times.append(batch_time)
                
                # Periodic logging
                if (i + 1) % config.display_interval == 0:
                    current_lr = scheduler.get_last_lr()[0]
                    avg_batch_time = np.mean(batch_times[-50:])  # Last 50 batches
                    avg_loss = np.mean(losses[-50:])  # Last 50 batches
                    
                    logger.info(f'[{epoch + 1:3d}/{config.epochs}] [{i + 1:4d}/{all_step}] '
                              f'Loss: {loss.item() * accumulation_steps:.4f} (avg: {avg_loss:.4f}) '
                              f'Time: {batch_time:.2f}s (avg: {avg_batch_time:.2f}s) '
                              f'LR: {current_lr:.2e}')
                
                # TensorBoard logging
                writer.add_scalar('Train/batch_loss', loss.item() * accumulation_steps, global_step)
                writer.add_scalar('Train/lr', scheduler.get_last_lr()[0], global_step)
                writer.add_scalar('Train/batch_time', batch_time, global_step)
                global_step += 1
            
            # Step scheduler
            scheduler.step()
            
            # Epoch statistics
            epoch_time = time.time() - epoch_start_time
            avg_train_loss = train_loss / len(train_loader)
            current_lr = scheduler.get_last_lr()[0]
            avg_batch_time = np.mean(batch_times)
            throughput = len(train_data) / epoch_time  # samples/second
            
            logger.info(f'=== EPOCH {epoch + 1:3d}/{config.epochs} SUMMARY ===')
            logger.info(f'Train loss: {avg_train_loss:.6f}')
            logger.info(f'Epoch time: {epoch_time:.2f}s')
            logger.info(f'Avg batch time: {avg_batch_time:.3f}s')
            logger.info(f'Throughput: {throughput:.1f} samples/sec')
            logger.info(f'Learning rate: {current_lr:.2e}')
            
            # Validation
            val_loss, val_ms_ssim, val_ad = None, None, None
            if (epoch + 1) % getattr(config, 'eval_interval', 5) == 0:
                logger.info("üîç Running validation...")
                val_loss, val_ms_ssim, val_ad = validate_epoch(net, test_loader, criterion, device, logger, use_amp)
                
                # TensorBoard logging
                writer.add_scalar('Val/Loss', val_loss, epoch)
                writer.add_scalar('Val/MS-SSIM', val_ms_ssim, epoch)
                writer.add_scalar('Val/AD', val_ad, epoch)
                
                # Save best models
                is_best_val_loss = val_loss < metrics_tracker.best_metrics['best_val_loss']
                is_best_ms_ssim = val_ms_ssim > metrics_tracker.best_metrics['best_ms_ssim']
                is_best_ad = val_ad < metrics_tracker.best_metrics['best_ad']
                
                if is_best_val_loss:
                    logger.info(f"üèÜ New best validation loss: {val_loss:.6f}")
                    save_checkpoint(
                        f'{config.output_dir}/best_val_loss.pth', 
                        net, optimizer, epoch + 1, scaler, 
                        {'train_loss': avg_train_loss, 'val_loss': val_loss, 'ms_ssim': val_ms_ssim, 'ad': val_ad}
                    )
                
                if is_best_ms_ssim:
                    logger.info(f"üèÜ New best MS-SSIM: {val_ms_ssim:.6f}")
                    save_checkpoint(
                        f'{config.output_dir}/best_ms_ssim.pth', 
                        net, optimizer, epoch + 1, scaler, 
                        {'train_loss': avg_train_loss, 'val_loss': val_loss, 'ms_ssim': val_ms_ssim, 'ad': val_ad}
                    )
                
                if is_best_ad:
                    logger.info(f"üèÜ New best AD: {val_ad:.6f}")
                    save_checkpoint(
                        f'{config.output_dir}/best_ad.pth', 
                        net, optimizer, epoch + 1, scaler,
                        {'train_loss': avg_train_loss, 'val_loss': val_loss, 'ms_ssim': val_ms_ssim, 'ad': val_ad}
                    )
            
            # Update metrics
            metrics_tracker.update(
                epoch=epoch + 1,
                train_loss=avg_train_loss,
                val_loss=val_loss,
                learning_rate=current_lr,
                epoch_time=epoch_time,
                batch_time=avg_batch_time,
                throughput=throughput,
                val_ms_ssim=val_ms_ssim,
                val_ad=val_ad
            )
            
            # TensorBoard epoch summary
            writer.add_scalar('Train/epoch_loss', avg_train_loss, epoch)
            writer.add_scalar('Train/epoch_time', epoch_time, epoch)
            writer.add_scalar('Train/throughput', throughput, epoch)
            
            # Save regular checkpoint
            if (epoch + 1) % 10 == 0:
                checkpoint_name = f'checkpoint_epoch_{epoch + 1:03d}.pth'
                save_checkpoint(
                    f'{config.output_dir}/{checkpoint_name}', 
                    net, optimizer, epoch + 1, scaler,
                    {'train_loss': avg_train_loss, 'val_loss': val_loss, 'ms_ssim': val_ms_ssim, 'ad': val_ad}
                )
            
            # Save metrics v√† create plots
            if (epoch + 1) % 5 == 0:
                metrics_tracker.save_metrics()
                metrics_tracker.plot_training_curves()
            
            logger.info(f'üìä {metrics_tracker.get_best_info()}')
            logger.info("=" * 60)
            
            # Memory cleanup
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        
        # Training completed
        total_training_time = time.time() - training_start_time
        logger.info("üéâ TRAINING COMPLETED SUCCESSFULLY! üéâ")
        logger.info(f"Total training time: {total_training_time / 3600:.2f} hours")
        logger.info(f"Average time per epoch: {total_training_time / config.epochs:.2f}s")
        logger.info(f"Final metrics: {metrics_tracker.get_best_info()}")
        
        # Save final artifacts
        metrics_tracker.save_metrics()
        metrics_tracker.plot_training_curves()
        save_checkpoint(f'{config.output_dir}/final_model.pth', net, optimizer, epoch + 1, scaler)
        
        writer.close()
        
    except KeyboardInterrupt:
        logger.warning("‚ö†Ô∏è Training interrupted by user")
        save_checkpoint(f'{config.output_dir}/interrupted_model_epoch_{epoch + 1}.pth', net, optimizer, epoch + 1, scaler)
        metrics_tracker.save_metrics()
        metrics_tracker.plot_training_curves()
        writer.close()
        
    except Exception as e:
        logger.error(f"‚ùå Training failed with error: {str(e)}")
        import traceback
        logger.error(f"Full traceback:\n{traceback.format_exc()}")
        save_checkpoint(f'{config.output_dir}/error_model_epoch_{epoch + 1}.pth', net, optimizer, epoch + 1, scaler)
        metrics_tracker.save_metrics()
        writer.close()
        raise

if __name__ == '__main__':
    train()

2025-09-30 10:17:37 [32mINFO     [0m 492157266.py: Logger initialized successfully[0m
2025-09-30 10:17:37 [32mINFO     [0m 2906090605.py: === SYSTEM INFORMATION ===[0m
2025-09-30 10:17:37 [32mINFO     [0m 2906090605.py: PyTorch version: 2.8.0+cu129[0m
2025-09-30 10:17:37 [32mINFO     [0m 2906090605.py: CUDA available: True[0m
2025-09-30 10:17:37 [32mINFO     [0m 2906090605.py: CUDA version: 12.9[0m
2025-09-30 10:17:37 [32mINFO     [0m 2906090605.py: GPU name: NVIDIA GeForce RTX 5090[0m
2025-09-30 10:17:37 [32mINFO     [0m 2906090605.py: GPU memory: 33.7 GB[0m
2025-09-30 10:17:37 [32mINFO     [0m 2906090605.py: === TRAINING CONFIGURATION ===[0m
2025-09-30 10:17:37 [32mINFO     [0m 2906090605.py: Epochs: 100[0m
2025-09-30 10:17:37 [32mINFO     [0m 2906090605.py: Batch size: 8[0m
2025-09-30 10:17:37 [32mINFO     [0m 2906090605.py: Accumulation steps: 4[0m
2025-09-30 10:17:37 [32mINFO     [0m 2906090605.py: Learning rate: 0.001[0m
2025-09-30 10:17:37 [3

Model saved to output//best_val_loss.pth
Model saved to output//best_ms_ssim.pth
Model saved to output//best_ad.pth


2025-09-30 10:39:12 [32mINFO     [0m 2906090605.py: üìä Best Train Loss: 57414.4472 (epoch 5), Best Val Loss: 54280.5612 (epoch 5), Best MS-SSIM: 0.5864 (epoch 5), Best AD: 0.8803 (epoch 5)[0m
2025-09-30 10:39:17 [32mINFO     [0m 2906090605.py: [  6/100] [  10/600] Loss: 54103.1289 (avg: 53869.3555) Time: 0.38s (avg: 0.39s) LR: 5.01e-04[0m
2025-09-30 10:39:21 [32mINFO     [0m 2906090605.py: [  6/100] [  20/600] Loss: 53314.7188 (avg: 54019.5053) Time: 0.39s (avg: 0.39s) LR: 5.01e-04[0m
2025-09-30 10:39:25 [32mINFO     [0m 2906090605.py: [  6/100] [  30/600] Loss: 53631.5039 (avg: 53930.8203) Time: 0.38s (avg: 0.39s) LR: 5.01e-04[0m
2025-09-30 10:39:29 [32mINFO     [0m 2906090605.py: [  6/100] [  40/600] Loss: 53775.8086 (avg: 53913.9609) Time: 0.39s (avg: 0.39s) LR: 5.01e-04[0m
2025-09-30 10:39:32 [32mINFO     [0m 2906090605.py: [  6/100] [  50/600] Loss: 53603.5859 (avg: 53849.9024) Time: 0.38s (avg: 0.38s) LR: 5.01e-04[0m
2025-09-30 10:39:36 [32mINFO     [0m 2906