In [None]:
import pandas as pd
import os, re, json, shutil, datetime, zipfile
from pathlib import Path
from convokit import Corpus, Speaker, Utterance, Conversation
from collections import Counter
import re

In [None]:
# Load the ConvoKit corpus
corpus = Corpus(filename="./dialog_corpus")
print(f"Number of conversations: {len(list(corpus.iter_conversations()))}")
print(f"Number of utterances: {len(list(corpus.iter_utterances()))}")
print(f"Number of speakers: {len(list(corpus.iter_speakers()))}")

# Show corpus metadata
print("\n=== Corpus Metadata ===")
print(corpus.meta)

In [None]:
# Randomly output several conversations and utterances to show data structure

import random

# Print structure of several conversations
print("\n=== Sample Conversations ===")
all_convs = list(corpus.iter_conversations())
sample_convs = random.sample(all_convs, min(3, len(all_convs)))
for conv in sample_convs:
    print(f"Conversation ID: {conv.id}")
    print(f"Meta: {conv.meta}")
    print(f"Number of Utterances: {len(conv.get_utterance_ids())}")
    print("Utterance IDs:", conv.get_utterance_ids()[:3], "...")
    print()

# Print structure of several utterances
print("\n=== Sample Utterances ===")
all_utts = list(corpus.iter_utterances())
sample_utts = random.sample(all_utts, min(3, len(all_utts)))
for utt in sample_utts:
    print(f"Utterance ID: {utt.id}")
    print(f"Speaker: {utt.speaker.id if utt.speaker else None}")
    print(f"Text: {utt.text[:60]}" + ("..." if len(utt.text) > 60 else ""))
    print(f"Conversation ID: {utt.conversation_id}")
    print(f"Reply to: {utt.reply_to}")
    print(f"Meta: {utt.meta}")
    print()


In [None]:
import json
import os

# Read all JSON files from GPT folder, replace 'seeker' in id with 'utterance'
sb_data = {}
gpt_dir = "./GPT"

print("Reading SB scores from GPT files...")
for filename in os.listdir(gpt_dir):
    if filename.endswith('.json'):
        filepath = os.path.join(gpt_dir, filename)
        with open(filepath, 'r') as f:
            data = json.load(f)
            for item in data:
                orig_id = item['id']  # e.g. "seeker_0_2"
                utt_id = orig_id.replace('seeker', 'utterance')
                sb_value = item['SB']
                sb_data[utt_id] = sb_value

print(json.dumps(sb_data, ensure_ascii=False, indent=2))


In [None]:
import os
import json

gpt_dir = "./GPT"
file_count = 0
total_item_count = 0
file_item_counts = {}

for filename in os.listdir(gpt_dir):
    if filename.endswith('.json'):
        filepath = os.path.join(gpt_dir, filename)
        file_count += 1
        with open(filepath, "r") as f:
            try:
                data = json.load(f)
                item_count = len(data)
                file_item_counts[filename] = item_count
                total_item_count += item_count
            except Exception as e:
                print(f"Error loading {filename}: {e}")

print(f"Total JSON files in GPT folder: {file_count}")
for fname, cnt in file_item_counts.items():
    print(f"{fname} contains {cnt} items (seeker_x_y)")
print(f"Total items across all JSON files: {total_item_count}")


In [None]:
# Read intent data from utterances_intent.jsonl
intent_data = {}

print("Loading intent data from utterances_intent.jsonl...")
with open('./utterances_intent.jsonl', 'r') as f:
    for line in f:
        data = json.loads(line.strip())
        utterance_id = data['id']
        intent_data[utterance_id] = {
            'intent': data['intent'],
            'intent_confidence': data['confidence']
        }

print(f"Loaded intent data for {len(intent_data)} utterances")
print(f"Sample intent data: {list(intent_data.items())[:1]}")


In [None]:
print(f"Number of SB data entries: {len(sb_data)}")
print(f"Number of intent data entries: {len(intent_data)}")
if len(sb_data) == len(intent_data):
    print("sb_data and intent_data have the same number of entries.")
else:
    print("sb_data and intent_data have DIFFERENT number of entries.")


In [None]:
# Check why keys (utterance_id) in sb_data and intent_data are different
from collections import defaultdict

# Get key sets of both
sb_keys = set(sb_data.keys())
intent_keys = set(intent_data.keys())

# utterance_ids in sb_data but not in intent_data, and vice versa
sb_only = sb_keys - intent_keys
intent_only = intent_keys - sb_keys

print(f"Number of utterance_ids in sb_data but not in intent_data: {len(sb_only)}")
if sb_only:
    print(f"Examples (first 10): {list(sb_only)[:10]}")
print(f"Number of utterance_ids in intent_data but not in sb_data: {len(intent_only)}")
if intent_only:
    print(f"Examples (first 10): {list(intent_only)[:10]}")

# Analyze id suffixes, conversation_id and local_utt_id
def analyze_ids(id_set, label):
    conv2utts = defaultdict(list)
    for uid in id_set:
        try:
            parts = uid.split('_')
            conv = parts[1]
            local = int(parts[2])
            conv2utts[conv].append(local)
        except Exception as e:
            print(f"{label}: Cannot parse utterance_id: {uid}, error: {e}")
    # Analyze if there are gaps in utterance local ids within conversations
    summary = {}
    for conv, utt_ids in conv2utts.items():
        sorted_ids = sorted(utt_ids)
        missing = []
        if sorted_ids:
            for i in range(sorted_ids[0], sorted_ids[-1]):
                if i not in utt_ids:
                    missing.append(i)
        if missing:
            summary[conv] = missing
    print(f"{label} - gaps in local_utt_id within conversations (showing only those with gaps):")
    if summary:
        for conv, missing in list(summary.items())[:5]:  # Show at most first 5
            print(f"  conversation {conv} missing local_utt_ids: {missing}")
    else:
        print("  No obvious gaps found.")

print("==== Analysis of sb_data keys (utterance_id) ====")
analyze_ids(sb_keys, "sb_data")
print("==== Analysis of intent_data keys (utterance_id) ====")
analyze_ids(intent_keys, "intent_data")

# Compare distributions of sb_only and intent_only to see if some conversations are completely skipped
def count_conversation(ids):
    counts = defaultdict(int)
    for uid in ids:
        try:
            parts = uid.split('_')
            conv = parts[1]
            counts[conv] += 1
        except Exception:
            continue
    return counts

print("\nsb_data only conversation distribution: (showing at most first 10)")
for conv,count in list(count_conversation(sb_only).items())[:10]:
    print(f"  conversation {conv} : {count} utterances")

print("\nintent_data only conversation distribution: (showing at most first 10)")
for conv,count in list(count_conversation(intent_only).items())[:10]:
    print(f"  conversation {conv} : {count} utterances")



In [None]:
sb_data['utterance_0_18']

In [None]:
# Map SB continuous values to binary labels (threshold 0.7)
sb_binary = {}
for k, v in sb_data.items():
    try:
        sb_binary[k] = 1 if v >= 0.7 else 0
    except Exception:
        sb_binary[k] = None


In [None]:
sb_binary['utterance_0_8']

In [None]:
len(sb_binary)

In [None]:
# Keep only IDs that exist in both sb_binary and intent_data, and integrate
# Each ID has three labels: intent, intent_confidence, sb_binary

# Take only IDs that exist in both sb_binary and intent_data (intersection)
common_ids = set(sb_binary.keys()) & set(intent_data.keys())
print(f"Number of IDs in both sb_binary and intent_data: {len(common_ids)}")

integrated_labels = {}
for uid in common_ids:
    # sb_binary value is already 0 or 1
    sb_val = sb_binary.get(uid)
    intent_info = intent_data.get(uid)
    
    if sb_val is not None and intent_info is not None:
        intent_val = intent_info.get("intent")
        confidence_val = intent_info.get("intent_confidence")  # Note: key is 'intent_confidence' not 'confidence'
        
        if intent_val is not None and confidence_val is not None:
            integrated_labels[uid] = {
                "utterance_id": uid,
                "intent": intent_val,
                "intent_confidence": confidence_val,
                "sb_binary": sb_val
            }

print(f"\nNumber of successfully integrated data points: {len(integrated_labels)}")

# Show first few examples
if integrated_labels:
    print("\nFirst 5 example data points:")
    for i, (uid, data) in enumerate(list(integrated_labels.items())[:5]):
        print(f"{i+1}. {data}")
else:
    print("No utterance_ids found that exist in both sb_binary and intent_data")


In [None]:
# Convert integrated_labels to DataFrame and split ID
import pandas as pd

# Convert integrated_labels to DataFrame and split conversation_id, turn_id as shown in example
df = pd.DataFrame(list(integrated_labels.values()))

# Add conversation_id and turn_id fields (split based on utterance_id format "utterance_{conv_id}_{turn_id}")
ids = df["utterance_id"].str.extract(r"utterance_(\d+)_(\d+)")
df["conv_id"] = ids[0].astype(int)
df["turn_id"] = ids[1].astype(int)

# To view first few rows
df.head()


In [None]:
df.shape

In [None]:
# Encode intent as integers, then do one-hot
intent2id = {intent: i for i, intent in enumerate(sorted(df["intent"].unique()))}
df["intent_id"] = df["intent"].map(intent2id)

num_intents = len(intent2id)
print("Num intents:", num_intents)

# Keep only columns we need
df = df[["utterance_id", "conv_id", "turn_id", "intent_id",
         "intent_confidence", "sb_binary"]]

# Sort by conversation and order
df = df.sort_values(["conv_id", "turn_id"]).reset_index(drop=True)

In [None]:
df.head(20)

In [None]:
from sklearn.model_selection import train_test_split

# Get all conversation IDs
all_convs = df["conv_id"].unique()

# Split by conversation
train_convs, test_convs = train_test_split(all_convs, test_size=0.2, random_state=42)
train_convs, val_convs = train_test_split(train_convs, test_size=0.1, random_state=42)

train_df = df[df["conv_id"].isin(train_convs)]
val_df = df[df["conv_id"].isin(val_convs)]
test_df = df[df["conv_id"].isin(test_convs)]

len(train_df), len(val_df), len(test_df)


In [None]:
def show_balance(name, df):
    print(f"=== {name} ===")
    print(df["sb_binary"].value_counts())
    print(df["sb_binary"].value_counts(normalize=True))

show_balance("TRAIN", train_df)
show_balance("VAL", val_df)
show_balance("TEST", test_df)


In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, f1_score, accuracy_score
from sklearn.preprocessing import StandardScaler

# ===== 1. Select baseline input features =====
baseline_features = ["intent_id", "intent_confidence", "turn_id"]  # If you have turn_id_norm, change it to that

X_train = train_df[baseline_features].values
X_val   = val_df[baseline_features].values
X_test  = test_df[baseline_features].values

y_train = train_df["sb_binary"].values
y_val   = val_df["sb_binary"].values
y_test  = test_df["sb_binary"].values

# ===== 2. Feature standardization (especially turn_id) =====
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_val   = scaler.transform(X_val)
X_test  = scaler.transform(X_test)

# ===== 3. Define Logistic Regression to handle class imbalance =====
clf = LogisticRegression(
    class_weight="balanced",  # Automatically give higher weight to minority class (SB=1)
    max_iter=500,
    solver="lbfgs"
)

# ===== 4. Train =====
clf.fit(X_train, y_train)

baseline_results = {}

# ===== 5. Evaluate on train / val / test =====
def eval_split(X, y, name):
    y_pred = clf.predict(X)
    print(f"\n=== {name} ===")
    print(classification_report(y, y_pred, digits=3))
    acc = accuracy_score(y, y_pred)
    macro = f1_score(y, y_pred, average="macro")
    pos_f1 = f1_score(y, y_pred, pos_label=1)
    print("Accuracy:", acc)
    print("Macro F1:", macro)
    print("Pos-class (SB=1) F1:", pos_f1)

    baseline_results[name] = {
        "accuracy": acc,
        "macro_f1": macro,
        "pos_f1": pos_f1,
    }
    return baseline_results[name]

eval_split(X_train, y_train, "TRAIN")
eval_split(X_val, y_val, "VAL")
# Run evaluation first
test_metrics = eval_split(X_test, y_test, "TEST")

# Convert results to table
import pandas as pd
result_table = pd.DataFrame([
    {
        "split": "train",
        **baseline_results["TRAIN"]
    },
    {
        "split": "val",
        **baseline_results["VAL"]
    },
    {
        "split": "test",
        **baseline_results["TEST"]
    }
])

print("\nBaseline results table:")
display(result_table)


In [None]:
import copy
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import classification_report, f1_score, accuracy_score

# ---------- First ensure turn_id_norm exists ----------
for split_df in [train_df, val_df, test_df]:
    split_df["turn_id_norm"] = split_df.groupby("conv_id")["turn_id"].transform(
        lambda x: x / x.max() if x.max() > 0 else 0
    )

# Range of all intents
num_intents = int(max(
    train_df["intent_id"].max(),
    val_df["intent_id"].max(),
    test_df["intent_id"].max()
)) + 1
print("num_intents:", num_intents)


def one_hot_intent(intent_ids, num_intents):
    """intent_ids: 1D array of ints -> 2D one-hot matrix"""
    N = len(intent_ids)
    oh = np.zeros((N, num_intents), dtype=np.float32)
    oh[np.arange(N), intent_ids.astype(int)] = 1.0
    return oh


def build_sequence_data(split_df, K=5):
    """
    For a split (train/val/test), construct:
      X: (num_samples, seq_len=K+1, feature_dim)
      y: (num_samples,)
    
    Features for each timestep:
      [ sb_binary, intent_onehot..., intent_confidence, turn_id_norm ]
    For current utterance t:
      sb_binary_t uses 0 as placeholder (to avoid label leakage)
    When history has fewer than K items, use left zero padding.
    """
    all_X = []
    all_y = []

    # Group by conversation
    for conv_id, g in split_df.groupby("conv_id"):
        g = g.sort_values("turn_id").reset_index(drop=True)

        intents = g["intent_id"].values
        confs = g["intent_confidence"].values.astype(np.float32)
        sbs    = g["sb_binary"].values.astype(np.float32)
        pos    = g["turn_id_norm"].values.astype(np.float32)

        # First compute one-hot for each time step
        intents_oh = one_hot_intent(intents, num_intents)  # (N, num_intents)

        N = len(g)
        for t in range(N):
            # Label for current utterance
            y_t = sbs[t]

            # Take history window [t-K, ..., t]
            start = max(0, t - K)
            end = t  # t itself is the current utterance, we handle it separately

            # Indices for history part: start..t-1
            hist_idx = list(range(start, t))  # May be empty

            hist_feats = []
            for idx in hist_idx:
                feat = np.concatenate([
                    np.array([sbs[idx]]),          # Historical SB
                    intents_oh[idx],               # intent one-hot
                    np.array([confs[idx]]),
                    np.array([pos[idx]]),
                ], axis=0)
                hist_feats.append(feat)

            # Current utterance features (note: do not include sb_binary_t)
            curr_feat = np.concatenate([
                np.array([0.0]),                  # Current utterance sb set to 0 as placeholder
                intents_oh[t],
                np.array([confs[t]]),
                np.array([pos[t]]),
            ], axis=0)

            # Concatenate history + current into sequence, length <= K+1
            seq_feats = hist_feats + [curr_feat]  # length = len(hist_idx) + 1

            # Left padding to fixed length K+1
            feat_dim = curr_feat.shape[0]
            pad_len = (K + 1) - len(seq_feats)
            if pad_len > 0:
                pad = [np.zeros(feat_dim, dtype=np.float32) for _ in range(pad_len)]
                seq_feats = pad + seq_feats

            X_t = np.stack(seq_feats, axis=0)  # (K+1, feat_dim)
            all_X.append(X_t)
            all_y.append(y_t)

    X = np.stack(all_X, axis=0).astype(np.float32)  # (num_samples, K+1, feat_dim)
    y = np.array(all_y).astype(np.float32)          # (num_samples,)

    print("build_sequence_data: X.shape =", X.shape, " y.mean(SB=1 rate) =", y.mean())
    return X, y





In [None]:
K = 0  # Use past 5 turns + current turn as context
X_train, y_train = build_sequence_data(train_df, K=K)
X_val,   y_val   = build_sequence_data(val_df,   K=K)
X_test,  y_test  = build_sequence_data(test_df,  K=K)

seq_len = X_train.shape[1]
feat_dim = X_train.shape[2]
print("seq_len =", seq_len, " feat_dim =", feat_dim)

In [None]:
class SBSeqDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


batch_size = 64  # Can be adjusted smaller/larger

train_loader = DataLoader(SBSeqDataset(X_train, y_train), batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(SBSeqDataset(X_val,   y_val),   batch_size=batch_size)
test_loader  = DataLoader(SBSeqDataset(X_test,  y_test),  batch_size=batch_size)


In [None]:
class GRUSBModel(nn.Module):
    def __init__(self, input_dim, hidden_dim=32, num_layers=1):
        super().__init__()
        self.gru = nn.GRU(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
        )
        self.fc = nn.Linear(hidden_dim, 1)  # Output a logit, BCE with logits

    def forward(self, x):
        # x: (batch, seq_len, input_dim)
        _, h_n = self.gru(x)       # h_n: (num_layers, batch, hidden_dim)
        h_last = h_n[-1]           # (batch, hidden_dim)
        logit = self.fc(h_last)    # (batch, 1)
        return logit.squeeze(-1)   # (batch,)


In [None]:
# Calculate pos_weight = (#neg / #pos)
num_pos = (y_train == 1).sum()
num_neg = (y_train == 0).sum()
pos_weight = torch.tensor([num_neg / num_pos], dtype=torch.float32)
print("num_pos:", num_pos, " num_neg:", num_neg, " pos_weight:", pos_weight.item())

device = torch.device("cpu")  # You can change to "cuda" if you have GPU
model = GRUSBModel(input_dim=feat_dim, hidden_dim=32).to(device)

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


In [None]:
def eval_model(model, data_loader, device, split_name="VAL", threshold=0.7):
    model.eval()
    all_labels = []
    all_probs = []
    with torch.no_grad():
        for X_batch, y_batch in data_loader:
            X_batch = X_batch.to(device)
            y_batch = y_batch.to(device)

            logits = model(X_batch)
            probs = torch.sigmoid(logits)

            all_labels.append(y_batch.cpu().numpy())
            all_probs.append(probs.cpu().numpy())

    all_labels = np.concatenate(all_labels)
    all_probs  = np.concatenate(all_probs)
    all_preds  = (all_probs >= threshold).astype(int)

    acc = accuracy_score(all_labels, all_preds)
    macro = f1_score(all_labels, all_preds, average="macro")
    pos_f1 = f1_score(all_labels, all_preds, pos_label=1)

    print(f"\n=== {split_name} ===")
    print(classification_report(all_labels, all_preds, digits=3))
    print("Accuracy:", acc)
    print("Macro F1:", macro)
    print("Pos-class (SB=1) F1:", pos_f1)

    metrics = {
        "accuracy": acc,
        "macro_f1": macro,
        "pos_f1": pos_f1,
    }
    return all_labels, all_probs, all_preds, metrics


num_epochs = 10  # Can try 10 epochs first, add more if needed

best_macro_f1 = 0.0
best_epoch = 0
best_model_state = None
best_val_metrics = None
training_history = []
best_model_path = f"best_grusb_model_K{K}.pth"

for epoch in range(1, num_epochs+1):
    model.train()
    total_loss = 0
    for X_batch, y_batch in train_loader:
        X_batch = X_batch.to(device)
        y_batch = y_batch.to(device)

        logits = model(X_batch)
        loss = criterion(logits, y_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * X_batch.size(0)

    avg_loss = total_loss / len(train_loader.dataset)
    print(f"Epoch {epoch}/{num_epochs} - Train Loss: {avg_loss:.4f}")

    # Check on VAL each epoch
    _, _, _, val_metrics = eval_model(model, val_loader, device, split_name="VAL")

    training_history.append({
        "epoch": epoch,
        "train_loss": avg_loss,
        "val_accuracy": val_metrics["accuracy"],
        "val_macro_f1": val_metrics["macro_f1"],
        "val_pos_f1": val_metrics["pos_f1"],
    })

    if val_metrics["macro_f1"] > best_macro_f1:
        best_macro_f1 = val_metrics["macro_f1"]
        best_epoch = epoch
        best_model_state = copy.deepcopy(model.state_dict())
        best_val_metrics = val_metrics.copy()
        print(f"[BEST] Saved best model at epoch {epoch} (macro F1={best_macro_f1:.4f})")

# Save best model to a file after training
if best_model_state is not None:
    torch.save(best_model_state, best_model_path)
    model.load_state_dict(best_model_state)
    print(
        f"Best model saved from epoch {best_epoch} with macro F1={best_macro_f1:.4f} to '{best_model_path}'"
    )
else:
    print("Warning: best model state was not set; using final epoch weights.")


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display

training_history_df = pd.DataFrame(training_history)
history_csv_path = f"final_project_data/training_history_K{K}.csv"
training_history_df.to_csv(history_csv_path, index=False)

print(f"Training history saved to {history_csv_path}")
display(training_history_df)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
ax1.plot(training_history_df["epoch"], training_history_df["train_loss"], marker='o', label='Train Loss')
ax1.set_title('Training Loss per Epoch')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.grid(True, alpha=0.3)
ax1.legend()

ax2.plot(training_history_df["epoch"], training_history_df["val_accuracy"], marker='o', label='Val Accuracy')
ax2.plot(training_history_df["epoch"], training_history_df["val_macro_f1"], marker='o', label='Val Macro F1')
ax2.plot(training_history_df["epoch"], training_history_df["val_pos_f1"], marker='o', label='Val Pos-class F1')
ax2.set_title('Validation Metrics per Epoch')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Score')
ax2.set_ylim(0.45, 0.8)
ax2.grid(True, alpha=0.3)
ax2.legend()

plt.tight_layout()
plt.show()



In [None]:
print("\n========== FINAL TEST EVAL ==========")
eval_model(model, test_loader, device, split_name="TEST")


In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt

# Load persisted test split (raise error if not found)
test_split_path = "final_project_data/test_split.csv"
if not os.path.exists(test_split_path):
    raise FileNotFoundError(f"Saved test split not found at {test_split_path}. Run the split-saving cell first.")

test_df_saved = pd.read_csv(test_split_path)
# Recalculate turn_id_norm to ensure consistency with training phase
test_df_saved["turn_id_norm"] = test_df_saved.groupby("conv_id")["turn_id"].transform(
    lambda x: x / x.max() if x.max() > 0 else 0
)

def evaluate_saved_models(k_values, test_df_source, batch_size=64, hidden_dim=32, threshold=0.7):
    results = []

    for k in k_values:
        model_path = f"best_grusb_model_K{k}.pth"
        print(f"\n--- Evaluating saved model: K={k} ({model_path}) ---")
        if not os.path.exists(model_path):
            print(f"[SKIP] Checkpoint not found: {model_path}")
            continue

        # Build test data for this K (based on persisted test split)
        X_test_k, y_test_k = build_sequence_data(test_df_source.copy(), K=k)
        test_loader_k = DataLoader(SBSeqDataset(X_test_k, y_test_k), batch_size=batch_size)

        # Create model and load weights
        model_k = GRUSBModel(input_dim=X_test_k.shape[2], hidden_dim=hidden_dim).to(device)
        state_dict = torch.load(model_path, map_location=device)
        model_k.load_state_dict(state_dict)

        # Evaluate on test set
        _, _, _, metrics = eval_model(
            model_k,
            test_loader_k,
            device,
            split_name=f"TEST (K={k})",
            threshold=threshold,
        )

        results.append({
            "K": k,
            "accuracy": metrics["accuracy"],
            "macro_f1": metrics["macro_f1"],
            "pos_f1": metrics["pos_f1"],
        })

    return pd.DataFrame(results).sort_values("K").reset_index(drop=True)


k_range = list(range(0, 9))  # K=1..8
eval_batch_size = 64
gru_eval_df = evaluate_saved_models(
    k_range,
    test_df_source=test_df_saved,
    batch_size=eval_batch_size,
    hidden_dim=32,
    threshold=0.7,
)

if not gru_eval_df.empty:
    csv_path = "final_project_data/gru_eval_across_K.csv"
    gru_eval_df.to_csv(csv_path, index=False)
    print(f"\nSaved evaluation summary to {csv_path}")
    display(gru_eval_df)

    plt.figure(figsize=(10, 5))
    for metric in ["accuracy", "macro_f1", "pos_f1"]:
        plt.plot(gru_eval_df["K"], gru_eval_df[metric], marker='o', label=metric)
    plt.title("GRU Performance vs K (best checkpoints)")
    plt.xlabel("K (history length)")
    plt.ylabel("Score")
    plt.ylim(0.45, 0.85)
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.show()
else:
    print("No saved checkpoints were evaluated. Please ensure best_grusb_model_K*.pth files exist.")
