# Inference Procedure: Apply Pre-trained Model and Convert to LOS Magnetogram

This notebook:
1. Loads the pre-trained Stokes profile model from checkpoint
2. Applies it to the training set to verify reasonable outputs
3. Converts the predicted I and V profiles to line-of-sight (LOS) magnetogram using integration method

## Setup and Imports

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"  # Use GPU 2

import sys
import torch
import numpy as np
import yaml
from pathlib import Path
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt

# Add paths for imports
sys.path.append("../../")
sys.path.append("../../Surya")

from surya.utils.data import build_scalers
from datasets.stokes_profile_dataset import StokesProfileDataset
from models.stokes_baseline import StokesBaselineModel
from lightning_modules.pl_stokes_baseline import StokesLightningModule
from metrics.stokes_metrics import StokesMetrics
from inference_stokes import load_trained_model, stokes_to_los_magnetogram, run_inference_on_training_set

torch.set_float32_matmul_precision('medium')

## Find the Best Checkpoint

First, let's find the checkpoint from the training run. The checkpoint is saved in `runs/stokes_baseline/version_X/checkpoints/`.

In [None]:
# Find the latest checkpoint
import glob

checkpoint_dir = Path("runs/stokes_baseline")
checkpoints = list(checkpoint_dir.glob("**/*.ckpt"))

if checkpoints:
    # Sort by modification time, get the latest
    latest_checkpoint = max(checkpoints, key=lambda p: p.stat().st_mtime)
    print(f"‚úÖ Found checkpoint: {latest_checkpoint}")
    checkpoint_path = str(latest_checkpoint)
else:
    print("‚ö†Ô∏è  No checkpoint found. Please train the model first using 2_stokes_baseline.ipynb")
    checkpoint_path = None

## Load Configuration and Dataset

## Load Trained Model

## Run Inference on Training Set

Apply the model to the training set and evaluate the results.

In [None]:
if checkpoint_path and trained_model:
    # Run inference using the script
    summary, all_los_pred, all_los_target = run_inference_on_training_set(
        checkpoint_path=checkpoint_path,
        config_path=config_path,
        hmi_b_dir=hmi_b_dir,
        output_dir="./inference_results",
        max_samples=1,  # Process 1 sample for testing (set to None for all)
        device='cuda'
    )
    
    print("\n" + "="*60)
    print("Inference Results Summary")
    print("="*60)
    for key, value in summary.items():
        if isinstance(value, dict):
            print(f"\n{key}:")
            for k, v in value.items():
                print(f"  {k}: {v:.6f}")
        else:
            print(f"{key}: {value}")
else:
    print("‚ö†Ô∏è  Cannot run inference: model not loaded")

## Manual Inference Example

Alternatively, you can run inference manually on a single sample:

In [None]:
if checkpoint_path and trained_model:
    # Create a small dataset for manual testing
    test_dataset = StokesProfileDataset(
        index_path=config["data"]["train_data_path"],
        hmi_b_dir=hmi_b_dir,
        time_delta_input_minutes=config["data"]["time_delta_input_minutes"],
        time_delta_target_minutes=config["data"]["time_delta_target_minutes"],
        n_input_timestamps=config["model"]["time_embedding"]["time_dim"],
        rollout_steps=config["rollout_steps"],
        channels=config["data"]["channels"],
        drop_hmi_probability=config["drop_hmi_probability"],
        use_latitude_in_learned_flow=config["use_latitude_in_learned_flow"],
        scalers=scalers,
        phase="train",
        s3_use_simplecache=True,
        s3_cache_dir="/tmp/helio_s3_cache",
        wavelengths=wavelengths,
        pixel_batch_size=10000,
        device='cpu',
        max_number_of_samples=1,
    )
    
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)
    
    # Get a sample
    batch = next(iter(test_loader))
    print(f"Input shape: {batch['stokes_input'].shape}")
    print(f"Target shape: {batch['forecast'].shape}")
    
    # Move to device
    batch['stokes_input'] = batch['stokes_input'].to(device)
    
    # Run inference
    with torch.no_grad():
        predictions = trained_model(batch)  # [B, 4, n_wavelengths, H, W]
    
    print(f"Prediction shape: {predictions.shape}")
    
    # Move back to CPU
    predictions = predictions.cpu()
    targets = batch['forecast'].cpu()
    
    # Extract I and V profiles (indices 0=I, 3=V)
    pred_I = predictions[0, 0, :, :, :].numpy()  # [n_wavelengths, H, W]
    pred_V = predictions[0, 3, :, :, :].numpy()  # [n_wavelengths, H, W]
    target_I = targets[0, 0, :, :, :].numpy()
    target_V = targets[0, 3, :, :, :].numpy()
    
    print(f"\n‚úÖ Inference complete!")
    print(f"   Predicted I shape: {pred_I.shape}")
    print(f"   Predicted V shape: {pred_V.shape}")
else:
    print("‚ö†Ô∏è  Cannot run inference: model not loaded")

## Convert Stokes I and V to Line-of-Sight Magnetogram

Convert the predicted I and V profiles to a line-of-sight (LOS) magnetogram using integration method.

In [None]:
if checkpoint_path and trained_model:
    # Convert predicted Stokes I and V to LOS magnetogram
    los_magnetogram_pred = stokes_to_los_magnetogram(
        pred_I, pred_V, wavelengths,
        lambda_rest=6173.15,
        geff=2.5
    )
    
    # Convert target Stokes I and V to LOS magnetogram
    los_magnetogram_target = stokes_to_los_magnetogram(
        target_I, target_V, wavelengths,
        lambda_rest=6173.15,
        geff=2.5
    )
    
    print(f"‚úÖ LOS magnetogram conversion complete!")
    print(f"   Predicted LOS shape: {los_magnetogram_pred.shape}")
    print(f"   Target LOS shape: {los_magnetogram_target.shape}")
    print(f"   Predicted LOS range: [{los_magnetogram_pred.min():.2f}, {los_magnetogram_pred.max():.2f}] G")
    print(f"   Target LOS range: [{los_magnetogram_target.min():.2f}, {los_magnetogram_target.max():.2f}] G")
    
    # Compute LOS magnetogram metrics
    los_mse = np.mean((los_magnetogram_pred - los_magnetogram_target) ** 2)
    los_mae = np.mean(np.abs(los_magnetogram_pred - los_magnetogram_target))
    
    print(f"\nüìä LOS Magnetogram Metrics:")
    print(f"   MSE: {los_mse:.6f}")
    print(f"   MAE: {los_mae:.6f}")
else:
    print("‚ö†Ô∏è  Cannot convert to LOS magnetogram: model not loaded")

## Visualize Results

Visualize the predicted and target LOS magnetograms, and compare Stokes V profiles.

In [None]:
if checkpoint_path and trained_model:
    # Create visualization
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # Plot LOS magnetograms
    vmin = min(los_magnetogram_pred.min(), los_magnetogram_target.min())
    vmax = max(los_magnetogram_pred.max(), los_magnetogram_target.max())
    
    im1 = axes[0, 0].imshow(los_magnetogram_pred, cmap='RdBu', vmin=vmin, vmax=vmax, origin='lower')
    axes[0, 0].set_title('Predicted LOS Magnetogram (G)', fontsize=14, fontweight='bold')
    axes[0, 0].axis('off')
    plt.colorbar(im1, ax=axes[0, 0], label='B_los (G)')
    
    im2 = axes[0, 1].imshow(los_magnetogram_target, cmap='RdBu', vmin=vmin, vmax=vmax, origin='lower')
    axes[0, 1].set_title('Target LOS Magnetogram (G)', fontsize=14, fontweight='bold')
    axes[0, 1].axis('off')
    plt.colorbar(im2, ax=axes[0, 1], label='B_los (G)')
    
    diff = los_magnetogram_pred - los_magnetogram_target
    im3 = axes[0, 2].imshow(diff, cmap='RdBu', origin='lower')
    axes[0, 2].set_title(f'Difference (Pred - Target)\nMAE: {los_mae:.2f} G', fontsize=14, fontweight='bold')
    axes[0, 2].axis('off')
    plt.colorbar(im3, ax=axes[0, 2], label='ŒîB_los (G)')
    
    # Plot Stokes V profiles for selected pixels
    H, W = los_magnetogram_pred.shape
    center_h, center_w = H // 2, W // 2
    
    pixels_to_plot = [
        (center_h, center_w, 'Center'),
        (center_h + 50 if center_h + 50 < H else center_h, center_w, 'Offset 1'),
        (center_h, center_w + 50 if center_w + 50 < W else center_w, 'Offset 2'),
    ]
    
    for idx, (h, w, label) in enumerate(pixels_to_plot):
        if h < pred_V.shape[1] and w < pred_V.shape[2]:
            axes[1, idx].plot(wavelengths, pred_V[:, h, w], 'b-', label='Predicted V', linewidth=2, alpha=0.8)
            axes[1, idx].plot(wavelengths, target_V[:, h, w], 'r--', label='Target V', linewidth=2, alpha=0.8)
            axes[1, idx].set_xlabel('Wavelength (√Ö)', fontsize=12)
            axes[1, idx].set_ylabel('Stokes V', fontsize=12)
            axes[1, idx].set_title(f'Stokes V Profile - {label}\nLOS: Pred={los_magnetogram_pred[h, w]:.1f}G, Target={los_magnetogram_target[h, w]:.1f}G', 
                                  fontsize=12, fontweight='bold')
            axes[1, idx].legend(fontsize=10)
            axes[1, idx].grid(True, alpha=0.3)
            axes[1, idx].axvline(6173.15, color='gray', linestyle=':', alpha=0.5)
    
    plt.tight_layout()
    plt.savefig("./inference_results/los_magnetogram_comparison.png", dpi=150, bbox_inches='tight')
    print("‚úÖ Visualization saved to ./inference_results/los_magnetogram_comparison.png")
    plt.show()
else:
    print("‚ö†Ô∏è  Cannot visualize: model not loaded")

## Save Results

Save the LOS magnetograms and comparison results.

In [None]:
if checkpoint_path and trained_model:
    output_dir = Path("./inference_results")
    output_dir.mkdir(exist_ok=True)
    
    # Save LOS magnetograms
    np.save(output_dir / "los_magnetogram_predicted.npy", los_magnetogram_pred)
    np.save(output_dir / "los_magnetogram_target.npy", los_magnetogram_target)
    np.save(output_dir / "los_magnetogram_difference.npy", diff)
    
    # Save Stokes profiles
    np.save(output_dir / "stokes_I_predicted.npy", pred_I)
    np.save(output_dir / "stokes_V_predicted.npy", pred_V)
    np.save(output_dir / "stokes_I_target.npy", target_I)
    np.save(output_dir / "stokes_V_target.npy", target_V)
    np.save(output_dir / "wavelengths.npy", wavelengths)
    
    print("‚úÖ Results saved to ./inference_results/")
    print(f"   - LOS magnetograms (predicted, target, difference)")
    print(f"   - Stokes I and V profiles (predicted and target)")
    print(f"   - Wavelength array")
else:
    print("‚ö†Ô∏è  Cannot save results: model not loaded")

## Summary

This notebook demonstrates:
1. ‚úÖ Loading the pre-trained Stokes profile model
2. ‚úÖ Running inference on the training set
3. ‚úÖ Converting predicted I and V profiles to line-of-sight magnetogram using integration
4. ‚úÖ Visualizing and comparing results

The integration method converts Stokes I and V profiles to LOS magnetogram using:
- Area under V profile relative to I profile
- Scaling based on effective Land√© factor and rest wavelength
- Simple approximation suitable for initial testing

**Note:** For production use, more sophisticated inversion methods (like the Milne-Eddington inversion) should be used for better accuracy.