# Ordinal Text Classification using BERT + CORAL Loss (Amazon Reviews)




In [1]:
!pip install -q transformers torch datasets accelerate scikit-learn

## 1. Problem Setup & Dataset

This notebook explores **ordinal classification** using the Amazon Polarity dataset. We simulate **1–5 star ratings**:
- Label 0 → randomly mapped to 1 or 2 stars  
- Label 1 → mapped to 4 or 5 stars  
- A subset is explicitly assigned a **neutral rating (3)** to enrich the distribution.

This transforms the binary sentiment task into an ordinal one, enabling the use of **ordinal-specific loss functions**.

In [36]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from datasets import load_dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import cohen_kappa_score, mean_squared_error
import numpy as np

# Config
MAX_LEN = 128
BATCH_SIZE = 16
EPOCHS = 3
NUM_CLASSES = 5  # ordinal classes 1..5

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [37]:
# Load Amazon Reviews dataset (small subset)
dataset = load_dataset("amazon_polarity", split="train[:10000]")

# Simulate ordinal star ratings from polarity: 0 -> 1-2 stars, 1 -> 4-5 stars, and add some 3 stars
def simulate_ordinal(label):
    if label == 0:
        return np.random.choice([1, 2])
    else:
        return np.random.choice([4, 5])

labels = [simulate_ordinal(l) for l in dataset['label']]
texts = dataset['content']

In [38]:
# Add some neutral middle class (3)
for i in range(300):
    labels[i] = 3

# Train/val split
train_texts, val_texts, train_labels, val_labels = train_test_split(texts, labels, test_size=0.2, random_state=42)

In [None]:
# Tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Dataset class
class OrdinalDataset(Dataset):
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]

        # Tokenize
        encoding = tokenizer(text, padding='max_length', truncation=True, max_length=MAX_LEN, return_tensors='pt')

        # Threshold encoding (k-1 binary labels)
        ordinal_label = torch.tensor([1 if label > k else 0 for k in range(NUM_CLASSES - 1)], dtype=torch.float)

        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': ordinal_label,
            'orig_label': label
        }

# Create DataLoader
train_dataset = OrdinalDataset(train_texts, train_labels)
val_dataset = OrdinalDataset(val_texts, val_labels)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

## 2. Model: BERT with CORAL Output Head

We use a pre-trained **BERT (base-uncased)** as the backbone. A simple linear classifier with `k-1` outputs (for 5 classes, 4 thresholds) is applied on top.

Key points:
- Each output neuron represents a **threshold boundary** (e.g., 1 vs 2+, 2 vs 3+, …).
- Outputs are passed through **sigmoid**, interpreted as cumulative probabilities.
- The final class is derived by **counting thresholds passed**.

In [None]:
# CORAL Loss function
class CoralLoss(nn.Module):
    def __init__(self):
        super(CoralLoss, self).__init__()
        self.bce = nn.BCELoss()

    def forward(self, logits, labels):
        return self.bce(logits, labels)

# Model with BERT + CORAL head
class BertCoralModel(nn.Module):
    def __init__(self, num_classes):
        super(BertCoralModel, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes - 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        probs = self.sigmoid(logits)
        return probs

model = BertCoralModel(NUM_CLASSES).to(device)
criterion = CoralLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)

## 3. Training with CORAL Loss

- **CORAL loss** (Cumulative Ordinal Regression with Absolute Logits) is applied using **binary cross-entropy** across thresholds.
- Training is conducted for 3 epochs using Adam optimizer.
- A custom `OrdinalDataset` handles tokenization and ordinal label encoding.


In [41]:
from tqdm import tqdm

# Train function with tqdm
def train_epoch(model, dataloader):
    model.train()
    total_loss = 0
    loop = tqdm(dataloader, desc="Training", leave=False)
    for batch in loop:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        loop.set_postfix(loss=loss.item())
    return total_loss / len(dataloader)

# Eval function with tqdm
def eval_model(model, dataloader):
    model.eval()
    preds = []
    true_labels = []

    loop = tqdm(dataloader, desc="Evaluating", leave=False)
    with torch.no_grad():
        for batch in loop:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            orig_labels = batch['orig_label']

            outputs = model(input_ids, attention_mask)
            pred_labels = torch.sum(outputs > 0.5, dim=1) + 1  # convert to class labels

            preds.extend(pred_labels.cpu().numpy())
            true_labels.extend(orig_labels.numpy())

    qwk = cohen_kappa_score(true_labels, preds, weights='quadratic')
    spearman = np.corrcoef(true_labels, preds)[0, 1]

    return qwk, spearman, preds, true_labels

In [42]:
# Training loop with tqdm progress bars
for epoch in range(EPOCHS):
    train_loss = train_epoch(model, train_loader)
    qwk, spearman, _, _ = eval_model(model, val_loader)
    print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val QWK={qwk:.4f}, Spearman={spearman:.4f}")



Epoch 1: Train Loss=0.2966, Val QWK=0.6013, Spearman=0.7520




Epoch 2: Train Loss=0.1907, Val QWK=0.7354, Spearman=0.7837


                                                             

Epoch 3: Train Loss=0.1556, Val QWK=0.6140, Spearman=0.7600




## 4. Evaluation & Prediction

We evaluate using:
- **Quadratic Weighted Kappa (QWK):** Measures agreement accounting for ordinal nature.
- **Spearman Correlation:** Rank-based correlation between predictions and true ratings.

Final prediction is decoded by counting threshold activations > 0.5.

In [43]:
# Final evaluation
qwk, spearman, preds, true_labels = eval_model(model, val_loader)
print(f"\nFinal QWK: {qwk:.4f}")
print(f"Final Spearman Correlation: {spearman:.4f}")

                                                             


Final QWK: 0.6140
Final Spearman Correlation: 0.7600




In [44]:
def predict(model, dataloader):
    model.eval()
    preds = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            outputs = model(input_ids, attention_mask)  # sigmoids, shape: (batch_size, k-1)

            # Threshold the outputs at 0.5 to get binary vector
            threshold_preds = (outputs > 0.5).int()

            # Sum over threshold preds + 1 to get final ordinal class label
            # (because label = number of thresholds passed + 1)
            ordinal_preds = torch.sum(threshold_preds, dim=1) + 1

            preds.extend(ordinal_preds.cpu().numpy())
    return preds

# Example usage on validation set:
predicted_labels = predict(model, val_loader)
print(predicted_labels[:10])  # print first 10 predictions

[np.int64(3), np.int64(3), np.int64(3), np.int64(3), np.int64(5), np.int64(5), np.int64(5), np.int64(3), np.int64(3), np.int64(3)]
