In [4]:
#IMPORTS 
import torch
import numpy as np
import pandas as pd
from datasets import load_dataset
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

#LOAD DATA 
real_embeddings = torch.load("real_surprise_embeddings.pt")
fake_embeddings = torch.load("fake_surprise_embeddings.pt")

#PREPARE DATA 
dataset = load_dataset("dair-ai/emotion")
real_labels = dataset["train"]["label"][:len(real_embeddings)]  # Original labels
synth_labels = [5] * len(fake_embeddings)  # Synthetic = "surprise"
balanced_labels = np.concatenate([real_labels, synth_labels])

# Convert to numpy arrays
X = torch.cat([
    real_embeddings.detach(),  # Add .detach() here
    fake_embeddings.detach()   # And here
]).numpy()  # Now safe to convert to numpy
y = balanced_labels

# Train-test split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, 
    test_size=0.2, 
    random_state=42
)

# Define emotion names
emotion_names = ["sadness", "joy", "love", "anger", "fear", "surprise"]

#TRAIN MODEL
clf = RandomForestClassifier(random_state=42)
clf.fit(X_train, y_train)

#EVALUATE
y_pred = clf.predict(X_test)

print("Class distribution in test set:", np.unique(y_test, return_counts=True))
print("\nClassification Report:")
print(classification_report(
    y_test, y_pred,
    target_names=emotion_names,  
    zero_division=0
))

print("\nConfusion Matrix:")
print(confusion_matrix(y_test, y_pred, labels=range(6)))

Class distribution in test set: (array([0, 1, 2, 3, 4, 5]), array([ 36,  41,  11,  17,  11, 199], dtype=int64))

Classification Report:
              precision    recall  f1-score   support

     sadness       0.26      0.19      0.22        36
         joy       0.34      0.78      0.47        41
        love       0.00      0.00      0.00        11
       anger       0.00      0.00      0.00        17
        fear       0.00      0.00      0.00        11
    surprise       1.00      0.96      0.98       199

    accuracy                           0.73       315
   macro avg       0.27      0.32      0.28       315
weighted avg       0.71      0.73      0.71       315


Confusion Matrix:
[[  7  28   0   1   0   0]
 [  9  32   0   0   0   0]
 [  5   6   0   0   0   0]
 [  5  12   0   0   0   0]
 [  0  11   0   0   0   0]
 [  1   6   0   0   0 192]]
