# Calculate mean cluster radius and inter-cluster distances

In [2]:
from datetime import datetime
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.metrics import confusion_matrix, accuracy_score

import numpy as np
import torch

import librosa
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import Sampler

import os
import tqdm.notebook as tqdm

import json

import time
import datetime
import math

from transforms import *
from loss_functions import *
from datasets import *
from models import *
from torchvision.transforms import Compose
from clustering_metrics import *

train_dataset_path = 'datasets/speech_commands/train'
valid_dataset_path = 'datasets/speech_commands/validation'
test_dataset_path = 'datasets/speech_commands/test'

device = torch.device('cpu')
use_gpu = False
if torch.cuda.is_available():
        use_gpu = True
        device = torch.device('cuda', 0)

def create_model(model_description):
        if 'name' not in model_description:
                return '[ERROR]: corrupted model description'

        if model_description['name'] == 'DSCNN':
                n_mels = model_description['n_mels']
                in_shape = (n_mels, 32)
                in_channels = model_description['in_channels']
                ds_cnn_number = model_description['ds_cnn_number']
                ds_cnn_size = model_description['ds_cnn_size']
                is_classifier = model_description['is_classifier']
                classes_number = 0 if not is_classifier else model_description['classes_number']

                return DSCNN(in_channels, in_shape, ds_cnn_number, ds_cnn_size, is_classifier, classes_number)

In [13]:
from sklearn.metrics import silhouette_score
from sklearn.metrics import calinski_harabasz_score
from sklearn.metrics import davies_bouldin_score
from clustering_metrics import *

def get_outs(model, dl, device, dataset):
    with torch.no_grad():
        all_pred = []
        all_labels = []

        for batch in dl:
            images = batch['input'].to(device)
            images = torch.unsqueeze(images, 1)

            labels = batch['target'].to(device)

            net_out = model(images)

            all_pred += net_out.tolist()
            all_labels += labels.tolist()

    all_text_labels = []
    for label in all_labels:
        all_text_labels.append(dataset.get_class_from_idx(label))

    return all_pred, all_text_labels

def extract_number_from_filename(filename):
    try:
        return int(filename.split('_')[1])
    except (ValueError, IndexError):
        return float('inf')

def compute_mean_embeddings(all_embeds, all_labels):
    unique_labels = np.unique(all_labels)
    mean_embeds = {}
    cluster_radius = {}

    for label in unique_labels:
        mask = all_labels == label
        embeds_for_label = all_embeds[mask]

        mean_embed = np.mean(embeds_for_label, axis=0).tolist()
        mean_distance_to_mean_embed = np.mean(np.linalg.norm(embeds_for_label - mean_embed, axis=1))

        mean_embeds[str(label)] = mean_embed
        cluster_radius[str(label)] = float(mean_distance_to_mean_embed)

    return mean_embeds, cluster_radius

def experiment_clusters_prototypes(experiment_folder, batch_size, device):
    train_dataset_path = 'datasets/speech_commands/train'
    valid_dataset_path = 'datasets/speech_commands/validation'
    test_dataset_path = 'datasets/speech_commands/test'

    experiment_settings_path = os.path.join(experiment_folder, "experiment_settings.json")
    stats_path = os.path.join(experiment_folder, "stats.json")

    with open(experiment_settings_path, 'r') as fp:
        experiment_settings = json.load(fp)

    with open(stats_path, 'r') as fp:
        stats = json.load(fp)
    
    experiment_settings['model']['is_classifier'] = False
    model = create_model(experiment_settings['model'])
    model.to(device)
    model.eval()

    n_mels = experiment_settings['model']['n_mels']

    feature_transform = Compose([ToSTFT(), ToMelSpectrogramFromSTFT(n_mels=n_mels), ToTensor('mel_spectrogram', 'input')])

    train_dataset = SpeechCommandsDataset(train_dataset_path,
                                Compose([LoadAudio(),
                                        FixAudioLength(),
                                        feature_transform]))

    valid_dataset = SpeechCommandsDataset(valid_dataset_path,
                                    Compose([LoadAudio(),
                                            FixAudioLength(),
                                            feature_transform]))
    
    test_dataset = SpeechCommandsDataset(test_dataset_path,
                                    Compose([LoadAudio(),
                                            FixAudioLength(),
                                            feature_transform]))


    dl_train = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, prefetch_factor=2)
    dl_valid = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=8, prefetch_factor=2)
    dl_test = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=8, prefetch_factor=2)

    epochs = stats['clustering_metrics']['fc']['epoch']
    checkpoints_folder = os.path.join(experiment_folder, 'checkpoints')

    result = {
        "epoch": [],
        "mean_embed": [],
        "cluster_radius": []
    }

    for epoch in tqdm.tqdm(epochs):
        all_embeds = []
        all_labels = []
        checkpoint_fname = os.path.join(checkpoints_folder, f"checkpoint_{epoch}")

        checkpoint = torch.load(checkpoint_fname)
        state_dict = checkpoint['state_dict']
        cur_epoch = checkpoint['epoch']
        remove_prefix = 'module.'
        state_dict = {k[len(remove_prefix):] if k.startswith(remove_prefix) else k: v for k, v in state_dict.items()}

        model.load_state_dict(state_dict, strict=False)
        model.to(device)
        model.eval()

        
        embeds, labels = get_outs(model, dl_train, device, train_dataset)
        all_embeds += embeds
        all_labels += labels

        embeds, labels = get_outs(model, dl_valid, device, valid_dataset)
        all_embeds += embeds
        all_labels += labels

        embeds, labels = get_outs(model, dl_test, device, test_dataset)
        all_embeds += embeds
        all_labels += labels

        all_embeds = np.array(all_embeds)
        all_labels = np.array(all_labels)

        mean_embeds, cluster_radius = compute_mean_embeddings(all_embeds, all_labels)

        result["epoch"].append(epoch)
        result["mean_embed"].append(mean_embeds)
        result["cluster_radius"].append(cluster_radius)


    
    # load stats if exists and add loss
    embeds_fpath = os.path.join(experiment_folder, 'embeds.json')

    with open(embeds_fpath, "w") as fp:
            json.dump(result, fp)

    return result

In [14]:
experiments_dir = './experiments'

items = os.listdir(experiments_dir)
to_do_list = [item for item in items if os.path.isdir(os.path.join(experiments_dir, item))]
to_do_list = list(sorted(to_do_list))
to_do_list

['base_01',
 'base_test',
 'lifted_structured_01',
 'lifted_structured_02',
 'lifted_structured_03',
 'lifted_structured_test',
 'npair_01',
 'npair_02',
 'npair_03',
 'npair_test',
 'silhouette_01',
 'silhouette_margin_01',
 'triplet_br_01',
 'triplet_br_02',
 'triplet_br_03',
 'triplet_br_04',
 'triplet_br_05',
 'triplet_br_06',
 'triplet_br_07',
 'triplet_br_08',
 'triplet_br_09',
 'triplet_br_test']

In [15]:
for experiment_name in to_do_list:
    start_time = datetime.datetime.now()
    print(f"Start {experiment_name} -- {start_time}")

    experiment_folder = os.path.join(experiments_dir, experiment_name)
    result = experiment_clusters_prototypes(experiment_folder, 64, device)
    
    end_time = datetime.datetime.now()
    print(f"Finished {experiment_name} -- {end_time} -- {end_time - start_time}")
    

Start base_01 -- 2023-11-25 23:17:26.892422


  0%|          | 0/1 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f2708269510><function _MultiProcessingDataLoaderIter.__del__ at 0x7f2708269510>

Traceback (most recent call last):
Exception ignored in: Traceback (most recent call last):
  File "/home/basil/Desktop/msu_4_coursework/.venv/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f2708269510>    
  File "/home/basil/Desktop/msu_4_coursework/.venv/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
Traceback (most recent call last):
self._shutdown_workers()    
self._shutdown_workers()
  File "/home/basil/Desktop/msu_4_coursework/.venv/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
  File "/home/basil/Desktop/msu_4_coursework/.venv/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
  File "/home/bas

Finished base_01 -- 2023-11-25 23:18:00.992536 -- 0:00:34.100114
Start base_test -- 2023-11-25 23:18:00.992651


  0%|          | 0/1 [00:00<?, ?it/s]

KeyboardInterrupt: 