In [None]:
import torch

from tqdm.auto import tqdm
from dataloader import Dataloader
from whisper_wrapper import WhisperWrapper
from feature_density_estimator import FeatureDensityEstimator

In [None]:
# Load data
id = 1 
_, finetune_audios = Dataloader.load_uq_partitions("fine-tune", id, id + 1)
test_ds, test_audios = Dataloader.load_uq_partitions("test", id, id + 1)

In [None]:
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = f"danrdz/whisper-finetuned-es-modelo_{id:02d}"
model_wrapper = WhisperWrapper(model_name, device=dev)
fde = FeatureDensityEstimator(model_wrapper)
top_k = 1 
aggregation_fn = lambda x: torch.cat(x, dim=1).squeeze()
reduction_fn = lambda x: torch.flatten(x)
gen_kwargs = {  "return_dict_in_generate": True,
                "output_scores": False, 
                "output_hidden_states": True,
                "output_attentions": False}

In [None]:
@torch.no_grad()
def dataset_embedding_extraction(audios: list,
                                    aggregation_fn,
                                    gen_kwargs: dict) -> dict:
    
    outputs = {"decoder_hidden_states": [], "encoder_hidden_states": []}

    # Extract embbedings for all audios
    for _, audio in tqdm(enumerate(audios), leave=False, desc="Extracting embeddings", total=len(audios)):
        output = model_wrapper(audio, **gen_kwargs)
        outputs["decoder_hidden_states"].append(output.decoder_hidden_states)
        outputs["encoder_hidden_states"].append(output.encoder_hidden_states)

    # Check if the outputs have hidden states
    embeddings = {}
    if(outputs["decoder_hidden_states"][0] is not None):
        # Generate embeddings per layer by concatenating the output for each token at the specific layer
        decoder_hidden_states = {}
        for layer in range(len(outputs["decoder_hidden_states"][0][0])):
            decoder_hidden_states[layer] = []
            for audio in range(len(outputs["decoder_hidden_states"])):
                for token in range(len(outputs["decoder_hidden_states"][audio])):
                    decoder_hidden_states[layer].append(outputs["decoder_hidden_states"][audio][token][layer].cpu())
            decoder_hidden_states[layer] = aggregation_fn(decoder_hidden_states[layer])
        # Store the embeddings
        embeddings["decoder_hidden_states"] = decoder_hidden_states


    if(outputs["encoder_hidden_states"] is not None):
        encoder_hidden_states = {}
        for layer in range(len(outputs["encoder_hidden_states"][0])):
            encoder_hidden_states[layer] = []
            for audio in range(len(outputs["encoder_hidden_states"])):
                encoder_hidden_states[layer].append(outputs["encoder_hidden_states"][audio][layer].cpu())
            encoder_hidden_states[layer] = aggregation_fn(encoder_hidden_states[layer])
        # Store the embeddings
        embeddings["encoder_hidden_states"] = encoder_hidden_states

    return embeddings

In [None]:
embeddings = dataset_embedding_extraction(finetune_audios[0][:20], aggregation_fn, gen_kwargs)

In [None]:
CosineSimilarity = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
@torch.no_grad()
def block_influence_layer_selector(embeddings: dict, top_k: int = 1):
    bi = {}
    selected_embeddings = {}
    # Calculate the BI metric for each hstate
    for hstate_name, layers in embeddings.items():
        bi[hstate_name] = {}
        print(layers.keys(), range(len(layers) - 1))
        # For each layer of each hstate
        for l in range(len(layers) - 1):
            bi[hstate_name][l] = torch.mean(1 - CosineSimilarity(layers[l], layers[l + 1])).item()

        # Once we have the BI, sort the dictionary
        bi[hstate_name] = dict(sorted(bi[hstate_name].items(), key=lambda item: item[1], reverse=True))
        # Select the top_k layers
        tk = top_k
        selected_embeddings[hstate_name] = {}
        for k, v in bi[hstate_name].items():
            if(top_k == 0):
                break
            selected_embeddings[hstate_name][k] = v
            tk -= 1

    return selected_embeddings 

In [None]:
bi_embeddings = block_influence_layer_selector(embeddings, top_k)