# test hubert few-shot ability

In [None]:
import os, sys
import random
# change working directory to project root
os.chdir("/home/yrb/code/MusicAudioPretrain/")

import torch
import numpy as np
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, confusion_matrix
from tqdm import tqdm
import librosa, librosa.display
import scipy, matplotlib.pyplot as plt

from benchmark.GTZAN.GTZAN_dataset import FeatureDataset as GTZAN_FeatureDataset
from benchmark.GS.GS_dataset import FeatureDataset as GS_FeatureDataset

random.seed(1234)
feature_dir = "data/GTZAN/hubert_features/HF_model_HuBERT_base_MPD_train_1Kh_valid_300h_iter1_250k_vanilla_model_ncluster_500_feature_layer_all_reduce_mean"
metadata_dir = "data/GTZAN"
layer = 'all'
train_dataset = GTZAN_FeatureDataset(feature_dir, metadata_dir, 'train', layer, return_audio_path=True)
valid_dataset = GTZAN_FeatureDataset(feature_dir, metadata_dir, 'valid', layer, return_audio_path=True)
test_dataset = GTZAN_FeatureDataset(feature_dir, metadata_dir, 'test', layer, return_audio_path=True)

In [None]:
def normalize(x):
    if not isinstance(x, torch.Tensor):
        x = torch.tensor(x)
    if x.ndim == 1:
        x = x.unsqueeze(0)
    x /= x.norm(dim=-1, keepdim=True)
    return x

In [None]:
# structure dataset
dataset = train_dataset
class_data = dict()

for i in range(len(dataset)):
    feature, label, audio_path = dataset[i]
    if feature.shape[0] != 768:  # cat all layers to one vector
        feature = feature.reshape(-1)
    if label not in class_data:
        class_data[label] = []
    class_data[label].append(feature)

In [None]:
def get_class_centroids(class_data, num_shots):
    class_centroids = dict()
    for label, features in class_data.items():
        features = torch.tensor(np.array(features))
        # sample num_shots features
        if features.shape[0] > num_shots:
            features = features[random.sample(range(features.shape[0]), num_shots)]
        features /= features.norm(dim=-1, keepdim=True)
        class_centroids[label] = features.mean(dim=0, keepdim=True)
    return class_centroids


all_acc = []
repeat_times = 100
num_shots = 10
# repeat the experiment for many times, since different centroid initialization will lead to different results
for _ in tqdm(range(repeat_times)):
    # compute class centroids
    class_centroids = get_class_centroids(class_data, num_shots)
    class_centroids = torch.cat([class_centroids[i] for i in range(10)])
    class_centroids = normalize(class_centroids)
    results, labels, paths = [], [], []
    for feature, label, audio_path in test_dataset:
        if feature.shape[0] != 768:  # cat all layers to one vector
            feature = feature.reshape(-1)
        feature = normalize(feature)
        probs = (feature @ class_centroids.T).softmax(dim=-1)
        top_prob, top_label = probs.topk(1, dim=-1)
        top_label = top_label.item()
        results.append(top_label)
        labels.append(label)
        paths.append(audio_path)
    all_acc.append(accuracy_score(labels, results))
all_acc = np.array(all_acc)
print(f"Accuracy: {all_acc.mean():.4f} +- {all_acc.std():.4f}")

# bad case analysis

In [None]:
for i, pred in enumerate(results):
    if pred != labels[i]:
        print(f"{paths[i]}: pred = {test_dataset.id2class[pred]}, label = {test_dataset.id2class[labels[i]]}")