### Global Aggregations of Local Explanations

This notebook includes the various methods to generate global aggregates mentioned in the paper. For more information you can access this [paper](https://arxiv.org/abs/1907.03039).

To generate global aggregates we followed a two step process. Firstly we generate local explanations for many instances of a dataset, sum the weights of common feautres and calculate the total number each feature appears. Then we calculate the LIME importance, the average importance and the homogeneity wheighted importance. An overview of these processes are in the cells below. We choose to present the multimodal approach.

In [None]:
explainer = LimeMusicExplainer(class_names=...)
model.eval()

def default_dict_of_float():
    return defaultdict(float)

def default_dict_of_int():
    return defaultdict(int)

feature_weights = defaultdict(default_dict_of_float)
feature_counts = defaultdict(default_dict_of_int)

with torch.no_grad():
    for idx, batch in enumerate(tqdm.tqdm(test_loader)):
        if idx >= 8:
            break

        specs = batch['spectogram'].to(device)
        audios = batch['audio'].cpu().numpy()
        lyrics = batch['lyrics']
        tokens = batch['ids'].to(device)
        masks = batch['mask'].to(device)
        
        outputs = model(input_values=specs, input_ids=tokens, attention_mask=masks)
        softmax_scores = torch.softmax(outputs, dim=1)
        preds = torch.argmax(softmax_scores, dim=1)


        for indx, (audio, text) in enumerate(zip(audios, lyrics)):
            predicted_class = preds[indx].item()
            factorization = OpenumixFactorization(audio, temporal_segmentation_params=10, composition_fn=None)
            explanation = explainer.explain_instance(factorization, text, predict_fn, num_samples=5000, labels=(predicted_class,))
            components, weights = explanation.get_sorted_components(label=predicted_class, negative_components=False)
            for i in range(len(weights)):
                feature_weights[predicted_class][components[i]] += abs(weights[i])
                feature_counts[predicted_class][components[i]] += 1

In [None]:
# The predict function for the multimodal apprach. The dataset should take text and waveforms as inputs and give the corresponding token ids, attention masks and spectrograms.
def predict_fn(texts, waveforms, batch_size=256):

    all_probabilities = []
    dataset = MultimodalmDataset(texts, waveforms, feature_extractor)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_masks = batch['attention_mask'].to(device)
        batch_specs = batch['spectrogram'].to(device)
        with torch.no_grad():
            outputs = model(input_values=batch_specs, input_ids=input_ids, attention_mask=attention_masks)
            batch_probabilities = F.softmax(outputs, dim=1).cpu().numpy()
        
        all_probabilities.extend(batch_probabilities)
        torch.cuda.empty_cache()

    return np.array(all_probabilities)

In [None]:
def calculate_feature_influences(feature_weights, feature_counts):
    Icj_LIME = {c: {j: math.sqrt(weight) for j, weight in features.items()} for c, features in feature_weights.items()}
    Icj_AVG = {
        c: {j: feature_weights[c][j] / feature_counts[c][j] if feature_counts[c][j] > 0 else 0 for j in feature_weights[c]}
        for c in feature_weights
    }

    total_feature_importance = {}
    for c in Icj_LIME:
        for j in Icj_LIME[c]:
            if j in total_feature_importance:
                total_feature_importance[j] += Icj_LIME[c][j]
            else:
                total_feature_importance[j] = Icj_LIME[c][j]

    p_cj = {
        c: {j: Icj_LIME[c][j] / total_feature_importance[j] if total_feature_importance[j] > 0 else 0 for j in Icj_LIME[c]}
        for c in Icj_LIME
    }
    H_j = {
        j: -sum(p_cj[c].get(j, 0) * math.log(p_cj[c].get(j, 0), 2) if p_cj[c].get(j, 0) > 0 else 0 for c in p_cj)
        for j in set(j for c in p_cj for j in p_cj[c])
    }
    H_min = min(H_j.values())
    H_max = max(H_j.values())
    
    Icj_H = {c: {j: (1 - (H_j[j] - H_min) / (H_max - H_min)) * Icj_LIME[c][j] for j in Icj_LIME[c]} for c in Icj_LIME}
    
    return Icj_LIME, Icj_AVG, Icj_H