In [None]:
!pip install -q transformers datasets torchaudio soundfile torchcodec

In [None]:
import torch
import torchaudio
import torchcodec

In [None]:
from datasets import load_dataset


train_ds = load_dataset("AbstractTTS/IEMOCAP", split="train")


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset, Audio
from transformers import Wav2Vec2FeatureExtractor

import torch
from torch.utils.data import Dataset

import torch
from torch.utils.data import Dataset

class IEMOCAPAudioDataset(Dataset):
    def __init__(
        self,
        hf_ds,
        threshold=0.05,
        disgust_boost=1.0,
        fear_surprise_boost=1.0,
        anger_shrink=0.9
    ):
        self.ds = hf_ds
        self.threshold = threshold
        self.disgust_boost = disgust_boost
        self.fear_surprise_boost = fear_surprise_boost
        self.anger_shrink = anger_shrink


        self.emotion_list = [
            "frustrated", "angry", "sad", "disgust",
            "excited", "fear", "neutral", "surprise", "happy"
        ]


        self.target_groups = {
            "angry": ["angry", "frustrated"],
            "happy": ["happy", "excited"],
            "sad": ["sad"],
            "neutral": ["neutral"],
            "fear_surprise": ["fear", "surprise"],
            "disgust": ["disgust"]}

        self.output_classes = [
            "angry",
            "happy",
            "sad",
            "neutral",
            "fear_surprise",
            "disgust"
        ]

        self.emotion_to_index = {cls: i for i, cls in enumerate(self.output_classes)}

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        item = self.ds[idx]

        audio_tensor = torch.tensor(item["audio"]["array"], dtype=torch.float32)


        merged = []
        for group in self.output_classes:
            members = self.target_groups[group]
            merged_value = sum(float(item[e]) for e in members)
            merged.append(merged_value)

        merged = torch.tensor(merged, dtype=torch.float32)
        # thresholding
        mask = merged >= self.threshold
        merged = merged * mask.float()

        # soft boosting and shrinking the model
        merged[self.emotion_to_index["angry"]] *= self.anger_shrink
        merged[self.emotion_to_index["disgust"]] *= self.disgust_boost
        merged[self.emotion_to_index["fear_surprise"]] *= self.fear_surprise_boost

        # normalize
        if merged.sum() > 0:
            merged = merged / merged.sum()
        else:
            merged = torch.ones_like(merged) / len(merged)

        return audio_tensor, merged




processor = Wav2Vec2FeatureExtractor.from_pretrained(
    "facebook/wav2vec2-large",
    cache_dir="./hf_cache"
)



def collate_fn(batch):
    """
    batch: list of (audio_tensor, label_tensor)
    """
    waveforms, labels = zip(*batch)

    waveforms = [w.numpy() for w in waveforms]


    inputs = processor(
        waveforms,
        sampling_rate=16000,
        return_tensors="pt",
        padding=True
    )

    labels = torch.stack(labels)
    return inputs["input_values"], labels


full_ds = load_dataset("AbstractTTS/IEMOCAP", split="train")

split_ds = full_ds.train_test_split(test_size=0.2, seed=42)
train_ds_split = split_ds["train"]
val_ds_split = split_ds["test"]

dataset = IEMOCAPAudioDataset(train_ds_split)
val_dataset = IEMOCAPAudioDataset(val_ds_split)

batch_size=16

train_loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False, # No need to shuffle validation data
    collate_fn=collate_fn
)

print("âœ“ Dataset + DataLoader ready for split.")

In [None]:


from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
# this is a sample that takes a singular waveform and just gets the embedding (no training at all)
'''
# Load Wav2Vec2
processor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-large", cache_dir="./hf_cache")
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large", cache_dir="./hf_cache")
model.eval()# delete if we were training it to specifically emotional

# Somewhere here we would do stuff like add freeze layers, add layers, etc. etc. etc.
# so we would in a training/fine tuning, we would edit model/processor


# Forward loop
inputs = processor(waveform, sampling_rate=sr, return_tensors="pt", padding=True)
inputs = {k: v for k, v in inputs.items()}

with torch.no_grad():
    hidden = model(**inputs).last_hidden_state


embedding = hidden.mean(dim=1).squeeze(0)
print("Embedding shape:", embedding.shape)
'''


In [None]:
# figuring out and building the architecture that classifies

In [None]:
import torch.nn.functional
import torch.nn as nn

class EmotionClassifierMLP(nn.Module):
    def __init__(self, emb_dim=1024, hidden_dim1=256, hidden_dim2=128, num_classes=9, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(emb_dim, hidden_dim1),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim1, hidden_dim2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim2, num_classes)
        )

    def forward(self, embeddings, return_probs=False):
        logits = self.net(embeddings)
        if return_probs:
            return torch.nn.functional.softmax(logits, dim=-1)
        return torch.nn.functional.log_softmax(logits, dim=-1)

In [None]:
import torch
"""
# Wav2Vec2 model parameters
model_name = "facebook/wav2vec2-large"
cache_dir = "./hf_cache"
num_labels = 6  # Number of emotions, corrected from 8 to 9
# Wav2Vec2
freeze_feature_extractor = True
number_unfrozen = 6
backbone_learning_rate = 2e-5
backbone_weight_decay = 0.0001

# Classifier head
hidden_dim1 = 256
hidden_dim2 = 128
dropout = 0.1
classifier_learning_rate = 1e-4
classifier_weight_decay = 0.0

# Training
epochs = 15
batch_size = 16
device = 'cuda' if torch.cuda.is_available() else 'cpu'
normalization = False

# Schedulers
backbone_schedule = "cosine"
classifier_schedule = "cosine"
"""

In [None]:
import random
import torch
import torch.nn as nn
"""
Below is my code for grid searching

param_space = {
    'number_unfrozen': [0, 3, 6, 9, 12],
    'backbone_learning_rate': [1e-6, 2e-6, 5e-6, 1e-5, 2e-5, 5e-5],
    'classifier_learning_rate': [1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4, 1e-3],
    'hidden_dim1': [128, 256, 512],
    'hidden_dim2': [64, 128, 256],
    'dropout': [0.1, 0.2, 0.3],
    'normalization': [True, False]
}

num_trials = 10
epochs_for_grid_search = 7
patience_for_grid_search = 3
num_labels=9

best_accuracy = -1.0
best_params = None
all_results = []

print(f"Starting Random Grid Search for {num_trials} trials...")

for i in range(num_trials):
    print(f"\n--- Trial {i+1}/{num_trials} ---")

    # Sample a random combination of hyperparameters
    current_params = {
        'number_unfrozen': random.choice(param_space['number_unfrozen']),
        'backbone_learning_rate': random.choice(param_space['backbone_learning_rate']),
        'classifier_learning_rate': random.choice(param_space['classifier_learning_rate']),
        'hidden_dim1': random.choice(param_space['hidden_dim1']),
        'hidden_dim2': random.choice(param_space['hidden_dim2']),
        'dropout': random.choice(param_space['dropout']),
        'normalization': random.choice(param_space['normalization'])
    }
    print(f"Current parameters: {current_params}")

    # Instantiate the EmotionClassifierMLP with current hyperparameters
    classifier_model_trial = EmotionClassifierMLP(
        emb_dim=1024,
        num_classes=num_labels,
        hidden_dim1=current_params['hidden_dim1'],
        hidden_dim2=current_params['hidden_dim2'],
        dropout=current_params['dropout']
    )

    # Re-initialize Wav2Vec2EmbeddingExtractor to ensure fresh backbone weights for each trial
    wav2vec2_trainer_trial = Wav2Vec2EmbeddingExtractor(
        model_name=model_name,
        cache_dir=cache_dir
    )

    val_accuracy = wav2vec2_trainer_trial.train_model(
        classification_model=classifier_model_trial,
        train_dataloader=train_loader,
        val_dataloader=val_loader,
        backbone_lr=current_params['backbone_learning_rate'],
        classifier_lr=current_params['classifier_learning_rate'],
        backbone_weight_decay=backbone_weight_decay,
        classifier_weight_decay=classifier_weight_decay,
        backbone_schedule=backbone_schedule,
        classifier_schedule=classifier_schedule,
        number_unfrozen=current_params['number_unfrozen'],
        normalization=current_params['normalization'],
        epochs=epochs_for_grid_search, # Using the potentially reduced epochs for grid search
        device=device,
        patience=patience_for_grid_search
    )


    all_results.append({
        'trial': i + 1,
        'params': current_params,
        'val_accuracy': val_accuracy
    })

    if val_accuracy > best_accuracy:
        best_accuracy = val_accuracy
        best_params = current_params

    print(f"Trial {i+1} completed. Validation Accuracy: {val_accuracy:.4f}")

print("\n--- Random Grid Search Summary ---")
print(f"Best Validation Accuracy: {best_accuracy:.4f}")
print(f"Best Parameters: {best_params}")

print("\nAll trial results:")
for res in all_results:
    print(f"Trial {res['trial']}: Acc={res['val_accuracy']:.4f}, Params={res['params']}")

"""

In [None]:
# WAVE 2 VEC 2 MLP best Random Search

import torch
import torchcodec

# Wav2Vec2 model parameters
model_name = "facebook/wav2vec2-large"
cache_dir = "./hf_cache"
num_labels = 6  # Number of emotions, corrected from 8 to 9
# Wav2Vec2
freeze_feature_extractor = True
number_unfrozen = 9
backbone_learning_rate = 1e-5
backbone_weight_decay = 0.0001

# Classifier head
hidden_dim1 = 256
hidden_dim2 = 64
dropout = 0.2
classifier_learning_rate = .003
classifier_weight_decay = 0.0

# Training
epochs = 15
batch_size = 16
device = 'cuda' if torch.cuda.is_available() else 'cpu'
normalization = False

# Schedulers
backbone_schedule = "cosine"
classifier_schedule = "cosine"

In [None]:
import torch
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
import torch.nn.functional as F
import numpy as np


def smooth_inverse_frequency(class_freq, alpha=1.3):

    smoothed = 1.0 / torch.log(alpha + class_freq)
    smoothed = smoothed / smoothed.mean()
    return smoothed


class Wav2Vec2EmbeddingExtractor:
    def __init__(self, model_name="facebook/wav2vec2-large", cache_dir="./hf_cache"):
        self.processor = Wav2Vec2FeatureExtractor.from_pretrained(model_name, cache_dir=cache_dir)
        self.model = Wav2Vec2Model.from_pretrained(model_name, cache_dir=cache_dir)

    def get_embedding(self, waveform, sampling_rate):

        self.model.eval()
        inputs = self.processor(waveform, sampling_rate=sampling_rate,
                                return_tensors="pt", padding=True)
        with torch.no_grad():
            hidden = self.model(**inputs).last_hidden_state
        return hidden.mean(dim=1).squeeze(0)

    def freezing_layers(self, number_unfrozen):
        for p in self.model.parameters():
            p.requires_grad = False

        total_layers = len(self.model.encoder.layers)
        number_unfrozen = min(number_unfrozen, total_layers)

        if number_unfrozen > 0:
            for layer in self.model.encoder.layers[-number_unfrozen:]:
                for p in layer.parameters():
                    p.requires_grad = True

    def evaluate_model(self, classification_model, dataloader, device):
        self.model.eval()
        classification_model.eval()
        total_major_correct = 0
        total_top2_correct = 0
        total_samples = 0

        with torch.no_grad():
            for input_values, labels in dataloader:
                input_values = input_values.to(device)
                labels = labels.to(device)

                hidden_states = self.model(input_values).last_hidden_state
                batch_embeddings = hidden_states.mean(dim=1)
                logits = classification_model(batch_embeddings)

                batch_size = input_values.size(0)
                pred_major = torch.argmax(logits, dim=1)
                true_major = torch.argmax(labels, dim=1)

                total_major_correct += (pred_major == true_major).sum().item()

                top2_preds = torch.topk(logits, k=2, dim=1).indices
                total_top2_correct += sum([
                    true_major[i].item() in top2_preds[i].tolist()
                    for i in range(batch_size)
                ])

                total_samples += batch_size

        return total_major_correct / total_samples, total_top2_correct / total_samples

    def get_predictions(self, classification_model, dataloader, device):
        self.model.eval()
        classification_model.eval()

        true_distributions = []
        predicted_distributions = []

        with torch.no_grad():
            for input_values, labels in dataloader:
                input_values = input_values.to(device)
                labels = labels.to(device)

                hidden_states = self.model(input_values).last_hidden_state
                batch_embeddings = hidden_states.mean(dim=1)

                logits = classification_model(batch_embeddings)
                probs = F.softmax(logits, dim=1)

                predicted_distributions.extend(probs.cpu().tolist())
                true_distributions.extend(labels.cpu().tolist())

        return {
            "true_distributions": np.array(true_distributions),
            "predicted_distributions": np.array(predicted_distributions),
        }

    def train_model(
        self,
        classification_model,
        train_dataloader,
        val_dataloader=None,
        backbone_lr=2e-5,
        classifier_lr=1e-4,
        backbone_weight_decay=0.01,
        classifier_weight_decay=0.0,
        backbone_schedule=None,
        classifier_schedule=None,
        number_unfrozen=0,
        normalization=False,
        epochs=10,
        device="cpu",
        patience=3,
        return_info=False
    ):
        self.model.to(device)
        classification_model.to(device)

        self.freezing_layers(number_unfrozen)



        num_classes = train_dataloader.dataset.output_classes.__len__()
        class_counts = torch.zeros(num_classes)

        for _, labels in train_dataloader:
            class_counts += labels.sum(dim=0)

        class_freq = class_counts / class_counts.sum()
        class_freq = class_freq.to(device)


        class_weights = smooth_inverse_frequency(class_freq).to(device)
        class_weights = class_weights.unsqueeze(0)

        print("Class frequencies:", class_freq)
        print("Smoothed class weights:", class_weights.squeeze(0))

        # Base KL loss
        criterion = torch.nn.KLDivLoss(reduction="batchmean")

        optimizer = torch.optim.AdamW([
            {"params": filter(lambda p: p.requires_grad, self.model.parameters()),
             "lr": backbone_lr, "weight_decay": backbone_weight_decay},
            {"params": filter(lambda p: p.requires_grad, classification_model.parameters()),
             "lr": classifier_lr, "weight_decay": classifier_weight_decay}
        ])

        # Learning rate schedulers
        schedulers = []
        if backbone_schedule in ["cosine", "linear"] or classifier_schedule in ["cosine", "linear"]:
            if backbone_schedule == "cosine" or classifier_schedule == "cosine":
                schedulers.append(torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs))
            elif backbone_schedule == "linear" or classifier_schedule == "linear":
                schedulers.append(torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5))

        best_val_major_acc = -1.0
        epochs_no_improve = 0

        # Storage lists for metrics
        train_losses = []
        train_major_accuracies = []
        train_top2_accuracies = []
        val_major_accuracies = []
        val_top2_accuracies = []

        for epoch in range(epochs):
            self.model.train()
            classification_model.train()

            epoch_loss = 0.0
            total_major_correct = 0
            total_top2_correct = 0
            total_samples = 0

            for input_values, labels in train_dataloader:
                input_values = input_values.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                # Freeze/unfreeze handling
                if number_unfrozen == 0:
                    with torch.no_grad():
                        hidden_states = self.model(input_values).last_hidden_state
                else:
                    hidden_states = self.model(input_values).last_hidden_state

                batch_embeddings = hidden_states.mean(dim=1)

                if normalization:
                    batch_embeddings = (
                        (batch_embeddings - batch_embeddings.mean(dim=1, keepdim=True)) /
                        (batch_embeddings.std(dim=1, keepdim=True) + 1e-7)
                    )

                logits = classification_model(batch_embeddings)
                log_probs = torch.log_softmax(logits, dim=1)

                # Apply class weights to label distributions
                weighted_labels = labels * class_weights

                loss = criterion(log_probs, weighted_labels)
                loss.backward()
                optimizer.step()

                batch_size_actual = input_values.size(0)
                epoch_loss += loss.item() * batch_size_actual

                pred_major = torch.argmax(logits, dim=1)
                true_major = torch.argmax(labels, dim=1)
                total_major_correct += (pred_major == true_major).sum().item()

                top2_preds = torch.topk(logits, k=2, dim=1).indices
                total_top2_correct += sum([
                    true_major[i].item() in top2_preds[i].tolist()
                    for i in range(batch_size_actual)
                ])

                total_samples += batch_size_actual

            # Scheduler update
            for scheduler in schedulers:
                scheduler.step()

            avg_loss = epoch_loss / total_samples
            train_major_acc = total_major_correct / total_samples
            train_top2_acc = total_top2_correct / total_samples

            train_losses.append(avg_loss)
            train_major_accuracies.append(train_major_acc)
            train_top2_accuracies.append(train_top2_acc)

            print(f"Epoch {epoch+1}/{epochs} | Loss={avg_loss:.4f} | "
                  f"Major={train_major_acc:.4f} | Top-2={train_top2_acc:.4f}")

            # Validation
            if val_dataloader is not None:
                val_major_acc, val_top2_acc = self.evaluate_model(
                    classification_model, val_dataloader, device
                )

                val_major_accuracies.append(val_major_acc)
                val_top2_accuracies.append(val_top2_acc)

                if val_major_acc > best_val_major_acc:
                    best_val_major_acc = val_major_acc
                    epochs_no_improve = 0
                else:
                    epochs_no_improve += 1

                if epochs_no_improve >= patience:
                    break

        if return_info:
            return (
                classification_model,
                best_val_major_acc,
                train_losses,
                train_major_accuracies,
                train_top2_accuracies,
                val_major_accuracies,
                val_top2_accuracies
            )
        else:
            return classification_model, best_val_major_acc


In [None]:
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import numpy as np


classifier_model = EmotionClassifierMLP(
    emb_dim=1024,
    num_classes=num_labels,
    hidden_dim1=hidden_dim1,
    hidden_dim2=hidden_dim2,
    dropout=dropout
)


wav2vec2_trainer = Wav2Vec2EmbeddingExtractor(
    model_name=model_name,
    cache_dir=cache_dir
)

print("Starting model training...")


print_graphs = True

trained_classifier_model, best_val_major_acc, train_losses, train_major_accuracies, train_top2_accuracies, val_major_accuracies, val_top2_accuracies = wav2vec2_trainer.train_model(
    classification_model=classifier_model,
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    backbone_lr=backbone_learning_rate,
    classifier_lr=classifier_learning_rate,
    backbone_weight_decay=backbone_weight_decay,
    classifier_weight_decay=classifier_weight_decay,
    backbone_schedule=backbone_schedule,
    classifier_schedule=classifier_schedule,
    number_unfrozen=number_unfrozen,
    normalization=normalization,
    epochs=epochs,
    device=device,
    patience=epochs,
    return_info=True
)

print("Model training complete.")
print(f"Best validation major accuracy: {best_val_major_acc:.4f}")

# Get predictions for confusion matrix
prediction_and_true = wav2vec2_trainer.get_predictions(trained_classifier_model, val_loader, device)
prediction_and_true_train = wav2vec2_trainer.get_predictions(trained_classifier_model, train_loader, device)



if print_graphs:
    epochs_ran = len(train_losses)
    plt.figure(figsize=(18, 6))

    # Plot Training Loss
    plt.subplot(1, 3, 1)
    plt.plot(range(1, epochs_ran + 1), train_losses, label='Train Loss')
    plt.title('Loss over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    # Plot Training & Validation Major Accuracy
    plt.subplot(1, 3, 2)
    plt.plot(range(1, epochs_ran + 1), train_major_accuracies, label='Train Major Accuracy')
    if val_major_accuracies:
        plt.plot(range(1, epochs_ran + 1), val_major_accuracies, label='Val Major Accuracy')
    plt.title('Major Accuracy over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)

    # Plot Training & Validation Top-2 Accuracy
    plt.subplot(1, 3, 3)
    plt.plot(range(1, epochs_ran + 1), train_top2_accuracies, label='Train Top-2 Accuracy')
    if val_top2_accuracies:
        plt.plot(range(1, epochs_ran + 1), val_top2_accuracies, label='Val Top-2 Accuracy')
    plt.title('Top-2 Accuracy over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

true_distributions = prediction_and_true["true_distributions"]
pred_distributions = prediction_and_true["predicted_distributions"]

true_labels = np.argmax(true_distributions, axis=1)
predicted_labels = np.argmax(pred_distributions, axis=1)


emotion_labels = [
    "angry", "happy", "sad", "neutral", "fear_surprise", "disgust"
]


cm = confusion_matrix(true_labels, predicted_labels)

disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=emotion_labels)
disp.plot(cmap=plt.cm.Blues, xticks_rotation=45)
plt.title("Confusion Matrix with Emotion Labels")
plt.show()


In [None]:
# Get top-2 predicted classes for each sample
top2_preds = np.argsort(pred_distributions, axis=1)[:, -2:]

pred_labels_for_cm = [
    true_labels[i] if true_labels[i] in top2_preds[i] else np.argmax(pred_distributions[i])
    for i in range(len(true_labels))
]

# Compute confusion matrix
cm = confusion_matrix(true_labels, pred_labels_for_cm)


disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=emotion_labels)
disp.plot(cmap=plt.cm.Blues, xticks_rotation=45)
plt.title("Top-2 Accuracy Confusion Matrix")
plt.show()


In [None]:
# note idea grabbed from https://library.virginia.edu/data/articles/correlation-pearson-spearman-and-kendalls-tau
import numpy as np
from scipy.stats import kendalltau
import matplotlib.pyplot as plt

threshold = 0.2
taus = []
print(f"{len(true_distributions)}")
for i in range(len(true_distributions)):
    overthreshold_indexes = np.where(true_distributions[i] > threshold)[0]

    if len(overthreshold_indexes) < 2:
        continue
    true_vals = true_distributions[i][overthreshold_indexes]
    pred_vals = pred_distributions[i][overthreshold_indexes]

    tau, _ = kendalltau(true_vals, pred_vals)
    taus.append(tau)


plt.figure(figsize=(8,5))
plt.hist(taus, bins=20, color='skyblue', edgecolor='black')
plt.xlabel("Kendall Tau Rank Correlation")
plt.ylabel("Number of samples")
plt.title("Rank Correlation for Significant Emotions")
plt.show()

print(f"Average Kendall Tau: {np.mean(taus):.3f}")
print(f"Std Kendall Tau: {np.std(taus):.3f}")


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Threshold for significant emotions
threshold = 0.20

num_emotions = true_distributions.shape[1]
emotion_labels = ["angry", "happy", "sad", "neutral", "fear_surprise", "disgust"]

exact_matches = []
missed_counts = np.zeros(num_emotions)
extra_counts = np.zeros(num_emotions)

for i in range(len(true_distributions)):
    if i == 200:
      print (true_distributions[i])
      print (pred_distributions[i])
    true_overthreshold = set(np.where(true_distributions[i] > threshold)[0])
    pred_overthreshold = set(np.where(pred_distributions[i] > threshold)[0])

    exact_matches.append(true_overthreshold == pred_overthreshold)

    missed = true_overthreshold - pred_overthreshold
    extra = pred_overthreshold - true_overthreshold

    for idx in missed:
        missed_counts[idx] += 1
    for idx in extra:
        extra_counts[idx] += 1


exact_matches = np.array(exact_matches)


exact_match_percent = exact_matches.mean() * 100
print(f"Exact set match (threshold={threshold}): {exact_match_percent:.2f}%")

plt.figure(figsize=(10,5))
sns.barplot(x=emotion_labels, y=missed_counts, color="red", label="Missed")
sns.barplot(x=emotion_labels, y=extra_counts, color="blue", alpha=0.5, label="Extra")
plt.ylabel("Count")
plt.title("Missed vs Extra Predicted Emotions per Class")
plt.legend()
plt.show()
