# Case Study: Model Predictions Visualization

This notebook visualizes 24-hour forecast comparisons for different nodes and days.
It loads prediction data from `.npy` files and creates a 2x4 grid of subplots.

## Configuration Parameters

Adjust these parameters to change data loading and slicing behavior.

In [None]:
# Import required libraries
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from pathlib import Path

# =============================================================================
# Configuration Parameters
# =============================================================================

# Data paths
RECORD_DIR = Path.cwd() / "record"

# Model mapping
MODELS = {
    "Ground Truth": "TRUE",
    "MODE": "MODE",
    "S_Mamba": "SMAM",
    "iTransformer": "ITRANS",
    "PatchTST": "PATST"
}

# Colors and line styles for each model
COLORS = {
    'Ground Truth': "#A0A0A0",   # Dark Gray
    'MODE': "#2D22FF",           # Purple
    'S_Mamba': '#00AAEE',        # Blue
    'iTransformer': '#3FB704',   # Green
    'PatchTST': "#EFBF00"        # Light Green
}

LINE_STYLES = {
    'Ground Truth': '--',        # Dashed
    'MODE': '-',
    'S_Mamba': '-',
    'iTransformer': '-',
    'PatchTST': '-'
}

LINE_WIDTHS = {
    'Ground Truth': 1.5,
    'MODE': 1.5,                 # Bold
    'S_Mamba': 1.5,
    'iTransformer': 1.5,
    'PatchTST': 1.5
}

# Z-order for plotting (higher value = on top)
Z_ORDERS = {
    'MODE': 5,
    'Ground Truth': 4,
    'S_Mamba': 3,
    'iTransformer': 2,
    'PatchTST': 1
}

# Data slicing configuration
# Day segmentation for different prediction lengths
# Each day is 96 time steps (24 hours at 15-minute intervals)
# pred_len=96: 1 day available
# pred_len=192: 2 days available (0,96) and (96,192)
# pred_len=384: 4 days available (0,96), (96,192), (192,288), (288,384)

def get_time_window(pred_len, day):
    '''Get the time window for a specific day from pred_len file.
    
    Args:
        pred_len: The prediction length (96, 192, or 384)
        day: Day index (0 for Day 1, 1 for Day 2, 2 for Day 3, 3 for Day 4)
              This is the offset from the sample's starting position.
    
    Returns:
        start_idx, end_idx: The start and end indices for slicing
    '''
    # Important: All files start from the same position in the test set.
    # TRUE_96[sample] contains Day N (24 hours)
    # TRUE_192[sample, :96] contains Day N (first 24 hours)
    # TRUE_192[sample, 96:192] contains Day N+1 (next 24 hours)
    # TRUE_384[sample, :96] contains Day N (first 24 hours)
    # TRUE_384[sample, 96:192] contains Day N+1 (second 24 hours)
    # TRUE_384[sample, 192:288] contains Day N+2 (third 24 hours)
    # TRUE_384[sample, 288:384] contains Day N+3 (fourth 24 hours)
    
    if pred_len == 96:  # Only 1 day of data
        if day == 0:    # Day N (only day available)
            return 0, 96
        else:
            raise ValueError(f'Day {day+1} not available in pred_len={pred_len} (only 1 day)')
            
    elif pred_len == 192:  # 2 days of data
        if day == 0:         # Day N (first 24 hours)
            return 0, 96
        elif day == 1:       # Day N+1 (second 24 hours)
            return 96, 192
        else:
            raise ValueError(f'Day {day+1} not available in pred_len={pred_len} (only 2 days)')
    
    elif pred_len == 384:    # 4 days of data
        if day == 0:         # Day N (first 24 hours)
            return 0, 96
        elif day == 1:       # Day N+1 (second 24 hours)
            return 96, 192
        elif day == 2:       # Day N+2 (third 24 hours)
            return 192, 288
        elif day == 3:       # Day N+3 (fourth 24 hours)
            return 288, 384
        else:
            raise ValueError(f'Day {day+1} not available in pred_len={pred_len} (only 4 days)')
    
    else:
        raise ValueError(f'Unsupported pred_len: {pred_len}')

## Define Experiments

Each experiment specifies which sample, day, and node to visualize.

### Data Sampling Logic:
- All pred_len files (96, 192, 384) start from the same position in the test set
- TRUE_96[sample] contains data for Day N (24 hours)
- TRUE_192[sample, :96] contains data for Day N (first 24 hours)
- TRUE_192[sample, 96:192] contains data for Day N+1 (next 24 hours)
- TRUE_384[sample, :96] contains data for Day N (first 24 hours)
- TRUE_384[sample, 96:192] contains data for Day N+1 (second 24 hours)
- TRUE_384[sample, 192:288] contains data for Day N+2 (third 24 hours)
- TRUE_384[sample, 288:384] contains data for Day N+3 (fourth 24 hours)

### Experiment Design:
We test 4 nodes (variables), each with 2 days for comparison.
- Node 1 (HUFL): Day 1 and Day 2 (consecutive)
- Node 2 (HULL): Day 3 and Day 4 (consecutive)
- Node 3 (MUFL): Day 1 and Day 3 (comparing early vs mid)
- Node 4 (MULL): Day 2 and Day 4 (comparing early-mid vs late)

For each pair, we use the same sample_idx to ensure temporal alignment.

In [None]:
# Define experiments: (sample_idx, day_offset, node)
# Layout: 2 rows, 4 columns (2x4 grid)
# Format: (sample_idx, day_offset, node_index)
# - sample_idx: Starting position in test set (all files share same start)
# - day_offset: 0=Day 1, 1=Day 2, 2=Day 3, 3=Day 4
# - node_index: 0=HUFL, 1=HULL, 2=MUFL, 3=MULL
# 
# For each experiment (Day X, Node Y):
# - day_offset = X - 1
# - sample_idx = X - day_offset - 1 = 0 (for all experiments when X=1-4 and day_offset=0-3)

EXPERIMENTS = [
    # Row 0: Experiments 0-3
    (0, 0, 0),   # Col 0: Day 1, Node 1 (pred_len=96)
    (0, 2, 0),   # Col 1: Day 3, Node 1 (pred_len=384)
    (0, 1, 1),   # Col 2: Day 2, Node 2 (pred_len=192)
    (0, 2, 1),   # Col 3: Day 3, Node 2 (pred_len=384)
    
    # Row 1: Experiments 4-7
    (0, 1, 2),   # Col 0: Day 2, Node 3 (pred_len=192)
    (0, 2, 2),   # Col 1: Day 3, Node 3 (pred_len=384)
    (0, 0, 3),   # Col 2: Day 1, Node 4 (pred_len=96)
    (0, 3, 3),   # Col 3: Day 4, Node 4 (pred_len=384)
]

# Mapping verification:
# - Exp 0: sample 0, day_offset 0 -> Day 1, Node 1, pred_len=96
# - Exp 1: sample 0, day_offset 2 -> Day 3, Node 1, pred_len=384  
# - Exp 2: sample 0, day_offset 1 -> Day 2, Node 2, pred_len=192
# - Exp 3: sample 0, day_offset 2 -> Day 3, Node 2, pred_len=384
# - Exp 4: sample 0, day_offset 1 -> Day 2, Node 3, pred_len=192
# - Exp 5: sample 0, day_offset 2 -> Day 3, Node 3, pred_len=384
# - Exp 6: sample 0, day_offset 0 -> Day 1, Node 4, pred_len=96
# - Exp 7: sample 0, day_offset 3 -> Day 4, Node 4, pred_len=384

## Helper Functions

Define functions for loading data and creating plots.

In [None]:
def get_time_window(pred_len, day):
    '''Get the time window for a specific day from pred_len file.
    
    Args:
        pred_len: The prediction length (96, 192, or 384)
        day: Day index (0 for Day 1, 1 for Day 2, 2 for Day 3, 3 for Day 4)
              This is the offset from the sample's starting position.
    
    Returns:
        start_idx, end_idx: The start and end indices for slicing
    '''
    # Important: All files start from the same position in the test set.
    # TRUE_96[sample] contains Day N (24 hours)
    # TRUE_192[sample, :96] contains Day N (24 hours)
    # TRUE_192[sample, 96:192] contains Day N+1 (next 24 hours)
    # TRUE_384[sample, :96] contains Day N (24 hours)
    # TRUE_384[sample, 96:192] contains Day N+1 (next 24 hours)
    # TRUE_384[sample, 192:288] contains Day N+2 (next 24 hours)
    # TRUE_384[sample, 288:384] contains Day N+3 (next 24 hours)
    
    if pred_len == 96:  # Only 1 day of data
        if day == 0:    # Day N (only day available)
            return 0, 96
        else:
            raise ValueError(f'Day {day+1} not available in pred_len={pred_len} (only 1 day)')
            
    elif pred_len == 192:  # 2 days of data
        if day == 0:         # Day N (first 24 hours)
            return 0, 96
        elif day == 1:       # Day N+1 (second 24 hours)
            return 96, 192
        else:
            raise ValueError(f'Day {day+1} not available in pred_len={pred_len} (only 2 days)')
    
    elif pred_len == 384:    # 4 days of data
        if day == 0:         # Day N (first 24 hours)
            return 0, 96
        elif day == 1:       # Day N+1 (second 24 hours)
            return 96, 192
        elif day == 2:       # Day N+2 (third 24 hours)
            return 192, 288
        elif day == 3:       # Day N+3 (fourth 24 hours)
            return 288, 384
        else:
            raise ValueError(f'Day {day+1} not available in pred_len={pred_len} (only 4 days)')
    
    else:
        raise ValueError(f'Unsupported pred_len: {pred_len}')

In [None]:
def load_predictions(pred_len):
    '''Load predictions for a given prediction length.'''
    predictions = {}
    
    for model_name, model_prefix in MODELS.items():
        if model_name == 'Ground Truth':
            filename = RECORD_DIR / f'{model_prefix}_{pred_len}.npy'
        else:
            filename = RECORD_DIR / f'{model_prefix}_{pred_len}_pred.npy'
        
        if filename.exists():
            predictions[model_name] = np.load(filename)
            print(f'✓ Loaded {model_name} predictions from {filename}')
        else:
            print(f'⚠ File not found: {filename}')
            predictions[model_name] = None
    
    return predictions

In [None]:
def create_subplot(ax, sample_idx, day, node, predictions_all):
    '''Create a single subplot showing predictions for a specific day and node.'''
    
    # Determine which pred_len file to use based on day
    if day == 0:
        pred_len = 96
        title_day = 1
    elif day == 1:
        pred_len = 192
        title_day = 2
    else:
        pred_len = 384
        title_day = day + 1
    
    # Get correct predictions for this pred_len
    predictions = predictions_all.get(pred_len, {})
    
    # Find available models for this pred_len
    available_models = [m for m, data in predictions.items() if data is not None]
    
    if not available_models:
        ax.text(0.5, 0.5, 'No data available', ha='center', va='center', transform=ax.transAxes, fontsize=5, color='#404040')
        return
    
    # Get time window for this day
    start_idx, end_idx = get_time_window(pred_len, day)
    time_steps = end_idx - start_idx
    
    # Create time axis (0:00 to 24:00 hours)
    hours = np.linspace(0, 24, time_steps)
    
    # Plot each model
    for model_name in available_models:
        data = predictions[model_name]
        
        # Get values for the specified node and time window
        actual_sample_idx = min(sample_idx, data.shape[0] - 1)
        values = data[actual_sample_idx, start_idx:end_idx, node]
        
        ax.plot(hours, values,
                color=COLORS[model_name],
                linestyle=LINE_STYLES[model_name],
                label=model_name,
                linewidth=1.2 if model_name == 'Ground Truth' else 1.0,
                zorder=Z_ORDERS.get(model_name, 1))
    
    # Subplot title with node/day/pred_len
    ax.set_title(f'Node {node + 1} | Day {title_day} | Pred Len {pred_len}', fontsize=7, color='#404040', pad=6)
    
    # Spines
    for spine in ax.spines.values():
        spine.set_color('#404040')
        spine.set_linewidth(0.6)
    
    ax.tick_params(axis='both', which='both', colors='#404040', labelsize=5, width=0.6)
    
    ax.grid(True, color='#d0d0d0', linestyle='-', alpha=0.35)
    
    ax.set_xlim(0, 24)
    ax.set_xticks(np.arange(0, 25, 4))
    
    return available_models

## Load Data

Load prediction data for all required prediction lengths (96, 192, 384).

In [None]:
# Load all required pred_len files
print('Loading prediction data for different time horizons...')
predictions_all = {}
required_pred_lens = {96, 192, 384}

for pred_len in required_pred_lens:
    predictions = load_predictions(pred_len)
    predictions_all[pred_len] = predictions
    available_models = [m for m, data in predictions.items() if data is not None]
    print(f'  Pred_len {pred_len}: Available models = {available_models}')

In [None]:
# Inspect data shapes after loading
for pred_len, preds in predictions_all.items():
    print(f"Pred Len: {pred_len}")
    for model, data in preds.items():
        if data is not None:
            print(f"  {model}: {data.shape}")

## Create Visualization

Generate the 2x4 subplot grid with predictions for each node and day.

In [None]:
# Create figure with 2x4 subplots (2 rows, 4 columns)
fig, axes = plt.subplots(2, 4, figsize=(12, 4.5), dpi=300)

legend_models = None
for exp_idx, (sample_idx, day, node) in enumerate(EXPERIMENTS):
    row = exp_idx // 4
    col = exp_idx % 4
    ax = axes[row, col]
    
    available_models = create_subplot(ax, sample_idx, day, node, predictions_all)
    
    # Only keep one y-label on the leftmost subplot in the first row
    if not (row == 0 and col == 0):
        ax.set_ylabel('')
    
    # Only keep one x-label on the leftmost subplot in the second row
    if not (row == 1 and col == 0):
        ax.set_xlabel('')
    else:
        ax.set_xlabel('Time (Hours)', fontsize=6, color='#404040')
    
    if legend_models is None and available_models:
        legend_models = available_models

# Shared axis labels (single occurrence for clarity)
fig.supylabel('Value', fontsize=7, color='#404040', x=0.02)
fig.supxlabel('Time (Hours)', fontsize=7, color='#404040', y=0.04)

# Add legend
if legend_models:
    handles, labels = [], []
    # Ensure consistent order if possible, or just use available
    # Let's sort to put Ground Truth first, then MODE, then others
    sorted_models = []
    if 'Ground Truth' in legend_models:
        sorted_models.append('Ground Truth')
    if 'MODE' in legend_models:
        sorted_models.append('MODE')
    for m in legend_models:
        if m not in sorted_models:
            sorted_models.append(m)
            
    for model_name in sorted_models:
        line = plt.Line2D([0], [0], color=COLORS[model_name],
                        linestyle=LINE_STYLES[model_name],
                        linewidth=1.2 if model_name == 'Ground Truth' else 1.0)
        handles.append(line)
        labels.append(model_name)
    
    leg = fig.legend(handles, labels, loc='lower center', ncol=len(sorted_models),
               bbox_to_anchor=(0.5, -0.06), fontsize=6, frameon=False, 
               columnspacing=1.0, handlelength=2.0)
    
    # Bold 'MODE' in legend
    for text in leg.get_texts():
        if text.get_text() == 'MODE':
            text.set_weight('bold')

plt.tight_layout()
plt.subplots_adjust(bottom=0.22, wspace=0.12, hspace=0.25)

# Display the plot (instead of saving to file)
plt.show()

## 4x4 Grid Visualization with Scaled Predictions

This visualization shows predictions scaled using Ground Truth statistics for each day-node combination.

**Scaling method**: For each subplot (day d, node n):
1. Calculate GT mean and std for that specific day-node
2. Calculate prediction mean and std for that specific day-node  
3. Apply transformation: `scaled_pred = pred × scale + shift`
   - `scale = std(GT) / std(pred)`
   - `shift = mean(GT) - scale × mean(pred)`

This ensures that each subplot's prediction curve has the same mean and standard deviation as its Ground Truth, making them directly comparable.

In [None]:
# Define experiments for 4x4 grid: (sample_idx, day, node)
# Layout: 4 rows × 4 columns (Days × Nodes)
# Row 0: Day 1, Row 1: Day 2, Row 2: Day 3, Row 3: Day 4
# Col 0: Node 1, Col 1: Node 2, Col 2: Node 3, Col 3: Node 4
EXPERIMENTS_4X4 = [
    # Row 0: Day 1 (all nodes)
    (0, 0, 0),   # Day 1, Node 1
    (0, 0, 1),   # Day 1, Node 2  
    (0, 0, 2),   # Day 1, Node 3
    (0, 0, 3),   # Day 1, Node 4
    
    # Row 1: Day 2 (all nodes)
    (0, 1, 0),   # Day 2, Node 1
    (0, 1, 1),   # Day 2, Node 2
    (0, 1, 2),   # Day 2, Node 3
    (0, 1, 3),   # Day 2, Node 4
    
    # Row 2: Day 3 (all nodes)
    (0, 2, 0),   # Day 3, Node 1
    (0, 2, 1),   # Day 3, Node 2
    (0, 2, 2),   # Day 3, Node 3
    (0, 2, 3),   # Day 3, Node 4
    
    # Row 3: Day 4 (all nodes)
    (0, 3, 0),   # Day 4, Node 1
    (0, 3, 1),   # Day 4, Node 2
    (0, 3, 2),   # Day 4, Node 3
    (0, 3, 3),   # Day 4, Node 4
]

def get_scaled_predictions(sample_idx, day, node, pred_slice, gt_slice):
    '''Scale predictions using Ground Truth statistics.
    
    Args:
        sample_idx: Sample index in the data
        day: Day offset (0=Day 1, 1=Day 2, 2=Day 3, 3=Day 4)
        node: Node index (0-3)
        pred_slice: Model predictions for this time window
        gt_slice: Ground Truth for this time window
        
    Returns:
        scaled_pred: Scaled predictions
        scale: Scaling factor used
        shift: Shift value used
    '''
    # Calculate statistics for Ground Truth
    gt_mean = np.mean(gt_slice)
    gt_std = np.std(gt_slice)
    
    # Calculate statistics for predictions
    pred_mean = np.mean(pred_slice)
    pred_std = np.std(pred_slice)
    
    # Apply scaling formula from assets/9_案例分析.md
    # scale = gt_std / pred_std
    # shift = gt_mean - scale * pred_mean
    # transformed_pred = pred * scale + shift
    
    if pred_std > 1e-10:  # Avoid division by zero
        scale = gt_std / pred_std
        shift = gt_mean - scale * pred_mean
        scaled_pred = pred_slice * scale + shift
    else:
        # If prediction has no variance, just shift to match GT mean
        scale = 1.0
        shift = gt_mean - pred_mean
        scaled_pred = pred_slice + shift
    
    return scaled_pred, scale, shift

def create_subplot_scaled(ax, sample_idx, day, node, predictions_all):
    '''Create a single subplot with SCALED predictions for a specific day and node.'''
    
    # Determine which pred_len file to use based on day
    if day == 0:
        pred_len = 96
        title_day = 1
    elif day == 1:
        pred_len = 192
        title_day = 2
    else:
        pred_len = 384
        title_day = day + 1
    
    # Get correct predictions for this pred_len
    predictions = predictions_all.get(pred_len, {})
    
    # Find available models for this pred_len
    available_models = [m for m, data in predictions.items() if data is not None]
    
    if not available_models:
        ax.text(0.5, 0.5, 'No data available', ha='center', va='center', 
                transform=ax.transAxes, fontsize=5, color='#404040')
        return available_models
    
    # Get time window for this day
    start_idx, end_idx = get_time_window(pred_len, day)
    time_steps = end_idx - start_idx
    
    # Create time axis (0:00 to 24:00 hours)
    hours = np.linspace(0, 24, time_steps)
    
    # Get Ground Truth data for scaling
    if 'Ground Truth' in available_models:
        gt_data = predictions['Ground Truth']
        gt_slice = gt_data[sample_idx, start_idx:end_idx, node]
    else:
        print(f"Warning: No Ground Truth for pred_len={pred_len}")
        return available_models
    
    # Plot each model (including Ground Truth)
    for model_name in available_models:
        data = predictions[model_name]
        
        # Get values for the specified node and time window
        actual_sample_idx = min(sample_idx, data.shape[0] - 1)
        values = data[actual_sample_idx, start_idx:end_idx, node]
        
        # Apply scaling if not Ground Truth
        if model_name == 'Ground Truth':
            plot_values = values
            scale, shift = 1.0, 0.0
        else:
            plot_values, scale, shift = get_scaled_predictions(
                sample_idx, day, node, values, gt_slice
            )
        
        ax.plot(hours, plot_values,
                color=COLORS[model_name],
                linestyle=LINE_STYLES[model_name],
                label=model_name,
                linewidth=LINE_WIDTHS.get(model_name, 1.0),
                zorder=Z_ORDERS.get(model_name, 1))
    
    # Subplot title with node/day/pred_len
    ax.set_title(f'Node {node + 1} | Day {title_day} | Pred Len {pred_len}', 
                 fontsize=7, color='#404040', pad=6)
    
    # Spines
    for spine in ax.spines.values():
        spine.set_color('#404040')
        spine.set_linewidth(0.6)
    
    ax.tick_params(axis='both', which='both', colors='#404040', labelsize=5, width=0.6)
    
    ax.grid(True, color='#d0d0d0', linestyle='-', alpha=0.35)
    
    ax.set_xlim(0, 24)
    ax.set_xticks(np.arange(0, 25, 4))
    
    # Only show y-label on leftmost column
    if node == 0:
        ax.set_ylabel('Value', fontsize=6, color='#404040')
    
    # Only show x-label on bottom row
    if day == 3:  # Only bottom row
        ax.set_xlabel('Time (Hours)', fontsize=6, color='#404040')
    
    return available_models

# Create figure with 4x4 subplots (4 rows, 4 columns)
fig, axes = plt.subplots(4, 4, figsize=(14, 12), dpi=300)

legend_models = None

for exp_idx, (sample_idx, day, node) in enumerate(EXPERIMENTS_4X4):
    row = day  # Use day as row
    col = node  # Use node as column
    ax = axes[row, col]
    
    available_models = create_subplot_scaled(ax, sample_idx, day, node, predictions_all)
    
    if legend_models is None and available_models:
        legend_models = available_models

# Add shared labels
fig.supylabel('Value (Scaled)', fontsize=8, color='#404040', x=0.02)
fig.supxlabel('Time (Hours)', fontsize=8, color='#404040', y=0.04)

# Add legend
if legend_models:
    handles, labels = [], []
    # Ensure consistent order: Ground Truth first, then MODE, then others
    sorted_models = []
    if 'Ground Truth' in legend_models:
        sorted_models.append('Ground Truth')
    if 'MODE' in legend_models:
        sorted_models.append('MODE')
    for m in legend_models:
        if m not in sorted_models:
            sorted_models.append(m)
    
    for model_name in sorted_models:
        line = plt.Line2D([0], [0], color=COLORS[model_name],
                        linestyle=LINE_STYLES[model_name],
                        linewidth=LINE_WIDTHS.get(model_name, 1.0))
        handles.append(line)
        labels.append(model_name)
    
    leg = fig.legend(handles, labels, loc='lower center', ncol=min(len(sorted_models), 5),
               bbox_to_anchor=(0.5, -0.03), fontsize=6, frameon=False, 
               columnspacing=1.0, handlelength=2.0)
    
    # Bold 'MODE' in legend
    for text in leg.get_texts():
        if text.get_text() == 'MODE':
            text.set_weight('bold')

plt.tight_layout()
plt.subplots_adjust(bottom=0.08, wspace=0.15, hspace=0.25)

# Display the plot
plt.show()

In [None]:
# Extract selected subplots and redraw as 2x4 grid

# Based on the previous 4x4 scaled visualization
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

# Define which subplots to extract (1-indexed coordinates)
# Convert to 0-indexed for Python arrays
SELECTED_SUBPLOTS = [
    (0, 0),  # (1,1) - Day 1, Node 1
    (0, 2),  # (1,3) - Day 1, Node 3
    (1, 0),  # (2,1) - Day 2, Node 1
    (1, 2),  # (2,3) - Day 2, Node 3
    (2, 0),  # (3,1) - Day 3, Node 1
    (2, 1),  # (3,2) - Day 3, Node 2
    (2, 2),  # (3,3) - Day 3, Node 3
    (3, 3),  # (4,4) - Day 4, Node 4
]

# Create figure with 2x4 layout
# Reduced height to shorten the image
fig, axes = plt.subplots(2, 4, figsize=(14, 5), dpi=300)

legend_models = None

# Plot each selected subplot
for idx, ((orig_row, orig_col), ax) in enumerate(zip(SELECTED_SUBPLOTS, axes.flat)):
    # Map original coordinates to experiment
    exp_idx = orig_row * 4 + orig_col
    sample_idx, day, node = EXPERIMENTS_4X4[exp_idx]

    # Determine pred_len based on day
    if day == 0:
        pred_len = 96
        title_day = 1
    elif day == 1:
        pred_len = 192
        title_day = 2
    else:
        pred_len = 384
        title_day = day + 1

    # Get predictions for this pred_len
    predictions = predictions_all.get(pred_len, {})
    available_models = [m for m, data in predictions.items() if data is not None]

    if not available_models:
        ax.text(0.5, 0.5, 'No data available', ha='center', va='center',
                transform=ax.transAxes, fontsize=7, color='#404040')
        continue

    # Get time window
    start_idx, end_idx = get_time_window(pred_len, day)
    time_steps = end_idx - start_idx
    hours = np.linspace(0, 24, time_steps)

    # Get Ground Truth data for scaling
    if 'Ground Truth' in available_models:
        gt_data = predictions['Ground Truth']
        actual_sample_idx = min(sample_idx, gt_data.shape[0] - 1)
        gt_slice = gt_data[actual_sample_idx, start_idx:end_idx, node]
    else:
        print(f"Warning: No Ground Truth for pred_len={pred_len}")
        continue

    # Plot each model
    for model_name in available_models:
        data = predictions[model_name]
        actual_sample_idx = min(sample_idx, data.shape[0] - 1)
        values = data[actual_sample_idx, start_idx:end_idx, node]

        # Apply scaling if not Ground Truth
        if model_name == 'Ground Truth':
            plot_values = values
        else:
            plot_values, _, _ = get_scaled_predictions(
                sample_idx, day, node, values, gt_slice
            )

        ax.plot(hours, plot_values,
                color=COLORS[model_name],
                linestyle=LINE_STYLES[model_name],
                label=model_name,
                linewidth=LINE_WIDTHS.get(model_name, 1.5),
                zorder=Z_ORDERS.get(model_name, 1))

    # Set title 
    ax.set_title(f'Day {title_day} | Node {node + 1} | Pred Len {pred_len}',
                 fontsize=9, color='#404040', pad=6)

    # Styling
    for spine in ax.spines.values():
        spine.set_color('#404040')
        spine.set_linewidth(0.8)

    ax.tick_params(axis='both', which='both', colors='#404040', labelsize=7, width=0.8)
    ax.grid(True, color='#d0d0d0', linestyle='-', alpha=0.35)
    ax.set_xlim(0, 24)
    ax.set_xticks(np.arange(0, 25, 4))
    
    # --- Modified Y-Axis Ticks Logic ---
    # 1. Calculate range and target step
    ymin, ymax = ax.get_ylim()
    target_step = (ymax - ymin) / 3.0
    
    # 2. Round step to nearest 0.5 multiple (minimum 0.5)
    step = np.ceil(target_step * 2) / 2.0
    if step == 0: step = 0.5
    
    # 3. Find start point (multiple of 0.5)
    # Start slightly below ymin to ensure coverage
    start = np.floor(ymin * 2) / 2.0
    
    # 4. Ensure 4 ticks cover the data range
    # If the calculated range [start, start+3*step] doesn't cover ymax, increase step
    while start + 3 * step < ymax:
        step += 0.5
        
    y_ticks = [start, start + step, start + 2*step, start + 3*step]
    ax.set_yticks(y_ticks)
    ax.set_ylim(y_ticks[0], y_ticks[-1])
    
    # Format y-axis tick labels to show integer or one decimal
    ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, _: f'{x:.1f}' if x % 1 != 0 else f'{int(x)}'))
    # -----------------------------------

    # Only show y-label on leftmost column
    if idx % 4 == 0:
        ax.set_ylabel('Value', fontsize=8, color='#404040')

    # Only show x-label on bottom row
    if idx >= 4:
        ax.set_xlabel('Time (Hours)', fontsize=8, color='#404040')

    if legend_models is None:
        legend_models = available_models

# Add legend
if legend_models:
    handles, labels = [], []
    # Ensure consistent order: Ground Truth first, then MODE, then others
    sorted_models = []
    if 'Ground Truth' in legend_models:
        sorted_models.append('Ground Truth')
    if 'MODE' in legend_models:
        sorted_models.append('MODE')
    for m in legend_models:
        if m not in sorted_models:
            sorted_models.append(m)

    for model_name in sorted_models:
        line = plt.Line2D([0], [0], color=COLORS[model_name],
                        linestyle=LINE_STYLES[model_name],
                        linewidth=LINE_WIDTHS.get(model_name, 1.5))
        handles.append(line)
        labels.append(model_name)

    # Adjusted legend position to be closer to the plots
    leg = fig.legend(handles, labels, loc='lower center', ncol=min(len(sorted_models), 5),
               bbox_to_anchor=(0.5, -0.02), fontsize=8, frameon=False,
               columnspacing=1.0, handlelength=2.0)

    # Bold 'MODE' in legend
    for text in leg.get_texts():
        if text.get_text() == 'MODE':
            text.set_weight('bold')

plt.tight_layout()
# Adjust bottom margin to reduce space between legend and subplots
# Reduced bottom from 0.14 to 0.12 to shorten the gap
plt.subplots_adjust(bottom=0.12, wspace=0.15, hspace=0.25)

# Display the plot
plt.show()

In [None]:
# Define which subplots to extract (1-indexed coordinates)
# Convert to 0-indexed for Python arrays
SELECTED_SUBPLOTS = [
    (0, 0),  # (1,1) - Day 1, Node 1
    (0, 2),  # (1,3) - Day 1, Node 3
    (1, 0),  # (2,1) - Day 2, Node 1
    (1, 2),  # (2,3) - Day 2, Node 3
    (2, 0),  # (3,1) - Day 3, Node 1
    (2, 1),  # (3,2) - Day 3, Node 2
    (2, 2),  # (3,3) - Day 3, Node 3
    (3, 3),  # (4,4) - Day 4, Node 4
]

# Create figure with 2x4 layout
# Reduced height to shorten the image
fig, axes = plt.subplots(2, 4, figsize=(14, 5), dpi=300)

legend_models = None

# Plot each selected subplot
for idx, ((orig_row, orig_col), ax) in enumerate(zip(SELECTED_SUBPLOTS, axes.flat)):
    # Map original coordinates to experiment
    exp_idx = orig_row * 4 + orig_col
    sample_idx, day, node = EXPERIMENTS_4X4[exp_idx]

    # Determine pred_len based on day
    if day == 0:
        pred_len = 96
        title_day = 1
    elif day == 1:
        pred_len = 192
        title_day = 2
    else:
        pred_len = 384
        title_day = day + 1

    # Get predictions for this pred_len
    predictions = predictions_all.get(pred_len, {})
    available_models = [m for m, data in predictions.items() if data is not None]

    if not available_models:
        ax.text(0.5, 0.5, 'No data available', ha='center', va='center',
                transform=ax.transAxes, fontsize=7, color='#404040')
        continue

    # Get time window
    start_idx, end_idx = get_time_window(pred_len, day)
    time_steps = end_idx - start_idx
    hours = np.linspace(0, 24, time_steps)

    # Get Ground Truth data for scaling
    if 'Ground Truth' in available_models:
        gt_data = predictions['Ground Truth']
        actual_sample_idx = min(sample_idx, gt_data.shape[0] - 1)
        gt_slice = gt_data[actual_sample_idx, start_idx:end_idx, node]
    else:
        print(f"Warning: No Ground Truth for pred_len={pred_len}")
        continue

    # Plot each model
    for model_name in available_models:
        data = predictions[model_name]
        actual_sample_idx = min(sample_idx, data.shape[0] - 1)
        values = data[actual_sample_idx, start_idx:end_idx, node]

        # Apply scaling if not Ground Truth
        if model_name == 'Ground Truth':
            plot_values = values
        else:
            plot_values, _, _ = get_scaled_predictions(
                sample_idx, day, node, values, gt_slice
            )

        ax.plot(hours, plot_values,
                color=COLORS[model_name],
                linestyle=LINE_STYLES[model_name],
                label=model_name,
                linewidth=LINE_WIDTHS.get(model_name, 1.5),
                zorder=Z_ORDERS.get(model_name, 1))

    # Set title 
    ax.set_title(f'Day {title_day} | Node {node + 1} | Pred Len {pred_len}',
                 fontsize=9, color='#404040', pad=6)

    # Styling
    for spine in ax.spines.values():
        spine.set_color('#404040')
        spine.set_linewidth(0.8)

    ax.tick_params(axis='both', which='both', colors='#404040', labelsize=7, width=0.8)
    ax.grid(True, color='#d0d0d0', linestyle='-', alpha=0.35)
    ax.set_xlim(0, 24)
    ax.set_xticks(np.arange(0, 25, 4))
    
    # --- Modified Y-Axis Ticks Logic ---
    # 1. Calculate range and target step
    ymin, ymax = ax.get_ylim()
    target_step = (ymax - ymin) / 3.0
    
    # 2. Round step to nearest 0.5 multiple (minimum 0.5)
    step = np.ceil(target_step * 2) / 2.0
    if step == 0: step = 0.5
    
    # 3. Find start point (multiple of 0.5)
    # Start slightly below ymin to ensure coverage
    start = np.floor(ymin * 2) / 2.0
    
    # 4. Ensure 4 ticks cover the data range
    # If the calculated range [start, start+3*step] doesn't cover ymax, increase step
    while start + 3 * step < ymax:
        step += 0.5
        
    y_ticks = [start, start + step, start + 2*step, start + 3*step]
    ax.set_yticks(y_ticks)
    ax.set_ylim(y_ticks[0], y_ticks[-1])
    
    # Format y-axis tick labels to show integer or one decimal
    ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, _: f'{x:.1f}' if x % 1 != 0 else f'{int(x)}'))
    # -----------------------------------

    # Only show y-label on leftmost column
    if idx % 4 == 0:
        ax.set_ylabel('Value', fontsize=8, color='#404040')

    # Only show x-label on bottom row
    if idx >= 4:
        ax.set_xlabel('Time (Hours)', fontsize=8, color='#404040')

    if legend_models is None:
        legend_models = available_models

# Add legend
if legend_models:
    handles, labels = [], []
    # Ensure consistent order: Ground Truth first, then MODE, then others
    sorted_models = []
    if 'Ground Truth' in legend_models:
        sorted_models.append('Ground Truth')
    if 'MODE' in legend_models:
        sorted_models.append('MODE')
    for m in legend_models:
        if m not in sorted_models:
            sorted_models.append(m)

    for model_name in sorted_models:
        line = plt.Line2D([0], [0], color=COLORS[model_name],
                        linestyle=LINE_STYLES[model_name],
                        linewidth=LINE_WIDTHS.get(model_name, 1.5))
        handles.append(line)
        labels.append(model_name)

    # Adjusted legend position to be closer to the plots
    leg = fig.legend(handles, labels, loc='lower center', ncol=min(len(sorted_models), 5),
               bbox_to_anchor=(0.5, -0.02), fontsize=8, frameon=False,
               columnspacing=1.0, handlelength=2.0)

    # Bold 'MODE' in legend
    for text in leg.get_texts():
        if text.get_text() == 'MODE':
            text.set_weight('bold')

plt.tight_layout()
# Adjust bottom margin to reduce space between legend and subplots
# Reduced bottom from 0.14 to 0.12 to shorten the gap
plt.subplots_adjust(bottom=0.12, wspace=0.15, hspace=0.25)

# Display the plot
plt.show()