<a href="https://colab.research.google.com/github/Tiagofv/sentiment-analysis/blob/main/imdb_sentiment_analysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers datasets

Collecting datasets
  Downloading datasets-2.20.0-py3-none-any.whl (547 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m547.8/547.8 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl (40.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.8/40.8 MB[0m [31m12.7 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
Collecting requests (from transformers)
  Downloading requests-2.32.3-py3-none-any.whl (64 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.9/64.9 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (

In [None]:
from datasets import load_dataset
from transformers import pipeline
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import numpy as np
import torch
from tqdm import tqdm
# Load the IMDB dataset
dataset = load_dataset("imdb")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the sentiment analysis pipeline
sentiment_pipeline = pipeline(
    model="distilbert/distilbert-base-uncased-finetuned-sst-2-english",
    device=device)

# Function to perform batch prediction
def batch_predict(texts, batch_size=100):
    all_results = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Predicting", unit="batch"):
        batch = texts[i:i+batch_size]
        results = sentiment_pipeline(batch, truncation=True )
        all_results.extend(results)
    return all_results

# Perform prediction on the test set
test_texts = dataset["test"]["text"]
predictions = batch_predict(test_texts)

# Extract labels from predictions
predicted_labels = [1 if pred['label'] == 'POSITIVE' else 0 for pred in predictions]

# Get true labels
true_labels = dataset["test"]["label"]

# Calculate metrics
accuracy = accuracy_score(true_labels, predicted_labels)
precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predicted_labels, average='binary')

# Print metrics
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")

# Calculate and print confusion matrix
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(true_labels, predicted_labels)
print("Confusion Matrix:")
print(cm)

Predicting: 100%|██████████| 250/250 [05:02<00:00,  1.21s/batch]

Accuracy: 0.8907
Precision: 0.9146
Recall: 0.8619
F1 Score: 0.8875
Confusion Matrix:
[[11494  1006]
 [ 1726 10774]]





In [None]:
predictions

[{'label': 'NEGATIVE', 'score': 0.999616265296936},
 {'label': 'NEGATIVE', 'score': 0.6170608401298523},
 {'label': 'NEGATIVE', 'score': 0.9997100234031677},
 {'label': 'NEGATIVE', 'score': 0.995756208896637},
 {'label': 'POSITIVE', 'score': 0.996307373046875},
 {'label': 'NEGATIVE', 'score': 0.9966711401939392},
 {'label': 'NEGATIVE', 'score': 0.9584168791770935},
 {'label': 'NEGATIVE', 'score': 0.9994093179702759},
 {'label': 'NEGATIVE', 'score': 0.9996923208236694},
 {'label': 'NEGATIVE', 'score': 0.99977046251297},
 {'label': 'NEGATIVE', 'score': 0.9997914433479309},
 {'label': 'NEGATIVE', 'score': 0.9940081834793091},
 {'label': 'NEGATIVE', 'score': 0.9997391104698181},
 {'label': 'NEGATIVE', 'score': 0.9996050000190735},
 {'label': 'NEGATIVE', 'score': 0.9997557997703552},
 {'label': 'NEGATIVE', 'score': 0.9913800358772278},
 {'label': 'NEGATIVE', 'score': 0.9960663914680481},
 {'label': 'NEGATIVE', 'score': 0.9991810917854309},
 {'label': 'POSITIVE', 'score': 0.656829833984375},