In [None]:
# =====================
# 0. Imports
# =====================
import torch
import matplotlib.pyplot as plt
import gradio as gr
from transformers import GPT2Tokenizer, GPT2Model
from openai import OpenAI

# Import HSMM utilities
from models.hsmm_module import HSMM_LDS_Torch, temporal_pool, zscore_torch
from models.sae import SparseAutoencoder


In [None]:
# =====================
# 1. OpenAI API
# =====================
# Recommended: set OPENAI_API_KEY in environment variables
client = OpenAI(api_key="YOUR_API_KEY")


In [None]:
# =====================
# 2. Load models
# =====================
# =====================
# 2. Load models
# =====================
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
gpt2_model = GPT2Model.from_pretrained("gpt2").cuda().eval()

# Load pretrained Sparse AutoEncoder (proper way)
from models.sae import SparseAutoencoder
sae = SparseAutoencoder(input_dim=768, hidden_dim=256, sparsity=1e-3).cuda()
sae.load_state_dict(torch.load("models/sae_model.pt", map_location="cuda"))
sae.eval()

K = 10  # set according to your pretrained model

In [None]:
# =====================
# 3. Analysis function
# =====================
def analyze_and_explain(text):
    # --- GPT-2 hidden states ---
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256).to("cuda")
    with torch.no_grad():
        hidden = gpt2_model(**inputs).last_hidden_state.squeeze(0)
        _, z_seq = sae(hidden)

    # --- SAE latent sequence ---
    z_seq = temporal_pool(z_seq.cpu(), w=5)
    Y_art = zscore_torch(z_seq).to("cuda")

    # --- HSMM segmentation ---
    hsmm = HSMM_LDS_Torch(
        K=min(10, K, max(2, Y_art.shape[0] // 2)),
        Dmax=10,
        obs_dim=Y_art.shape[1],
        init_mean_dur=20.0,
        allow_self_transition=True,
        duration_model="negbin",
        device="cuda",
        Y_init=Y_art
    )
    _ = hsmm.fit(Y_art, n_iter=5, verbose=False)
    z_example = hsmm.viterbi_decode(Y_art).cpu().numpy()

    # --- Visualization ---
    fig = plt.figure(figsize=(14, 2))
    plt.imshow(z_example[:50].reshape(1, -1), aspect="auto", cmap="tab10")
    plt.axis("off")

    # --- OpenAI explanation ---
    prompt = f"""
    The following text has been segmented into computational modes:
    Text: {text}
    Mode sequence: {z_example[:50].tolist()}

    Please explain how each segment might correspond to syntactic or semantic boundaries.
    """
    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {"role": "user", "content": prompt}
        ]
    )
    explanation = response.choices[0].message.content
    return fig, explanation

In [None]:
# =====================
# 4. Gradio Interface
# =====================
with gr.Blocks() as demo:
    gr.Markdown("# MoDeLM-Analyzer: Interactive Demo")
    inp = gr.Textbox(lines=3, label="Input Text")
    out_fig = gr.Plot(label="Mode Visualization")
    out_txt = gr.Markdown(label="Explanation")
    run_btn = gr.Button("Run")
    run_btn.click(analyze_and_explain, inputs=inp, outputs=[out_fig, out_txt])

demo.launch()