In [1]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (

In [3]:
import os
import numpy as np
from sklearn.metrics import classification_report
from transformers import T5Tokenizer, T5ForConditionalGeneration
from datasets import load_dataset
from tqdm import tqdm

drive_results_dir = "./results/t5-flan"
os.makedirs(drive_results_dir, exist_ok=True)

model_name = "google/flan-t5-small"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

dataset = load_dataset("rotten_tomatoes")
test_data = dataset["test"]

def classify_sentiment(example, model, tokenizer):
    prompt = f"Classify the sentiment: {example['text']}"
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128)
    outputs = model.generate(**inputs)
    prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return prediction.strip()

predictions = []
true_labels = []
example_queries = []
label_map = {"negative": 0, "positive": 1}

for idx, example in enumerate(tqdm(test_data)):
    true_label = example["label"]
    predicted_label = classify_sentiment(example, model, tokenizer)

    if predicted_label in label_map:
        predictions.append(label_map[predicted_label])
    else:
        print(f"Unexpected label: {predicted_label}, skipping example.")
        continue

    true_labels.append(true_label)

    if idx < 5: #change for more predicion examples
        example_queries.append({
            "text": example["text"],
            "true_label": "positive" if true_label == 1 else "negative",
            "predicted_label": predicted_label
        })

print("Generating detailed evaluation report...")
report = classification_report(true_labels, predictions, target_names=["negative", "positive"], digits=4)

classification_report_file = os.path.join(drive_results_dir, "classification_report.txt")
with open(classification_report_file, "w") as f:
    f.write(report)

print(f"Classification report saved at: {classification_report_file}")

print("Classification Report:\n", report)

print("\nSample predictions:")
for example in example_queries:
    print(f"Text: {example['text']}")
    print(f"True Label: {example['true_label']}, Predicted Label: {example['predicted_label']}")
    print("-" * 50)


  3%|▎         | 30/1066 [00:02<01:32, 11.15it/s]

Unexpected label: warm, skipping example.


 71%|███████   | 758/1066 [01:18<00:33,  9.14it/s]

Unexpected label: less, skipping example.


100%|██████████| 1066/1066 [01:47<00:00,  9.90it/s]

Generating detailed evaluation report...
Classification report saved at: ./results/t5-flan/classification_report.txt
Classification Report:
               precision    recall  f1-score   support

    negative     0.8000    0.8571    0.8276       532
    positive     0.8462    0.7857    0.8148       532

    accuracy                         0.8214      1064
   macro avg     0.8231    0.8214    0.8212      1064
weighted avg     0.8231    0.8214    0.8212      1064


Sample predictions:
Text: lovingly photographed in the manner of a golden book sprung to life , stuart little 2 manages sweetness largely without stickiness .
True Label: positive, Predicted Label: positive
--------------------------------------------------
Text: consistently clever and suspenseful .
True Label: positive, Predicted Label: positive
--------------------------------------------------
Text: it's like a " big chill " reunion of the baader-meinhof gang , only these guys are more harmless pranksters than political a


