In [1]:
import torch
import os
import numpy as np
import pandas as pd
import json

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import Sampler

import tqdm.notebook as tqdm

from sklearn.metrics import confusion_matrix, accuracy_score

from transforms import *
from loss_functions import *
from datasets import *
from models import *
from torchvision.transforms import Compose
from clustering_metrics import *

In [2]:
device = torch.device('cpu')
use_gpu = False
if torch.cuda.is_available():
    use_gpu = True
    device = torch.device('cuda', 0)
    
print(type(device), device)

# may benefit if network size/input/output is stable
if use_gpu:
    torch.backends.cudnn.benchmark = True

<class 'torch.device'> cuda:0


In [3]:
def create_model(model_description):
        if 'name' not in model_description:
                return '[ERROR]: corrupted model description'

        if model_description['name'] == 'DSCNN':
                n_mels = model_description['n_mels']
                in_shape = (n_mels, 32)
                in_channels = model_description['in_channels']
                ds_cnn_number = model_description['ds_cnn_number']
                ds_cnn_size = model_description['ds_cnn_size']
                is_classifier = model_description['is_classifier']
                classes_number = 0 if not is_classifier else model_description['classes_number']

                return DSCNN(in_channels, in_shape, ds_cnn_number, ds_cnn_size, is_classifier, classes_number)

In [4]:
class ClassifierClosestKnown():
    def __init__(self, embedding_block, embedding_size, device):
        self.device = device
        self.embedding_block = embedding_block
        self.embedding_size = embedding_size
        self.freeze(self.embedding_block)

        self.class_ids = torch.empty((0, ), dtype=torch.int32)
        self.class_names = []
        self.class_embeddings = torch.empty((0, self.embedding_size), dtype=torch.float32).to(self.device)

        self.current_id = 0
    
    def freeze(self, block) -> None:
        for p in block.parameters():
            p.requires_grad = False
        block.eval()

    def add_class(self, input, label):
        embed = self.embedding_block(input.to(self.device))
        mean_embed = torch.mean(embed, dim=0).unsqueeze(0)

        self.class_embeddings = torch.cat((self.class_embeddings, mean_embed), dim=0)
        self.class_names.append(label)
        self.class_ids = torch.cat((self.class_ids, torch.tensor([self.current_id], dtype=torch.int32)))
        self.current_id += 1
    
    def get_classes(self):
        return self.class_names

    def classify(self, inputs):
        input_embed = self.embedding_block(inputs.to(self.device))
        distances = torch.cdist(input_embed, self.class_embeddings)  # Compute distances between input embeddings and class embeddings

        closest_ids = torch.argmin(distances, dim=1).cpu().tolist()  # Find the index of the closest known embedding for each input
        labels = [self.class_names[id] for id in closest_ids]  # Get the corresponding class labels

        return labels

In [5]:
def form_class_batch(dataset, class_indices, samples_number, class_idx):
    indexes = np.random.choice(class_indices[class_idx], samples_number, replace=False)

    batch = []

    for i in indexes:
        item = dataset.__getitem__(i)
        batch.append(item['input'])
    
    batch = torch.stack(batch, dim=0)
    return batch

def eval_experiment_closest_known(experiments_folder, experiment_name, epochs_to_test=[], fsl_examples=5, random_seed=42):
    experiment_folder = os.path.join(experiments_folder, experiment_name)

    experiment_settings_path = os.path.join(experiment_folder, "experiment_settings.json")
    experiment_stats_path = os.path.join(experiment_folder, "stats.json")

    experiment_stats = {}
    if os.path.isfile(experiment_stats_path):
        with open(experiment_stats_path, 'r') as fp:
            experiment_stats = json.load(fp)

    with open(experiment_settings_path, 'r') as fp:
        experiment_settings = json.load(fp)


    ''' create model '''
    experiment_settings['model']['is_classifier'] = False
    model = create_model(experiment_settings['model'])

    n_mels = experiment_settings['model']['n_mels']
    embedding_size = experiment_settings['model']['ds_cnn_size']

    ''' prepare datasets '''
    test_dataset_path = 'datasets/speech_commands/test'

    batch_size = 128

    feature_transform = Compose([ToSTFT(), ToMelSpectrogramFromSTFT(n_mels=n_mels), ToTensor('mel_spectrogram', 'input')])


    test_dataset = SpeechCommandsDataset(test_dataset_path,
                                    Compose([LoadAudio(),
                                            FixAudioLength(),
                                            feature_transform]))

    dl_test = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=16, prefetch_factor=2)
    test_indices = test_dataset.get_class_indices()
    test_classes = test_dataset.classes

    ''' Prepare train data '''
    np.random.seed(random_seed)
    n_samples = fsl_examples
    classifier_train_batches = []
    classifier_train_labels = []
        
    for class_name in test_classes:
        class_idx = test_dataset.get_idx_from_class(class_name)
        classifier_train_batches.append(form_class_batch(test_dataset, test_indices, n_samples, class_idx))
        classifier_train_labels.append(class_name)

    all_class_names = test_classes
    class_name_to_idx = {class_name: i for i, class_name in enumerate(all_class_names)} 

    ''' get epochs to test '''
    if not epochs_to_test:
        for key in experiment_stats['clustering_metrics']:
            epochs_to_test.append(experiment_stats['clustering_metrics'][key]['best_train_epoch'])
            epochs_to_test.append(experiment_stats['clustering_metrics'][key]['best_valid_epoch'])
            epochs_to_test.append(experiment_stats['clustering_metrics'][key]['best_test_epoch'])
        
        epochs_to_test.append(experiment_stats['loss']['best_train_epoch'])

        if ('best_valid_epoch' in experiment_stats['loss'].keys()) and experiment_stats['loss']['best_valid_epoch']:
            epochs_to_test.append(experiment_stats['loss']['best_valid_epoch'])

    epochs_to_test = set(epochs_to_test)
    print(f"epochs to test = {epochs_to_test}")

    ''' actual testing '''
    print(f"Epochs to test: {epochs_to_test}")

    for epoch in epochs_to_test:
        print(f"{epoch} ", end="")
        checkpoint_fname = os.path.join(experiment_folder, 'checkpoints', f'checkpoint_{epoch}')

        checkpoint = torch.load(checkpoint_fname)
        state_dict = checkpoint['state_dict']
        cur_epoch = checkpoint['epoch']
        remove_prefix = 'module.'
        state_dict = {k[len(remove_prefix):] if k.startswith(remove_prefix) else k: v for k, v in state_dict.items()}

        model.load_state_dict(state_dict, strict=False)
        model.to(device)
        model.eval()
        model.freeze()

        em_classifier = ClassifierClosestKnown(model, embedding_size, device)

        ''' Actual Training of classifier'''
        for i, batch in enumerate(classifier_train_batches):
            batch = torch.unsqueeze(batch, 1)
            em_classifier.add_class(batch, classifier_train_labels[i])

        ''' test classifier on test'''
        test_true_labels = []
        test_predictions = []
        with torch.no_grad():
            for batch in dl_test:
                input = torch.unsqueeze(batch['input'].to(device), 1)

                target = [test_dataset.get_class_from_idx(item.item()) for item in batch['target']]
                target_idx = [class_name_to_idx[t] for t in target]

                prediction = em_classifier.classify(input)
                prediction_idx = [class_name_to_idx[p] for p in prediction]

                test_predictions += prediction_idx
                test_true_labels += target_idx
        
        test_accuracy = accuracy_score(test_true_labels, test_predictions)

        fsl_key = f'closest_known_epoch{epoch}_shots{fsl_examples}'

        if 'fsl_only_test' not in experiment_stats.keys():
            experiment_stats['fsl_only_test'] = {}

        if fsl_key in experiment_stats['fsl_only_test'].keys():
            if not isinstance(experiment_stats['fsl_only_test'][fsl_key]['accuracy'], list):
                experiment_stats['fsl_only_test'][fsl_key]['accuracy'] = [experiment_stats['fsl_only_test'][fsl_key]['accuracy']]
            experiment_stats['fsl_only_test'][fsl_key]['accuracy'].append(test_accuracy)
        else:
            experiment_stats['fsl_only_test'][fsl_key] = {
                "epoch": epoch,
                "shots": fsl_examples,
                "accuracy": test_accuracy
            }
    print()
    with open(experiment_stats_path, "w") as fp:
        json.dump(experiment_stats, fp)

In [6]:
import os
experiments_folder = './experiments'

items = os.listdir(experiments_folder)
to_do_list = [item for item in items if os.path.isdir(os.path.join(experiments_folder, item))]
to_do_list = sorted(to_do_list)

In [7]:
new_to_list = []

for exp_name in to_do_list:
    if 'test' in exp_name:
        new_to_list.append(exp_name)

to_do_list = new_to_list
to_do_list = ['lifted_structured_test', 'triplet_br_test']
to_do_list

['lifted_structured_test', 'triplet_br_test']

In [8]:
np.random.seed(42)
seeds = np.random.randint(0, 1000000, 10)


for experiment_name in to_do_list:
    print(f"-----{experiment_name}-----")
    experiment_folder = os.path.join(experiments_folder, experiment_name)

    for i in tqdm.tqdm(range(10)):
        print(f"seed = {seeds[i]}:")
        eval_experiment_closest_known(experiments_folder, experiment_name, epochs_to_test=[], fsl_examples=5, random_seed=seeds[i])

-----lifted_structured_test-----


  0%|          | 0/10 [00:00<?, ?it/s]

seed = 121958:
epochs to test = {198, 199, 40, 105, 185, 60, 190}
Epochs to test: {198, 199, 40, 105, 185, 60, 190}
198 199 40 105 185 60 190 
seed = 671155:
epochs to test = {198, 199, 40, 105, 185, 60, 190}
Epochs to test: {198, 199, 40, 105, 185, 60, 190}
198 199 40 105 185 60 190 
seed = 131932:
epochs to test = {198, 199, 40, 105, 185, 60, 190}
Epochs to test: {198, 199, 40, 105, 185, 60, 190}
198 199 40 105 185 60 190 
seed = 365838:
epochs to test = {198, 199, 40, 105, 185, 60, 190}
Epochs to test: {198, 199, 40, 105, 185, 60, 190}
198 199 40 105 185 60 190 
seed = 259178:
epochs to test = {198, 199, 40, 105, 185, 60, 190}
Epochs to test: {198, 199, 40, 105, 185, 60, 190}
198 199 40 105 185 60 190 
seed = 644167:
epochs to test = {198, 199, 40, 105, 185, 60, 190}
Epochs to test: {198, 199, 40, 105, 185, 60, 190}
198 199 40 105 185 60 190 
seed = 110268:
epochs to test = {198, 199, 40, 105, 185, 60, 190}
Epochs to test: {198, 199, 40, 105, 185, 60, 190}
198 199 40 105 185 60 190 

  0%|          | 0/10 [00:00<?, ?it/s]

seed = 121958:
epochs to test = {35, 195, 175, 180, 155, 60, 190}
Epochs to test: {35, 195, 175, 180, 155, 60, 190}
35 195 175 180 155 60 190 
seed = 671155:
epochs to test = {35, 195, 175, 180, 155, 60, 190}
Epochs to test: {35, 195, 175, 180, 155, 60, 190}
35 195 175 180 155 60 190 
seed = 131932:
epochs to test = {35, 195, 175, 180, 155, 60, 190}
Epochs to test: {35, 195, 175, 180, 155, 60, 190}
35 195 175 180 155 60 190 
seed = 365838:
epochs to test = {35, 195, 175, 180, 155, 60, 190}
Epochs to test: {35, 195, 175, 180, 155, 60, 190}
35 195 175 180 155 60 190 
seed = 259178:
epochs to test = {35, 195, 175, 180, 155, 60, 190}
Epochs to test: {35, 195, 175, 180, 155, 60, 190}
35 195 175 180 155 60 190 
seed = 644167:
epochs to test = {35, 195, 175, 180, 155, 60, 190}
Epochs to test: {35, 195, 175, 180, 155, 60, 190}
35 195 175 180 155 60 190 
seed = 110268:
epochs to test = {35, 195, 175, 180, 155, 60, 190}
Epochs to test: {35, 195, 175, 180, 155, 60, 190}
35 195 175 180 155 60 190 