# Notebook 2: HMM Training, Analysis, and Visualization

**Objective:** Load the logged inference trajectories, train an HMM surrogate, analyze its learned states, and visualize its behavior on test reviews.

In [None]:
import sys
import os
import numpy as np
import matplotlib.pyplot as plt

# Add src directory to Python path
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from src.config import (
    LOG_FILE_PATH, HMM_MODEL_PATH, TARGET_SENTIMENT,
    NUM_TEST_SAMPLES, MAX_TOKENS, MODEL_NAME, DEVICE
)
from src.hmm_surrogate import HMMSurrogate
from src.data_utils import load_imdb_data, preprocess_data_for_inference_logging
from src.black_box_model import BlackBoxSentimentClassifier, log_inference_trajectories
from src.visualization_utils import plot_state_timeline, plot_hmm_transition_matrix, plot_avg_probabilities_per_state

%matplotlib inline

## 1. Load Logged Inference Trajectories

In [None]:
try:
    loaded_data = np.load(LOG_FILE_PATH)
    train_trajectories = [loaded_data[f'arr_{i}'] for i in range(len(loaded_data.files))]
    print(f"Loaded {len(train_trajectories)} trajectories from {LOG_FILE_PATH}.")
    if train_trajectories:
        print(f"Example trajectory 0 shape: {train_trajectories[0].shape}")
except FileNotFoundError:
    print(f"Error: Log file {LOG_FILE_PATH} not found. Please run Notebook 1 first.")
    train_trajectories = []

## 2. Train HMM Surrogate

In [None]:
hmm_surrogate_model = None
if train_trajectories:
    hmm_surrogate_model = HMMSurrogate() # Uses N_HMM_STATES from config
    hmm_surrogate_model.train(train_trajectories)
    
    # Save the trained HMM model
    if not os.path.exists('models'):
        os.makedirs('models')
    hmm_surrogate_model.save_model(HMM_MODEL_PATH)
else:
    print("Skipping HMM training as no trajectories were loaded.")

## 3. Analyze HMM Parameters and States

If you didn't just train, you can load a pre-trained HMM model.

In [None]:
if hmm_surrogate_model is None or not hmm_surrogate_model.is_trained:
    print(f"Attempting to load pre-trained HMM from {HMM_MODEL_PATH}")
    try:
        hmm_surrogate_model = HMMSurrogate()
        hmm_surrogate_model.load_model(HMM_MODEL_PATH)
    except FileNotFoundError:
        print(f"Error: Pre-trained HMM model not found at {HMM_MODEL_PATH}. Please train one first.")
        hmm_surrogate_model = None # Ensure it's None if loading fails

state_analysis_results = None
if hmm_surrogate_model and hmm_surrogate_model.is_trained and train_trajectories:
    # For state analysis, we need decoded states for the training trajectories
    decoded_train_states = [hmm_surrogate_model.decode_sequence(traj) for traj in train_trajectories if traj.shape[0] > 0]
    
    # Filter out empty sequences from train_trajectories if their decoded_train_states are empty
    valid_train_trajectories = [traj for traj, states in zip(train_trajectories, decoded_train_states) if states.size > 0]
    valid_decoded_train_states = [states for states in decoded_train_states if states.size > 0]

    state_analysis_results = hmm_surrogate_model.analyze_states(
        valid_train_trajectories, 
        valid_decoded_train_states, 
        target_class_idx=TARGET_SENTIMENT
    )
    
    # Visualize HMM parameters
    plot_hmm_transition_matrix(hmm_surrogate_model.model, state_names=state_analysis_results.get('state_names'))
    plot_avg_probabilities_per_state(state_analysis_results, target_class_idx=TARGET_SENTIMENT)

## 4. Visualize HMM Behavior on Test Reviews

Load a few test reviews, get their black-box probability trajectories, decode HMM states, and plot.

In [None]:
if hmm_surrogate_model and hmm_surrogate_model.is_trained and state_analysis_results:
    print("\n--- Visualizing HMM on Test Reviews ---")
    # Load black-box model for generating test trajectories
    bb_model_vis = BlackBoxSentimentClassifier(model_name=MODEL_NAME, device=DEVICE)
    tokenizer_vis = bb_model_vis.tokenizer

    # Load some test data
    imdb_test_raw = load_imdb_data(split='test', num_samples=NUM_TEST_SAMPLES, shuffle=False)
    processed_test_data = preprocess_data_for_inference_logging(imdb_test_raw, tokenizer_vis)

    # Get trajectories for test data
    test_trajectories_vis = log_inference_trajectories(processed_test_data, bb_model_vis, max_len=MAX_TOKENS)
    
    # Get the actual tokens for each processed test sample (up to MAX_TOKENS)
    test_tokens_vis = []
    for item in processed_test_data:
        # Ensure [CLS] is handled consistently if present
        cls_token_id = tokenizer_vis.cls_token_id
        actual_tokens = item['tokens']
        if cls_token_id is not None and item['input_ids'][0] == cls_token_id:
             # If CLS is already first, use tokens directly
             pass # actual_tokens = item['tokens']
        elif cls_token_id is not None:
             # Prepend [CLS] token string if model adds ID but it's not in original 'tokens'
             actual_tokens = [tokenizer_vis.cls_token] + item['tokens']
        
        test_tokens_vis.append(actual_tokens[:MAX_TOKENS])


    for i, (prob_traj, tokens) in enumerate(zip(test_trajectories_vis, test_tokens_vis)):
        if prob_traj.shape[0] == 0: continue # Skip empty trajectories

        decoded_states_vis = hmm_surrogate_model.decode_sequence(prob_traj)
        
        # Ensure token list length matches trajectory length for plotting
        # The trajectory length is min(len(original_tokens_with_cls), MAX_TOKENS)
        # The `tokens` list should be sliced to match this length.
        plot_tokens = tokens[:len(decoded_states_vis)]

        print(f"\nVisualizing Test Sample {i+1}:")
        print(f"Text: {processed_test_data[i]['text'][:200]}...")
        
        plt.figure(figsize=(18, 6))
        plot_state_timeline(
            plot_tokens,
            prob_traj,
            decoded_states_vis,
            state_names=state_analysis_results.get('state_names'),
            target_class_idx=TARGET_SENTIMENT,
            ax=plt.gca() # Pass current axes
        )
        plt.show()
else:
    print("Skipping visualization as HMM model is not available or not trained, or state analysis is missing.")