## Get aligned MusicGen predictions

In [13]:
%cd /home/DAVIDSON/dutuller/Workspace/DRI1/MusicGen/

from sklearn.metrics import f1_score, recall_score, precision_score, confusion_matrix
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
import yaml
from embeddings.h5_processor import H5DataProcessor, DatasetConfig, ProcessedDataset
import pandas as pd
import re

/home/DAVIDSON/dutuller/Workspace/DRI1/MusicGen


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


Infrastructure for loading and splitting the embedding data from storage.

In [14]:
with open("universal_music/NHS_full.yaml", 'r') as f:
    config = yaml.safe_load(f)

# Process datasets using H5DataProcessor
processor = H5DataProcessor()
all_train_data = []
all_test_data = []
class_names = set()

# Process each dataset and split
for dataset_config in config['datasets']:
    dataset = processor.process_h5_file(
        processor.get_embedding_path(DatasetConfig(**dataset_config)),
        DatasetConfig(**dataset_config)
    )
    
    # Split the dataset
    train_data, test_data = processor.get_train_test_split(
        dataset, 
        test_ratio=0.2,
        random_seed=42
    )
    
    all_train_data.append(train_data)
    all_test_data.append(test_data)
    class_names.update(dataset.labels)

# Combine datasets
train_data = ProcessedDataset(
    embeddings=np.vstack([d.embeddings for d in all_train_data]),
    labels=[l for d in all_train_data for l in d.labels],
    filenames=[f for d in all_train_data for f in d.filenames],
    name="combined",
    num_samples=sum(d.num_samples for d in all_train_data)
)

test_data = ProcessedDataset(
    embeddings=np.vstack([d.embeddings for d in all_test_data]),
    labels=[l for d in all_test_data for l in d.labels],
    filenames=[f for d in all_test_data for f in d.filenames],
    name="combined",
    num_samples=sum(d.num_samples for d in all_test_data)
)

# Create and configure model
model = LogisticRegression(max_iter=1000)

Train the classifier on the full songs from the train set.

In [15]:
# Create label mapping for string class labels
unique_labels = sorted(set(train_data.labels + test_data.labels))
label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}

# Convert labels to indices
X_train = train_data.embeddings
y_train = np.array([label_to_idx[label] for label in train_data.labels])

# Scale features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)

# Train model
model.fit(X_train_scaled, y_train);

Identify which song ids are present in the test set, so that we can find them among the sample audio.

In [16]:
test_song_ids = [int(re.search(r"Discography-(\d+)_\d+.wav", filename).group(1)) for filename in test_data.filenames]
test_unique_song_ids = np.unique(test_song_ids)

Now load in the audio samples that the human survey participants actually listened to

In [17]:
# Get the embeddings of sample clips (14s)
with open("universal_music/NHS_samples.yaml", 'r') as f:
    config = yaml.safe_load(f)


dataset_config = DatasetConfig(**config['datasets'][0])
embedding_filename = processor.get_embedding_path(dataset_config)
dataset = processor.process_h5_file(embedding_filename, dataset_config)

# Select the ones that align with the test set above
sample_filenames = dataset.filenames
sample_embeddings = dataset.embeddings

Search for the songs selected to be in the test set from the sample audio dataset and prepare for model eval

In [18]:
sample_test_embeddings = []
sample_song_ids = []
sample_test_labels = []

for i, filename in enumerate(sample_filenames):
    id = int(re.search(r"NAIV-(\d+).wav", filename).group(1))
    if id in test_unique_song_ids:
        sample_test_embeddings.append(sample_embeddings[i])
        sample_song_ids.append(id)

X_test_sample = np.array(sample_test_embeddings)
X_test_scaled = scaler.transform(X_test_sample)

Get the ground truth function of the full test audio, then aggregate so that there is one label per song. This part will not be necessary when comparing directly to human ratings

In [19]:
def find_id(filename):
    return int(re.search(r"Discography-(\d+)_\d+.wav", filename).group(1))

test_data_info = pd.DataFrame(data=zip(test_data.labels, test_data.filenames), columns=['labels','filenames'])
test_data_info['id'] = test_data_info.filenames.apply(find_id)
labels = test_data_info.groupby('id').first()

y_true_sample = np.array([label_to_idx[label] for label in labels.labels])

In [20]:
y_pred_sample = model.predict(X_test_scaled)
smaple_pred_lookup = dict(zip(sample_song_ids, y_pred_sample))

Load the human ratings, filter to song ids present in the test set.

In [None]:
df = pd.read_csv('universal_music/FFfull.csv', low_memory=False)
web_df = df[df['study'] == 'web'].copy()

# Add 'predictions' from each of the policies

web_df = web_df[web_df['song'].isin(sample_song_ids)]

In [22]:
# Let's just apply the tiebreaking policies in here
def policy_argmax_per_row(arr, true_labels, policy='random'):
    """
    Vectorized implementation for finding argmax with specific tie-breaking policies:
    - 'random': randomly select from ties
    - 'generous': select the index that matches true_labels if possible (best case)
    - 'strict': select the index that differs from true_labels if possible (worst case)
    
    Args:
        arr: array of shape (n_samples, n_classes)
        true_labels: array of true labels of shape (n_samples,)
        policy: tie-breaking policy ('random', 'generous', or 'strict')
    """
    # Get mask of all max values
    max_vals = np.max(arr, axis=1, keepdims=True)
    mask = (arr == max_vals)
    
    if policy == 'random':
        # Original efficient random method
        random_values = np.random.random(arr.shape) * mask
        return np.argmax(random_values, axis=1)
    
    # Create a range array for comparing with true_labels
    row_indices = np.arange(arr.shape[0])
    
    if policy == 'generous':
        # For "best case" - prioritize the true label when it's among the max values
        
        # Check if true label is among the max values
        true_label_is_max = mask[row_indices, true_labels]
        
        # Where true label is max, use it; otherwise use random tie-breaking
        result = np.zeros(arr.shape[0], dtype=int)
        
        # For rows where true label is max, use the true label
        result[true_label_is_max] = true_labels[true_label_is_max]
        
        # For rows where true label is not max, use random tie-breaking
        non_match_rows = ~true_label_is_max
        if np.any(non_match_rows):
            # Create random values just for these rows
            random_subset = np.random.random(arr[non_match_rows].shape) * mask[non_match_rows]
            result[non_match_rows] = np.argmax(random_subset, axis=1)
        
        return result
        
    elif policy == 'strict':
        # For "worst case" - avoid the true label when other max values exist
        
        # Create a penalty matrix - make true labels less favorable
        penalty = np.zeros(arr.shape)
        penalty[row_indices, true_labels] = 1
        
        # Apply penalty only to elements that are max
        masked_penalty = penalty * mask
        
        # When choosing argmax with penalty, true labels will only be chosen
        # if they're the only max value
        random_values = np.random.random(arr.shape) * 0.1  # Small random values for secondary tie-breaking
        selection_values = mask * (1 - masked_penalty + random_values)
        
        return np.argmax(selection_values, axis=1)

In [23]:
# Two options: 
# 1. align the numpy 'ratings' files with song ids
# 2. get the ground truth preds to run policy argmax in this notebook
#     - still requires ground truth to be associated with the df in order to apply the policy argmax tho
y_true_sample = np.array([label_to_idx[label] for label in labels.labels])

Option 2

In [None]:
arr = web_df[["danc","heal","baby","love"]].copy().to_numpy()
web_df['generous'] = policy_argmax_per_row(arr, y_true_sample, policy='generous')

In [None]:
broadcasted_preds = many_df[id_column].map(id_to_pred)