In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import nltk
from nltk.tokenize import word_tokenize
from collections import Counter
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, classification_report
from x_transformers import TransformerWrapper, Decoder

In [None]:
nltk.download("punkt")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_VOCAB = 20000
MAX_LEN = 256
BATCH_SIZE = 16
EPOCHS = 5
LR = 3e-4

PAD = "<pad>"
UNK = "<unk>"

In [None]:
def load_imdb(path):
    df = pd.read_csv(path)
    texts = df["review"].tolist()
    labels = [1 if s == "positive" else 0 for s in df["sentiment"]]
    return texts, labels

data_path = "/kaggle/input/imdb-dataset-of-50k-movie-reviews/IMDB Dataset.csv"
texts, labels = load_imdb(data_path)


In [None]:
def tokenize(text):
    return word_tokenize(text.lower())

counter = Counter()
for t in texts:
    counter.update(tokenize(t))

vocab = {PAD: 0, UNK: 1}
for i, (w, _) in enumerate(counter.most_common(MAX_VOCAB - 2), start=2):
    vocab[w] = i

def encode(text):
    ids = [vocab.get(t, vocab[UNK]) for t in tokenize(text)][:MAX_LEN]
    return ids + [0] * (MAX_LEN - len(ids))

In [None]:
class IMDBDataset(Dataset):
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return {
            "input_ids": torch.tensor(encode(self.texts[idx]), dtype=torch.long),
            "labels": torch.tensor(self.labels[idx], dtype=torch.long)
        }


In [None]:
class XTransformerClassifier(nn.Module):
    def __init__(self, vocab_size, num_classes=2):
        super().__init__()
        self.transformer = TransformerWrapper(
            num_tokens=vocab_size,
            max_seq_len=MAX_LEN,
            attn_layers=Decoder(
                dim=512,
                depth=6,
                heads=8
            )
        )
        self.classifier = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.transformer(x)
        x = x.mean(dim=1)
        return self.classifier(x)


In [None]:
def train_epoch(model, loader, optimizer):
    model.train()
    loss_fn = nn.CrossEntropyLoss()

    for batch in loader:
        optimizer.zero_grad()
        x = batch["input_ids"].to(DEVICE)
        y = batch["labels"].to(DEVICE)

        logits = model(x)
        loss = loss_fn(logits, y)
        loss.backward()
        optimizer.step()

def evaluate(model, loader):
    model.eval()
    y_true, y_pred = [], []

    with torch.no_grad():
        for batch in loader:
            x = batch["input_ids"].to(DEVICE)
            y = batch["labels"].to(DEVICE)

            logits = model(x)
            preds = torch.argmax(logits, dim=1)

            y_true.extend(y.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())

    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred, average="binary")

    print("Accuracy:", acc)
    print("F1-score:", f1)
    print("\nClassification Report:")
    print(classification_report(y_true, y_pred))

    return acc, f1

In [None]:
train_texts, val_texts, train_labels, val_labels = train_test_split(
    texts, labels, test_size=0.2, random_state=42
)

train_ds = IMDBDataset(train_texts, train_labels)
val_ds = IMDBDataset(val_texts, val_labels)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)

model = XTransformerClassifier(len(vocab)).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch + 1}/{EPOCHS}")
    train_epoch(model, train_loader, optimizer)
    evaluate(model, val_loader)

In [None]:
torch.save(model.state_dict(), "xtransformer_imdb.pth")

In [None]:
def predict(text):
    model.eval()
    x = torch.tensor([encode(text)], device=DEVICE)
    with torch.no_grad():
        pred = torch.argmax(model(x), dim=1).item()
    return "positive" if pred == 1 else "negative"

print("\nTest prediction:")
print(predict("This movie was absolutely fantastic and emotional."))