In [1]:
import json
import torch
import random
import re
import string
import nltk
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
from bs4 import BeautifulSoup
from nltk.corpus import stopwords

nltk.download('stopwords')

import warnings
warnings.filterwarnings("ignore")


  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package stopwords to
[nltk_data]     /home/numaan.naeem/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


## Data Pre-processing

In [2]:
JSON_PATH = "/home/numaan.naeem/BEA_2025/mrbench_v3_devset.json"

In [3]:
with open(JSON_PATH, "r", encoding="utf-8") as f:
    data = json.load(f)

records = []

for item in data:
    conv_id = item["conversation_id"]
    history = item["conversation_history"]
    for model, details in item["tutor_responses"].items():
        record = {
            "conversation_id": conv_id,
            "model": model,
            "conversation_history": history,
            "response": details["response"],
            "mistake_identification": details["annotation"]["Mistake_Identification"].lower()
        }
        records.append(record)

df = pd.DataFrame(records)

In [4]:
df.head(8)

Unnamed: 0,conversation_id,model,conversation_history,response,mistake_identification
0,221-362eb11a-f190-42a6-b2a4-985fafdcfa9e,Sonnet,"Tutor: Hi, could you please provide a step-by-...","Great, you've correctly identified the cost of...",yes
1,221-362eb11a-f190-42a6-b2a4-985fafdcfa9e,Llama318B,"Tutor: Hi, could you please provide a step-by-...",Now that we know the cost of 1 pound of meat i...,yes
2,221-362eb11a-f190-42a6-b2a4-985fafdcfa9e,Llama31405B,"Tutor: Hi, could you please provide a step-by-...","You're close, but I notice that you calculated...",yes
3,221-362eb11a-f190-42a6-b2a4-985fafdcfa9e,GPT4,"Tutor: Hi, could you please provide a step-by-...","That's correct. So, if 1 pound of meat costs $...",yes
4,221-362eb11a-f190-42a6-b2a4-985fafdcfa9e,Mistral,"Tutor: Hi, could you please provide a step-by-...",It seems like you've calculated the cost as if...,yes
5,221-362eb11a-f190-42a6-b2a4-985fafdcfa9e,Expert,"Tutor: Hi, could you please provide a step-by-...","And if there are 5 sandwiches, what is the mea...",yes
6,221-362eb11a-f190-42a6-b2a4-985fafdcfa9e,Gemini,"Tutor: Hi, could you please provide a step-by-...",You're absolutely right about the cost of meat...,yes
7,221-362eb11a-f190-42a6-b2a4-985fafdcfa9e,Phi3,"Tutor: Hi, could you please provide a step-by-...","To find the area of a rectangle, multiply its ...",no


In [5]:
df.shape

(2476, 5)

In [6]:
df["mistake_identification"].value_counts()

mistake_identification
yes               1932
no                 370
to some extent     174
Name: count, dtype: int64

In [7]:
def clean_text(text):
    '''Clean emoji, Make text lowercase, remove text in square brackets,remove links,remove punctuation
    and remove words containing numbers.'''
    text = re.sub(r'\:(.*?)\:','',text)
    text = str(text).lower()    #Making Text Lowercase
    text = re.sub('\[.*?\]', '', text)
    #The next 2 lines remove html text
    text = BeautifulSoup(text, 'lxml').get_text()
    text = re.sub('https?://\S+|www\.\S+', '', text)
    text = re.sub('<.*?>+', '', text)
    text = re.sub('\n', '', text)
    text = re.sub('\w*\d\w*', '', text)
    # replacing everything with space except (a-z, A-Z, ".", "?", "!", ",", "'")
    text = re.sub(r"[^a-zA-Z?.!,¿']+", " ", text)
    return text


def clean_contractions(text):
    '''Clean contraction using contraction mapping'''    
    specials = ["’", "‘", "´", "`"]
    for s in specials:
        text = text.replace(s, "'")
    #Remove Punctuations
    text = re.sub('[%s]' % re.escape(string.punctuation), '', text)
    # creating a space between a word and the punctuation following it
    # eg: "he is a boy." => "he is a boy ."
    text = re.sub(r"([?.!,¿])", r" \1 ", text)
    text = re.sub(r'[" "]+', " ", text)
    return text

def remove_space(text):
    '''Removes awkward spaces'''   
    #Removes awkward spaces 
    text = text.strip()
    text = text.split()
    return " ".join(text)

def text_preprocessing_pipeline(text):
    '''Cleaning and parsing the text.'''
    text = clean_text(text)
    text = clean_contractions(text)
    text = remove_space(text)
    # text = remove_stopwords(text)  # Added stopword removal step
    return text

def remove_stopwords(text):
    '''Removes stopwords from the text.'''
    stop_words = set(stopwords.words('english'))
    text = text.split()
    text = [word for word in text if word not in stop_words]
    return " ".join(text)

df['conversation_history'] = df['conversation_history'].apply(text_preprocessing_pipeline)
df['response'] = df['response'].apply(text_preprocessing_pipeline)



df['conversation_history'] = df['conversation_history'].apply(remove_stopwords)
df['response'] = df['response'].apply(remove_stopwords)

In [8]:
df.head(8)

Unnamed: 0,conversation_id,model,conversation_history,response,mistake_identification
0,221-362eb11a-f190-42a6-b2a4-985fafdcfa9e,Sonnet,tutor tyson decided make muffaletta sandwiches...,great youve correctly identified cost meat let...,yes
1,221-362eb11a-f190-42a6-b2a4-985fafdcfa9e,Llama318B,tutor tyson decided make muffaletta sandwiches...,know cost pound meat use information find corr...,yes
2,221-362eb11a-f190-42a6-b2a4-985fafdcfa9e,Llama31405B,tutor tyson decided make muffaletta sandwiches...,youre close notice calculated cost pounds meat...,yes
3,221-362eb11a-f190-42a6-b2a4-985fafdcfa9e,GPT4,tutor tyson decided make muffaletta sandwiches...,thats correct pound meat costs sandwiches need...,yes
4,221-362eb11a-f190-42a6-b2a4-985fafdcfa9e,Mistral,tutor tyson decided make muffaletta sandwiches...,seems like youve calculated cost pounds meat r...,yes
5,221-362eb11a-f190-42a6-b2a4-985fafdcfa9e,Expert,tutor tyson decided make muffaletta sandwiches...,sandwiches meat cost,yes
6,221-362eb11a-f190-42a6-b2a4-985fafdcfa9e,Gemini,tutor tyson decided make muffaletta sandwiches...,youre absolutely right cost meat remember sand...,yes
7,221-362eb11a-f190-42a6-b2a4-985fafdcfa9e,Phi3,tutor tyson decided make muffaletta sandwiches...,find area rectangle multiply length width note...,no


In [9]:
# Normalize the strings (remove case/space issues)
df['mistake_identification'] = df['mistake_identification'].str.strip().str.lower()

# Map values
label_map = {
    'no': 0,
    'yes': 1,
    'to some extent': 2
}

df['mistake_identification'] = df['mistake_identification'].map(label_map)


In [10]:
df

Unnamed: 0,conversation_id,model,conversation_history,response,mistake_identification
0,221-362eb11a-f190-42a6-b2a4-985fafdcfa9e,Sonnet,tutor tyson decided make muffaletta sandwiches...,great youve correctly identified cost meat let...,1
1,221-362eb11a-f190-42a6-b2a4-985fafdcfa9e,Llama318B,tutor tyson decided make muffaletta sandwiches...,know cost pound meat use information find corr...,1
2,221-362eb11a-f190-42a6-b2a4-985fafdcfa9e,Llama31405B,tutor tyson decided make muffaletta sandwiches...,youre close notice calculated cost pounds meat...,1
3,221-362eb11a-f190-42a6-b2a4-985fafdcfa9e,GPT4,tutor tyson decided make muffaletta sandwiches...,thats correct pound meat costs sandwiches need...,1
4,221-362eb11a-f190-42a6-b2a4-985fafdcfa9e,Mistral,tutor tyson decided make muffaletta sandwiches...,seems like youve calculated cost pounds meat r...,1
...,...,...,...,...,...
2471,5910-25617a89-a4ae-47bb-8812-d6b39fa4e691,Mistral,tutor hector purchased container gumballs gave...,seems might misunderstanding calculating numbe...,1
2472,5910-25617a89-a4ae-47bb-8812-d6b39fa4e691,Phi3,tutor hector purchased container gumballs gave...,solve problem need add number apples person al...,0
2473,5910-25617a89-a4ae-47bb-8812-d6b39fa4e691,Sonnet,tutor hector purchased container gumballs gave...,thats great start like worked backwards lets t...,1
2474,5910-25617a89-a4ae-47bb-8812-d6b39fa4e691,Expert,tutor hector purchased container gumballs gave...,okay hector gave less four times many bobby gi...,1


In [11]:
df["mistake_identification"].value_counts()

mistake_identification
1    1932
0     370
2     174
Name: count, dtype: int64

In [12]:
X_train, X_val, y_train, y_val = train_test_split(df.drop(["mistake_identification"], axis=1), df["mistake_identification"], test_size=0.2, stratify=df["mistake_identification"], random_state=42)

In [13]:
X_train.shape, X_val.shape

((1980, 4), (496, 4))

## Modeling & Evaluation

In [14]:
MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
# MODEL_NAME = "google-bert/bert-base-uncased"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device is: ", DEVICE)

BATCH_SIZE = 4
EPOCHS = 10
LR = 1e-4

Device is:  cuda


In [15]:
class HistoryDataset(Dataset):
    def __init__(self, samples):
        # samples = list of (history_str, response_str, label_int)
        self.samples = samples
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        return self.samples[idx]

In [16]:
train_samples = list(zip(
    X_train['conversation_history'],
    X_train['response'],
    y_train
))

test_samples = list(zip(
    X_val['conversation_history'],
    X_val['response'],
    y_val
))

In [17]:
train_ds = HistoryDataset(train_samples)
test_ds  = HistoryDataset(test_samples)

In [18]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
encoder   = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
encoder.eval()  # We'll freeze it, as in the paper's approach

MPNetModel(
  (embeddings): MPNetEmbeddings(
    (word_embeddings): Embedding(30527, 768, padding_idx=1)
    (position_embeddings): Embedding(514, 768, padding_idx=1)
    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): MPNetEncoder(
    (layer): ModuleList(
      (0-11): 12 x MPNetLayer(
        (attention): MPNetAttention(
          (attn): MPNetSelfAttention(
            (q): Linear(in_features=768, out_features=768, bias=True)
            (k): Linear(in_features=768, out_features=768, bias=True)
            (v): Linear(in_features=768, out_features=768, bias=True)
            (o): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (intermediate): MPNetIntermediate(
          (dense): Linear(in_

### CLS + Concat

In [19]:
# [CLS]
@torch.no_grad()
def get_sequence_embeddings(texts):
    """
    Return shape: [batch_size, hidden_dim]
    """
    enc = tokenizer(texts, return_tensors="pt", padding=True,
                    truncation=True, return_attention_mask=True).to(DEVICE)
    outputs = encoder(**enc)
    # CLS token is at position 0
    cls_embeddings = outputs.last_hidden_state[:, 0, :]  # [batch, hidden_dim]
    return cls_embeddings

def collate_fn(batch):
    """
    batch: list of (history_str, response_str, label_int)
    We'll embed them in a single pass for efficiency.
    Returns (hist_emb, resp_emb, labels).
    """
    hist_texts = [item[0] for item in batch]
    resp_texts = [item[1] for item in batch]
    labels = [item[2] for item in batch]

    # shape => [B, hist_len, hidden_dim]
    hist_emb = get_sequence_embeddings(hist_texts)
    # shape => [B, resp_len, hidden_dim]
    resp_emb = get_sequence_embeddings(resp_texts)

    labels_t = torch.tensor(labels, dtype=torch.long)

    return hist_emb, resp_emb, labels_t

In [20]:
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          collate_fn=collate_fn)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False,
                          collate_fn=collate_fn)

In [21]:
# MODEL WITH [CSL] OR POOLING
class SimpleHistoryBasedModel(nn.Module):
    def __init__(self, hidden_dim=768, num_classes=3):
        super().__init__()

        # Project both embeddings to a common space (optional but useful)
        self.history_proj = nn.Linear(hidden_dim, hidden_dim)
        self.response_proj = nn.Linear(hidden_dim, hidden_dim)

        # You can use element-wise interaction (e.g., concat, diff, dot)
        self.ff = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim // 2),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim // 2, num_classes)
        )

    def forward(self, hist_emb, resp_emb):
        """
        hist_emb: [batch, hidden_dim]
        resp_emb: [batch, hidden_dim]
        Returns logits => [batch, num_classes]
        """

        # Optionally project inputs
        hist_proj = self.history_proj(hist_emb)  # [B, H]
        resp_proj = self.response_proj(resp_emb)  # [B, H]

        # Combine them — here we use concatenation
        combined = torch.cat([hist_proj, resp_proj], dim=-1)  # [B, 2*H]

        # Feedforward for classification
        logits = self.ff(combined)  # [B, num_classes]
        return logits

In [22]:
model = SimpleHistoryBasedModel(
    hidden_dim=768,
    num_classes=3
).to(DEVICE)

In [23]:
###############################################################################
# 5) TRAINING LOOP
###############################################################################
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR)

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0.0
    for hist_emb, resp_emb, labels in train_loader:
        labels = labels.to(DEVICE)

        optimizer.zero_grad()
        logits = model(hist_emb, resp_emb)  # [batch, num_classes]
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{EPOCHS}, Train Loss: {avg_loss:.4f}")
###############################################################################
# 6) EVALUATION ON TEST SET
###############################################################################
model.eval()
all_preds = []
all_labels= []
with torch.no_grad():
    for hist_emb, resp_emb, labels in test_loader:
        labels = labels.to(DEVICE)
        logits = model(hist_emb, resp_emb)
        preds = torch.argmax(logits, dim=1)

        all_preds.append(preds.cpu().numpy())
        all_labels.append(labels.cpu().numpy())

all_preds  = np.concatenate(all_preds, axis=0)
all_labels = np.concatenate(all_labels, axis=0)

print("\nTEST RESULTS:")
print(classification_report(
    all_labels,
    all_preds,
    target_names=["No (0)", "Yes (1)", "To Some Extent (2)"]
))

Epoch 1/10, Train Loss: 0.5778
Epoch 2/10, Train Loss: 0.4313
Epoch 3/10, Train Loss: 0.3810
Epoch 4/10, Train Loss: 0.3296
Epoch 5/10, Train Loss: 0.2973
Epoch 6/10, Train Loss: 0.2753
Epoch 7/10, Train Loss: 0.2464
Epoch 8/10, Train Loss: 0.2222
Epoch 9/10, Train Loss: 0.2010
Epoch 10/10, Train Loss: 0.1815

TEST RESULTS:
                    precision    recall  f1-score   support

           Yes (0)       0.75      0.59      0.66        74
            No (1)       0.88      0.92      0.90       387
To Some Extent (2)       0.25      0.23      0.24        35

          accuracy                           0.82       496
         macro avg       0.62      0.58      0.60       496
      weighted avg       0.81      0.82      0.82       496



### MEAN POOLING & Concat

In [25]:
# MEAN POOLING
@torch.no_grad()
def get_sequence_embeddings(texts):
    enc = tokenizer(texts, return_tensors="pt", padding=True,
                    truncation=True, return_attention_mask=True).to(DEVICE)
    outputs = encoder(**enc)
    token_embeddings = outputs.last_hidden_state  # [batch, seq_len, hidden_dim]
    attention_mask = enc['attention_mask'].unsqueeze(-1)  # [batch, seq_len, 1]

    # Zero out pad tokens, then average
    sum_embeddings = (token_embeddings * attention_mask).sum(1)
    valid_token_count = attention_mask.sum(1)
    mean_embeddings = sum_embeddings / valid_token_count  # [batch, hidden_dim]
    return mean_embeddings

def collate_fn(batch):
    """
    batch: list of (history_str, response_str, label_int)
    We'll embed them in a single pass for efficiency.
    Returns (hist_emb, resp_emb, labels).
    """
    hist_texts = [item[0] for item in batch]
    resp_texts = [item[1] for item in batch]
    labels = [item[2] for item in batch]

    # shape => [B, hist_len, hidden_dim]
    hist_emb = get_sequence_embeddings(hist_texts)
    # shape => [B, resp_len, hidden_dim]
    resp_emb = get_sequence_embeddings(resp_texts)

    labels_t = torch.tensor(labels, dtype=torch.long)

    return hist_emb, resp_emb, labels_t

In [26]:
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          collate_fn=collate_fn)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False,
                          collate_fn=collate_fn)

In [27]:
# MODEL WITH [CSL] OR POOLING
class SimpleHistoryBasedModel(nn.Module):
    def __init__(self, hidden_dim=768, num_classes=3):
        super().__init__()

        # Project both embeddings to a common space (optional but useful)
        self.history_proj = nn.Linear(hidden_dim, hidden_dim)
        self.response_proj = nn.Linear(hidden_dim, hidden_dim)

        # You can use element-wise interaction (e.g., concat, diff, dot)
        self.ff = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim // 2),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim // 2, num_classes)
        )

    def forward(self, hist_emb, resp_emb):
        """
        hist_emb: [batch, hidden_dim]
        resp_emb: [batch, hidden_dim]
        Returns logits => [batch, num_classes]
        """

        # Optionally project inputs
        hist_proj = self.history_proj(hist_emb)  # [B, H]
        resp_proj = self.response_proj(resp_emb)  # [B, H]

        # Combine them — here we use concatenation
        combined = torch.cat([hist_proj, resp_proj], dim=-1)  # [B, 2*H]

        # Feedforward for classification
        logits = self.ff(combined)  # [B, num_classes]
        return logits

In [28]:
model = SimpleHistoryBasedModel(
    hidden_dim=768,
    num_classes=3
).to(DEVICE)

In [29]:
###############################################################################
# 5) TRAINING LOOP
###############################################################################
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR)

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0.0
    for hist_emb, resp_emb, labels in train_loader:
        labels = labels.to(DEVICE)

        optimizer.zero_grad()
        logits = model(hist_emb, resp_emb)  # [batch, num_classes]
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{EPOCHS}, Train Loss: {avg_loss:.4f}")
###############################################################################
# 6) EVALUATION ON TEST SET
###############################################################################
model.eval()
all_preds = []
all_labels= []
with torch.no_grad():
    for hist_emb, resp_emb, labels in test_loader:
        labels = labels.to(DEVICE)
        logits = model(hist_emb, resp_emb)
        preds = torch.argmax(logits, dim=1)

        all_preds.append(preds.cpu().numpy())
        all_labels.append(labels.cpu().numpy())

all_preds  = np.concatenate(all_preds, axis=0)
all_labels = np.concatenate(all_labels, axis=0)

print("\nTEST RESULTS:")
print(classification_report(
    all_labels,
    all_preds,
    target_names=["No (0)", "Yes (1)", "To Some Extent (2)"]
))

Epoch 1/10, Train Loss: 0.5747
Epoch 2/10, Train Loss: 0.4362
Epoch 3/10, Train Loss: 0.3829
Epoch 4/10, Train Loss: 0.3425
Epoch 5/10, Train Loss: 0.3022
Epoch 6/10, Train Loss: 0.2824
Epoch 7/10, Train Loss: 0.2628
Epoch 8/10, Train Loss: 0.2310
Epoch 9/10, Train Loss: 0.2064
Epoch 10/10, Train Loss: 0.1870

TEST RESULTS:
                    precision    recall  f1-score   support

           Yes (0)       0.74      0.57      0.64        74
            No (1)       0.87      0.94      0.90       387
To Some Extent (2)       0.30      0.17      0.22        35

          accuracy                           0.83       496
         macro avg       0.64      0.56      0.59       496
      weighted avg       0.81      0.83      0.82       496



### CLS/MEAN & Siamese

In [30]:
# [CLS]
@torch.no_grad()
def get_sequence_embeddings(texts):
    """
    Return shape: [batch_size, hidden_dim]
    """
    enc = tokenizer(texts, return_tensors="pt", padding=True,
                    truncation=True, return_attention_mask=True).to(DEVICE)
    outputs = encoder(**enc)
    # CLS token is at position 0
    cls_embeddings = outputs.last_hidden_state[:, 0, :]  # [batch, hidden_dim]
    return cls_embeddings

# # MEAN POOLING
# @torch.no_grad()
# def get_sequence_embeddings(texts):
#     enc = tokenizer(texts, return_tensors="pt", padding=True,
#                     truncation=True, return_attention_mask=True).to(DEVICE)
#     outputs = encoder(**enc)
#     token_embeddings = outputs.last_hidden_state  # [batch, seq_len, hidden_dim]
#     attention_mask = enc['attention_mask'].unsqueeze(-1)  # [batch, seq_len, 1]

#     # Zero out pad tokens, then average
#     sum_embeddings = (token_embeddings * attention_mask).sum(1)
#     valid_token_count = attention_mask.sum(1)
#     mean_embeddings = sum_embeddings / valid_token_count  # [batch, hidden_dim]
#     return mean_embeddings

def collate_fn(batch):
    """
    batch: list of (history_str, response_str, label_int)
    We'll embed them in a single pass for efficiency.
    Returns (hist_emb, resp_emb, labels).
    """
    hist_texts = [item[0] for item in batch]
    resp_texts = [item[1] for item in batch]
    labels = [item[2] for item in batch]

    # shape => [B, hist_len, hidden_dim]
    hist_emb = get_sequence_embeddings(hist_texts)
    # shape => [B, resp_len, hidden_dim]
    resp_emb = get_sequence_embeddings(resp_texts)

    labels_t = torch.tensor(labels, dtype=torch.long)

    return hist_emb, resp_emb, labels_t

In [31]:
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          collate_fn=collate_fn)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False,
                          collate_fn=collate_fn)

In [32]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SiameseWithCosineClassifier(nn.Module):
    def __init__(self, hidden_dim=768, num_classes=3):
        super().__init__()

        self.proj = nn.Linear(hidden_dim, hidden_dim)

        # Classifier uses h, r, |h - r|, cos_sim
        self.ff = nn.Sequential(
            nn.Linear(hidden_dim * 3 + 1, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, num_classes)
        )

    def forward(self, hist_emb, resp_emb):
        # Optional projection layer
        h = self.proj(hist_emb)  # [B, 768]
        r = self.proj(resp_emb)  # [B, 768]

        # Cosine similarity between h and r
        cos_sim = F.cosine_similarity(h, r, dim=1)  # [B]
        cos_sim = cos_sim.unsqueeze(1)              # [B, 1]

        # Concatenate: h, r, |h - r|, cosine sim
        combined = torch.cat([
            h, r, torch.abs(h - r), cos_sim
        ], dim=-1)  # [B, 3*768 + 1]

        logits = self.ff(combined)
        return logits

In [33]:
model = SiameseWithCosineClassifier(
    hidden_dim=768,
    num_classes=3
).to(DEVICE)

In [34]:
###############################################################################
# 5) TRAINING LOOP
###############################################################################
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR)

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0.0
    for hist_emb, resp_emb, labels in train_loader:
        labels = labels.to(DEVICE)

        optimizer.zero_grad()
        logits = model(hist_emb, resp_emb)  # [batch, num_classes]
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{EPOCHS}, Train Loss: {avg_loss:.4f}")
###############################################################################
# 6) EVALUATION ON TEST SET
###############################################################################
model.eval()
all_preds = []
all_labels= []
with torch.no_grad():
    for hist_emb, resp_emb, labels in test_loader:
        labels = labels.to(DEVICE)
        logits = model(hist_emb, resp_emb)
        preds = torch.argmax(logits, dim=1)

        all_preds.append(preds.cpu().numpy())
        all_labels.append(labels.cpu().numpy())

all_preds  = np.concatenate(all_preds, axis=0)
all_labels = np.concatenate(all_labels, axis=0)

print("\nTEST RESULTS:")
print(classification_report(
    all_labels,
    all_preds,
    target_names=["No (0)", "Yes (1)", "To Some Extent (2)"]
))

Epoch 1/10, Train Loss: 0.5959
Epoch 2/10, Train Loss: 0.4395
Epoch 3/10, Train Loss: 0.3641
Epoch 4/10, Train Loss: 0.2876
Epoch 5/10, Train Loss: 0.2152
Epoch 6/10, Train Loss: 0.1414
Epoch 7/10, Train Loss: 0.0809
Epoch 8/10, Train Loss: 0.0485
Epoch 9/10, Train Loss: 0.0287
Epoch 10/10, Train Loss: 0.0251

TEST RESULTS:
                    precision    recall  f1-score   support

           Yes (0)       0.73      0.61      0.66        74
            No (1)       0.88      0.94      0.91       387
To Some Extent (2)       0.26      0.14      0.19        35

          accuracy                           0.83       496
         macro avg       0.62      0.56      0.58       496
      weighted avg       0.81      0.83      0.82       496



### TOKEN LEVEL & KQV

In [29]:
# TOKEN LEVEL EMBEDDING
@torch.no_grad()
def get_sequence_embeddings(texts):
    """
    texts: list of strings
    Return shape: [batch_size, seq_len, hidden_dim]
    We do *no pooling*, we keep the full token sequence for attention.
    """
    enc = tokenizer(texts, return_tensors="pt", padding=True,
                    truncation=True).to(DEVICE)
    outputs = encoder(**enc)
    return outputs.last_hidden_state  # [batch, seq_len, hidden_dim]

def collate_fn(batch):
    """
    batch: list of (history_str, response_str, label_int)
    We'll embed them in a single pass for efficiency.
    Returns (hist_emb, resp_emb, labels).
    """
    hist_texts = [item[0] for item in batch]
    resp_texts = [item[1] for item in batch]
    labels = [item[2] for item in batch]

    # shape => [B, hist_len, hidden_dim]
    hist_emb = get_sequence_embeddings(hist_texts)
    # shape => [B, resp_len, hidden_dim]
    resp_emb = get_sequence_embeddings(resp_texts)

    labels_t = torch.tensor(labels, dtype=torch.long)

    return hist_emb, resp_emb, labels_t

In [30]:
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          collate_fn=collate_fn)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False,
                          collate_fn=collate_fn)

In [31]:
###############################################################################
# 4) MODEL DEFINITION: SIMPLE HISTORY-BASED
###############################################################################
# This follows the paper's architecture for the "Simple History-Based Model":
#  - K from "previous sentence" embeddings (the conversation_history)
#  - Q,V from "current sentence" (the tutor response).
#  - MultiHeadAttention (Q=resp, K=hist, V=resp).
#  - Then we pool the output and pass it through a small feed-forward to get 3-class logits.

class SimpleHistoryBasedModel(nn.Module):
    def __init__(self, hidden_dim=768, n_heads=8, num_classes=3):
        super().__init__()
        self.mha = nn.MultiheadAttention(embed_dim=hidden_dim,
                                         num_heads=n_heads,
                                         batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LeakyReLU(),  # paper typically used some activation
            nn.Linear(hidden_dim // 2, num_classes)
        )

    def forward(self, hist_emb, resp_emb):
        """
        hist_emb: [batch, hist_len, hidden_dim] -> used as K
        resp_emb: [batch, resp_len, hidden_dim] -> used as Q & V
        Returns logits => [batch, num_classes]
        """
        # standard multi-head attention: Q=resp, K=hist, V=resp
        # attn_out => [batch, resp_len, hidden_dim]
        attn_out, _ = self.mha(query=resp_emb,
                               key=hist_emb,
                               value=hist_emb)

        # We can pool over resp_len dimension to get a single vector
        # The paper used a feed-forward on "the output of the attention mechanism"
        # We'll do a simple mean-pool:
        pooled = attn_out.mean(dim=1)  # => [batch, hidden_dim]
        logits = self.ff(pooled)       # => [batch, num_classes]
        return logits

In [32]:
model = SimpleHistoryBasedModel(
    hidden_dim=768,
    num_classes=3
).to(DEVICE)

In [33]:
###############################################################################
# 5) TRAINING LOOP
###############################################################################
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR)

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0.0
    for hist_emb, resp_emb, labels in train_loader:
        labels = labels.to(DEVICE)

        optimizer.zero_grad()
        logits = model(hist_emb, resp_emb)  # [batch, num_classes]
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{EPOCHS}, Train Loss: {avg_loss:.4f}")
###############################################################################
# 6) EVALUATION ON TEST SET
###############################################################################
model.eval()
all_preds = []
all_labels= []
with torch.no_grad():
    for hist_emb, resp_emb, labels in test_loader:
        labels = labels.to(DEVICE)
        logits = model(hist_emb, resp_emb)
        preds = torch.argmax(logits, dim=1)

        all_preds.append(preds.cpu().numpy())
        all_labels.append(labels.cpu().numpy())

all_preds  = np.concatenate(all_preds, axis=0)
all_labels = np.concatenate(all_labels, axis=0)

print("\nTEST RESULTS:")
print(classification_report(
    all_labels,
    all_preds,
    target_names=["Yes (0)", "No (1)", "To Some Extent (2)"]
))

Epoch 1/10, Train Loss: 0.6388
Epoch 2/10, Train Loss: 0.4716
Epoch 3/10, Train Loss: 0.4111
Epoch 4/10, Train Loss: 0.3743
Epoch 5/10, Train Loss: 0.3396
Epoch 6/10, Train Loss: 0.3046
Epoch 7/10, Train Loss: 0.2742
Epoch 8/10, Train Loss: 0.2424
Epoch 9/10, Train Loss: 0.2216
Epoch 10/10, Train Loss: 0.2060

TEST RESULTS:
                    precision    recall  f1-score   support

           Yes (0)       0.75      0.66      0.71        74
            No (1)       0.88      0.95      0.91       387
To Some Extent (2)       0.27      0.09      0.13        35

          accuracy                           0.85       496
         macro avg       0.63      0.57      0.58       496
      weighted avg       0.82      0.85      0.83       496

