In [None]:
import pandas as pd
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
from tqdm import tqdm

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
df = pd.read_csv("../../data/yelp_datasets/balanced_yelp_dataset.csv")

In [None]:
label_encoder = LabelEncoder()
df["label"] = label_encoder.fit_transform(df["sentiment"])
true_labels = df["label"].tolist()

In [None]:
model = BertForSequenceClassification.from_pretrained("../../models/fine_tuned_bert_imdb").to(device)
tokenizer = BertTokenizer.from_pretrained("../../models/fine_tuned_bert_imdb")
model.eval()

In [None]:
texts = df["review"].tolist()
batch_size = 32  # Tune this depending on your system
all_predictions = []
all_confidences = []

for i in tqdm(range(0, len(texts), batch_size), desc="Running inference"):
    batch_texts = texts[i:i+batch_size]
    encodings = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)

    with torch.no_grad():
        outputs = model(**encodings)
        probs = torch.nn.functional.softmax(outputs.logits, dim=1)
        preds = torch.argmax(probs, dim=1)
        confs = torch.max(probs, dim=1).values

    all_predictions.extend(preds.cpu().tolist())
    all_confidences.extend(confs.cpu().tolist())

In [None]:
predicted_labels = label_encoder.inverse_transform(all_predictions)
df["predicted_sentiment"] = predicted_labels
df["confidence"] = all_confidences

In [None]:
df.to_csv("../../data/yelp_datasets/yelp_predictions_with_confidence.csv", index=False)


In [None]:
overall_acc = accuracy_score(true_labels, all_predictions)

neg_indices = [i for i, label in enumerate(true_labels) if label == 0]
neg_acc = accuracy_score(
    [true_labels[i] for i in neg_indices],
    [all_predictions[i] for i in neg_indices]
)

pos_indices = [i for i, label in enumerate(true_labels) if label == 1]
pos_acc = accuracy_score(
    [true_labels[i] for i in pos_indices],
    [all_predictions[i] for i in pos_indices]
)

print(f"Overall Accuracy: {overall_acc:.4f}")
print(f"Negative Class Accuracy (label=0): {neg_acc:.4f}")
print(f"Positive Class Accuracy (label=1): {pos_acc:.4f}")