In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd

from transformers import BertTokenizer, BertModel
from x_transformers import Decoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, classification_report

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

BERT_NAME = "bert-base-uncased"
MAX_LEN = 256
BATCH_SIZE = 16
EPOCHS = 4
LR = 2e-4
NUM_CLASSES = 2

In [None]:
def load_imdb(path):
    df = pd.read_csv(path)
    texts = df["review"].tolist()
    labels = [1 if s == "positive" else 0 for s in df["sentiment"]]
    return texts, labels

data_path = "/kaggle/input/imdb-dataset-of-50k-movie-reviews/IMDB Dataset.csv"
texts, labels = load_imdb(data_path)

In [None]:
tokenizer = BertTokenizer.from_pretrained(BERT_NAME)

class IMDBDataset(Dataset):
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        enc = tokenizer(
            self.texts[idx],
            padding="max_length",
            truncation=True,
            max_length=MAX_LEN,
            return_tensors="pt"
        )
        return {
            "input_ids": enc["input_ids"].squeeze(0),
            "attention_mask": enc["attention_mask"].squeeze(0),
            "labels": torch.tensor(self.labels[idx], dtype=torch.long)
        }


In [None]:
class BertXTransformerClassifier(nn.Module):
    def __init__(self, freeze_bert=True):
        super().__init__()
        self.bert = BertModel.from_pretrained(BERT_NAME)

        if freeze_bert:
            for p in self.bert.parameters():
                p.requires_grad = False

        self.hidden_dim = self.bert.config.hidden_size

        self.decoder = Decoder(
            dim=self.hidden_dim,
            depth=4,
            heads=8
        )

        self.classifier = nn.Linear(self.hidden_dim, NUM_CLASSES)

In [None]:
 def encode_with_bert(self, input_ids, attention_mask):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        return outputs.last_hidden_stat

In [None]:
 def decode_with_xtransformer(self, hidden_states):
        return self.decoder(hidden_states)

In [None]:
def forward(self, input_ids, attention_mask):
        encoded = self.encode_with_bert(input_ids, attention_mask)
        decoded = self.decode_with_xtransformer(encoded)
        pooled = decoded.mean(dim=1)
        return self.classifier(pooled)

In [None]:
def train_epoch(model, loader, optimizer):
    model.train()
    loss_fn = nn.CrossEntropyLoss()

    for batch in loader:
        optimizer.zero_grad()
        logits = model(
            batch["input_ids"].to(DEVICE),
            batch["attention_mask"].to(DEVICE)
        )
        loss = loss_fn(logits, batch["labels"].to(DEVICE))
        loss.backward()
        optimizer.step()

In [None]:
def evaluate(model, loader):
    model.eval()
    y_true, y_pred = [], []

    with torch.no_grad():
        for batch in loader:
            logits = model(
                batch["input_ids"].to(DEVICE),
                batch["attention_mask"].to(DEVICE)
            )
            preds = torch.argmax(logits, dim=1)
            y_true.extend(batch["labels"].numpy())
            y_pred.extend(preds.cpu().numpy())

    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred, average="binary")

    print("Accuracy:", acc)
    print("F1-score:", f1)
    print("\nClassification Report:")
    print(classification_report(y_true, y_pred))
    return acc, f1

In [None]:
train_texts, val_texts, train_labels, val_labels = train_test_split(
    texts, labels, test_size=0.2, random_state=42
)

train_loader = DataLoader(
    IMDBDataset(train_texts, train_labels),
    batch_size=BATCH_SIZE,
    shuffle=True
)

val_loader = DataLoader(
    IMDBDataset(val_texts, val_labels),
    batch_size=BATCH_SIZE
)

model = BertXTransformerClassifier(freeze_bert=True).to(DEVICE)

optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=LR
)

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch + 1}/{EPOCHS}")
    train_epoch(model, train_loader, optimizer)
    evaluate(model, val_loader)

In [None]:
torch.save(model.state_dict(), "bert_xtransformer_split.pth")

In [None]:
def predict(text):
    model.eval()
    enc = tokenizer(
        text,
        padding="max_length",
        truncation=True,
        max_length=MAX_LEN,
        return_tensors="pt"
    )
    with torch.no_grad():
        logits = model(
            enc["input_ids"].to(DEVICE),
            enc["attention_mask"].to(DEVICE)
        )
        return "positive" if torch.argmax(logits, dim=1).item() == 1 else "negative"

print("\nTest prediction:")
print(predict("This movie was absolutely fantastic and emotional."))