# Sentiment Analysis with RNN and Transformer
This notebook trains two models (an RNN-based LSTM/GRU and a Transformer) from scratch on the IMDB movie review dataset to classify sentiment as positive or negative.

## 1. Environment Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install torch torchtext matplotlib --quiet
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m86.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m92.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m56.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m19.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

## 2. Data Download and Preparation

In [None]:
# Download and extract the IMDB dataset
!curl -O https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
!tar -xzf aclImdb_v1.tar.gz

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 80.2M  100 80.2M    0     0  44.9M      0  0:00:01  0:00:01 --:--:-- 44.9M


## 3. Utility Functions

In [None]:
import os
import re
import math
import json
import random
from collections import Counter
from typing import List, Tuple

import torch
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

# -------------------------
# Device helper
# -------------------------
def get_device():
    if torch.backends.mps.is_available():
        return torch.device("mps")
    if torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")

# -------------------------
# Simple tokenizer & vocab
# -------------------------
TOKEN_RE = re.compile(r"\w+|[^\w\s]", re.UNICODE)

def simple_tokenize(text: str) -> List[str]:
    text = text.lower()
    return TOKEN_RE.findall(text)

class Vocab:
    def __init__(self, max_size: int = 20000, min_freq: int = 2, specials: List[str] = None):
        self.max_size = max_size
        self.min_freq = min_freq
        self.freqs = Counter()
        self.itos = []
        self.stoi = {}
        self.specials = specials or ["<pad>", "<unk>"]
        for tok in self.specials:
            self.add_token(tok, count=0)  # reserve positions

    def add_token(self, token, count=1):
        # used internally to reserve special tokens
        if token not in self.freqs:
            self.freqs[token] += count

    def build_from_texts(self, texts: List[str]):
        for t in texts:
            toks = simple_tokenize(t)
            self.freqs.update(toks)

        # filter by min_freq
        items = [(tok, c) for tok, c in self.freqs.items() if c >= self.min_freq and tok not in self.specials]
        items.sort(key=lambda x: (-x[1], x[0]))
        items = items[: self.max_size - len(self.specials)]
        self.itos = list(self.specials) + [t for t, _ in items]
        self.stoi = {tok: i for i, tok in enumerate(self.itos)}

    def __len__(self):
        return len(self.itos)

    def encode(self, tokens: List[str], max_len: int = None) -> List[int]:
        ids = [self.stoi.get(t, self.stoi.get("<unk>")) for t in tokens]
        if max_len is not None:
            if len(ids) >= max_len:
                ids = ids[:max_len]
            else:
                ids = ids + [self.stoi["<pad>"]] * (max_len - len(ids))
        return ids

    def save(self, path):
        with open(path, "w") as f:
            json.dump({"itos": self.itos, "max_size": self.max_size, "min_freq": self.min_freq}, f, indent=2)

    @classmethod
    def load(cls, path):
        with open(path) as f:
            data = json.load(f)
        v = cls(max_size=data.get("max_size", 20000), min_freq=data.get("min_freq", 2))
        v.itos = data["itos"]
        v.stoi = {tok: i for i, tok in enumerate(v.itos)}
        return v

# -------------------------
# Dataset wrapper
# -------------------------
class TextDataset(Dataset):
    def __init__(self, texts: List[str], labels: List[int], vocab: Vocab, max_len: int = 256):
        assert len(texts) == len(labels)
        self.texts = texts
        self.labels = labels
        self.vocab = vocab
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        t = self.texts[idx]
        toks = simple_tokenize(t)
        ids = self.vocab.encode(toks, max_len=self.max_len)
        return torch.tensor(ids, dtype=torch.long), torch.tensor(self.labels[idx], dtype=torch.long)

def collate_batch(batch):
    xs, ys = zip(*batch)
    xs = torch.stack(xs)
    ys = torch.stack(ys)
    return xs, ys

# -------------------------
# Basic training utilities
# -------------------------
def save_checkpoint(state: dict, path: str):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(state, path)

def load_checkpoint(path: str, device):
    checkpoint = torch.load(path, map_location=device)
    return checkpoint

def calc_accuracy(preds: torch.Tensor, targets: torch.Tensor):
    # preds: logits (batch, classes)
    pred_labels = preds.argmax(dim=-1)
    correct = (pred_labels == targets).sum().item()
    return correct / targets.size(0)

def plot_metrics(history: dict, out_path: str):
    # history: {"train_loss":[], "val_loss":[], "train_acc":[], "val_acc":[]}
    plt.figure(figsize=(10,4))
    plt.subplot(1,2,1)
    plt.plot(history["train_loss"], label="train_loss")
    plt.plot(history["val_loss"], label="val_loss")
    plt.legend()
    plt.xlabel("epoch")
    plt.title("Loss")
    plt.subplot(1,2,2)
    plt.plot(history["train_acc"], label="train_acc")
    plt.plot(history["val_acc"], label="val_acc")
    plt.legend()
    plt.xlabel("epoch")
    plt.title("Accuracy")
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()

## 4. Model Definitions

In [None]:
import torch
import torch.nn as nn
import math

# -------------------------
# RNN (LSTM/GRU) classifier
# -------------------------
class RNNClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=128, num_layers=1, rnn_type="lstm", bidirectional=True, num_classes=2, dropout=0.2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.rnn_type = rnn_type.lower()
        self.bidirectional = bidirectional
        if self.rnn_type == "lstm":
            self.rnn = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers, batch_first=True, bidirectional=bidirectional, dropout=dropout if num_layers>1 else 0.0)
        else:
            self.rnn = nn.GRU(embed_dim, hidden_dim, num_layers=num_layers, batch_first=True, bidirectional=bidirectional, dropout=dropout if num_layers>1 else 0.0)
        self.pool = nn.AdaptiveAvgPool1d(1)
        factor = 2 if bidirectional else 1
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim * factor, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, x):
        # x: (batch, seq_len)
        mask = (x != 0).float()  # pad idx = 0
        emb = self.embedding(x)  # (batch, seq_len, embed_dim)
        outputs, _ = self.rnn(emb)  # (batch, seq_len, hidden*dirs)
        # average pooling across tokens with mask
        outputs = outputs * mask.unsqueeze(-1)
        summed = outputs.sum(dim=1)  # sum over seq
        denom = mask.sum(dim=1).unsqueeze(-1).clamp(min=1.0)
        avg = summed / denom
        logits = self.fc(avg)
        return logits

# -------------------------
# Simple Transformer classifier
# -------------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer("pe", pe)

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        x = x + self.pe[:, : x.size(1), :]
        return x

class TransformerClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, nhead=4, num_encoder_layers=2,
                 dim_feedforward=256, max_len=256, num_classes=2, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.pos_enc = PositionalEncoding(embed_dim, max_len=max_len)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, dim_feedforward),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, num_classes)
        )

    def forward(self, x):
        # x: (batch, seq_len)
        mask = (x == 0)  # pad tokens

        emb = self.embedding(x)
        emb = self.pos_enc(emb)

        # 🚫 No src_key_padding_mask here to avoid MPS nested tensor bug
        out = self.transformer(emb)  # (batch, seq_len, embed_dim)

        # Zero out pad positions manually
        mask_float = (~mask).unsqueeze(-1).float()
        out = out * mask_float

        # Mean pool over non-pad tokens
        summed = out.sum(dim=1)
        denom = mask_float.sum(dim=1).clamp(min=1.0)
        avg = summed / denom

        logits = self.classifier(avg)
        return logits


## 5. Training the RNN Model

In [None]:
#!/usr/bin/env python3
"""
Train an RNN-based sentiment classifier (LSTM or GRU) from scratch on the IMDB dataset.

Usage:
    python train_sentiment.py --model lstm --epochs 6 --batch_size 64
"""
import os
import argparse
import random
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

#from utils import get_device, Vocab, TextDataset, collate_batch, save_checkpoint, plot_metrics, simple_tokenize
#from models import RNNClassifier
import glob

def load_imdb_fallback(data_dir="aclImdb", split="train"):
    """
    Loads IMDB dataset from a local folder in the Stanford 'aclImdb' format.
    data_dir: path to 'aclImdb' folder
    split: 'train' or 'test'
    Returns: (texts, labels) where labels are 0=neg, 1=pos
    """
    texts, labels = [], []
    for label_name, label_val in [("neg", 0), ("pos", 1)]:
        path_pattern = os.path.join(data_dir, split, label_name, "*.txt")
        for file_path in glob.glob(path_pattern):
            with open(file_path, "r", encoding="utf-8") as f:
                texts.append(f.read())
                labels.append(label_val)
    return texts, labels

def load_imdb_locally(split="train"):
    """
    Try torchtext.datasets.IMDB first.
    If unavailable, load from local folder.
    """
    try:
        from torchtext.datasets import IMDB
        ds = list(IMDB(root=".data", split=split))
        texts = [t for label, t in ds]
        labels = [0 if label.lower().startswith("neg") else 1 for label, t in ds]
        return texts, labels
    except Exception as e:
        print("[WARN] torchtext IMDB not available, falling back to local dataset.")
        if not os.path.exists("aclImdb"):
            raise RuntimeError(
                "Local IMDB dataset not found. Please download from "
                "https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz "
                "and extract into ./aclImdb"
            )
        return load_imdb_fallback(data_dir="aclImdb", split=split)


def split_train_val(texts, labels, val_frac=0.1, seed=42):
    idx = list(range(len(texts)))
    random.Random(seed).shuffle(idx)
    cut = int(len(idx)*(1-val_frac))
    train_idx = idx[:cut]
    val_idx = idx[cut:]
    train_texts = [texts[i] for i in train_idx]
    train_labels = [labels[i] for i in train_idx]
    val_texts = [texts[i] for i in val_idx]
    val_labels = [labels[i] for i in val_idx]
    return train_texts, train_labels, val_texts, val_labels

def train_epoch(model, loader, opt, criterion, device):
    model.train()
    total_loss = 0.0
    total_acc = 0.0
    n = 0
    for x, y in loader:
        x = x.to(device)
        y = y.to(device)
        opt.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        opt.step()
        batch_size = x.size(0)
        total_loss += loss.item() * batch_size
        total_acc += (logits.argmax(dim=-1) == y).sum().item()
        n += batch_size
    return total_loss / n, total_acc / n

def eval_epoch(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    total_acc = 0.0
    n = 0
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            logits = model(x)
            loss = criterion(logits, y)
            batch_size = x.size(0)
            total_loss += loss.item() * batch_size
            total_acc += (logits.argmax(dim=-1) == y).sum().item()
            n += batch_size
    return total_loss / n, total_acc / n

def main(args):
    device = get_device()
    print("Using device:", device)

    # Load IMDB
    texts, labels = load_imdb_locally(split="train")
    print("Loaded IMDB train size:", len(texts))

    # small subset option for quick experiments
    if args.limit and args.limit > 0:
        texts = texts[:args.limit]
        labels = labels[:args.limit]

    train_texts, train_labels, val_texts, val_labels = split_train_val(texts, labels, val_frac=args.val_frac)

    # build vocab
    print("Building vocab...")
    vocab = Vocab(max_size=args.vocab_size, min_freq=args.min_freq)
    vocab.build_from_texts(train_texts)
    print("Vocab size:", len(vocab))
    os.makedirs(args.out_dir, exist_ok=True)
    vocab.save(os.path.join(args.out_dir, "vocab.json"))

    # datasets and loaders
    train_ds = TextDataset(train_texts, train_labels, vocab, max_len=args.max_len)
    val_ds = TextDataset(val_texts, val_labels, vocab, max_len=args.max_len)
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate_batch)
    val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_batch)

    # model
    model = RNNClassifier(vocab_size=len(vocab), embed_dim=args.embed_dim, hidden_dim=args.hidden_dim, num_layers=args.num_layers, rnn_type=args.rnn_type, bidirectional=args.bidirectional, num_classes=2, dropout=args.dropout)
    model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}
    best_val_acc = 0.0
    for epoch in range(1, args.epochs + 1):
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_acc = eval_epoch(model, val_loader, criterion, device)
        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        history["train_acc"].append(train_acc)
        history["val_acc"].append(val_acc)
        print(f"Epoch {epoch}/{args.epochs} -- train_loss={train_loss:.4f}, train_acc={train_acc:.4f}, val_loss={val_loss:.4f}, val_acc={val_acc:.4f}")
        # checkpoint
        ckpt_path = os.path.join(args.out_dir, f"model_epoch{epoch}.pt")
        save_checkpoint({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "vocab": vocab.itos,
            "history": history
        }, ckpt_path)
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            save_checkpoint({
                "epoch": epoch,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "vocab": vocab.itos,
                "history": history
            }, os.path.join(args.out_dir, "best_model.pt"))
    # save final history plot
    plot_metrics(history, os.path.join(args.out_dir, "training_metrics.png"))
    print("Training complete. Best val acc:", best_val_acc)



## 6. Training the Transformer Model

In [None]:
#!/usr/bin/env python3
"""
Train a simple Transformer encoder classifier on IMDB dataset.
Safe for Apple Silicon MPS (no nested tensor mask bug).
"""

import os
import argparse
import random

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

#from utils import get_device, Vocab, TextDataset, collate_batch, save_checkpoint, plot_metrics
#from models import TransformerClassifier
import glob

# -------------------------
# Local IMDB loader
# -------------------------
def load_imdb_fallback(data_dir="aclImdb", split="train"):
    texts, labels = [], []
    for label_name, label_val in [("neg", 0), ("pos", 1)]:
        path_pattern = os.path.join(data_dir, split, label_name, "*.txt")
        for file_path in glob.glob(path_pattern):
            with open(file_path, "r", encoding="utf-8") as f:
                texts.append(f.read())
                labels.append(label_val)
    return texts, labels

def load_imdb_locally(split="train"):
    try:
        from torchtext.datasets import IMDB
        ds = list(IMDB(root=".data", split=split))
        texts = [t for label, t in ds]
        labels = [0 if label.lower().startswith("neg") else 1 for label, t in ds]
        return texts, labels
    except Exception:
        print("[WARN] torchtext IMDB not available, falling back to local dataset.")
        if not os.path.exists("aclImdb"):
            raise RuntimeError(
                "Local IMDB dataset not found. Please download from:\n"
                "https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz\n"
                "and extract into ./aclImdb"
            )
        return load_imdb_fallback(data_dir="aclImdb", split=split)

def split_train_val(texts, labels, val_frac=0.1, seed=42):
    idx = list(range(len(texts)))
    random.Random(seed).shuffle(idx)
    cut = int(len(idx) * (1 - val_frac))
    train_idx = idx[:cut]
    val_idx = idx[cut:]
    return (
        [texts[i] for i in train_idx],
        [labels[i] for i in train_idx],
        [texts[i] for i in val_idx],
        [labels[i] for i in val_idx],
    )

# -------------------------
# Training / Eval loops
# -------------------------
def train_epoch(model, loader, opt, criterion, device):
    model.train()
    total_loss, total_acc, n = 0.0, 0.0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        opt.zero_grad()
        logits = model(x)  # No src_key_padding_mask used in model
        loss = criterion(logits, y)
        loss.backward()
        opt.step()
        total_loss += loss.item() * x.size(0)
        total_acc += (logits.argmax(dim=-1) == y).sum().item()
        n += x.size(0)
    return total_loss / n, total_acc / n

def eval_epoch(model, loader, criterion, device):
    model.eval()
    total_loss, total_acc, n = 0.0, 0.0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = criterion(logits, y)
            total_loss += loss.item() * x.size(0)
            total_acc += (logits.argmax(dim=-1) == y).sum().item()
            n += x.size(0)
    return total_loss / n, total_acc / n

# -------------------------
# Main
# -------------------------
def main(args):
    device = get_device()
    print("Using device:", device)

    texts, labels = load_imdb_locally(split="train")
    print("Loaded IMDB train size:", len(texts))

    if args.limit and args.limit > 0:
        texts, labels = texts[:args.limit], labels[:args.limit]

    train_texts, train_labels, val_texts, val_labels = split_train_val(
        texts, labels, val_frac=args.val_frac
    )

    print("Building vocab...")
    vocab = Vocab(max_size=args.vocab_size, min_freq=args.min_freq)
    vocab.build_from_texts(train_texts)
    print("Vocab size:", len(vocab))
    os.makedirs(args.out_dir, exist_ok=True)
    vocab.save(os.path.join(args.out_dir, "vocab.json"))

    train_ds = TextDataset(train_texts, train_labels, vocab, max_len=args.max_len)
    val_ds = TextDataset(val_texts, val_labels, vocab, max_len=args.max_len)
    train_loader = DataLoader(
        train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate_batch
    )
    val_loader = DataLoader(
        val_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_batch
    )

    model = TransformerClassifier(
        vocab_size=len(vocab),
        embed_dim=args.embed_dim,
        nhead=args.nhead,
        num_encoder_layers=args.num_layers,
        dim_feedforward=args.dim_feedforward,
        max_len=args.max_len,
        num_classes=2,
        dropout=args.dropout,
    )
    model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}
    best_val_acc = 0.0

    for epoch in range(1, args.epochs + 1):
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_acc = eval_epoch(model, val_loader, criterion, device)
        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        history["train_acc"].append(train_acc)
        history["val_acc"].append(val_acc)

        print(
            f"Epoch {epoch}/{args.epochs} -- "
            f"train_loss={train_loss:.4f}, train_acc={train_acc:.4f}, "
            f"val_loss={val_loss:.4f}, val_acc={val_acc:.4f}"
        )

        ckpt = {
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "vocab": vocab.itos,
            "history": history,
        }
        torch.save(ckpt, os.path.join(args.out_dir, f"transformer_epoch{epoch}.pt"))
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(ckpt, os.path.join(args.out_dir, "transformer_best.pt"))

    plot_metrics(history, os.path.join(args.out_dir, "transformer_training_metrics.png"))
    print("Training complete. Best val acc:", best_val_acc)

# -------------------------
# Entry point
# -------------------------


## 7. Inference Pipeline

In [None]:
#!/usr/bin/env python3
import argparse
import torch
import torch.nn.functional as F

from utils import get_device, Vocab, simple_tokenize
from models import RNNClassifier, TransformerClassifier

def detect_model_type(state_dict):
    if any(k.startswith("rnn.") for k in state_dict.keys()):
        return "rnn"
    elif any(k.startswith("transformer.") for k in state_dict.keys()):
        return "transformer"
    else:
        raise ValueError("Cannot detect model type from checkpoint keys")

def main(args):
    device = get_device()
    ckpt = torch.load(args.ckpt, map_location=device)

    # rebuild vocab
    vocab = Vocab()
    vocab.itos = ckpt["vocab"]
    vocab.stoi = {tok: i for i, tok in enumerate(vocab.itos)}

    model_type = detect_model_type(ckpt["model_state"])
    print(f"[INFO] Detected model type: {model_type}")

    # crude guessing of config from weights
    if model_type == "rnn":
        fc_in_features = ckpt["model_state"]["fc.0.weight"].shape[1]
        bidirectional = "rnn.weight_ih_l0_reverse" in ckpt["model_state"]
        hidden_dim = fc_in_features // (2 if bidirectional else 1)
        model = RNNClassifier(
            vocab_size=len(vocab),
            embed_dim=ckpt["model_state"]["embedding.weight"].shape[1],
            hidden_dim=hidden_dim,
            num_layers=1,  # default guess
            rnn_type="lstm",
            bidirectional=bidirectional,
            num_classes=2,
            dropout=0.2
        )
    else:  # transformer
        embed_dim = ckpt["model_state"]["embedding.weight"].shape[1]
        # assume nhead=4, dim_feedforward=256, num_layers=2 (matches your training default)
        model = TransformerClassifier(
            vocab_size=len(vocab),
            embed_dim=embed_dim,
            nhead=4,
            num_encoder_layers=2,
            dim_feedforward=256,
            max_len=256,
            num_classes=2,
            dropout=0.1
        )

    model.load_state_dict(ckpt["model_state"])
    model.to(device)
    model.eval()

    # tokenize & encode
    toks = simple_tokenize(args.text)
    ids = vocab.encode(toks, max_len=256)
    inp = torch.tensor([ids], dtype=torch.long).to(device)

    # inference
    with torch.no_grad():
        logits = model(inp)
        probs = F.softmax(logits, dim=-1).cpu().numpy()[0]
        pred = int(probs.argmax())

    label_map = {0: "negative", 1: "positive"}
    print(f"Input: {args.text}")
    print(f"Predicted: {label_map[pred]} (probs: {probs.tolist()})")

