# BERT with LoRA: Sentiment, Paraphrase, and Semantic Similarity

This notebook demonstrates fine-tuning BERT with LoRA adapters for three tasks:
- SST-5 Sentiment Analysis (5-way)
- Quora Paraphrase Detection (binary)
- Semantic Textual Similarity (STS-B, binary)

All tasks use HuggingFace's `transformers`, `datasets`, and `peft`.


## 1. Install dependencies

In [None]:
!pip install torch transformers datasets peft

## 2. Sentiment Analysis (SST-5)

In [None]:
import torch
from datasets import load_dataset
from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer
from peft import get_peft_model, LoraConfig, TaskType

sst_dataset = load_dataset("SetFit/sst5")
sst_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

def sst_preprocess(example):
    return sst_tokenizer(
        example["text"],
        truncation=True,
        padding="max_length",
        max_length=128
    )

sst_encoded = sst_dataset.map(sst_preprocess, batched=False)
sst_encoded = sst_encoded.rename_column('label', 'labels')
sst_encoded = sst_encoded.remove_columns(['text'])

sst_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=5)
sst_peft_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1
)
sst_model = get_peft_model(sst_model, sst_peft_config)

sst_training_args = TrainingArguments(
    output_dir="./results_sst5_nb",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    num_train_epochs=2,
    weight_decay=0.01,
    logging_dir="./logs_sst5_nb",
    load_best_model_at_end=True
)

sst_trainer = Trainer(
    model=sst_model,
    args=sst_training_args,
    train_dataset=sst_encoded["train"],
    eval_dataset=sst_encoded["validation"],
    tokenizer=sst_tokenizer,
)

sst_trainer.train()
sst_results = sst_trainer.evaluate()
print("SST-5 Validation:", sst_results)

def predict_sst(text):
    inputs = sst_tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
    outputs = sst_model(**inputs)
    probs = torch.softmax(outputs.logits, dim=1)
    pred = torch.argmax(probs, dim=1).item()
    return pred, probs.detach().numpy()

print(predict_sst("A wonderful, emotional film!"))

## 3. Paraphrase Detection (Quora)

In [None]:
quora_dataset = load_dataset('quora')
quora_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def quora_preprocess(example):
    return quora_tokenizer(
        example['questions']['text'][0],
        example['questions']['text'][1],
        truncation=True,
        padding='max_length',
        max_length=128
    )

quora_encoded = quora_dataset.map(quora_preprocess, batched=False)
quora_encoded = quora_encoded.rename_column('is_duplicate', 'labels')
quora_encoded = quora_encoded.remove_columns(['questions', 'id'])

quora_model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
quora_peft_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1
)
quora_model = get_peft_model(quora_model, quora_peft_config)

quora_training_args = TrainingArguments(
    output_dir='./results_paraphrase_nb',
    evaluation_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    num_train_epochs=2,
    weight_decay=0.01,
    logging_dir='./logs_paraphrase_nb',
    load_best_model_at_end=True
)

quora_trainer = Trainer(
    model=quora_model,
    args=quora_training_args,
    train_dataset=quora_encoded['train'],
    eval_dataset=quora_encoded['validation'],
    tokenizer=quora_tokenizer,
)

quora_trainer.train()
quora_results = quora_trainer.evaluate()
print("Quora Validation:", quora_results)

def predict_quora(q1, q2):
    inputs = quora_tokenizer(q1, q2, return_tensors='pt', truncation=True, padding='max_length', max_length=128)
    outputs = quora_model(**inputs)
    probs = torch.softmax(outputs.logits, dim=1)
    pred = torch.argmax(probs, dim=1).item()
    return pred, probs.detach().numpy()

print(predict_quora("How to learn coding?", "What should I do to learn programming?"))

## 4. Semantic Textual Similarity (STS-B, binary)

In [None]:
sts_dataset = load_dataset("glue", "stsb")

def sts_binarize(example):
    example["label"] = int(example["label"] >= 4.0)
    return example

sts_dataset = sts_dataset.map(sts_binarize)
sts_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

def sts_preprocess(example):
    return sts_tokenizer(
        example["sentence1"],
        example["sentence2"],
        truncation=True,
        padding="max_length",
        max_length=128
    )

sts_encoded = sts_dataset.map(sts_preprocess, batched=False)
sts_encoded = sts_encoded.remove_columns(['sentence1', 'sentence2', 'idx'])

sts_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
sts_peft_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1
)
sts_model = get_peft_model(sts_model, sts_peft_config)

sts_training_args = TrainingArguments(
    output_dir="./results_sts_binary_nb",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    num_train_epochs=2,
    weight_decay=0.01,
    logging_dir="./logs_sts_binary_nb",
    load_best_model_at_end=True
)

sts_trainer = Trainer(
    model=sts_model,
    args=sts_training_args,
    train_dataset=sts_encoded["train"],
    eval_dataset=sts_encoded["validation"],
    tokenizer=sts_tokenizer,
)

sts_trainer.train()
sts_results = sts_trainer.evaluate()
print("STS-B Validation:", sts_results)

def predict_sts(s1, s2):
    inputs = sts_tokenizer(s1, s2, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
    outputs = sts_model(**inputs)
    probs = torch.softmax(outputs.logits, dim=1)
    pred = torch.argmax(probs, dim=1).item()
    return pred, probs.detach().numpy()

print(predict_sts("A man is playing a guitar.", "A person plays an instrument."))