# Import Statement

In [1]:
import os
import re
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# loading saved model and vocab path.
STATE_PATH = "../models/checkpoints/emotion_classifier/tuning/emotion_classifier_last.pt"
VOCAB_PATH = "../data/processed/vocab.pkl"

In [3]:
# Mapping from numeric emotion IDs to their corresponding emotion labels
id2label = {
    0: "no emotion",
    1: "anger",
    2: "disgust",
    3: "fear",
    4: "happiness",
    5: "sadness",
    6: "surprise",
}

# BiLSTM based emotion classifier

In [4]:
# A BiLSTM-based emotion classification model.
class BiLSTMEmotionClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, bidirectional=True, batch_first=True)
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(hidden_dim * 2, output_dim)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        embedded = self.embedding(x)
        lstm_out, _ = self.lstm(embedded)
        avg_pool = torch.mean(lstm_out, dim=1)
        dropped = self.dropout(avg_pool)
        output = self.fc(dropped)
        return self.softmax(output)

In [5]:
# Load the vocabulary from file
with open(VOCAB_PATH, "rb") as f:
    vocab = pickle.load(f)

PAD_IDX = vocab["<PAD>"]
UNK_IDX = vocab["<UNK>"]
VOCAB_SIZE = len(vocab)

# Model hyperparameters
EMBED_DIM = 128
HIDDEN_DIM = 64
OUTPUT_DIM = 7

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the BiLSTM-based emotion classifier
model = BiLSTMEmotionClassifier(
    vocab_size=VOCAB_SIZE,
    embed_dim=EMBED_DIM,
    hidden_dim=HIDDEN_DIM,
    output_dim=OUTPUT_DIM,
    pad_idx=PAD_IDX
).to(device)

# Load the pre-trained model weights
state = torch.load(STATE_PATH, map_location=device)
model.load_state_dict(state)
model.eval()

  state = torch.load(STATE_PATH, map_location=device)


BiLSTMEmotionClassifier(
  (embedding): Embedding(10948, 128, padding_idx=0)
  (lstm): LSTM(128, 64, batch_first=True, bidirectional=True)
  (dropout): Dropout(p=0.5, inplace=False)
  (fc): Linear(in_features=128, out_features=7, bias=True)
  (softmax): LogSoftmax(dim=1)
)

### Pre-processing text

In [6]:
# Function to preprocess a text string into a list of tokens.
def preprocess_text(s):
    try:
        import nltk
        from nltk.corpus import stopwords
        from nltk.stem import WordNetLemmatizer
        from nltk import word_tokenize
        try:
            nltk.download('punkt', quiet=True)
            nltk.download('stopwords', quiet=True)
            nltk.download('wordnet', quiet=True)
        except Exception:
            pass

        stop_words = set(stopwords.words('english'))
        lemmatizer = WordNetLemmatizer()
        s = s.lower()
        s = re.sub(r"http\S+", "", s)
        s = re.sub(r"[^a-zA-Z\s]", " ", s)
        s = re.sub(r"\s+", " ", s).strip()
        toks = word_tokenize(s)
        return [lemmatizer.lemmatize(t) for t in toks if t and t not in stop_words]
    except Exception:
        s = s.lower()
        s = re.sub(r"http\S+", "", s)
        s = re.sub(r"[^a-zA-Z\s]", " ", s)
        s = re.sub(r"\s+", " ", s).strip()
        return [t for t in s.split(" ") if t]

# Function to split a text string into sentences.
def sent_split(s):
    try:
        from nltk.tokenize import sent_tokenize
        return [x.strip() for x in sent_tokenize(s) if x.strip()]
    except Exception:
        parts = re.split(r'(?<=[.!?])\s+', s.strip())
        return [x for x in parts if x]

In [7]:
# Maximum sequence length for model input
MAX_LEN = 50

# Function to convert a list of tokens into a fixed-length tensor of token IDs.
def encode(tokens):
    ids = [vocab.get(t, UNK_IDX) for t in tokens[:MAX_LEN]]
    ids += [PAD_IDX] * (MAX_LEN - len(ids))
    return torch.tensor(ids, dtype=torch.long)

# Function to split a list of tokens into fixed-size chunks.
def chunks(tokens, size=MAX_LEN):
    return [tokens[i:i+size] for i in range(0, len(tokens), size)] or [[]]

### Predicting emotion

In [8]:
# Function to predict the emotion of a given text input using the BiLSTMEmotionClassifier.
@torch.no_grad()
def predict_emotion(text, agg="mean", return_breakdown=True):
    text = (text or "").strip()
    if not text:
        return {"input": text, "pred": None, "prob": 0.0, "note": "Empty input"}

    sentences = sent_split(text)
    if not sentences:
        return {"input": text, "pred": None, "prob": 0.0, "note": "No sentences found"}

    per_sentence = []
    probs_matrix = []

    for s in sentences:
        tokens = preprocess_text(s)
        if not tokens:
            continue

        chs = chunks(tokens, MAX_LEN)
        batch = torch.stack([encode(c) for c in chs], dim=0).to(device)

        logp = model(batch)
        p = logp.exp()

        if agg == "max":
            p_sent = p.max(dim=0).values
        else:
            p_sent = p.mean(dim=0)

        probs_matrix.append(p_sent.cpu().numpy())

        if return_breakdown and len(sentences) > 1:
            pred_idx = int(p_sent.argmax().item())
            per_sentence.append({
                "text": s,
                "pred": id2label[pred_idx],
                "prob": round(float(p_sent[pred_idx].item()), 2)
            })

    if not probs_matrix:
        return {"input": text, "pred": None, "prob": 0.0, "note": "No valid tokens after preprocessing"}

    P = np.stack(probs_matrix, axis=0)
    P_agg = P.max(axis=0) if agg == "max" else P.mean(axis=0)
    final_idx = int(P_agg.argmax())
    final_prob = float(P_agg[final_idx])

    out = {
        "input": text,
        "pred_id": final_idx,
        "pred": id2label[final_idx],
        "prob": round(final_prob, 2),
        "num_sentences": len(sentences),
        "agg": agg
    }
    if return_breakdown and per_sentence:
        out["per_sentence"] = per_sentence
    return out