In [None]:
import sys
import os
import subprocess
import importlib

def install_if_missing(package, import_name=None):
    """Install package if it's not already installed."""
    if import_name is None:
        import_name = package.split('==')[0].split('>=')[0].split('<=')[0]
    
    # Special handling for protobuf to avoid version conflicts
    if import_name == "wandb":
        try:
            # First ensure protobuf is at compatible version
            subprocess.check_call([sys.executable, "-m", "pip", "install", "protobuf<=3.20.3", "--quiet", "--user"])
        except:
            pass
    
    try:
        __import__(import_name)
        print(f"✓ {package} already installed")
        return True
    except ImportError:
        print(f"Installing {package}...")
        try:
            subprocess.check_call([sys.executable, "-m", "pip", "install", package, "--quiet", "--user"])
            print(f"✓ Installed {package}")
            return True
        except subprocess.CalledProcessError as e:
            print(f"✗ Failed to install {package}: {e}")
            return False
    except Exception as e:
        print(f"✗ Error importing {import_name}: {e}")
        print(f"  Attempting to reinstall...")
        try:
            subprocess.check_call([sys.executable, "-m", "pip", "install", package, "--force-reinstall", "--quiet", "--user"])
            print(f"✓ Reinstalled {package}")
            return True
        except subprocess.CalledProcessError as e:
            print(f"✗ Failed to reinstall {package}: {e}")
            return False

# First, ensure protobuf is at a compatible version to avoid wandb issues
print("Setting up temporary environment...")
print("=" * 60)
print("Ensuring compatible protobuf version...")

try:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "protobuf<=3.20.3", "--quiet", "--user"])
    print("✓ Set protobuf to compatible version")
except:
    print("Note: Could not adjust protobuf version, continuing...")

print("-" * 60)

# Install required packages with specific versions to avoid conflicts
packages = [
    ("protobuf<=3.20.3", "google.protobuf"),  # Force compatible version first
    ("torch", "torch"),
    ("numpy==1.22.3", "numpy"),
    ("pandas==1.4.2", "pandas"),
    ("matplotlib==3.5.2", "matplotlib"),
    ("seaborn==0.11.2", "seaborn"),
    ("tqdm==4.64.0", "tqdm"),
    ("munch==2.5.0", "munch"),
    ("pyyaml==6.0", "yaml"),
    ("transformers==4.17.0", "transformers"),
    ("wandb==0.12.11", "wandb"),
    ("scikit-learn==1.0.2", "sklearn"),
]

# Install packages in order
successful_installs = []
for package, import_name in packages:
    if install_if_missing(package, import_name):
        successful_installs.append(import_name)

print("-" * 60)

# Try a workaround for the protobuf issue
print("Attempting to fix protobuf descriptor issue...")
try:
    # Clear any cached modules
    for module in list(sys.modules.keys()):
        if 'google.protobuf' in module or 'wandb' in module:
            del sys.modules[module]
    
    # Set environment variable as suggested in error message
    os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
    print("✓ Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python")
except Exception as e:
    print(f"Note: Could not apply workaround: {e}")

print("-" * 60)

# Set up paths
USER_HOME = os.path.expanduser("~")
PROJECT_DIR = os.path.join(USER_HOME, "icl-time-series")
SRC_DIR = os.path.join(PROJECT_DIR, "src")

# Add src directory to path
if SRC_DIR not in sys.path:
    sys.path.insert(0, SRC_DIR)
    print(f"✓ Added src directory: {SRC_DIR}")
else:
    print(f"✓ src directory already in path")

# Verify imports with error handling
print("\nVerifying imports...")
required_imports = {
    'torch': 'torch',
    'numpy': 'np',
    'pandas': 'pd',
    'matplotlib.pyplot': 'plt',
    'seaborn': 'sns',
    'munch': 'Munch',
    'yaml': 'yaml',
    'tqdm': 'tqdm',
    'sklearn': 'sklearn'
}

successful_imports = []
for module, alias in required_imports.items():
    try:
        if '.' in module:
            # Handle submodules like matplotlib.pyplot
            parts = module.split('.')
            base_module = __import__(parts[0])
            for part in parts[1:]:
                base_module = getattr(base_module, part)
            globals()[alias] = base_module
        else:
            globals()[alias] = __import__(module)
        successful_imports.append(module)
        print(f"✓ {module} imported successfully")
    except ImportError as e:
        print(f"✗ Failed to import {module}: {e}")
    except Exception as e:
        print(f"⚠ Unexpected error importing {module}: {e}")

print("\n" + "=" * 60)
print(f"Setup Summary:")
print(f"- Attempted to install: {len(packages)} packages")
print(f"- Successfully imported: {len(successful_imports)}/{len(required_imports)} core modules")
print("\nIf you continue to have wandb/protobuf issues:")
print("1. Restart the kernel and run only this cell")
print("2. Or try: pip install wandb==0.12.11 protobuf==3.20.3")
print("3. Or downgrade: pip install protobuf==3.20.3")
print("=" * 60)

In [None]:
import sys
import os

# Get the user's home directory (works for ORCD cluster)
USER_HOME = os.path.expanduser("~")
# Project directory: ~/icl-time-series
PROJECT_DIR = os.path.join(USER_HOME, "icl-time-series")
SRC_DIR = os.path.join(PROJECT_DIR, "src")

# Add src directory to path
if SRC_DIR not in sys.path:
    sys.path.insert(0, SRC_DIR)
    print(f"✓ Added src directory to path: {SRC_DIR}")
else:
    print(f"✓ src directory already in path: {SRC_DIR}")

# Check if we're in a conda environment
CONDA_PREFIX = os.environ.get('CONDA_PREFIX')
if CONDA_PREFIX:
    print(f"✓ Running in conda environment: {CONDA_PREFIX}")
    print(f"  Python version: {sys.version}")
    print(f"  Python executable: {sys.executable}")
else:
    print("⚠ Warning: Not in a conda environment. Make sure to activate 'in-context-learning' first.")
    print(f"  Current Python: {sys.executable}")

# Verify we can import project modules
try:
    import eval
    print("✓ Successfully imported eval module")
except ImportError as e:
    print(f"✗ Failed to import eval module: {e}")
    print(f"  Make sure you're in the conda environment and src/ directory is correct")

In [None]:
from collections import OrderedDict
import re
import os

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
from tqdm.notebook import tqdm

from eval import get_run_metrics, read_run_dir, get_model_from_run
from plot_utils import basic_plot, collect_results, relevant_model_names

%matplotlib inline
%load_ext autoreload
%autoreload 2

sns.set_theme('notebook', 'darkgrid')
palette = sns.color_palette('colorblind')

  import pkg_resources


In [None]:
# Configuration
# Update these to match your training run
USER_HOME = os.path.expanduser("~")
PROJECT_DIR = os.path.join(USER_HOME, "icl-time-series")

# Model output directory (relative to project root)
run_dir = os.path.join(USER_HOME, "models", "group_mixture_linear")
# Alternative: if models are stored elsewhere, use absolute path:
# run_dir = os.path.join(USER_HOME, "models", "group_mixture_linear")

task = "group_mixture_linear"  # Update to match your task
run_id = "0092300d-259b-46b0-bfa7-2117e360f05d"  # Replace with actual run_id from training output

run_path = os.path.join(run_dir, run_id)
recompute_metrics = False

print(f"Run directory: {run_dir}")
print(f"Run path: {run_path}")
print(f"Task: {task}")

if recompute_metrics:
    get_run_metrics(run_path)  # these are normally precomputed at the end of training
else:
    print("Using precomputed metrics (set recompute_metrics=True to recompute)")

# Interactive setup

We will now directly load the model and measure its in-context learning ability on a batch of random inputs. (In the paper we average over multiple such batches to obtain better estimates.)

In [None]:
from samplers import get_data_sampler
from tasks import get_task_sampler

In [None]:
model, conf = get_model_from_run(run_path)
n_dims = conf.model.n_dims
batch_size = conf.training.batch_size


# Convert task_kwargs properly
task_kwargs = {}
if hasattr(conf.training, 'task_kwargs') and conf.training.task_kwargs:
    if hasattr(conf.training.task_kwargs, '__dict__'):
        task_kwargs = conf.training.task_kwargs.__dict__
    else:
        task_kwargs = conf.training.task_kwargs

print("Task kwargs:", task_kwargs)

# Override predict mode for eval: None = use run config; True = one position; False = all positions
EVAL_PREDICT_TARGET_ONLY = None
if EVAL_PREDICT_TARGET_ONLY is not None:
    task_kwargs = dict(task_kwargs)
    task_kwargs['predict_target_only'] = EVAL_PREDICT_TARGET_ONLY

data_sampler = get_data_sampler(conf.training.data, n_dims=n_dims, **task_kwargs)

task_sampler = get_task_sampler(
    conf.training.task,
    n_dims,
    batch_size,
    num_tasks=conf.training.num_tasks if hasattr(conf.training, 'num_tasks') else None,
    **task_kwargs,
)

sequence_structure = None
predict_inds = None
if hasattr(data_sampler, 'get_sequence_structure'):
    sequence_structure = data_sampler.get_sequence_structure()
    predict_inds = sequence_structure.get('predict_inds', [])
    print(f"Multi-context structure: {sequence_structure}")
    print(f"Will predict {len(predict_inds)} indices: {predict_inds[:10]}..." if len(predict_inds) > 10 else f"Will predict indices: {predict_inds}")
    
    # Predict mode: one position (target only) or all positions — from run config or EVAL_PREDICT_TARGET_ONLY
    if conf.training.task == "group_mixture_linear":
        print(f"\nTask: {conf.training.task}")
        print(f"  - Total sequence length: {sequence_structure['total_length']}")
        mode_str = "1 position (target only)" if len(predict_inds) == 1 else f"all {len(predict_inds)} positions"
        print(f"  - Predicting {mode_str}")
        if hasattr(data_sampler, 'n_components'):
            K = data_sampler.n_components
            C = data_sampler.contexts_per_component
            T_target = data_sampler.target_cluster_context_points
            context_length = K * C
            print(f"  - Structure: {K} context clusters × {C} points + target cluster ({T_target} context + 1 prediction)")
            print(f"  - Context clusters: positions 0-{context_length-1}")
            print(f"  - Target cluster: positions {context_length}-{sequence_structure['total_length']-1}")

In [None]:
if sequence_structure is not None:
    n_points = sequence_structure['total_length']
else:
    n_points = conf.training.curriculum.points.end

# For group_mixture_linear, sample_xs() must be called first to set current_components
# and component_assignments; then we create the task and evaluate.
task_sampler_args = {}
if conf.training.task == "group_mixture_linear":
    xs = data_sampler.sample_xs(b_size=batch_size, n_points=n_points)
    assert hasattr(data_sampler, "current_components") and data_sampler.current_components is not None, \
        "Sampler must set current_components in sample_xs()"
    assert hasattr(data_sampler, "component_assignments") and data_sampler.component_assignments is not None, \
        "Sampler must set component_assignments in sample_xs()"
    task = task_sampler(
        components=data_sampler.current_components,
        component_assignments=data_sampler.component_assignments,
        **task_sampler_args,
    )
    ys = task.evaluate(xs)
else:
    task = task_sampler(**task_sampler_args)
    xs = data_sampler.sample_xs(b_size=batch_size, n_points=n_points)
    ys = task.evaluate(xs)

print(f"Data shapes - xs: {xs.shape}, ys: {ys.shape}")

print(f"xs stats: mean={xs.mean():.4f}, std={xs.std():.4f}")
print(f"ys stats: mean={ys.mean():.4f}, std={ys.std():.4f}")

print("\nFirst training example (first few points):")
print(f"xs[0, :5]: {xs[0, :5]}")
print(f"ys[0, :5]: {ys[0, :5]}")

# Show structure for group_mixture_linear
if conf.training.task == "group_mixture_linear" and hasattr(data_sampler, 'component_assignments'):
    K = data_sampler.n_components
    C = data_sampler.contexts_per_component
    T_target = data_sampler.target_cluster_context_points
    context_length = K * C
    print(f"\nSequence structure (first example):")
    print(f"  Component assignments: {data_sampler.component_assignments[0].cpu().tolist()}")
    print(f"  Context clusters (0-{context_length-1}): components {data_sampler.component_assignments[0, :context_length].cpu().tolist()}")
    print(f"  Target cluster ({context_length}-{n_points-1}): components {data_sampler.component_assignments[0, context_length:].cpu().tolist()}")

In [None]:
with torch.no_grad():
    if predict_inds is not None and len(predict_inds) > 0 and sequence_structure is not None:
        # Check if predicting all positions (autoregressive mode)
        if len(predict_inds) == ys.shape[1] and set(predict_inds) == set(range(ys.shape[1])):
            # Predicting all positions: use standard autoregressive call
            pred = model(xs, ys)
            print(f"Autoregressive prediction (all positions) shape: {pred.shape}")
        else:
            # Predicting specific positions
            pred = model(xs, ys, inds=predict_inds, sequence_structure=sequence_structure)
            print(f"Multi-context prediction shape: {pred.shape}")
    else:
        pred = model(xs, ys)
        print(f"Standard prediction shape: {pred.shape}")

print("\nPrediction examples:")
if predict_inds is not None and len(predict_inds) > 0:
    # If predicting all positions, pred shape is (B, T)
    # If predicting specific positions, pred shape is (B, len(predict_inds))
    if len(predict_inds) == ys.shape[1]:
        # All positions predicted
        for i in range(min(3, batch_size)):
            print(f"Example {i} (showing first 10 positions):")
            print(f"  Actual: {ys[i, :10].numpy()}")
            print(f"  Pred:   {pred[i, :10].numpy()}")
            print(f"  Error:  {(pred[i, :10] - ys[i, :10]).abs().numpy()}")
    else:
        # Specific positions predicted
        for i in range(min(3, batch_size)):
            actual_targets = ys[i, predict_inds]
            predictions = pred[i]
            print(f"Example {i}:")
            print(f"  Actual: {actual_targets[:5].numpy()}")
            print(f"  Pred:   {predictions[:5].numpy()}")
else:
    for i in range(min(3, batch_size)):
        print(f"Example {i}: actual={ys[i, -1]:.4f}, pred={pred[i, -1]:.4f}")

  import pkg_resources


In [None]:
metric = task.get_metric()

# Compute loss: if predicting all positions, pred and ys have same shape
if predict_inds is not None and len(predict_inds) > 0:
    if len(predict_inds) == ys.shape[1]:
        # Predicting all positions: pred shape is (B, T), compare with ys (B, T)
        loss = metric(pred, ys).numpy()  # (B, T)
        print(f"Loss computed on all {len(predict_inds)} positions")
    else:
        # Predicting specific positions: pred shape is (B, len(predict_inds))
        loss = metric(pred, ys[:, predict_inds]).numpy()
        print(f"Loss computed on {len(predict_inds)} prediction positions")
else:
    loss = metric(pred, ys).numpy()

print(f"Loss shape: {loss.shape}") 
print(f"Mean loss: {loss.mean():.4f}")

# Visualize performance across all positions
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: MSE by position
if len(loss.shape) == 2 and loss.shape[1] > 1:
    # Loss is (B, T) - compute mean across batch for each position
    mse_by_position = loss.mean(axis=0)  # (T,)
    positions = range(len(mse_by_position))
    
    axes[0].plot(positions, mse_by_position, 'o-', linewidth=2, markersize=4)
    axes[0].set_xlabel("Position in Sequence", fontsize=12)
    axes[0].set_ylabel("Mean Squared Error", fontsize=12)
    axes[0].set_title(f'MSE by Position (Mean: {mse_by_position.mean():.4f})', fontsize=13)
    axes[0].grid(True, alpha=0.3)
    
    # Add vertical line to separate context clusters from target cluster
    if conf.training.task == "group_mixture_linear" and hasattr(data_sampler, 'contexts_per_component'):
        K = data_sampler.n_components
        C = data_sampler.contexts_per_component
        context_length = K * C
        if context_length < len(positions):
            axes[0].axvline(x=context_length-0.5, color='r', linestyle='--', alpha=0.5, label='Context/Target boundary')
            axes[0].legend()
    
    # Plot 2: Loss distribution histogram
    axes[1].hist(loss.flatten(), bins=30, alpha=0.7, edgecolor='black')
    axes[1].set_xlabel("Squared Error", fontsize=12)
    axes[1].set_ylabel("Frequency", fontsize=12)
    axes[1].set_title(f'Error Distribution (Mean: {loss.mean():.4f})', fontsize=13)
    axes[1].grid(True, alpha=0.3)
else:
    # Single position or aggregated loss
    axes[0].hist(loss.flatten(), bins=20, alpha=0.7, edgecolor='black')
    axes[0].set_xlabel("Squared Error", fontsize=12)
    axes[0].set_ylabel("Frequency", fontsize=12)
    axes[0].set_title(f'Loss Distribution (Mean MSE: {loss.mean():.4f})', fontsize=13)
    axes[0].grid(True, alpha=0.3)
    axes[1].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Component-specific analysis for group_mixture_linear
if conf.training.task == "group_mixture_linear" and hasattr(data_sampler, 'component_assignments'):
    print("=" * 50)
    print("COMPONENT-SPECIFIC ANALYSIS")
    print("=" * 50)
    
    K = data_sampler.n_components
    C = data_sampler.contexts_per_component
    T_target = data_sampler.target_cluster_context_points
    context_length = K * C
    
    # Compute errors
    if len(predict_inds) == ys.shape[1]:
        e = (pred - ys)**2  # (B, T)
    else:
        e = (pred - ys[:, predict_inds])**2
    
    # Analyze by component
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Plot 1: MSE by component (averaged across all positions where that component is used)
    component_mse = {}
    for comp_id in range(K):
        # Find all positions where this component is used
        comp_mask = data_sampler.component_assignments == comp_id  # (B, T)
        if comp_mask.any():
            # Average error across all positions using this component
            comp_errors = e[comp_mask].mean().item()
            component_mse[comp_id] = comp_errors
    
    if component_mse:
        comp_ids = list(component_mse.keys())
        comp_mses = [component_mse[cid] for cid in comp_ids]
        axes[0].bar(comp_ids, comp_mses, alpha=0.7, edgecolor='black')
        axes[0].set_xlabel("Component ID", fontsize=12)
        axes[0].set_ylabel("Mean Squared Error", fontsize=12)
        axes[0].set_title(f'MSE by Component (Averaged Across All Positions)', fontsize=13)
        axes[0].grid(True, alpha=0.3, axis='y')
        for i, (cid, mse) in enumerate(zip(comp_ids, comp_mses)):
            axes[0].text(cid, mse, f'{mse:.4f}', ha='center', va='bottom', fontsize=10)
    
    # Plot 2: MSE by position, colored by component
    if len(e.shape) == 2 and e.shape[1] > 1:
        mse_by_pos = e.mean(dim=0).cpu().numpy()  # (T,)
        positions = range(len(mse_by_pos))
        
        # Color points by which component is used (for first example)
        comp_assignments_ex0 = data_sampler.component_assignments[0].cpu().numpy()
        colors = plt.cm.tab10(range(K))
        
        for comp_id in range(K):
            comp_positions = [p for p in positions if comp_assignments_ex0[p] == comp_id]
            if comp_positions:
                comp_mses = [mse_by_pos[p] for p in comp_positions]
                axes[1].scatter(comp_positions, comp_mses, 
                               c=[colors[comp_id]], label=f'Component {comp_id}', 
                               s=50, alpha=0.7, edgecolors='black', linewidths=0.5)
        
        axes[1].plot(positions, mse_by_pos, 'k--', alpha=0.3, linewidth=1, label='Overall')
        axes[1].set_xlabel("Position in Sequence", fontsize=12)
        axes[1].set_ylabel("Mean Squared Error", fontsize=12)
        axes[1].set_title('MSE by Position (Colored by Component)', fontsize=13)
        axes[1].legend(loc='best', fontsize=10)
        axes[1].grid(True, alpha=0.3)
        
        # Add boundary line
        if context_length < len(positions):
            axes[1].axvline(x=context_length-0.5, color='r', linestyle='--', alpha=0.5)
    
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    print(f"\nComponent usage statistics:")
    for comp_id in range(K):
        comp_count = (data_sampler.component_assignments == comp_id).sum().item()
        total_positions = data_sampler.component_assignments.numel()
        pct = 100 * comp_count / total_positions
        print(f"  Component {comp_id}: used in {comp_count}/{total_positions} positions ({pct:.1f}%)")
        if comp_id in component_mse:
            print(f"    Average MSE: {component_mse[comp_id]:.6f}")

elif hasattr(task, 'get_mixture_info'):
    # Legacy support for other mixture tasks
    mix_info = task.get_mixture_info()
    print(f"Mixture info:")
    print(f"  Number of components: {mix_info['n_components']}")
    print(f"  Number of contexts: {mix_info['n_contexts']}")
    print(f"  Context assignments: {mix_info['context_assignments']}")
    print(f"  Target assignment: {mix_info['target_assignment']}")
    
    plt.figure(figsize=(6, 4))
    for comp_id in range(mix_info['n_components']):
        comp_mask = mix_info['target_assignment'] == comp_id
        if comp_mask.any():
            comp_loss = loss[comp_mask].mean(axis=0)
            
            plt.plot([0], comp_loss, marker='o', linestyle='', markersize=10, 
                     label=f'Component {comp_id} (MSE: {comp_loss[0]:.4f})')

    plt.xticks([0], [f"Position {predict_inds[0] if predict_inds else 0}"])
    plt.xlabel("Prediction position")
    plt.ylabel("Squared error")
    plt.legend()
    plt.title('Error by Mixture Component (Single Point)')
    plt.show()

# Test: Context (component 0, component 1) → Target component 0 vs 1 (predict all)

Compare loss when we **fix context clusters** to show component 0 then component 1, and **vary only the target cluster**: ask for component 0 in all targets vs component 1 in all targets. Uses **predicting-all** mode so we get predictions (and loss) at every position.

In [None]:
# Only for group_mixture_linear: fixed context [0,1], target 0 vs target 1, predict-all
if conf.training.task == "group_mixture_linear":
    # Sampler in predict-all mode for this test
    test_task_kwargs = dict(task_kwargs)
    test_task_kwargs["predict_target_only"] = False
    data_sampler_test = get_data_sampler(conf.training.data, n_dims=n_dims, **test_task_kwargs)
    seq_struct = data_sampler_test.get_sequence_structure()
    predict_inds_all = seq_struct["predict_inds"]
    n_pts = seq_struct["total_length"]
    K = data_sampler_test.n_components
    C = data_sampler_test.contexts_per_component
    context_length = K * C
    metric_fn = task.get_metric()

    results = {}
    for target_comp in [0, 1]:
        xs = data_sampler_test.sample_xs(
            n_points=n_pts,
            b_size=batch_size,
            fixed_cluster_assignments=[0, 1],
            fixed_target_component=target_comp,
        )
        assert hasattr(data_sampler_test, "current_components") and data_sampler_test.current_components is not None
        task_inst = task_sampler(
            components=data_sampler_test.current_components,
            component_assignments=data_sampler_test.component_assignments,
        )
        ys = task_inst.evaluate(xs)
        with torch.no_grad():
            pred = model(xs, ys)
        sq_err = (pred - ys).pow(2)
        mse_all = sq_err.mean().item()
        mse_by_pos = sq_err.mean(dim=0).cpu().numpy()
        mse_context = sq_err[:, :context_length].mean().item()
        mse_target_cluster = sq_err[:, context_length:].mean().item()
        results[target_comp] = {
            "mse_all": mse_all,
            "mse_by_pos": mse_by_pos,
            "mse_context": mse_context,
            "mse_target_cluster": mse_target_cluster,
        }
        print(f"Target component {target_comp}: MSE (all) = {mse_all:.6f}, MSE (context) = {mse_context:.6f}, MSE (target cluster) = {mse_target_cluster:.6f}")

    print("\nDifference (target 1 − target 0):")
    print(f"  MSE (all):           {results[1]['mse_all'] - results[0]['mse_all']:+.6f}")
    print(f"  MSE (context):      {results[1]['mse_context'] - results[0]['mse_context']:+.6f}")
    print(f"  MSE (target cluster): {results[1]['mse_target_cluster'] - results[0]['mse_target_cluster']:+.6f}")

    # Plot MSE by position for both
    fig, ax = plt.subplots(1, 1, figsize=(10, 4))
    ax.plot(results[0]["mse_by_pos"], "o-", label="Target component 0", markersize=3)
    ax.plot(results[1]["mse_by_pos"], "s-", label="Target component 1", markersize=3)
    ax.axvline(x=context_length - 0.5, color="gray", linestyle="--", alpha=0.7, label="Context end")
    ax.set_xlabel("Position")
    ax.set_ylabel("MSE")
    ax.set_title("MSE by position: context [0,1] → target 0 vs target 1 (predict all)")
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
else:
    print("Skipped (task is not group_mixture_linear).")

# Probing with special inputs

Feed the model **controlled inputs** to see what information the outputs reflect. All probes are **eval-only** (no training), and only for `group_mixture_linear`.

- **Probe 1 – Same context & query x, different target component:** Use the same sequence of (x,y) in the context and the same query x at the end, but change *which component* is used for the target cluster (and thus the true y at the query). Compare the model’s prediction at the last position to the two possible true values. If the model has learned to use the target-cluster context to select the component, its prediction should track the correct component’s true y.
- **Probe 2 – Same context & target component, vary query x:** Keep context and target component fixed; sweep the **query x** over several values (e.g. random or basis-aligned). For each query x, set the last position to that x and the true y to w^T x. Plot model prediction vs true y; if the model has learned the linear map, points should lie near the line y=x.

In [None]:
# Probing with special inputs (group_mixture_linear only; eval-only, no training changes)
if conf.training.task != "group_mixture_linear":
    print("Skipped (task is not group_mixture_linear).")
else:
    # Sampler with predict-target-only for clean last-position readout
    probe_task_kwargs = dict(task_kwargs)
    probe_task_kwargs["predict_target_only"] = True
    probe_sampler = get_data_sampler(conf.training.data, n_dims=n_dims, **probe_task_kwargs)
    probe_struct = probe_sampler.get_sequence_structure()
    probe_inds = probe_struct["predict_inds"]
    n_pts_probe = probe_struct["total_length"]
    K_probe = probe_sampler.n_components
    C_probe = probe_sampler.contexts_per_component
    ctx_len = K_probe * C_probe
    scale_probe = getattr(probe_sampler, "scale", 0.5)

    # ----- Probe 1: Same xs, different target component (via different ys) -----
    # One batch: fixed context [0,1], target=0
    xs_one = probe_sampler.sample_xs(
        n_points=n_pts_probe, b_size=batch_size,
        fixed_cluster_assignments=[0, 1], fixed_target_component=0,
    )
    comps = probe_sampler.current_components
    assign_0 = probe_sampler.component_assignments.clone()
    assign_1 = assign_0.clone()
    assign_1[:, ctx_len:] = 1
    task_0 = task_sampler(components=comps, component_assignments=assign_0)
    task_1 = task_sampler(components=comps, component_assignments=assign_1)
    ys_0 = task_0.evaluate(xs_one)
    ys_1 = task_1.evaluate(xs_one)
    with torch.no_grad():
        pred_0 = model(xs_one, ys_0, inds=probe_inds, sequence_structure=probe_struct)
        pred_1 = model(xs_one, ys_1, inds=probe_inds, sequence_structure=probe_struct)
    pred_0_last = pred_0[:, 0]
    pred_1_last = pred_1[:, 0]
    true_0_last = ys_0[:, -1]
    true_1_last = ys_1[:, -1]

    print("Probe 1: Same xs & query x, different target component (different ys in target cluster)")
    print(f"  True y (target=0): mean={true_0_last.mean().item():.4f}, std={true_0_last.std().item():.4f}")
    print(f"  True y (target=1): mean={true_1_last.mean().item():.4f}, std={true_1_last.std().item():.4f}")
    print(f"  Pred when fed ys_0: mean={pred_0_last.mean().item():.4f}, std={pred_0_last.std().item():.4f}")
    print(f"  Pred when fed ys_1: mean={pred_1_last.mean().item():.4f}, std={pred_1_last.std().item():.4f}")
    print(f"  Correlation pred_0 vs true_0: {np.corrcoef(pred_0_last.cpu().numpy(), true_0_last.cpu().numpy())[0,1]:.4f}")
    print(f"  Correlation pred_1 vs true_1: {np.corrcoef(pred_1_last.cpu().numpy(), true_1_last.cpu().numpy())[0,1]:.4f}")

    fig1, ax1 = plt.subplots(1, 1, figsize=(5, 5))
    ax1.scatter(true_0_last.cpu().numpy(), pred_0_last.cpu().numpy(), alpha=0.6, label="Fed ys_0 (target=0)")
    ax1.scatter(true_1_last.cpu().numpy(), pred_1_last.cpu().numpy(), alpha=0.6, label="Fed ys_1 (target=1)")
    mn = min(true_0_last.min().item(), true_1_last.min().item(), pred_0_last.min().item(), pred_1_last.min().item())
    mx = max(true_0_last.max().item(), true_1_last.max().item(), pred_0_last.max().item(), pred_1_last.max().item())
    ax1.plot([mn, mx], [mn, mx], "k--", alpha=0.5, label="y=x")
    ax1.set_xlabel("True y at query"); ax1.set_ylabel("Model prediction at query")
    ax1.set_title("Probe 1: Prediction vs true y (same xs, different target component)")
    ax1.legend(); ax1.grid(True, alpha=0.3); plt.tight_layout(); plt.show()

    # ----- Probe 2: Same context & target component, vary query x -----
    n_queries = 20
    # One fixed context, then expand to n_queries and only vary last-position x
    xs_one_row = probe_sampler.sample_xs(
        n_points=n_pts_probe, b_size=1,
        fixed_cluster_assignments=[0, 1], fixed_target_component=0,
    )
    comps2 = probe_sampler.current_components.repeat(n_queries, 1, 1, 1)
    assign2 = probe_sampler.component_assignments.repeat(n_queries, 1)
    task2 = task_sampler(components=comps2, component_assignments=assign2)
    xs_probe = xs_one_row.repeat(n_queries, 1, 1).clone()
    query_xs = torch.randn(n_queries, n_dims)
    xs_probe[:, -1, :] = query_xs
    w_target = comps2[torch.arange(n_queries), assign2[:, -1], :, 0]
    true_y_query = scale_probe * (query_xs * w_target).sum(dim=1)
    ys_probe = task2.evaluate(xs_probe)
    ys_probe[:, -1] = true_y_query.to(ys_probe.device)
    with torch.no_grad():
        pred_probe = model(xs_probe, ys_probe, inds=probe_inds, sequence_structure=probe_struct)[:, 0]
    print("\nProbe 2: Same context & target comp, vary query x")
    print(f"  Correlation pred vs true y at query: {np.corrcoef(pred_probe.cpu().numpy(), true_y_query.numpy())[0,1]:.4f}")

    fig2, ax2 = plt.subplots(1, 1, figsize=(5, 5))
    ax2.scatter(true_y_query.numpy(), pred_probe.cpu().numpy(), alpha=0.8)
    lims = [min(true_y_query.min().item(), pred_probe.min().item()), max(true_y_query.max().item(), pred_probe.max().item())]
    ax2.plot(lims, lims, "k--", alpha=0.5, label="y=x")
    ax2.set_xlabel("True y = w^T x_query"); ax2.set_ylabel("Model prediction")
    ax2.set_title("Probe 2: Vary query x (same context & target component)")
    ax2.legend(); ax2.grid(True, alpha=0.3); plt.tight_layout(); plt.show()

# Test: Vary number of contexts per component (1 to 10) with zero-padding

Model is trained on **10 contexts per component** (10-10-6 format). Here we feed it **i** contexts per component for i = 1, 2, ..., 10: we keep the same 10-10-6 sequence length, but for each context cluster we use only the first **i** positions as real (x,y) and **pad the rest with zeros** (x=0, y=0). The **target cluster (6 context + 1 query) is left unchanged** so the model always has full target-context signal. We measure MSE at the query position to see how performance depends on context size.

In [None]:
# Vary contexts per component i in 1..10 with zero-padding; target cluster unchanged. Compare pred vs padded ys.
if conf.training.task != "group_mixture_linear":
    print("Skipped (task is not group_mixture_linear).")
else:
    K = data_sampler.n_components
    C_full = data_sampler.contexts_per_component  # 10
    T_target = data_sampler.target_cluster_context_points  # 6
    context_length = K * C_full  # 20
    target_start = context_length
    total_len = sequence_structure["total_length"]  # 27

    # One batch: full 10-10-6 data
    xs_full = data_sampler.sample_xs(n_points=total_len, b_size=batch_size)
    if hasattr(data_sampler, "current_components") and data_sampler.current_components is not None:
        task_ctx = task_sampler(
            components=data_sampler.current_components,
            component_assignments=data_sampler.component_assignments,
        )
    else:
        task_ctx = task_sampler()
    ys_full = task_ctx.evaluate(xs_full)

    # i = contexts per component (1..10); target cluster unchanged. Error = pred vs ys_pad (padded version).
    mse_by_pos_per_i = []
    for i in range(1, C_full + 1):
        xs_pad = xs_full.clone()
        ys_pad = ys_full.clone()
        # Cluster 0: keep 0..i-1, zero i..9
        xs_pad[:, i:C_full, :] = 0.0
        ys_pad[:, i:C_full] = 0.0
        # Cluster 1: keep 10..10+i-1, zero 10+i..19
        xs_pad[:, C_full + i : context_length, :] = 0.0
        ys_pad[:, C_full + i : context_length] = 0.0
        # Target cluster (target_start..total_len-1) unchanged
        with torch.no_grad():
            pred = model(xs_pad, ys_pad)
        sq_err = (pred - ys_pad.to(pred.device)) ** 2
        mse_by_pos = sq_err.mean(dim=0).cpu().numpy()
        mse_by_pos_per_i.append(mse_by_pos)
        print(f"  i={i}: MSE at query (vs padded) = {mse_by_pos[-1]:.6f}")

    positions = range(total_len)
    fig, ax = plt.subplots(1, 1, figsize=(10, 5))
    for i in range(1, C_full + 1):
        ax.plot(positions, mse_by_pos_per_i[i - 1], "o-", markersize=3, label=f"i={i}")
    ax.axvline(x=context_length - 0.5, color="gray", linestyle="--", alpha=0.7, label="Context end")
    ax.set_xlabel("Position")
    ax.set_ylabel("MSE (prediction error)")
    ax.set_title("Prediction error by position (i=1..10; compare pred vs padded ys)")
    ax.legend(loc="best", ncol=2, fontsize=9)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

In [None]:
cols = 3
rows = (C_full + cols - 1) // cols

fig, axes = plt.subplots(rows, cols, figsize=(15, rows * 3), sharex=True, sharey=True)
axes = axes.flatten()

for i in range(1, C_full + 1):
    ax = axes[i-1]
    # Plot ALL lines in light gray first (the "ghost" effect)
    for other_mse in mse_by_pos_per_i:
        ax.plot(positions, other_mse, color="gray", alpha=0.1, lw=1)
    
    # Plot the specific line for this subplot in color
    ax.plot(positions, mse_by_pos_per_i[i-1], "o-", markersize=3, color="blue", label=f"i={i}")
    ax.axvline(x=context_length - 0.5, color="red", linestyle="--", alpha=0.5)
    ax.set_title(f"Context i={i}")

# Clean up empty subplots
for j in range(i, len(axes)):
    fig.delaxes(axes[j])

plt.tight_layout()

In [None]:
print("=" * 50)
print("DETAILED ERROR ANALYSIS BY POSITION")
print("=" * 50)

# Compute squared errors
if predict_inds is not None and len(predict_inds) > 0:
    if len(predict_inds) == ys.shape[1]:
        # All positions: pred and ys have same shape (B, T)
        e = (pred - ys)**2  # (B, T)
    else:
        # Specific positions
        e = (pred - ys[:, predict_inds])**2
else:
    e = (pred - ys)**2

mse_by_pos = e.mean(dim=0)  # Average across batch
print(f"MSE by position: {mse_by_pos.tolist()}")

# Create comprehensive visualization
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: MSE by position (line plot)
if len(mse_by_pos.shape) == 0:
    # Single value
    axes[0, 0].bar([0], [mse_by_pos.item()])
    axes[0, 0].set_xlabel("Position")
    axes[0, 0].set_ylabel("MSE")
    axes[0, 0].set_title(f'MSE at Position (Value: {mse_by_pos.item():.6f})')
else:
    positions = range(len(mse_by_pos))
    axes[0, 0].plot(positions, mse_by_pos.cpu().numpy(), 'o-', linewidth=2, markersize=5)
    axes[0, 0].set_xlabel("Position in Sequence", fontsize=11)
    axes[0, 0].set_ylabel("Mean Squared Error", fontsize=11)
    axes[0, 0].set_title(f'MSE by Position (Mean: {mse_by_pos.mean():.6f})', fontsize=12)
    axes[0, 0].grid(True, alpha=0.3)
    
    # Add boundary line for group_mixture_linear
    if conf.training.task == "group_mixture_linear" and hasattr(data_sampler, 'contexts_per_component'):
        K = data_sampler.n_components
        C = data_sampler.contexts_per_component
        context_length = K * C
        if context_length < len(positions):
            axes[0, 0].axvline(x=context_length-0.5, color='r', linestyle='--', alpha=0.5)
            axes[0, 0].text(context_length, mse_by_pos.max() * 0.9, 'Context/Target\nboundary', 
                           ha='center', fontsize=9, color='r')

# Plot 2: Error distribution histogram
axes[0, 1].hist(loss.flatten(), bins=30, alpha=0.7, edgecolor='black')
axes[0, 1].set_xlabel("Squared Error", fontsize=11)
axes[0, 1].set_ylabel("Frequency", fontsize=11)
axes[0, 1].set_title(f'Error Distribution (Mean: {loss.mean():.4f})', fontsize=12)
axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Heatmap of errors by example and position (if multiple positions)
if len(e.shape) == 2 and e.shape[1] > 1:
    # Show first 20 examples to avoid overcrowding
    n_examples_show = min(20, e.shape[0])
    im = axes[1, 0].imshow(e[:n_examples_show].cpu().numpy(), aspect='auto', cmap='YlOrRd', interpolation='nearest')
    axes[1, 0].set_xlabel("Position in Sequence", fontsize=11)
    axes[1, 0].set_ylabel("Example Index", fontsize=11)
    axes[1, 0].set_title(f'Squared Error Heatmap (First {n_examples_show} Examples)', fontsize=12)
    plt.colorbar(im, ax=axes[1, 0], label='Squared Error')
    
    # Add boundary line
    if conf.training.task == "group_mixture_linear" and hasattr(data_sampler, 'contexts_per_component'):
        K = data_sampler.n_components
        C = data_sampler.contexts_per_component
        context_length = K * C
        if context_length < e.shape[1]:
            axes[1, 0].axvline(x=context_length-0.5, color='cyan', linestyle='--', linewidth=2, alpha=0.7)
else:
    axes[1, 0].axis('off')

# Plot 4: Box plot of errors by position (if multiple positions)
if len(mse_by_pos.shape) > 0 and len(mse_by_pos) > 1:
    # Group positions for box plot (every 5 positions or so)
    n_positions = len(mse_by_pos)
    if n_positions > 20:
        # Sample positions for readability
        step = max(1, n_positions // 20)
        sampled_positions = list(range(0, n_positions, step))
        sampled_errors = [e[:, pos].cpu().numpy() for pos in sampled_positions]
        axes[1, 1].boxplot(sampled_errors, labels=[f'P{p}' for p in sampled_positions])
        axes[1, 1].set_xlabel("Position (sampled)", fontsize=11)
    else:
        errors_by_pos = [e[:, pos].cpu().numpy() for pos in range(n_positions)]
        axes[1, 1].boxplot(errors_by_pos, labels=[f'P{p}' for p in range(n_positions)])
        axes[1, 1].set_xlabel("Position", fontsize=11)
    axes[1, 1].set_ylabel("Squared Error", fontsize=11)
    axes[1, 1].set_title('Error Distribution by Position', fontsize=12)
    axes[1, 1].tick_params(axis='x', rotation=45)
    axes[1, 1].grid(True, alpha=0.3)
else:
    axes[1, 1].axis('off')

plt.tight_layout()
plt.show()

print(f"\nOverall statistics:")
if len(mse_by_pos.shape) == 0:
    print(f"  MSE: {mse_by_pos.item():.6f}")
else:
    print(f"  Mean MSE across all positions: {mse_by_pos.mean():.6f}")
    print(f"  Std MSE across positions: {mse_by_pos.std():.6f}")
    print(f"  Min MSE (best position): {mse_by_pos.min():.6f} at position {mse_by_pos.argmin().item()}")
    print(f"  Max MSE (worst position): {mse_by_pos.max():.6f} at position {mse_by_pos.argmax().item()}")
    
    # Show performance in context vs target regions
    if conf.training.task == "group_mixture_linear" and hasattr(data_sampler, 'contexts_per_component'):
        K = data_sampler.n_components
        C = data_sampler.contexts_per_component
        context_length = K * C
        if context_length < len(mse_by_pos):
            context_mse = mse_by_pos[:context_length].mean()
            target_mse = mse_by_pos[context_length:].mean()
            print(f"\nPerformance by region:")
            print(f"  Context clusters (0-{context_length-1}): Mean MSE = {context_mse:.6f}")
            print(f"  Target cluster ({context_length}-{len(mse_by_pos)-1}): Mean MSE = {target_mse:.6f}")

  state = torch.load(state_path, map_location='cpu')


In [None]:
# ============================================================================
# COMPREHENSIVE VISUALIZATION: Performance at Every Index
# ============================================================================

print("=" * 60)
print("PERFORMANCE AT EVERY INDEX - DETAILED VISUALIZATION")
print("=" * 60)

# Compute errors for all positions
if len(predict_inds) == ys.shape[1]:
    # All positions predicted
    e = (pred - ys)**2  # (B, T)
    positions = list(range(ys.shape[1]))
else:
    # Specific positions
    e = (pred - ys[:, predict_inds])**2
    positions = predict_inds

mse_by_pos = e.mean(dim=0).cpu().numpy()  # (T,)
std_by_pos = e.std(dim=0).cpu().numpy()   # (T,)

# Create comprehensive figure
fig = plt.figure(figsize=(16, 10))

# Plot 1: MSE by position with error bars (top left)
ax1 = plt.subplot(2, 3, 1)
ax1.errorbar(positions, mse_by_pos, yerr=std_by_pos, 
            fmt='o-', linewidth=2, markersize=5, capsize=3, capthick=1.5)
ax1.set_xlabel("Position Index", fontsize=11)
ax1.set_ylabel("Mean Squared Error", fontsize=11)
ax1.set_title('MSE by Position (with Std Dev)', fontsize=12, fontweight='bold')
ax1.grid(True, alpha=0.3)

# Add context/target boundary if applicable
if conf.training.task == "group_mixture_linear" and hasattr(data_sampler, 'contexts_per_component'):
    K = data_sampler.n_components
    C = data_sampler.contexts_per_component
    context_length = K * C
    if context_length < len(positions):
        ax1.axvline(x=context_length-0.5, color='r', linestyle='--', linewidth=2, alpha=0.6)
        ax1.text(context_length, mse_by_pos.max() * 0.95, 'Context/Target\nBoundary', 
                ha='center', fontsize=9, color='r', fontweight='bold')

# Plot 2: Log-scale MSE (top middle)
ax2 = plt.subplot(2, 3, 2)
ax2.semilogy(positions, mse_by_pos, 'o-', linewidth=2, markersize=5)
ax2.set_xlabel("Position Index", fontsize=11)
ax2.set_ylabel("MSE (log scale)", fontsize=11)
ax2.set_title('MSE by Position (Log Scale)', fontsize=12, fontweight='bold')
ax2.grid(True, alpha=0.3, which='both')
if conf.training.task == "group_mixture_linear" and hasattr(data_sampler, 'contexts_per_component'):
    K = data_sampler.n_components
    C = data_sampler.contexts_per_component
    context_length = K * C
    if context_length < len(positions):
        ax2.axvline(x=context_length-0.5, color='r', linestyle='--', linewidth=2, alpha=0.6)

# Plot 3: Relative error (MSE normalized by mean) (top right)
ax3 = plt.subplot(2, 3, 3)
mean_mse = mse_by_pos.mean()
relative_error = mse_by_pos / mean_mse
ax3.plot(positions, relative_error, 'o-', linewidth=2, markersize=5, color='green')
ax3.axhline(y=1.0, color='k', linestyle='--', alpha=0.5, label='Mean')
ax3.set_xlabel("Position Index", fontsize=11)
ax3.set_ylabel("Relative MSE (vs Mean)", fontsize=11)
ax3.set_title('Relative Performance by Position', fontsize=12, fontweight='bold')
ax3.legend()
ax3.grid(True, alpha=0.3)
if conf.training.task == "group_mixture_linear" and hasattr(data_sampler, 'contexts_per_component'):
    K = data_sampler.n_components
    C = data_sampler.contexts_per_component
    context_length = K * C
    if context_length < len(positions):
        ax3.axvline(x=context_length-0.5, color='r', linestyle='--', linewidth=2, alpha=0.6)

# Plot 4: Error heatmap (bottom left, larger)
ax4 = plt.subplot(2, 3, (4, 5))
n_examples_show = min(30, e.shape[0])
im = ax4.imshow(e[:n_examples_show].cpu().numpy(), aspect='auto', 
                cmap='YlOrRd', interpolation='nearest', vmin=0, vmax=e.max().item())
ax4.set_xlabel("Position Index", fontsize=11)
ax4.set_ylabel("Example Index", fontsize=11)
ax4.set_title(f'Error Heatmap (First {n_examples_show} Examples)', fontsize=12, fontweight='bold')
cbar = plt.colorbar(im, ax=ax4)
cbar.set_label('Squared Error', fontsize=10)

# Add boundary line
if conf.training.task == "group_mixture_linear" and hasattr(data_sampler, 'contexts_per_component'):
    K = data_sampler.n_components
    C = data_sampler.contexts_per_component
    context_length = K * C
    if context_length < len(positions):
        ax4.axvline(x=positions.index(context_length) if context_length in positions else context_length-0.5, 
                   color='cyan', linestyle='--', linewidth=2, alpha=0.8)

# Plot 5: Box plot by position groups (bottom right)
ax5 = plt.subplot(2, 3, 6)
if len(positions) > 10:
    # Group positions into bins
    n_bins = min(10, len(positions))
    bin_size = len(positions) // n_bins
    position_groups = []
    group_labels = []
    for i in range(n_bins):
        start_idx = i * bin_size
        end_idx = (i + 1) * bin_size if i < n_bins - 1 else len(positions)
        group_positions = positions[start_idx:end_idx]
        group_errors = [e[:, pos].cpu().numpy() for pos in group_positions]
        position_groups.append(group_errors)
        group_labels.append(f'{start_idx}-{end_idx-1}')
    
    # Flatten for box plot
    box_data = []
    box_labels = []
    for group_errors, label in zip(position_groups, group_labels):
        for pos_errors in group_errors:
            box_data.append(pos_errors)
            box_labels.append(label)
    
    bp = ax5.boxplot(box_data, labels=box_labels, patch_artist=True)
    for patch in bp['boxes']:
        patch.set_facecolor('lightblue')
        patch.set_alpha(0.7)
else:
    # Show all positions
    box_data = [e[:, pos].cpu().numpy() for pos in positions]
    box_labels = [f'P{p}' for p in positions]
    bp = ax5.boxplot(box_data, labels=box_labels, patch_artist=True)
    for patch in bp['boxes']:
        patch.set_facecolor('lightblue')
        patch.set_alpha(0.7)

ax5.set_xlabel("Position Range", fontsize=11)
ax5.set_ylabel("Squared Error", fontsize=11)
ax5.set_title('Error Distribution by Position Group', fontsize=12, fontweight='bold')
ax5.tick_params(axis='x', rotation=45)
ax5.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

# Print summary statistics
print(f"\n{'='*60}")
print("SUMMARY STATISTICS")
print(f"{'='*60}")
print(f"Total positions evaluated: {len(positions)}")
print(f"Mean MSE across all positions: {mse_by_pos.mean():.6f}")
print(f"Std MSE across positions: {mse_by_pos.std():.6f}")
print(f"Min MSE: {mse_by_pos.min():.6f} at position {positions[np.argmin(mse_by_pos)]}")
print(f"Max MSE: {mse_by_pos.max():.6f} at position {positions[np.argmax(mse_by_pos)]}")
print(f"Median MSE: {np.median(mse_by_pos):.6f}")

# Performance by region (if applicable)
if conf.training.task == "group_mixture_linear" and hasattr(data_sampler, 'contexts_per_component'):
    K = data_sampler.n_components
    C = data_sampler.contexts_per_component
    context_length = K * C
    if context_length < len(positions):
        context_positions = [p for p in positions if p < context_length]
        target_positions = [p for p in positions if p >= context_length]
        
        if context_positions:
            context_indices = [positions.index(p) for p in context_positions]
            context_mse = mse_by_pos[context_indices].mean()
            print(f"\nContext Clusters (positions 0-{context_length-1}):")
            print(f"  Mean MSE: {context_mse:.6f}")
            print(f"  Positions: {len(context_positions)}")
        
        if target_positions:
            target_indices = [positions.index(p) for p in target_positions]
            target_mse = mse_by_pos[target_indices].mean()
            print(f"\nTarget Cluster (positions {context_length}-{len(positions)-1}):")
            print(f"  Mean MSE: {target_mse:.6f}")
            print(f"  Positions: {len(target_positions)}")
            print(f"  Ratio (target/context): {target_mse/context_mse:.3f}" if context_positions else "")

print(f"\n{'='*60}")

In [None]:
print("=" * 50)
print("MODEL CAPABILITY ANALYSIS")
print("=" * 50)

if conf.training.task == "group_mixture_linear":
    print("On-the-fly Mixture Linear Regression Analysis:")
    K = data_sampler.n_components
    C = data_sampler.contexts_per_component
    T_target = data_sampler.target_cluster_context_points
    context_length = K * C
    total_length = sequence_structure['total_length'] if sequence_structure else n_points
    
    print(f"  - Task structure:")
    print(f"    * {K} mixture components")
    print(f"    * {K} context clusters, each with {C} points")
    print(f"    * 1 target cluster with {T_target} context points + 1 prediction point")
    print(f"    * Total sequence length: {total_length}")
    print(f"  - Model must:")
    print(f"    1. Learn {K} weight vectors from context clusters (positions 0-{context_length-1})")
    print(f"    2. Infer which component is used in target cluster from first {T_target} target points")
    print(f"    3. Predict ALL {total_length} positions using the learned components")
    print(f"  - Clusters are in fixed order (no randomization)")
    print(f"  - All components guaranteed to appear in context clusters")
    
    # Analyze component usage
    if hasattr(data_sampler, 'component_assignments'):
        print(f"\n  - Component usage in this batch:")
        for comp_id in range(K):
            comp_count = (data_sampler.component_assignments == comp_id).sum().item()
            print(f"    Component {comp_id}: {comp_count} positions")
        
        # Check if target components appeared in context
        if hasattr(data_sampler, 'target_components'):
            target_in_context = 0
            for b in range(batch_size):
                target_comp = data_sampler.target_components[b].item()
                context_comps = set(data_sampler.component_assignments[b, :context_length].cpu().tolist())
                if target_comp in context_comps:
                    target_in_context += 1
            print(f"  - Target component appeared in context for {target_in_context}/{batch_size} examples")

elif predict_inds is not None and sequence_structure is not None:
    print("Multi-context learning analysis:")
    if 'n_contexts' in sequence_structure:
        print(f"  - Model sees {sequence_structure['n_contexts']} context series")
        print(f"  - Each context has {sequence_structure['context_length']} points") 
        print(f"  - Predicts {sequence_structure['predict_length']} future points")
    else:
        print(f"  - Sequence structure: {sequence_structure}")
        print(f"  - Predicting {len(predict_inds)} positions")
    
    if hasattr(task, 'get_mixture_info'):
        mix_info = task.get_mixture_info()
        print(f"  - Must identify correct component from {mix_info['n_components']} possibilities")
        
        correct_identification = 0
        total = 0
        
        for b in range(batch_size):
            target_comp = mix_info['target_assignment'][b]
            context_comps = mix_info['context_assignments'][b]
            
            if target_comp in context_comps:
                total += 1
                
        if total > 0:
            print(f"  - Target component appeared in context for {total}/{batch_size} examples")

else:
    print("Standard in-context learning analysis:")
    print(f"  - Model learns from increasing context (up to {n_points} points)")
    print(f"  - Task: {conf.training.task}")
    print(f"  - Data: {conf.training.data}")
    if predict_inds is not None:
        print(f"  - Predicting {len(predict_inds)} positions")
    else:
        print(f"  - Predicting all positions (autoregressive)")

xs stats: mean=-0.0017, std=0.5688, min=-3.4024, max=3.4349
ys stats: mean=-0.0052, std=0.1817, min=-0.8882, max=0.7165
coefficients[0]: tensor([ 0.0216,  0.0845, -0.0595,  0.0055, -0.0728,  0.1485,  0.0060, -0.0074,
        -0.0203,  0.0035])


(tensor([[-0.7556, -0.0793,  0.7708, -0.8922, -0.4725,  0.5542, -1.1830,  0.5065,
           0.2305, -0.2648],
         [-0.2118, -0.7556, -0.0793,  0.7708, -0.8922, -0.4725,  0.5542, -1.1830,
           0.5065,  0.2305],
         [ 0.1393, -0.2118, -0.7556, -0.0793,  0.7708, -0.8922, -0.4725,  0.5542,
          -1.1830,  0.5065],
         [-0.3126,  0.1393, -0.2118, -0.7556, -0.0793,  0.7708, -0.8922, -0.4725,
           0.5542, -1.1830],
         [ 0.1181, -0.3126,  0.1393, -0.2118, -0.7556, -0.0793,  0.7708, -0.8922,
          -0.4725,  0.5542],
         [ 0.3356,  0.1181, -0.3126,  0.1393, -0.2118, -0.7556, -0.0793,  0.7708,
          -0.8922, -0.4725],
         [-0.2420,  0.3356,  0.1181, -0.3126,  0.1393, -0.2118, -0.7556, -0.0793,
           0.7708, -0.8922],
         [-0.1926, -0.2420,  0.3356,  0.1181, -0.3126,  0.1393, -0.2118, -0.7556,
          -0.0793,  0.7708],
         [-0.0583, -0.1926, -0.2420,  0.3356,  0.1181, -0.3126,  0.1393, -0.2118,
          -0.7556, -0.0793],
 

# ✅ FIX APPLIED

## The Bug and the Fix

**The Bug:** The targets (ys) were **noiseless predictions** instead of actual noisy values.
- Before: `ys = task.evaluate(xs)` computed ys as dot products (no noise!)
- This allowed MSE < noise_variance (0.04) because targets had no noise

**The Fix:**
1. **Modified `ARWarmupSampler.sample_xs()`** to store actual noisy next values in `self.current_ys`
2. **Modified training loop** to use `data_sampler.current_ys` instead of `task.evaluate(xs)` for AR tasks
3. **Need to update eval cell 6** - replace `ys = task.evaluate(xs)` with:
   ```python
   if hasattr(data_sampler, 'current_ys'):
       ys = data_sampler.current_ys
   else:
       ys = task.evaluate(xs)
   ```

**After the fix:**
- ys will contain the actual noisy values from the AR sequence
- Theoretical minimum MSE will correctly be 0.04 (noise variance)
- The model will be learning to predict noisy targets, which is the correct formulation

**To apply:** 
1. Restart the kernel
2. Re-run cells from the beginning
3. You should now see MSE values that make sense relative to the noise floor!


# Manual Update Required for Cell 6

**Before running the evaluation, update the cell where `ys = task.evaluate(xs)` appears:**

### Find this code (around cell 5 or 6):
```python
xs = data_sampler.sample_xs(b_size=batch_size, n_points=conf.training.curriculum.points.end)

task_sampler_args = {"coefficients": data_sampler.current_coefficients}
task = task_sampler(**task_sampler_args)
ys = task.evaluate(xs)  # ← OLD (noiseless)
```

### Replace with:
```python
xs = data_sampler.sample_xs(b_size=batch_size, n_points=conf.training.curriculum.points.end)

task_sampler_args = {"coefficients": data_sampler.current_coefficients}
task = task_sampler(**task_sampler_args)

# Use actual noisy targets (includes noise from AR generation)
if hasattr(data_sampler, 'current_ys'):
    ys = data_sampler.current_ys
    print("✓ Using actual noisy targets")
else:
    ys = task.evaluate(xs)
    print("⚠️ Falling back to noiseless predictions")
```

After this change, **restart the kernel and re-run from the beginning!**
