# Fine-Tuning BERT for Sentiment Analysis

This notebook demonstrates how to fine-tune a pre-trained BERT model for a text classification task using the IMDb movie review dataset.

## 1. Setup and Library Imports

In [None]:
import torch
from datasets import load_dataset
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import numpy as np

## 2. Load and Prepare the Dataset

We use the `datasets` library from Hugging Face to easily load the IMDb dataset. We'll then create a smaller subset for faster training as a demonstration.

In [None]:
# Load the IMDb dataset
dataset = load_dataset('imdb')

# For demonstration purposes, let's use a smaller subset of the data
train_dataset = dataset['train'].shuffle(seed=42).select(range(2000)) # 2000 examples for training
test_dataset = dataset['test'].shuffle(seed=42).select(range(500))   # 500 examples for testing

print("Training data sample:", train_dataset[0])

## 3. Tokenization

We need to tokenize the text data into a format that BERT can understand. We'll use the tokenizer corresponding to the 'bert-base-uncased' model.

In [None]:
# Load the tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Create a tokenization function
def tokenize_function(examples):
    return tokenizer(examples['text'], padding='max_length', truncation=True)

# Apply the tokenization to our datasets
tokenized_train_dataset = train_dataset.map(tokenize_function, batched=True)
tokenized_test_dataset = test_dataset.map(tokenize_function, batched=True)

## 4. Model Fine-Tuning

Now we load the pre-trained BERT model and set up the `Trainer` API to fine-tune it on our specific task.

In [None]:
# Load the pre-trained model
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

# Define a function to compute metrics for evaluation
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

# Define the training arguments
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    evaluation_strategy="epoch"
)

# Create the Trainer instance
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_test_dataset,
    compute_metrics=compute_metrics
)

# Fine-tune the model
trainer.train()

## 5. Evaluation

Let's evaluate the final performance of our fine-tuned model on the test set.

In [None]:
print("Final evaluation on the test set:")
trainer.evaluate()

## 6. Save the Model and Tokenizer

We'll save our fine-tuned model so we can use it for inference later.

In [None]:
model_save_path = "../models/fine-tuned-bert"
trainer.save_model(model_save_path)
tokenizer.save_pretrained(model_save_path)
print(f"Model saved to {model_save_path}")

## 7. Inference on New Text

Finally, let's use our saved model to predict the sentiment of a new movie review.

In [None]:
from transformers import pipeline

# Load the fine-tuned model using a pipeline
sentiment_analyzer = pipeline("sentiment-analysis", model=model_save_path, tokenizer=model_save_path)

# Test with a positive review
positive_review = "This movie was absolutely fantastic! The acting was superb and the plot was gripping."
result_pos = sentiment_analyzer(positive_review)
print(f"Review: '{positive_review}'")
print(f"Prediction: {result_pos}")

# Test with a negative review
negative_review = "I was really disappointed with this film. It was boring and the ending was predictable."
result_neg = sentiment_analyzer(negative_review)
print(f"\nReview: '{negative_review}'")
print(f"Prediction: {result_neg}")