**Low-Rank RoBERTa: Achieving High Accuracy Under 1M Parameters**

This notebook implements a modified RoBERTa architecture for text classification on the AG News dataset using Low-Rank Adaptation

In [None]:
!pip install datasets
import nltk
nltk.download('punkt_tab')


In [None]:
from transformers import RobertaTokenizer, RobertaForSequenceClassification
from peft import LoraConfig, get_peft_model, TaskType
from datasets import load_dataset, Dataset, concatenate_datasets
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import Trainer, TrainingArguments
import torch
import json, re, pickle, random, string
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import pandas as pd
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
import seaborn as sns
from transformers import DataCollatorWithPadding

# Import nltk for augmentation functions
import nltk
from nltk.corpus import wordnet
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords

# Download nltk resources
try:
    nltk.data.find('corpora/wordnet')
    nltk.data.find('corpora/stopwords')
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('wordnet')
    nltk.download('stopwords')
    nltk.download('punkt')

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")


In [25]:
# Text augmentation functions
stop_words = set(stopwords.words('english'))

def get_synonyms(word):
    """Return a list of synonyms for a word from WordNet (excluding itself)."""
    syns = set()
    for syn in wordnet.synsets(word):
        for lemma in syn.lemmas():
            name = lemma.name().replace('_', ' ')
            if name.lower() != word.lower():
                syns.add(name)
    return list(syns)

def synonym_replacement_rate(text, rate=0.01):
    """
    Replace approximately `rate` fraction of non-stopwords in the text with synonyms.
    """
    words = word_tokenize(text)
    # indices of words eligible for replacement
    candidates = [i for i, w in enumerate(words)
                  if w.isalpha() and w.lower() not in stop_words]
    # how many to replace
    n_replace = int(len(candidates) * rate)
    if n_replace < 1:
        return text  # rate too low → no changes

    random.shuffle(candidates)
    replaced = 0
    for idx in candidates:
        syns = get_synonyms(words[idx])
        if syns:
            words[idx] = random.choice(syns)
            replaced += 1
        if replaced >= n_replace:
            break

    return ' '.join(words)

def aug_html_entities(text: str, p=0.1) -> str:
    entities = {"'": "&#39;", '"': "&quot;", "&": "&amp;"}
    for ch, ent in entities.items():
        if random.random() < p:
            text = text.replace(ch, ent)
    return text

def aug_word_dup(text: str, p=0.05) -> str:
    words = text.split()
    if words and random.random() < p:
        i = random.randrange(len(words))
        words.insert(i, words[i])
    return " ".join(words)

def aug_case_swap(text: str, p=0.1) -> str:
    return "".join(c.upper() if random.random() < p else c.lower() for c in text)

def aug_punct_space(text: str, p=0.05) -> str:
    # sprinkle punctuation
    out = []
    for c in text:
        if c.isalnum() and random.random() < p:
            out.append(c + random.choice(string.punctuation))
        else:
            out.append(c)
    s = "".join(out)
    return re.sub(r" ", lambda m: " " + (" " if random.random() < p else ""), s)

def aug_truncate(text: str, p=0.1) -> str:
    if random.random() < p and len(text) > 20:
        cut = int(len(text) * random.uniform(0.7, 0.9))
        return text[:cut]
    return text

def aug_char_swap(text: str, p=0.02) -> str:
    chars = list(text)
    for i in range(len(chars) - 1):
        if random.random() < p:
            chars[i], chars[i+1] = chars[i+1], chars[i]
    return "".join(chars)

def augment_text(text: str) -> str:
    aug_funcs = [
        aug_html_entities,
        aug_word_dup,
        aug_case_swap,
        aug_punct_space,
        aug_truncate,
        aug_char_swap,
        synonym_replacement_rate
    ]

    n = random.randint(1, 5)
    # pick n distinct functions
    chosen = random.sample(aug_funcs, k=n)
    # apply them in sequence
    for fn in chosen:
        text = fn(text)
    return text


In [None]:
# load dataset and labels
agnews = load_dataset("ag_news")
class_names = agnews["train"].features["label"].names
id2label = {i: name for i, name in enumerate(class_names)}
label2id = {name: i for i, name in enumerate(class_names)}

print("Classes:", class_names)

# define company list for masking
COMPANY_LIST = [
    "google", "apple", "microsoft", "amazon", "facebook", "tesla",
    "oracle", "ibm", "intel", "nvidia", "qualcomm", "sap",
    "salesforce", "uber", "airbnb", "twitter", "meta", "snap",
    "zoom", "palantir"
]

# text preprocessing functions
def preprocess(text):
    text = text.lower()
    text = re.sub(r'\s+', ' ', text)
    return text.strip()

def mask_text(text: str) -> str:
    t = text.replace("\n", " ").strip().lower()
    t = ''.join('[NUM]' if ch.isdigit() else ch for ch in t)
    for comp in COMPANY_LIST:
        t = re.sub(rf"\b{comp}\b", '[COMPANY]', t)
    return t


print("creating a new augmented dataset...")
random.seed(42)
orig = load_dataset("ag_news", split="train")  # 120K examples

# Create a balanced subset for augmentation
labels = sorted(set(orig["label"]))
n_per = 40_000 // len(labels)
indices_by_label = {lab: [] for lab in labels}
for i, lab in enumerate(orig["label"]):
    indices_by_label[lab].append(i)

sampled_idxs = []
for lab in labels:
    sampled_idxs += random.sample(indices_by_label[lab], n_per)
subset_30k = orig.select(sampled_idxs)

# Select 80000 from orig
new_orig = orig.shuffle(seed=42).select(range(80_000))

def perturb(ex):
    return {"text": augment_text(ex["text"])}

perturbed_30k = subset_30k.map(perturb)
train_dataset = concatenate_datasets([new_orig, perturbed_30k])
print(f"Created augmented dataset with {len(train_dataset)} examples")

# Load test dataset
test_dataset = load_dataset("ag_news", split="test")
print(f"Test dataset has {len(test_dataset)} examples")


In [None]:
# Load model and tokenizer
model_id = "roberta-base"
tokenizer = RobertaTokenizer.from_pretrained(model_id)

# Tokenization function
def tokenize_function(examples):
    examples["text"] = [preprocess(text) for text in examples["text"]]
    examples["text"] = [mask_text(text) for text in examples["text"]]
    examples["text"] = [text.replace("\n", " ") for text in examples["text"]]

    tokenizer_resp = tokenizer(
        examples["text"],
        truncation=True,
        max_length=256,
        padding="max_length",
        return_tensors="pt",
    )
    examples["input_ids"] = tokenizer_resp["input_ids"]
    examples["attention_mask"] = tokenizer_resp["attention_mask"]
    return examples

# Process datasets
train_dataset = train_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=["text"],
)

test_dataset = test_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=["text"],
)

print("Dataset preparation complete")


In [None]:
# Load model
model = RobertaForSequenceClassification.from_pretrained(
    model_id,
    num_labels=4,
    id2label=id2label,
).to(device)

# Configure LoRA
lora_config = LoraConfig(
    r=2,
    lora_alpha=4,
    target_modules=["value", "query", "key"],
    lora_dropout=0.1,
    bias='none',
    task_type=TaskType.SEQ_CLS,
)

# Apply LoRA to model
lora_model = get_peft_model(model, lora_config)
print(lora_model.print_trainable_parameters())

# Define data collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")

# Training parameters
train_batch_size = 64
test_batch_size = 32

# Create data loaders
train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=train_batch_size,
)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=test_batch_size,
)

print(f"Training on {len(train_dataset)} examples with batch size {train_batch_size}")
print(f"Testing on {len(test_dataset)} examples with batch size {test_batch_size}")


In [30]:
# Define metrics computation function
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    accuracy = accuracy_score(labels, preds)
    return {
        'accuracy': accuracy
    }

# Define training arguments
trainer_args = TrainingArguments(
    output_dir="output",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-4,
    per_device_train_batch_size=train_batch_size,
    per_device_eval_batch_size=test_batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
    optim="adamw_torch",
    label_names=["label"],
    logging_dir="./logs",
    report_to="none",
)

# Create trainer
trainer = Trainer(
    model=lora_model,
    args=trainer_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)


In [None]:
# Train the model
print("Starting training...")

# Add tracking lists for manual plotting (from first code)
train_losses, train_accuracies = [], []
test_losses, test_accuracies = [], []

# Training loop with manual tracking
for epoch in range(int(trainer_args.num_train_epochs)):
    # Train for one epoch
    trainer.train(resume_from_checkpoint=False if epoch==0 else True)

    # Evaluate and collect metrics
    train_metrics = trainer.evaluate(eval_dataset=train_dataset)
    eval_metrics = trainer.evaluate(eval_dataset=test_dataset)

    train_losses.append(train_metrics['eval_loss'])
    train_accuracies.append(train_metrics['eval_accuracy'])
    test_losses.append(eval_metrics['eval_loss'])
    test_accuracies.append(eval_metrics['eval_accuracy'])

    print(f"Epoch {epoch+1}/{trainer_args.num_train_epochs} — "
          f"Train Loss: {train_metrics['eval_loss']:.4f}, "
          f"Train Acc: {train_metrics['eval_accuracy']:.4f} — "
          f"Test Loss: {eval_metrics['eval_loss']:.4f}, "
          f"Test Acc: {eval_metrics['eval_accuracy']:.4f}")

    # Plot after each epoch (from first code)
    plt.figure(figsize=(12, 5))

    # Loss plot
    plt.subplot(1, 2, 1)
    plt.plot(range(0, epoch+1), train_losses, label="Train Loss")
    plt.plot(range(0, epoch+1), test_losses, label="Test Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.title(f"Loss - Epoch {epoch+1}")

    # Accuracy plot
    plt.subplot(1, 2, 2)
    plt.plot(range(0, epoch+1), train_accuracies, label="Train Acc")
    plt.plot(range(0, epoch+1), test_accuracies, label="Test Acc")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.title(f"Accuracy - Epoch {epoch+1}")

    plt.tight_layout()
    plt.show()

# Save the model
output_dir = "final_model"
trainer.save_model(output_dir)
print(f"Model saved to {output_dir}")


In [None]:
# Get final accuracy
predictions = trainer.predict(test_dataset)
preds = predictions.predictions.argmax(-1)
accuracy = accuracy_score(test_dataset["label"], preds)
print(f"Final Test Accuracy: {accuracy * 100:.2f}%")



In [None]:
# Optional: Run inference on unlabeled test data if available
try:
    # Load unlabeled test data if exists
    with open("test_unlabelled.pkl", "rb") as f:
        test_unlabelled = pickle.load(f)

    # Process the unlabeled data
    test_unlabelled = test_unlabelled.map(
        tokenize_function,
        batched=True,
        remove_columns=["text"],
    )

    # Make predictions
    print("Making predictions on unlabeled test data...")
    predictions = trainer.predict(test_unlabelled)

    # Format predictions
    out = {
        "ID": list(range(len(predictions.predictions))),
        "Label": predictions.predictions.argmax(-1)
    }

    # Save predictions
    output_file = "predictions.csv"
    df = pd.DataFrame(out)
    df.to_csv(output_file, index=False)
    print(f"Predictions saved to {output_file}")

except FileNotFoundError:
    print("No unlabeled test data found, skipping inference")
