### Metrics Logging

In [None]:
import json
import pandas as pd

def save_training_logs(log_dict, output_path="./outputs/training_log.json"):
    """
    Save the training history to a JSON file.
    
    Args:
        log_dict: Dictionary containing training metrics for each epoch
                  Format: {epoch_num: {'loss': ..., 'psnr': ..., 'val_loss': ..., 'val_psnr': ...}}
        output_path: Where to save the file
    
    Example of what gets saved:
    {
        "1": {"loss": 0.123, "psnr": 25.4, "val_loss": 0.145, "val_psnr": 24.1},
        "2": {"loss": 0.098, "psnr": 27.2, "val_loss": 0.112, "val_psnr": 26.5},
        ...
    }
    """
    # Make sure the directory exists
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    # Save as JSON (human-readable format)
    with open(output_path, 'w') as f:
        json.dump(log_dict, f, indent=2)
    
    print(f"Training logs saved to: {output_path}")
    
    # Also save as CSV for easy viewing in Excel
    csv_path = output_path.replace('.json', '.csv')
    
    # Convert the dictionary to a pandas DataFrame
    # Each row is one epoch, columns are the metrics
    df = pd.DataFrame.from_dict(log_dict, orient='index')
    df.index.name = 'epoch'
    df.to_csv(csv_path)
    
    print(f"Training logs also saved as CSV: {csv_path}")
    
    return df  # Return the dataframe in case we want to use it

In [None]:
def plot_training_history(log_dict, save_path="./outputs/training_curves.png"):
    """
    Create graphs showing how the model improved during training.
    
    This creates two graphs:
    1. Loss over time (training and validation)
    2. PSNR over time (training and validation)
    
    Args:
        log_dict: Dictionary with training history
        save_path: Where to save the graph image
    """
    # Convert dictionary to pandas DataFrame for easier plotting
    df = pd.DataFrame.from_dict(log_dict, orient='index')
    df.index = df.index.astype(int)  # Make sure epoch numbers are integers
    df = df.sort_index()  # Sort by epoch number
    
    # Check if validation metrics exist
    has_val_metrics = 'val_loss' in df.columns and 'val_psnr' in df.columns
    
    # Create a figure with 2 subplots side by side
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot 1: Loss over epochs
    axes[0].plot(df.index, df['loss'], label='Training Loss', marker='o', linewidth=2)
    if has_val_metrics:
        axes[0].plot(df.index, df['val_loss'], label='Validation Loss', marker='s', linewidth=2)
    axes[0].set_xlabel('Epoch', fontsize=12)
    axes[0].set_ylabel('Loss', fontsize=12)
    axes[0].set_title('Loss over Training', fontsize=14, fontweight='bold')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Plot 2: PSNR over epochs
    axes[1].plot(df.index, df['psnr'], label='Training PSNR', marker='o', linewidth=2)
    if has_val_metrics:
        axes[1].plot(df.index, df['val_psnr'], label='Validation PSNR', marker='s', linewidth=2)
    axes[1].set_xlabel('Epoch', fontsize=12)
    axes[1].set_ylabel('PSNR (dB)', fontsize=12)
    axes[1].set_title('PSNR over Training', fontsize=14, fontweight='bold')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save the figure
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    print(f"Training curves saved to: {save_path}")
    
    # Display the plot in the notebook
    plt.show()
    
    # Print summary statistics
    print("\n=== Training Summary ===")
    print(f"Best Training Loss: {df['loss'].min():.6f} (Epoch {df['loss'].idxmin()})")
    if has_val_metrics:
        print(f"Best Validation Loss: {df['val_loss'].min():.6f} (Epoch {df['val_loss'].idxmin()})")
    print(f"Best Training PSNR: {df['psnr'].max():.2f} dB (Epoch {df['psnr'].idxmax()})")
    if has_val_metrics:
        print(f"Best Validation PSNR: {df['val_psnr'].max():.2f} dB (Epoch {df['val_psnr'].idxmax()})")

### Testing Code

In [None]:
def test_model(model, test_loader, device, checkpoint_path="./outputs/best.pt"):
    """
    Test the trained model on the test dataset.
    
    Args:
        model: The autoencoder model
        test_loader: DataLoader for test images
        device: 'cuda' or 'cpu'
        checkpoint_path: Path to the saved best model
    
    Returns:
        Dictionary with test metrics and some example images
    """
    print("=" * 60)
    print("TESTING THE MODEL")
    print("=" * 60)
    
    # Load the best model weights
    print(f"\nLoading best model from: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model'])
    print(f"Model loaded from epoch {checkpoint['epoch']}")
    
    # Put model in evaluation mode (important!)
    # This turns off things like dropout that are only used during training
    model.eval()
    model.to(device)
    
    # We'll calculate metrics and store some examples
    test_loss_meter = AverageMeter()
    test_psnr_meter = AverageMeter()
    
    # Store some example images to visualize later
    example_noisy = []
    example_denoised = []
    example_clean = []
    
    print("\nRunning model on test images...")
    
    # We use torch.no_grad() because we're not training, just testing
    # This saves memory and makes things faster
    with torch.no_grad():
        for batch_idx, (noisy, clean) in enumerate(test_loader):
            # Move images to the device (GPU or CPU)
            noisy = noisy.to(device)
            clean = clean.to(device)
            
            # Pass noisy images through the model to get denoised version
            denoised = model(noisy)
            
            # Calculate loss (how different is denoised from clean?)
            loss = F.mse_loss(denoised, clean)
            
            # Calculate PSNR for each image in the batch
            batch_mse = F.mse_loss(denoised, clean, reduction='none').view(clean.size(0), -1).mean(dim=1)
            batch_psnr = psnr_from_mse(batch_mse).mean().item()
            
            # Update our running averages
            test_loss_meter.update(loss.item(), n=clean.size(0))
            test_psnr_meter.update(batch_psnr, n=clean.size(0))
            
            # Save first 8 images as examples
            if batch_idx == 0:
                num_examples = min(8, noisy.size(0))
                example_noisy = noisy[:num_examples].cpu()
                example_denoised = denoised[:num_examples].cpu()
                example_clean = clean[:num_examples].cpu()
    
    # Print final test results
    print("\n" + "=" * 60)
    print("TEST RESULTS")
    print("=" * 60)
    print(f"Test Loss (MSE): {test_loss_meter.avg:.6f}")
    print(f"Test PSNR: {test_psnr_meter.avg:.2f} dB")
    print("=" * 60)
    
    # Return everything
    return {
        'test_loss': test_loss_meter.avg,
        'test_psnr': test_psnr_meter.avg,
        'example_noisy': example_noisy,
        'example_denoised': example_denoised,
        'example_clean': example_clean
    }

In [None]:
def visualize_results(example_noisy, example_denoised, example_clean, save_path="./outputs/test_results.png"):
    """
    Create a visual comparison of noisy, denoised, and clean images.
    
    This creates a grid showing:
    - Row 1: Noisy input images
    - Row 2: Model's denoised output
    - Row 3: Ground truth (actual clean images)
    
    Args:
        example_noisy: Tensor of noisy images
        example_denoised: Tensor of denoised images from model
        example_clean: Tensor of ground truth clean images
        save_path: Where to save the comparison image
    """
    
    def denormalize(img):
        """Convert images from [-1, 1] back to [0, 1] for display"""
        return (img + 1) / 2
    
    num_examples = example_noisy.size(0)
    
    # Create figure with 3 rows (noisy, denoised, clean) and multiple columns
    fig, axes = plt.subplots(3, num_examples, figsize=(num_examples * 2, 6))
    
    # If we only have one example, axes won't be 2D, so fix that
    if num_examples == 1:
        axes = axes.reshape(3, 1)
    
    for i in range(num_examples):
        # Denormalize images (convert from [-1,1] to [0,1])
        noisy_img = denormalize(example_noisy[i]).permute(1, 2, 0).numpy()
        denoised_img = denormalize(example_denoised[i]).permute(1, 2, 0).numpy()
        clean_img = denormalize(example_clean[i]).permute(1, 2, 0).numpy()
        
        # Clip to valid range [0, 1]
        noisy_img = np.clip(noisy_img, 0, 1)
        denoised_img = np.clip(denoised_img, 0, 1)
        clean_img = np.clip(clean_img, 0, 1)
        
        # Row 1: Noisy images
        axes[0, i].imshow(noisy_img)
        axes[0, i].axis('off')
        if i == 0:
            axes[0, i].set_title('Noisy Input', fontsize=12, fontweight='bold')
        
        # Row 2: Denoised images
        axes[1, i].imshow(denoised_img)
        axes[1, i].axis('off')
        if i == 0:
            axes[1, i].set_title('Model Output\n(Denoised)', fontsize=12, fontweight='bold')
        
        # Row 3: Clean ground truth
        axes[2, i].imshow(clean_img)
        axes[2, i].axis('off')
        if i == 0:
            axes[2, i].set_title('Ground Truth\n(Clean)', fontsize=12, fontweight='bold')
    
    plt.tight_layout()
    
    # Save the figure
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    print(f"\nTest result visualization saved to: {save_path}")
    
    # Display in notebook
    plt.show()

In [None]:
def save_test_metrics(test_results, output_path="./outputs/test_metrics.json"):
    """
    Save the final test metrics to a file.
    
    Args:
        test_results: Dictionary with test_loss and test_psnr
        output_path: Where to save the metrics
    """
    # Remove the image tensors from what we save (they're too big)
    metrics = {
        'test_loss': test_results['test_loss'],
        'test_psnr': test_results['test_psnr']
    }
    
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    with open(output_path, 'w') as f:
        json.dump(metrics, f, indent=2)
    
    print(f"Test metrics saved to: {output_path}")

### Example Usage

In [None]:
# STEP 1: After training is complete, you'll have log_dict from the fit() function
# The fit() function returns: best_loss, log_dict = fit(...)
# Let's assume log_dict is already available from the training cell above

# STEP 2: Save and visualize training logs
print("Saving training logs...")
training_df = save_training_logs(log_dict)

print("\nCreating training visualizations...")
plot_training_history(log_dict)

In [None]:
# STEP 3: Test the model on the test dataset
print("\n" + "="*60)
print("TESTING PHASE")
print("="*60)

test_results = test_model(
    model=model,
    test_loader=test_NIDS_loader,
    device=device,
    checkpoint_path="./outputs/best.pt"  # This is where fit() saved the best model
)

# STEP 4: Save test metrics
save_test_metrics(test_results)

# STEP 5: Visualize the results
print("\nCreating result visualizations...")
visualize_results(
    test_results['example_noisy'],
    test_results['example_denoised'],
    test_results['example_clean']
)

print("\n" + "="*60)
print("ALL DONE!")
print("="*60)
print("\nCheck the ./outputs folder for:")
print("  - training_log.json & .csv (training history)")
print("  - training_curves.png (loss and PSNR graphs)")
print("  - test_metrics.json (final test performance)")
print("  - test_results.png (visual comparison of images)")
print("  - best.pt (best model checkpoint)")
print("  - last.pt (last epoch checkpoint)")