## I. Import Dependencies

In [1]:
#papermill_description=IMPORT_DEPENDENCIES
%env CUDA_LAUNCH_BLOCKING=1
%load_ext autoreload
%load_ext line_profiler
%autoreload 0
# %autoreload 3

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.utils.data import DataLoader
from peft import LoraConfig, get_peft_model, LoraModel
from modules import OpenELMModel
from modules import LlamaTokenizer, LlamaModel, BertModel
from dep_model import DependencyParser, Hparam
from transformers import BertTokenizer
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from gpustat import GPUStatCollection
from tqdm import tqdm, trange
from collections import defaultdict
from typing import Dict, Literal, Tuple, List, Union, Any
import os
from os import path as osp
import json
import math
from matplotlib import (
    pyplot as plt,
    markers
)
import numpy as np
from subprocess import Popen, PIPE
import pickle as pkl
from util import eisner
import re
import random as rd

env: CUDA_LAUNCH_BLOCKING=1
Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.
using modified version of Apple OpenELM


## II. Extract Attention Features
### 1. Load Models and Tokenizers

In [2]:
LLAMA_MODEL_NAME: str = 'open_llama_7b'
TRAIN_DATA_DIR: str = './data/train.conllu'
DATA_PROPORTION = 1
TRANSPOSE = False # !!remember to change the cache file name `CACHE_DIR` if you change this option!!
CACHE_DIR = '<specify cache dir>'
FORCE_RECACHE = False # pos, height and label attention samples are cached in '${HOME}/attn-sample-cache.pt' to accelerate reading
# default behavior is directly jump to dependency tree reconstruction (assume attn_samples, kde_estims and mi_estims are already saved to disk)
LOAD_ATTN_SAMPLES = True # when set to True, will skip extracting attention weights and load from disk
LOAD_KDE_ESTIMS = True # when set to True, will skip calculating KDE estimations (f(a_i|y)) and load from disk
LOAD_MI = True # when set to True, will skip calculating MI estimations and load from disk
BREAK_AFTER_EXTRACT = False # when set to True, will only extract attention weights and save to disk, and then interrupt
BREAK_AFTER_KDE = False # when set to True, after calculating KDE estimations(f(a_i|t=t_i)) and save to disk, and then interrupt
BREAK_AFTER_MI = False # when set to True, after calculating MI estimations and save to disk, and then interrupt
SAVE_DURING_EXTRACTING = True
USE_MEMMAP = False # when extracting, save negative samples to disk using memmap
METADATA_ONLY = False # when extracting, don't save anything except matadata
SAVE_EVERY = 1024 # this option will only be valid if SAVE_DURING_EXTRACTING or USE_MEMMAP is set to True
KDE_NUM_SLICES = 8
LOAD_BASELINES = True

In [3]:
if not LOAD_ATTN_SAMPLES and SAVE_DURING_EXTRACTING and not BREAK_AFTER_EXTRACT:
    raise NotImplementedError(f'Error, if you process attention sample gathering and set `SAVE_DURING_EXTRACTING` to True, BREAK_AFTER_EXTRACT must be set to True, '
                              '\nand next time, execute with LOAD_ATTN_SAMPLES=True, '
                              'since the attention samples are saved during extracting and not in memory.')
assert int(SAVE_DURING_EXTRACTING) + int(METADATA_ONLY) + int(USE_MEMMAP) <= 1, 'Error, only one of `SAVE_DURING_EXTRACTING`, `METADATA_ONLY` and `USE_MEMMAP` can be set to True.'

In [None]:
#papermill_description=LOAD_MODEL
# warning: these code are dependent on the environment, please modify them according to your own environment
pretrained_path = '../pretrained-models/'
LLAMA_PATH = osp.join(pretrained_path, LLAMA_MODEL_NAME)
print(f'initializing LLAMA tokenizer and model from `{LLAMA_PATH}`...')
# config_dir_name = 'openllama' if 'llama' in LLAMA_MODEL_NAME else 'bert' if 'bert' in LLAMA_MODEL_NAME 
if 'llama' in LLAMA_MODEL_NAME:
    config_dir_name = 'openllama'
elif 'bert' in LLAMA_MODEL_NAME:
    config_dir_name = 'bert'
else:
    raise ValueError(f'Error, invalid model name {LLAMA_MODEL_NAME} to choose from `openllama` or `bert`')
dep_parser_llama = DependencyParser(
    json.load(open(f"./configs/{config_dir_name}/sample_labelmap.json")), 
    Hparam(**json.load(open(f"./configs/{config_dir_name}/sample_hpara.json"))), 
    LLAMA_PATH, device_to_place='cuda:0'
) 
model_llama, tok_llama = dep_parser_llama.llm if dep_parser_llama.llm is not None else dep_parser_llama.bert, dep_parser_llama.tokenizer
print(model_llama.device, model_llama.dtype)

### 2. Extract Attentions over the Dataset and Visualize

In [None]:
#papermill_description=LOAD_DATA
# data_instances = readfile('./data/train.conllu', 'train')
# lens = [len(each[0]) for each in data_instances[:10]]
train_data_llama = dep_parser_llama.load_data(TRAIN_DATA_DIR)
train_data_llama = train_data_llama[:round(len(train_data_llama) * DATA_PROPORTION)]
dl_llama = DataLoader(train_data_llama, collate_fn=lambda x: dep_parser_llama.feature2input('cuda:0', dep_parser_llama.convert_examples_to_features(x)))
def formalize_pickle_file_name(original_save_name: str, save_idx: int = None, num_saves: int = None):
    file_name, file_ext = original_save_name.rsplit('.', maxsplit=1)
    file_name = f"{file_name}_{str(DATA_PROPORTION).replace('.', '_')}{'_transpose' if TRANSPOSE else ''}"
    if save_idx is not None and num_saves is not None:
        # pad save_idx with 0, making it have the save length as num_saves
        file_name = f"{file_name}_{str(save_idx).zfill(len(str(num_saves)))}of{num_saves}"
    return f'{file_name}.{file_ext}'

def save_attn_samples(samples, path: str):
    print(f'saving to {path}...')
    torch.save(samples, path)

if 'dev' in TRAIN_DATA_DIR:
    kde_save_pth = osp.join(osp.split(osp.realpath(LLAMA_PATH))[0], 'kde', osp.split(LLAMA_PATH)[1] + '_dev',)
else:
    kde_save_pth = osp.join(osp.split(osp.realpath(LLAMA_PATH))[0], 'kde', osp.split(LLAMA_PATH)[1],)
os.makedirs(kde_save_pth, exist_ok=True)
print(f'When running, attention samples will be saved to `{kde_save_pth}`')

In [6]:
#papermill_description=EXTRACT_ATTN_FEATURES
from explaination import adj_matrix_to_heights, get_shape, print_listlike, FeatureList, concatenate_attn_features

# test model output dim
def extract_attn_features(
    tok: Union[LlamaTokenizer, BertTokenizer], 
    model: Union[LlamaModel, BertModel], 
    dataloader: DataLoader, 
    labelmap: Dict[str, int], 
    num_samples: int = 100, 
    silent: bool = True, 
    transpose: bool = False, 
    save_during_extracting: bool = False,
    save_every: int = -1,
    metadata_only: bool = False,
    use_memmap: bool = False,
) -> Tuple[
        Tuple[Union[List[Tensor], FeatureList], Union[List[Tensor], FeatureList]], Dict[str, Union[List[Tensor], FeatureList]]
    ]:
    """
    returns 
        (
            pos_features: List[Tensor[n_arcs_of_sample_i, feature_dim(n_layers * n_heads)]], 
            neg_features: List[Tensor[sample_i_seq_len ^ 2 - n_arcs_of_sample_i, feature_dim]]
        ), label_features: Dict[str(label_name of rel j), List[Tensor[samp_i_n_rel_j_arcs, feature_dim]]]
    """
    tok.pad_token = tok.eos_token
    torch.backends.cuda.enable_flash_sdp(True)
    torch.backends.cuda.enable_mem_efficient_sdp(False)
    # if concat:
    #     pos_features, neg_features = FeatureList(), FeatureList()
    #     height_features, label_features = defaultdict(FeatureList), defaultdict(FeatureList)
    # else:
    pos_features, neg_features = [], []
    height_features = defaultdict(list)
    label_features = defaultdict(list)
    min_attn_score, max_attn_score = 1145141919810, -1145141919810
    if num_samples == -1:
        num_samples = len(dataloader)
    if save_during_extracting:
        save_idx = 0
        num_saves = math.ceil(num_samples / save_every)
        n_pos_samples, n_neg_samples = 0, 0
    if metadata_only:
        n_pos_samples, n_neg_samples = 0, 0
    if use_memmap:
        # load metadata
        with open(osp.join(kde_save_pth, formalize_pickle_file_name('metadata.json'))) as f:
            metadata = json.load(f)
            n_pos_samples, n_neg_samples = metadata['n_pos_samples'], metadata['n_neg_samples']
        # initialize memmap for neg_features_concat
        neg_features_shape=(n_neg_samples, model_llama.config.num_hidden_layers * model_llama.config.num_attention_heads)
        print(f'creating memmap file (shape: {neg_features_shape})...')
        neg_features_memmap = np.memmap(osp.join(kde_save_pth, formalize_pickle_file_name('neg_attn_samples.pkl')), dtype='float32', mode='w+', shape=neg_features_shape)
        neg_features_memmap_ptr = 0
    
    id2label = {label_idx: label_name for label_name, label_idx in labelmap.items()}
    for data_idx, (input_ids, input_attention_mask, label_mask, eval_mask, \
        arcs, rels, word_ids, pos_ids, ngram_ids, \
        ngram_positions, segment_ids, valid_ids) in enumerate(tqdm(dataloader, total=num_samples, desc='extracting attentions', ncols=100)):

        echo = not silent and data_idx == 0

        if echo:
            print(input_ids.shape, label_mask.shape, valid_ids.shape, arcs.shape, rels.shape)
        B, S = input_ids.shape
        w2s = [] # the idx at position i is whole-word i's last subword's idx
        for subword_idx, each in enumerate(valid_ids[0].tolist()):
            if each == 1:
                w2s.append(subword_idx)

        arc_adj_matrix = torch.zeros(S, S).to('cuda:0') # 1 for `have arc`, 0 for `no arc` (marking at whole-word's last subword)
        label_adj_matrix = torch.zeros(S, S).to('cuda:0') # arc[i][j]'s relation type (marking at whole-word's last subword)
        # print_listlike([*enumerate(tok.convert_ids_to_tokens(input_ids[0]))])
        for word_idx, head_idx in enumerate(arcs[0]):
            if head_idx != -1:
                arc_adj_matrix[w2s[word_idx]][w2s[head_idx]] = 1
                label_adj_matrix[w2s[word_idx]][w2s[head_idx]] = rels[0][word_idx]
                # print(input_ids[0][w2s[word_idx]], input_ids[0][w2s[head_idx]])
                dependant_pos_id, head_pos_id = w2s[word_idx], w2s[head_idx]
                dependant_token_id, head_token_id = input_ids[0][dependant_pos_id].item(), input_ids[0][head_pos_id].item()
                # print(f"[{dependant_pos_id}]{tok.convert_ids_to_tokens(dependant_token_id)} -> [{head_pos_id}]{tok.convert_ids_to_tokens(head_token_id)} ({id2label[rels[0][word_idx].item()]})")
                # print([*enumerate(tok_llama.convert_ids_to_tokens(input_ids[0]))])

        height_adj_matrix = adj_matrix_to_heights(arc_adj_matrix)

        if metadata_only:
            n_pos_samples += arc_adj_matrix.sum().item()
            n_neg_samples += (1 - arc_adj_matrix).sum().item()
            if data_idx >= num_samples - 1:
                break

            continue

        if transpose:
            arc_adj_matrix, height_adj_matrix, label_adj_matrix = arc_adj_matrix.T, height_adj_matrix.T, label_adj_matrix.T
        if isinstance(model, LlamaModel):
            res = model.forward(
                input_ids=input_ids, attention_mask=input_attention_mask,
                output_hidden_states=True,
                output_attentions=False,
                output_attention_queries=True
            )
            key_values, queries = res.past_key_values, res.queries
            # kv: [num_layers, 2(k and v), batch_size, num_heads, sequence_length, head_dim], q: [num_layers, batch_size, num_heads, sequence_length, head_dim]
            if echo:
                print('past_key_values shape:', get_shape(res.past_key_values))
                # print(key_values)
                print('queries shape:', get_shape(res.queries))
                # print(queries)

            attn_scores = () # [num_layers, 1(batch_size), num_heads, seq_len, seq_len]
            for layer_idx in range(len(model.layers)):
                k = key_values[layer_idx][0]
                q = queries[layer_idx]
                if q.shape[-3] != k.shape[-3]: # num_q_heads == num_k_heads * n_groups
                    n_groups = q.shape[-3] / k.shape[-3]
                    assert n_groups == int(n_groups)
                    n_groups = int(n_groups)
                else:
                    n_groups = 1
                
                k = torch.repeat_interleave(k, n_groups, -3)
                this_attn_score = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(res.last_hidden_state.shape[-1]) 
                # scaled dot-product attention, this_attn_score: [batch_size(1), num_heads, sequence_length, sequence_length]
                attn_scores += (this_attn_score,)
            
            if echo:
                print('attn_scores shape:', get_shape(attn_scores))
            attn_scores = torch.cat(attn_scores, dim=1).squeeze(0) # [num_heads * num_layers, sequence_length, sequence_length], batch_size(1) is squeezed
        else:
            res = model.forward( 
                input_ids=input_ids, attention_mask=input_attention_mask,
                output_raw_attentions=True
            )
            attn_scores = torch.cat(res.attentions) # cat along the first dim (originally for batch size, now is `1`), resulting in [num_layers, num_heads, sequence_length, sequence_length]
            L, H, S, S = attn_scores.shape
            attn_scores = attn_scores.reshape(L * H, S, S)
            
        if echo:
            print('attn_score after squeezing:', get_shape(attn_scores))
        attn_features = attn_scores.permute(1, 2, 0) # [sequence_length, sequence_length, num_heads * num_layers]
        # if use_memmap:
        if attn_features.min() < min_attn_score:
            min_attn_score = attn_features.min().item()
        if attn_features.max() > max_attn_score:
            max_attn_score = attn_features.max().item()
        # add flatten attention features
        pos_features.append(attn_features[(arc_adj_matrix == 1).cpu()].cpu()) 
        # [[sentence1_num_arcs, num_heads * num_layers], [sentence2_num_arcs, num_heads * num_layers], ... ]
        neg_features.append(attn_features[(arc_adj_matrix == 0).cpu()].cpu()) 
        # [[sent1_seq_len * sent1_seq_len - sent1_num_arcs, num_heads * num_layers], [sent2_seq_len * sent2_seq_len - sent2_num_arcs, num_heads * num_layers]]
        for label_idx in labelmap.values():
            label_features[label_idx].append(attn_features[(label_adj_matrix == label_idx).cpu()].cpu())
            # {rel1_label_name: [[samp1_num_rel1_samples, num_heads * num_layers]], [[samp2_num_rel1_samples, num_heads * num_layers]]}
        
        for height in range(1, int(height_adj_matrix.max().item() + 1)):
            height_features[height].append(attn_features[(height_adj_matrix == height).cpu()].cpu())

        # save during
        if save_during_extracting and (data_idx != 0 and data_idx % save_every == 0 or data_idx == num_samples - 1):
            print(f'saving attention samples at {data_idx} (save idx {save_idx})...')
            (pos_features_concat, neg_features_concat), height_features_concat, label_features_concat, label_ids = concatenate_attn_features(
                (pos_features, neg_features), height_features, label_features, concat_across_labels=False)

            save_attn_samples(pos_features_concat, osp.join(kde_save_pth, formalize_pickle_file_name(f'pos_attn_samples.pkl', save_idx, num_saves)))
            save_attn_samples(neg_features_concat, osp.join(kde_save_pth, formalize_pickle_file_name('neg_attn_samples.pkl', save_idx, num_saves)))
            save_attn_samples(height_features_concat, osp.join(kde_save_pth, formalize_pickle_file_name('height_attn_samples.pkl', save_idx, num_saves)))
            save_attn_samples(label_features_concat, osp.join(kde_save_pth, formalize_pickle_file_name('label_attn_samples.pkl', save_idx, num_saves)))
            n_pos_samples += len(pos_features_concat)
            n_neg_samples += len(neg_features_concat)
            pos_features, neg_features = [], []
            height_features = defaultdict(list)
            label_features = defaultdict(list)
            save_idx += 1
        
        if use_memmap and (data_idx != 0 and data_idx % save_every == 0 or data_idx == num_samples - 1):
            neg_features_concat = torch.cat(neg_features, dim=0)
            neg_features_memmap[neg_features_memmap_ptr: neg_features_memmap_ptr + neg_features_concat.shape[0]] = neg_features_concat.float().numpy()
            print(f"writing to memmap from {neg_features_memmap_ptr} to {neg_features_memmap_ptr + neg_features_concat.shape[0]}")
            neg_features_memmap_ptr += neg_features_concat.shape[0]
            neg_features = []

        if data_idx >= num_samples - 1:
            break
    # end for data_idx, (input_ids, input_attention_mask, label_mask, eval_mask, arcs, rels, word_ids, pos_ids, ngram_ids, ngram_positions, segment_ids, valid_ids) in enumerate(dataloader)

    if save_during_extracting:
        with open(osp.join(kde_save_pth, 'labelmap.json'), 'w') as f:
            json.dump(labelmap, f)
        with open(osp.join(kde_save_pth, formalize_pickle_file_name('metadata.json')), 'w') as f:
            json.dump({
                'n_pos_samples': n_pos_samples,
                'n_neg_samples': n_neg_samples,
                'min_attn_score': min_attn_score,
                'max_attn_score': max_attn_score,
            }, f)
        # with open(osp.join(kde_save_pth, 'label_id2name.json'), 'w') as f:
        #     json.dump({label_id: label_name for label_name, label_id in dep_parser_llama.labelmap.items()}, f)
        label_id2name = {label_id: label_name for label_name, label_id in labelmap.items()}
        with open(osp.join(kde_save_pth, 'label_names.txt'), 'w') as f:
            label_names = [label_id2name[label_id] for label_id in labelmap.values()]
            json.dump(label_names, f)
    
    if metadata_only:
        with open(osp.join(kde_save_pth, formalize_pickle_file_name('metadata.json')), 'w') as f:
            json.dump({
                'n_pos_samples': int(n_pos_samples),
                'n_neg_samples': int(n_neg_samples),
            }, f)
    
    if use_memmap:
        neg_features_memmap.flush()
        with open(osp.join(kde_save_pth, formalize_pickle_file_name('metadata.json')), 'w') as f:
            json.dump({
                'n_pos_samples': n_pos_samples,
                'n_neg_samples': n_neg_samples,
                'min_attn_score': min_attn_score,
                'max_attn_score': max_attn_score,
            }, f)

    # if save_during_extracting is set to True, the return value will be empty lists and dicts, if not, return the UNconcatenated (chunked by sentence) features
    return (pos_features, neg_features), height_features, label_features

if not LOAD_ATTN_SAMPLES:
    with torch.no_grad():
        arc_attn_features, height_attn_features, label_attn_features = extract_attn_features(
            tok=tok_llama, model=model_llama, dataloader=dl_llama, labelmap=dep_parser_llama.labelmap, num_samples=-1, silent=False, transpose=TRANSPOSE, save_during_extracting=SAVE_DURING_EXTRACTING, save_every=SAVE_EVERY, metadata_only=METADATA_ONLY, use_memmap=USE_MEMMAP
        )
# result_test(tok_llama, model_llama, dl_llama)

In [7]:
if not LOAD_ATTN_SAMPLES and not SAVE_DURING_EXTRACTING:
    print(sum([each.shape[0] for each in arc_attn_features[0]]), sum([each.shape[0] for each in arc_attn_features[1]]))
    print(sum([sum([each.shape[0] for each in height_features]) for height_features in height_attn_features.values()]))
    print(sum([sum([each.shape[0] for each in label_features]) for label_features in label_attn_features.values()]))
# print(sum([each.shape[0] for each in label_attn_features]))

## III. Make Attention Statistical Analysis

### 1. Do Kernel Density Estimation, Calculate MI and Infer Dependency Trees

#### 1.1 Save & Load attention features

In [8]:
#papermill_description=SAVE_CONCATENATED_FEATURES
import os.path as osp
import os


KDE_DEVICE='cuda:0'

if not LOAD_ATTN_SAMPLES and not METADATA_ONLY and not SAVE_DURING_EXTRACTING:
    print(f'concatenating features')
    # if CONCAT_DURING_EXTRACTING:
    #     pos_features, neg_features = arc_attn_features[0].item, arc_attn_features[1].item
    #     height_features = [each.item for each in height_attn_features.values()]
    #     label_features = [each.item for each in label_attn_features.values()]
    #     label_ids = [*label_attn_features.keys()]

    # else:

    (pos_features, neg_features), height_features, label_features, label_ids = concatenate_attn_features(arc_attn_features, height_attn_features, label_attn_features, concat_across_labels=False)
    # pos|neg_features: [num_(pos|neg)_arcs, num_features(num_layers * num_heads)], height_features: [heights(tuple), num_arcs_of_this_height, num_features]
    print(f'{pos_features.numel():,} pos attn-feature samples, {neg_features.numel():,} neg attn-feature samples')
    print(f'{sum([this_height_feature.numel() for this_height_feature in height_features]):,} height attn-feature samples')
    print(f'saving features...')
    # transpose = ''

    save_attn_samples(pos_features, osp.join(kde_save_pth, formalize_pickle_file_name(f'pos_attn_samples.pkl')))
    if not USE_MEMMAP:
        save_attn_samples(neg_features, osp.join(kde_save_pth, formalize_pickle_file_name('neg_attn_samples.pkl')))
    save_attn_samples(height_features, osp.join(kde_save_pth, formalize_pickle_file_name('height_attn_samples.pkl')))
    save_attn_samples(label_features, osp.join(kde_save_pth, formalize_pickle_file_name('label_attn_samples.pkl')))
    with open(osp.join(kde_save_pth, 'labelmap.json'), 'w') as f:
        labelmap = dep_parser_llama.labelmap
        json.dump(labelmap, f)
    if not USE_MEMMAP:
        with open(osp.join(kde_save_pth, formalize_pickle_file_name('metadata.json')), 'w') as f:
            json.dump({
                'n_pos_samples': len(pos_features),
                'n_neg_samples': len(neg_features),
            }, f)
    # with open(osp.join(kde_save_pth, 'label_id2name.json'), 'w') as f:
    #     json.dump({label_id: label_name for label_name, label_id in dep_parser_llama.labelmap.items()}, f)
    label_id2name = {label_id: label_name for label_name, label_id in dep_parser_llama.labelmap.items()}
    with open(osp.join(kde_save_pth, 'label_names.txt'), 'w') as f:
        label_names = [label_id2name[label_id] for label_id in label_ids]
        json.dump(label_names, f)

    # torch.cuda.empty_cache()

if BREAK_AFTER_EXTRACT:
    raise KeyboardInterrupt

In [9]:
#papermill_description=LOAD_CONCATENATED_FEATURES
if LOAD_ATTN_SAMPLES and not SAVE_DURING_EXTRACTING:
    pos_features: Tensor = torch.load(osp.join(kde_save_pth, formalize_pickle_file_name('pos_attn_samples.pkl')), map_location='cpu')
    if not LOAD_KDE_ESTIMS: 
        # since neg_features is only used by KDE estimation of f(a_i|t=t_i) 
        # if loads both attention samples and kde estims, then `neg_features`, which is very large, will not be loaded
        if USE_MEMMAP: 
            with open(osp.join(kde_save_pth, formalize_pickle_file_name('metadata.json'))) as f:
                metadata = json.load(f)
                min_attn_feature_val, max_attn_feature_val = metadata['min_attn_score'], metadata['max_attn_score']
                total_n_neg_samples = metadata['n_neg_samples']

            neg_features: np.ndarray = np.memmap(osp.join(kde_save_pth, formalize_pickle_file_name('neg_attn_samples.pkl')), shape=(total_n_neg_samples, pos_features.shape[1]), dtype='float32', mode='r', )
        else:
            neg_features: Tensor = torch.load(osp.join(kde_save_pth, formalize_pickle_file_name('neg_attn_samples.pkl')), map_location='cpu')
    height_features: Tuple[Tensor] = torch.load(osp.join(kde_save_pth, formalize_pickle_file_name('height_attn_samples.pkl')), map_location='cpu')
    label_features: Tuple[Tensor] = torch.load(osp.join(kde_save_pth, formalize_pickle_file_name('label_attn_samples.pkl')), map_location='cpu')
    labelmap: Dict[str, int] = json.load(open(osp.join(kde_save_pth, 'labelmap.json')))
    # label_id2name = json.load(open(osp.join(kde_save_pth, 'label_id2name.json')))
    label_names: List[int] = json.load(open(osp.join(kde_save_pth, 'label_names.txt')))
    # if USE_MEMMAP:
if LOAD_ATTN_SAMPLES and SAVE_DURING_EXTRACTING:
    pickle_file_names = [*filter(lambda x: re.match(r'.*_attn_samples.*_(\d+)of(\d+).pkl', x), os.listdir(kde_save_pth))]
    possible_total_num_files = set([re.search(r'.*_(\d+)of(\d+)', each).group(2) for each in pickle_file_names])
    if len(possible_total_num_files) != 1:
        raise RuntimeError(f"found multiple possible total number of attn-sample save chunks: {', '.join(possible_total_num_files)}")

    labelmap: Dict[str, int] = json.load(open(osp.join(kde_save_pth, 'labelmap.json')))
    # label_id2name = json.load(open(osp.join(kde_save_pth, 'label_id2name.json')))
    label_names: List[int] = json.load(open(osp.join(kde_save_pth, 'label_names.txt')))
    def add_to_pos_features(features, this_features):
        return this_features if features is None else torch.cat([features, this_features], dim=0)
    # def add_to_preallocated_features(features, start_idx, this_features):
    #     features[start_idx: start_idx + this_features.shape[0]] = this_features
    def add_to_label_features(features, this_features):
        if len(features) == 0:
            return this_features
        else:
            return [torch.cat([feature, this_feature], dim=0) for feature, this_feature in zip(features, this_features)]
    pos_features, neg_features = None, None
    height_features, label_features = [], []
    total_num_files = int(possible_total_num_files.pop())
    if LOAD_KDE_ESTIMS and not FORCE_RECACHE and osp.exists(CACHE_DIR):
        pos_features, height_features, label_features = torch.load(CACHE_DIR)

    else:
        with open(osp.join(kde_save_pth, formalize_pickle_file_name('metadata.json')), 'r') as f:
            metadata = json.load(f)
            total_n_neg_samples = metadata['n_neg_samples']
            min_attn_feature_val, max_attn_feature_val = metadata['min_attn_score'], metadata['max_attn_score']
            current_neg_sample_ptr = 0

        for current_num_save in (chunked_load_pbar := trange(total_num_files, desc='loading attn sample chunks...')):
            # these lines are annotated since I want to load neg_features during ESTIMATE_KDE
            # if not LOAD_KDE_ESTIMS:
            #     # since neg_features is only used by KDE estimation of f(a_i|t=t_i) 
            #     # if both attention samples and kde estims are loaded, the time consumption will be incredibly high
            #     this_neg_features: Tensor = torch.load(osp.join(kde_save_pth, formalize_pickle_file_name('neg_attn_samples.pkl', current_num_save, total_num_files)), map_location='cpu')
            #     if neg_features is None:
            #         neg_features = torch.empty((total_n_neg_samples, this_neg_features.shape[1]))
            #         chunked_load_pbar.set_description('loading attention samples...')
            #     neg_features[current_neg_sample_ptr: current_neg_sample_ptr + this_neg_features.shape[0]] = this_neg_features
            #     current_neg_sample_ptr += this_neg_features.shape[0]
            
            this_pos_features: Tensor = torch.load(osp.join(kde_save_pth, formalize_pickle_file_name('pos_attn_samples.pkl', current_num_save, total_num_files)), map_location='cpu')
            pos_features: Tensor = add_to_pos_features(pos_features, this_pos_features)
            this_height_features: Tuple[Tensor] = torch.load(osp.join(kde_save_pth, formalize_pickle_file_name('height_attn_samples.pkl', current_num_save, total_num_files)), map_location='cpu')
            height_features = add_to_label_features(height_features, this_height_features)
            this_label_features: Tuple[Tensor] = torch.load(osp.join(kde_save_pth, formalize_pickle_file_name('label_attn_samples.pkl', current_num_save, total_num_files)), map_location='cpu')
            label_features = add_to_label_features(label_features, this_label_features)
        torch.save([pos_features, height_features, label_features], CACHE_DIR)

#### 1.2 Estimate attention conditional probability ($f(a_i|y)$)

##### 1.2.1 mask out dependency labels with no attention samples corresponding to it

In [10]:
#papermill_description=MASK_ZERO_DIM_FEATURES
masked_label_indices = [] # NOTE: indices corresponding to the position in label_features (concatenated, Type: List[Tensor[num_samples, num_attn_features]])
for idx, label_name in enumerate(label_names):
    # n_samples = len(label_attn_features[label_idx])
    # print(n_samples)
    size = label_features[idx].shape
    # except IndexError:
    #     print(f'Error occurred when processing [{idx}]{label_idx}, features:')
    #     print(label_features[idx])
    if size[0] == 0:
        print(f'0-dim attn-vector occurred at [{idx}]{label_name}')
        # print(size, label_attn_features[labelmap[label_name]])
        masked_label_indices.append(idx)

0-dim attn-vector occurred at [0]<UNK>
0-dim attn-vector occurred at [46]<s>


##### 1.2.2 calculate attention conditional probabilities ($f(a_i|y=0), f(a_i|y=1), f(a_i|l=l_0)$) using KDE

In [11]:
#papermill_description=ESTIMATE_KDE
import math
import matplotlib.pyplot as plt
from tqdm import trange
from explaination import integral_torch_cuda, estimate_kde_torch
from copy import deepcopy

def estimate_kde_torch_sliced(x: Tensor, samples: Tensor, num_slices: int = 1):
    """split across x, and estimate KDE for each slice, then concatenate them, in order to save GPU memory"""
    x_slices = torch.split(x, math.ceil(x.shape[0] / num_slices), dim=0)
    for idx, x_slice in enumerate(x_slices):
        if idx == 0:
            kde_estim = estimate_kde_torch(x_slice, samples, normalize=False)
        else:
            kde_estim = torch.cat([kde_estim, estimate_kde_torch(x_slice, samples, normalize=False)])

    
    cdf_max = integral_torch_cuda(x, kde_estim)
    return kde_estim / cdf_max

if not LOAD_KDE_ESTIMS:
    model_llama = model_llama.cpu() 
    print(f'offload model_llama to cpu to save GPU memory before doing KDE estimations')
    n_feature_dims = pos_features.shape[-1]
    pos_kde_estims, neg_kde_estims = [], []
    label_kde_estims = defaultdict(list)

    # TODO: if neg_features is a memmap, then this operation will be very slow, implement max and min during extracting and save to metadata
    # if not USE_MEMMAP:
        # max_attn_feature_val = max(pos_features.max().item(), neg_features.max().item())
        # min_attn_feature_val = min(pos_features.min().item(), neg_features.min().item())

    print(f'feature range: [{min_attn_feature_val}, {max_attn_feature_val}]')

    x = torch.cat([torch.arange(round(min_attn_feature_val, 1) - 0.2, 0, 0.1), torch.arange(0, 1, 0.01), torch.arange(1, max_attn_feature_val + 0.2)]).to(KDE_DEVICE)
    neg_shard_size = 128
    neg_shard_starts = range(0, n_feature_dims, neg_shard_size)
    current_shard_start = 0
    for feat_idx in trange(n_feature_dims, desc=f"estimating kde for each head"): # iterate over features
        # pos_kde_estims: [f(a_1|y=1), f(a_2|y=1), f(a_3|y=1), ... ]
        pos_kde_estims.append(estimate_kde_torch_sliced(x, pos_features[:, feat_idx].view(-1).to(KDE_DEVICE), num_slices=KDE_NUM_SLICES).cpu())
        # neg_kde_estims: [f(a_1|y=0), f(a_2|y=0), f(a_3|y=0), ... ]
        # if not USE_MEMMAP:
        #     neg_kde_estims.append(estimate_kde_torch_sliced(x, neg_features[:, feat_idx].view(-1).to(KDE_DEVICE), num_slices=KDE_NUM_SLICES).cpu())
        if SAVE_DURING_EXTRACTING: 
            if feat_idx in neg_shard_starts:
                current_shard_start = feat_idx
                neg_shard = torch.empty((total_n_neg_samples, neg_shard_size), dtype=pos_features.dtype)
                current_neg_sample_ptr = 0
                for current_num_save in trange(total_num_files, desc=f'loading neg tensor ({current_shard_start} to {current_shard_start + neg_shard_size})...', leave=False):
                    this_neg_features: Tensor = torch.load(osp.join(kde_save_pth, formalize_pickle_file_name('neg_attn_samples.pkl', current_num_save, total_num_files)), map_location='cpu')
                    neg_shard[current_neg_sample_ptr: current_neg_sample_ptr + this_neg_features.shape[0]] = this_neg_features[:, current_shard_start: current_shard_start + neg_shard_size]
                    current_neg_sample_ptr += this_neg_features.shape[0]
                    
            neg_kde_estims.append(estimate_kde_torch_sliced(x, neg_shard[:, feat_idx - current_shard_start].view(-1).to(KDE_DEVICE), num_slices=KDE_NUM_SLICES).cpu())
        else:
            raise NotImplementedError('Error, attn sample saving methods other than `SAVE_DURING_EXTRACTING` are not implemented yet')
        for idx, label_name in enumerate(label_names):
            if idx not in masked_label_indices:
                label_kde_estims[label_name].append(estimate_kde_torch_sliced(x, label_features[idx][:, feat_idx].view(-1).to(KDE_DEVICE), num_slices=KDE_NUM_SLICES).cpu())
                # label_kde_estims: {'label_name_1': [f(a_1|l=1), f(a_2|l=1), ...], 'label_name_m': [f(a_1|l=m), f(a_2|l=m), ...]}
    
    # if USE_MEMMAP:
    #     chunk_size = 128
    #     pbar = trange(n_feature_dims, desc='estimating KDE for neg_features', ncols=100)
    #     for feat_start in range(0, neg_features.shape[1], chunk_size):
    #         this_neg_features = torch.from_numpy(neg_features[:, feat_start: feat_start + chunk_size])
    #         for feat_idx in range(feat_start, min(feat_start + chunk_size, neg_features.shape[1])):
    #             neg_kde_estims.append(estimate_kde_torch_sliced(x.cpu(), this_neg_features[:, feat_idx - feat_start].view(-1), num_slices=KDE_NUM_SLICES).cpu())

number of attention heads: $h$, number of dependency labels: $m$

pos_kde_estims: $[f(a_1|y=1), f(a_2|y=1), ..., f(a_h|y=1)]$

neg_kde_estims: $[f(a_1|y=0), f(a_2|y=0), ..., f(a_h|y=0)]$

label_kde_estims: 
$$
\textrm{label}_1: [f(a_1|l=1), f(a_2|l=1), ..., f(a_h|l=1)],\\...\\
\textrm{label}_m: [f(a_1|l=m), f(a_2|l=m), ..., f(a_h|l=m)]\\
$$

##### 1.2.3 Save/Load KDE estimated probabilities

In [12]:
#papermill_description=SAVE_KDE_ESTIMS
# data_proportion = '_0_2'
# transpose = '_transpose'
if not LOAD_KDE_ESTIMS:
    pkl.dump([each.cpu() for each in pos_kde_estims], open(osp.join(kde_save_pth, formalize_pickle_file_name('pos_attn_conditional.pkl')), 'wb'))
    pkl.dump([each.cpu() for each in neg_kde_estims], open(osp.join(kde_save_pth, formalize_pickle_file_name('neg_attn_conditional.pkl')), 'wb'))
    for label_name in label_kde_estims: # move 
        for idx in range(len(label_kde_estims[label_name])):
            label_kde_estims[label_name][idx] = label_kde_estims[label_name][idx].cpu()
    pkl.dump(label_kde_estims, open(osp.join(kde_save_pth, formalize_pickle_file_name('label_attn_conditional.pkl')), 'wb'))
    pkl.dump(x.cpu(), open(osp.join(kde_save_pth, formalize_pickle_file_name('x.pkl')), 'wb'))
    if BREAK_AFTER_KDE:
        raise KeyboardInterrupt

In [13]:
#papermill_description=LOAD_KDE_ESTIMS
from os import path as osp

if LOAD_KDE_ESTIMS:
    pos_kde_estims: List[Tensor] = pkl.load(open(osp.join(kde_save_pth, formalize_pickle_file_name('pos_attn_conditional.pkl')), 'rb')) # f(a_i|y=1)
    neg_kde_estims: List[Tensor] = pkl.load(open(osp.join(kde_save_pth, formalize_pickle_file_name('neg_attn_conditional.pkl')), 'rb')) # f(a_i|y=0)
    label_kde_estims: Dict[str, List[Tensor]] = pkl.load(open(osp.join(kde_save_pth, formalize_pickle_file_name('label_attn_conditional.pkl')), 'rb')) # {'label_name(l)': [f(a_i|y=l)]}
    x: Tensor = pkl.load(open(osp.join(kde_save_pth, formalize_pickle_file_name('x.pkl')), 'rb'))

In [14]:
# print(label_kde_estims)
# label_kde_estims.pop('<s>')
print(len(label_kde_estims)) 
print(label_kde_estims.keys()) 
print(masked_label_indices)
print([len(each) for each in label_kde_estims.values()])
print([each[0].device for each in label_kde_estims.values()])

45
dict_keys(['prep', 'det', 'nn', 'num', 'pobj', 'punct', 'poss', 'possessive', 'amod', 'nsubj', 'appos', 'dobj', 'dep', 'cc', 'conj', 'nsubjpass', 'partmod', 'auxpass', 'advmod', 'root', 'ccomp', 'aux', 'cop', 'xcomp', 'quantmod', 'tmod', 'neg', 'infmod', 'rcmod', 'pcomp', 'mark', 'advcl', 'predet', 'csubj', 'mwe', 'parataxis', 'npadvmod', 'number', 'acomp', 'prt', 'iobj', 'preconj', 'expl', 'discourse', 'csubjpass'])
[0, 46]
[1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024]
[device(type='cpu'), device(type='cpu'), device(type='cpu'), device(type='cpu'), device(type='cpu'), device(type='cpu'), device(type='cpu'), device(type='cpu'), device(type='cpu'), device(type='cpu'), device(type='cpu'), device(type='cpu'), device(type='cpu'), device(type='cpu'), device(type='cpu

#### 1.3 Calculate label joint probability ($f(a_i,y)$), marginal probability ($p(y=y_m)$) and label conditional probability ($f(y|a_i)$ or $f(l|a_i)$)

In [27]:
#papermill_description=EST_JOINT_CONDITIONAL_PROBAB
with open(osp.join(kde_save_pth, formalize_pickle_file_name('metadata.json')), 'r') as f:
    metadata = json.load(f)
n_pos_samples, n_neg_samples = metadata['n_pos_samples'], metadata['n_neg_samples']
n_samples = n_neg_samples + n_pos_samples
n_samples_per_label = [label_features[idx].shape[0] for idx in range(len(label_features)) if idx not in masked_label_indices]
assert sum(n_samples_per_label) == n_pos_samples, \
    "the sum of all samples of each label from `label_features` must be equal to the number of positive samples: "\
    f"{sum(n_samples_per_label)} != {n_pos_samples}"
n_samples_per_label += [n_neg_samples]
# print(n_pos_samples, n_neg_samples)
pos_probab, neg_probab = n_pos_samples / n_samples, n_neg_samples / n_samples # p(y=1), p(y=0)
label_probabilities = [n_samples / sum(n_samples_per_label) for n_samples in n_samples_per_label] # p(l=0,1,2,...,num_labels) with last probability for `no arc`
joint_probab_pos: List[Tensor] = [] # p(a_0, y=0), p(a_1, y=0), ..., p(a_h, y=0): List[n_heads], of Tensor[n_x]
joint_probab_neg: List[Tensor] = [] # p(a_0, y=1), p(a_1, y=1), ..., p(a_h, y=1): List[n_heads], of Tensor[n_x]
joint_probab_label: List[List[Tensor]] = [] # List[n_heads, n_labels] of Tensor[n_x]
# p(l,a_i): [
#   [p(l=1,a_0), p(l=2,a_0), p(l=3,a_0), ...], 
#   [p(l=1,a_1), p(l=2,a_1), p(l=3,a_1), ...], 
#  ]
# joint_probab_label_neg: List[List[Tensor]] = [] # Tensor[n_heads, n_labels, n_x]
# # p(^l,a_i): [
# #   [p(l!=1,a_0)[n_x], p(l!=2,a_0)[n_x], p(l!=3,a_0)[n_x], ...], 
# #   [p(l!=1,a_1)[n_x], p(l!=2,a_1)[n_x], p(l!=3,a_1)[n_x], ...], 
# #  ]
marginal_probab_attn: List[Tensor] = [] # p(a_i): List[n_heads] of Tensor[n_x]
marginal_probab_attn_from_label: List[Tensor] = [] # p(a_i): List[n_heads] of Tensor[n_x]
marginal_probab_attn_pos_from_label: List[Tensor] = [] # p(a_i): List[n_heads] of Tensor[n_x]
conditional_probab_pos = [] # p(y=0|a_0), p(y=0|a_1), ... p(y=0|a_h): List[n_heads], of Tensor[n_x]
conditional_probab_neg = [] # p(y=1|a_0), p(y=1|a_1), ... p(y=1|a_h): List[n_heads], of Tensor[n_x]
conditional_probab_label = [] # List[n_heads, n_labels] of Tensor[n_x]
# p(l|a_i): [
#   [p(l=1|a_0), p(l=2|a_0), p(l=3|a_0), ...], 
#   [p(l=1|a_1), p(l=2|a_1), p(l=3|a_1), ...], 
#  ]
# conditional_probab_label_neg = [] # List[n_heads, n_labels] of Tensor[n_x]
# # p(^l|a_i)

for attn_feat_idx in trange(len(pos_kde_estims)): # attn_feat_idx (i)
    joint_probab_pos.append(pos_probab * pos_kde_estims[attn_feat_idx]) # p(a_i,y=1) = p(a_i|y=1) * p(y=1)
    joint_probab_neg.append(neg_probab * neg_kde_estims[attn_feat_idx]) # p(a_i,y=0) = p(a_i|y=0) * p(y=0)
    joint_probab_label.append([])
    # appended: [p(a_i, l=0), p(a_i, l=1), ... p(a_i, l=m)] = 
    #       [p(a_i|l=0) * p(l=0), p(a_i|l=1) * p(l=1), ... p(a_i|l=m) * p(l=m)]
    for label_idx, (label_name, label_estim) in enumerate(label_kde_estims.items()):
        joint_probab_label[-1].append(label_probabilities[label_idx] * label_estim[attn_feat_idx])
    joint_probab_label[-1].append(label_probabilities[-1] * neg_kde_estims[attn_feat_idx])
    marginal_probab_attn.append(joint_probab_pos[-1] + joint_probab_neg[-1]) # p(a_i) = p(a_i,y=1) + p(a_i,y=0)
    # print(len(joint_probab_label[-1]), [type(each) for each in joint_probab_label[-1]])
    # print(sum(joint_probab_label[-1]))
    marginal_probab_attn_from_label.append(sum(joint_probab_label[-1])) # p(a_i) = p(a_i,l=0) + p(a_i,l=1) + ... + p(a_i,l=m) 
    marginal_probab_attn_pos_from_label.append(sum(joint_probab_label[-1][:-1])) # (excluding the negative label, WARNING: NOT corresponding to a valid probabalistic distribution, should be normalized)
    conditional_probab_pos.append((joint_probab_pos[-1] / marginal_probab_attn[-1]).masked_fill(marginal_probab_attn[-1] == 0, 0)) # p(y=1|a_i) = p(a_i,y=1) / p(a_i)
    conditional_probab_neg.append((joint_probab_neg[-1] / marginal_probab_attn[-1]).masked_fill(marginal_probab_attn[-1] == 0, 0)) # p(y=0|a_i) = p(a_i,y=0) / p(a_i)
    conditional_probab_label.append([])
    for label_idx, this_joint_probab_label in enumerate(joint_probab_label[-1]): # this_joint_probab_label: p(a_i, l=label_idx or no_arc)
        conditional_probab_label[-1].append(this_joint_probab_label / marginal_probab_attn_from_label[-1]) # appended p(l=label_idx or no_arc | a_i)

joint_probab_label_stacked = torch.stack([torch.stack(each) for each in joint_probab_label]) # p(a_i, l=label_idx) List[n_heads, n_labels] of Tensor[n_x] -> Tensor[n_heads, n_labels, n_x]
_, n_labels, _ = joint_probab_label_stacked.shape

# for attn_feat_idx in trange(len(pos_kde_estims)):
#     joint_probab_label_neg.append([])
#     conditional_probab_label_neg.append([])
#     for label_idx in range(joint_probab_label_stacked.shape[1]):
#         joint_probab_label_neg[-1].append(
#             joint_probab_label_stacked[attn_feat_idx, torch.arange(n_labels) != label_idx, :].sum(dim=0)
#         )
#         # print(joint_probab_label_stacked[attn_feat_idx, torch.arange(n_labels) != label_idx, :].shape)
#         # print(joint_probab_label_neg[-1][-1].shape , marginal_probab_attn_from_label[attn_feat_idx].shape)
#         conditional_probab_label_neg[-1].append(
#             joint_probab_label_neg[-1][-1] / marginal_probab_attn_from_label[attn_feat_idx]
#         )

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1024/1024 [00:00<00:00, 1154.13it/s]


In [16]:
# check if the marginal probabilities integrate to 1 (and )
# for this_marginal_probab_attn, this_marginal_probab_attn_from_label in zip(marginal_probab_attn, marginal_probab_attn_from_label):
#     print(integral_torch_cuda(x, this_marginal_probab_attn.cuda()))
#     print(integral_torch_cuda(x, this_marginal_probab_attn_from_label.cuda()))

# for attn_feat_idx in range(len(conditional_probab_label)):
#     # for label_idx in range(len(joint_probab_label_neg[attn_feat_idx])):
#     this_conditional_probab_label = torch.stack(conditional_probab_label[attn_feat_idx])
#     print(this_conditional_probab_label.sum(0))
# print(marginal_probab_attn[0].shape, marginal_probab_attn_from_label[0].shape)
# plt.plot(x.cpu(), marginal_probab_attn[0].cpu())
# plt.plot(x.cpu(), marginal_probab_attn_from_label[0].cpu())

joint_probab_pos: $p(a_1, y=1), p(a_2, y=1), \cdots, p(a_h, y=1) $

joint_probab_neg: $p(a_1, y=0), p(a_2, y=0), \cdots, p(a_h, y=0) $

joint_probab_label: 
$$ \begin{matrix}
& label_0 & label_1 && label_m \\
{\textrm{head}_1} & p(a_1, l=0)& p(a_1, l=1)& \cdots& p(a_1, l=m) \\&\vdots&\ddots&&\vdots\\
{\textrm{head}_h} & p(a_h, l=0)& p(a_h, l=1)& \cdots& p(a_h, l=m)
\end{matrix} $$

**marginal**_probab_attn, marginal_probab_attn_from_label: $f(a_1), f(a_2), \cdots, f(a_h)$

conditional_probab_pos: $p(y=1|a_1), (y=1|a_2), \cdots, p(y=1|a_h)$

conditional_probab_neg: $p(y=0|a_1), (y=0|a_2), \cdots, p(y=0|a_h)$

conditional_probab_label: 
$$ \begin{matrix}
& label_0 & label_1 && label_m \\
{\textrm{head}_1} & p( l=0 | a_1)& p( l=1|a_1)& \cdots& p( l=m|a_1) \\&\vdots&\ddots&&\vdots\\
{\textrm{head}_h} & p( l=0 | a_h)& p( l=1 | a_h)& \cdots& p( l=m | a_h)
\end{matrix} $$

In [17]:
#papermill_description=SAVE_JOINT_CONDITIONAL_PROBAB
pkl.dump([each.cpu() for each in conditional_probab_pos], open(osp.join(kde_save_pth, formalize_pickle_file_name('pos_arc_conditional.pkl')), 'wb'))
pkl.dump([each.cpu() for each in conditional_probab_neg], open(osp.join(kde_save_pth, formalize_pickle_file_name('neg_arc_conditional.pkl')), 'wb'))
pkl.dump([each.cpu() for each in joint_probab_pos], open(osp.join(kde_save_pth, formalize_pickle_file_name('joint_pos.pkl')), 'wb'))
pkl.dump([each.cpu() for each in joint_probab_neg], open(osp.join(kde_save_pth, formalize_pickle_file_name('joint_neg.pkl')), 'wb'))
for i in range(len(conditional_probab_label)):
    for j in range(len(conditional_probab_label[i])):
        conditional_probab_label[i][j] = conditional_probab_label[i][j].cpu()
        joint_probab_label[i][j] = joint_probab_label[i][j].cpu()
pkl.dump(conditional_probab_label, open(osp.join(kde_save_pth, formalize_pickle_file_name('label_conditional.pkl')), 'wb'))
pkl.dump(joint_probab_label, open(osp.join(kde_save_pth, formalize_pickle_file_name('joint_label.pkl')), 'wb'))

In [18]:
#papermill_description=SAVE_JOINT_CONDITIONAL_PROBAB
# conditional_probab_pos: List[Tensor] = pkl.load(open(osp.join(kde_save_pth, formalize_pickle_file_name('pos_arc_conditional.pkl')), 'rb'))
# conditional_probab_neg: List[Tensor] = pkl.load(open(osp.join(kde_save_pth, formalize_pickle_file_name('neg_arc_conditional.pkl')), 'rb'))
# conditional_probab_label: List[List[Tensor]] = pkl.load(open(osp.join(kde_save_pth, formalize_pickle_file_name('label_conditional.pkl')), 'rb'))

In [19]:
#papermill_description=MOVE_JOINT_PROBAB_TO_CPU
conditional_probab_pos = [each.cpu() for each in conditional_probab_pos]
conditional_probab_neg = [each.cpu() for each in conditional_probab_neg]
for i in range(len(conditional_probab_label)):
    for j in range(len(conditional_probab_label[i])):
        conditional_probab_label[i][j] = conditional_probab_label[i][j].cpu()

#### 1.4 Estimate Mutual Information
formula
$$
\sum_{m=0}^{n_{\rm{labels}}} {
    \int_{attn_{\rm{min}}}^{attn_{\rm{max}}} {
        f(a_i, y=y_m) (\textrm{joint\textunderscore probab\textunderscore pos|neg|label})
        \log{ \frac{f(a_i,y=y_m)}{f(a_i)p(y=y_m)} }
        \rm{d} a_i
    }
}
$$

In [20]:
%load_ext autoreload
%autoreload 3

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [21]:
print("label_probabilities:", get_shape(label_probabilities))
print("marginal_probab_attn_from_label:", get_shape(marginal_probab_attn_from_label))
print("joint_probab_label:", get_shape(joint_probab_label))

label_probabilities: [46(A)]
marginal_probab_attn_from_label: [1024(A), [154](T)]
joint_probab_label: [1024(A), 46(A), [154](T)]


In [None]:
#papermill_description=ESTIMATE_MI
from explaination import integral_torch_cuda, inference_by_func, estimate_mi
# from debugpy import breakpoint as breakpoint_dbp

def calculate_entropy(probabilities):
    return sum([-each * math.log2(each) for each in probabilities])

arc_probabilities = [pos_probab, neg_probab] # p(y=y_m)
joint_probabilities_arc = [*zip(joint_probab_pos, joint_probab_neg)] # [[f(a_0, y=0), f(a_0, y=1)], [f(a_1, y=0), f(a_1, y=1)], ... ]

# raise KeyboardInterrupt
pos_scaler = sum(label_probabilities[:-1])
label_probabilities_pos = (torch.tensor(label_probabilities[:-1]) / pos_scaler).tolist() # [n_labels - 1]
pos_joint_probab_label_stacked = joint_probab_label_stacked[:, :-1, :] / pos_scaler # [n_attn_features, n_labels - 1, n_x]
marginal_probab_attn_pos_from_label_stacked = torch.stack(marginal_probab_attn_pos_from_label) / pos_scaler

if not LOAD_MI:
    binary_label_mi: List[List[float]] = [] # List[num_labels, num_heads]
    binary_label_mi_pos: List[List[float]] = [] # List[num_labels, num_heads]

    # joint_probab_label_stacked = torch.stack([torch.stack(each) for each in joint_probab_label])
    # print(joint_probab_label_stacked.shape)
    # print(joint_probab_label_stacked[:, 0, :].shape)
    # print(joint_probab_label_stacked[:, torch.arange(joint_probab_label_stacked.size(1)) != 0, :].shape)
    # print(get_shape(joint_probab_label_stacked))
    mi_intermediate_dir = osp.join(kde_save_pth, f"mi_intermediate_{str(DATA_PROPORTION).replace('.', '_')}{'_transpose' if TRANSPOSE else ''}")
    os.makedirs(mi_intermediate_dir, exist_ok=True)


    for label_idx in trange(len(label_probabilities), desc='estimating for each label type'):
        label_mi_intermediate_dir = osp.join(mi_intermediate_dir, f'{label_idx:02d}')
        os.makedirs(label_mi_intermediate_dir, exist_ok=True)
        n_attn_features, n_labels, n_x = joint_probab_label_stacked.shape

        binary_label_mi.append(
            estimate_mi(
                x, [label_probabilities[label_idx], 1 - label_probabilities[label_idx]], marginal_probab_attn_from_label, 
                torch.cat([
                    joint_probab_label_stacked[:, label_idx, :].view(n_attn_features, 1, n_x), 
                    joint_probab_label_stacked[:, torch.arange(n_labels) != label_idx, :].sum(1).view(n_attn_features, 1, n_x)
                ], dim=1), # Tensor[n_heads, 2, n_x]
                KDE_DEVICE, 
                # intermediate_save_dir=label_mi_intermediate_dir, 
            )
        )

        if label_idx != (len(label_probabilities) - 1):
            binary_label_mi_pos.append(
                estimate_mi(
                    x, [label_probabilities_pos[label_idx], 1 - label_probabilities_pos[label_idx]], marginal_probab_attn_pos_from_label_stacked, 
                    torch.cat([
                        _label_pos_joint_probab_label_stacked := pos_joint_probab_label_stacked[:, label_idx, :].view(n_attn_features, 1, n_x), 
                        _others_pos_joint_probab_label_stacked := pos_joint_probab_label_stacked[:, torch.arange(n_labels - 1) != label_idx, :].sum(1).view(n_attn_features, 1, n_x)
                    ], dim=1), # Tensor[n_heads, 2, n_x]
                    KDE_DEVICE, 
                    # intermediate_save_dir=label_mi_intermediate_dir, 
                )
            )
            mape = ((_label_pos_joint_probab_label_stacked + _others_pos_joint_probab_label_stacked).squeeze(1) - marginal_probab_attn_pos_from_label_stacked).abs() \
                / marginal_probab_attn_pos_from_label_stacked
            mape.masked_fill_(mape.isnan(), 0)
            this_label_entropy = calculate_entropy([label_probabilities_pos[label_idx], 1 - label_probabilities_pos[label_idx]])
            if ((this_binary_label_mi_pos_proportion := (Tensor(binary_label_mi_pos[-1]) / this_label_entropy)) > 1).any():
                breakpoint()
            else:
                print(f"avg. pos binary mi proportion: {this_binary_label_mi_pos_proportion.mean().item()}, max: {this_binary_label_mi_pos_proportion.max().item()}" )


    arc_mi = estimate_mi(x, arc_probabilities, marginal_probab_attn, joint_probabilities_arc, KDE_DEVICE)
    label_mi = estimate_mi(x, label_probabilities, marginal_probab_attn_from_label, joint_probab_label, KDE_DEVICE)
    label_mi_pos = estimate_mi(x, label_probabilities_pos, 
        marginal_probab_attn_pos_from_label_stacked, 
        pos_joint_probab_label_stacked, KDE_DEVICE
    )

    with open(osp.join(kde_save_pth, formalize_pickle_file_name('mi.json')), 'w') as f:
        json.dump({'arc_mi': arc_mi, 'label_mi': label_mi, 'binary_label_mi': binary_label_mi, 'pos_label_mi': label_mi_pos, 'pos_binary_label_mi': binary_label_mi_pos}, f)
else:
    with open(osp.join(kde_save_pth, formalize_pickle_file_name('mi.json')), 'r') as f:
        mi_json = json.load(f)
    arc_mi, label_mi, binary_label_mi, label_mi_pos, binary_label_mi_pos = \
        mi_json['arc_mi'], mi_json['label_mi'], mi_json['binary_label_mi'], mi_json['pos_label_mi'], mi_json['pos_binary_label_mi']

NameError: name 'pos_probab' is not defined

In [35]:
print(sum(label_mi_pos) / len(label_mi_pos))
print(TRANSPOSE)

0.2717388420678617
True


In [23]:
#papermill_description=CALCULATE_VAR_ENTROPY
unmasked_label_names = [each for idx, each in enumerate(label_names) if idx not in masked_label_indices]
arc_entropy = calculate_entropy(arc_probabilities)
arc_entropy_proportions = [each / arc_entropy for each in arc_mi]
label_entropy = calculate_entropy(label_probabilities)
label_entropy_proportions = [each / label_entropy for each in label_mi]
label_binary_entropies = [calculate_entropy([each, 1 - each]) for each in label_probabilities]
label_binary_entropies_pos = [calculate_entropy([each, 1 - each]) for each in label_probabilities_pos]
max_label_binary_mis = [max(each) for each in binary_label_mi]
max_label_binary_mis_pos = [max(each) for each in binary_label_mi_pos]
max_label_binary_head_indices = [np.argmax(each) for each in binary_label_mi]
max_label_binary_head_indices_pos = [np.argmax(each) for each in binary_label_mi_pos]
label_max_binary_entropy_proportions = []
for idx, (e1, e2) in enumerate(zip(label_binary_entropies, max_label_binary_mis)):
    label_max_binary_entropy_proportions.append(e2 / e1)
    # print(f'[{idx}]{unmasked_label_names[idx]}: {e2 / e1}')
    if e2 / e1 > 1:
        print(f"[{idx}]{unmasked_label_names[idx]} MI({e2}) greater than entropy({e1}) at head {max_label_binary_head_indices[idx]}")

label_max_binary_entropy_proportions_pos = []
for idx, (e1, e2) in enumerate(zip(label_binary_entropies_pos, max_label_binary_mis_pos)):
    label_max_binary_entropy_proportions_pos.append(e2 / e1)
    # print(f'[{idx}]{unmasked_label_names[idx]}: {e2 / e1}')
    if e2 / e1 > 1:
        print(f"[{idx}]{unmasked_label_names[idx]} MI({e2}) greater than entropy({e1}) at head {max_label_binary_head_indices[idx]}")

print(f'maximum arc entropy proportion: {max(arc_entropy_proportions)}, maximum label entropy proportion: {max(label_entropy_proportions)}')
print(f'avg. of max entropy proportion for each label {sum(label_max_binary_entropy_proportions) / len(label_max_binary_entropy_proportions)}')
print(f'avg. of max pos entropy proportion for each label {sum(label_max_binary_entropy_proportions_pos) / len(label_max_binary_entropy_proportions_pos)}')
# labels without no_arc 0.21166560621232114 0.19102649549393663 0.24943423602262535 
# labels with no_arc 0.21166560621232114 0.16670989650835305 0.23851857726097217

maximum arc entropy proportion: 0.7371160386615214, maximum label entropy proportion: 0.4858623397666946
avg. of max entropy proportion for each label 0.500809713974314


In [24]:
print(len(unmasked_label_names))
print([*enumerate(unmasked_label_names)])

45
[(0, 'prep'), (1, 'det'), (2, 'nn'), (3, 'num'), (4, 'pobj'), (5, 'punct'), (6, 'poss'), (7, 'possessive'), (8, 'amod'), (9, 'nsubj'), (10, 'appos'), (11, 'dobj'), (12, 'dep'), (13, 'cc'), (14, 'conj'), (15, 'nsubjpass'), (16, 'partmod'), (17, 'auxpass'), (18, 'advmod'), (19, 'root'), (20, 'ccomp'), (21, 'aux'), (22, 'cop'), (23, 'xcomp'), (24, 'quantmod'), (25, 'tmod'), (26, 'neg'), (27, 'infmod'), (28, 'rcmod'), (29, 'pcomp'), (30, 'mark'), (31, 'advcl'), (32, 'predet'), (33, 'csubj'), (34, 'mwe'), (35, 'parataxis'), (36, 'npadvmod'), (37, 'number'), (38, 'acomp'), (39, 'prt'), (40, 'iobj'), (41, 'preconj'), (42, 'expl'), (43, 'discourse'), (44, 'csubjpass')]


#### 1.5 Reconstruction Utilities

In [50]:
#papermill_description=GET_HIGH_MI_HEADS
def get_high_mi_heads_threshold(binary_label_mi: List[List[float]], label_binary_entropies: List[float], threshold: float = 0.2, topk: int = 5):
    label_high_mi_heads = []
    for label_idx, this_binary_label_mi in enumerate(binary_label_mi):
        sorted_mi_heads_and_mis = sorted(enumerate(this_binary_label_mi), key=lambda x: x[1], reverse=True)
        # print(f"label [{label_idx}]{unmasked_label_names[label_idx]}: {[(idx, round(each, 4), round(each / label_binary_entropies[label_idx], 4)) for idx, each in  high_mi_heads_and_mis]}")
        topk_mi_heads = set([idx for idx, mi in sorted_mi_heads_and_mis[:topk]])
        high_mi_proportion_heads = set([idx for idx, mi in sorted_mi_heads_and_mis if mi / label_binary_entropies[label_idx] > threshold])
        label_high_mi_heads.append(topk_mi_heads.union(high_mi_proportion_heads))
        # for each in high_mi_heads:
    return label_high_mi_heads

def get_high_mi_heads_threshold_mix(binary_label_mi: List[List[float]], label_binary_entropies: List[float], 
    binary_label_mi_pos: List[List[float]], label_binary_entropies_pos: List[float], alpha: float = 0.5, threshold: float = 0.2, last_threshold: float = 0.2):
    label_high_mi_heads = []
    binary_label_mi, label_binary_entropies, binary_label_mi_pos, label_binary_entropies_pos = \
        torch.Tensor(binary_label_mi), torch.Tensor(label_binary_entropies), torch.Tensor(binary_label_mi_pos), torch.Tensor(label_binary_entropies_pos)
    for label_idx in range(len(binary_label_mi)):
        if label_idx != (len(binary_label_mi) - 1):
            this_mixed_mi = binary_label_mi_pos[label_idx] * (1 - alpha) + binary_label_mi[label_idx] * alpha
            this_mixed_mi_proportions = (binary_label_mi_pos[label_idx] / label_binary_entropies_pos[label_idx]) * (1 - alpha) + (binary_label_mi[label_idx] / label_binary_entropies[label_idx]) * alpha
            # breakpoint()
        else:
            this_mixed_mi = binary_label_mi[label_idx]
            this_mixed_mi_proportions = binary_label_mi[label_idx] / label_binary_entropies[label_idx]
        # breakpoint()
        # print(f'label {label_idx} avg. mi proportion {this_mixed_mi_proportions.mean().item()}')
        sorted_heads_and_proportions = sorted(enumerate(this_mixed_mi_proportions.tolist()), key=lambda x: x[1], reverse=True)
        # print(f"label [{label_idx}]{unmasked_label_names[label_idx]}: {[(idx, round(each, 4), round(each / label_binary_entropies[label_idx], 4)) for idx, each in  high_mi_heads_and_mis]}")
        topk_mi_heads = set([idx for idx, mi in sorted_heads_and_proportions[:5]])
        high_mi_proportion_heads = set([idx for idx, mi_proportion in sorted_heads_and_proportions if mi_proportion > (threshold if label_idx != (len(binary_label_mi) - 1) else last_threshold)])
        label_high_mi_heads.append(topk_mi_heads.union(high_mi_proportion_heads))
        # for each in high_mi_heads:
    return label_high_mi_heads

def get_high_mi_heads_mass(binary_label_mi: List[List[float]], mass_proportion: float = 0.2):
    label_high_mi_heads = []
    for label_idx, this_binary_label_mi in enumerate(binary_label_mi):
        current_mass = 0
        mass_max = sum(this_binary_label_mi) * mass_proportion
        sorted_mi_heads_and_mis = sorted(enumerate(this_binary_label_mi), key=lambda x: x[1], reverse=True)
        this_label_high_mi_heads = []
        for idx, mi in sorted_mi_heads_and_mis:
            current_mass += mi
            this_label_high_mi_heads.append(idx)
            if current_mass > mass_max:
                break
        label_high_mi_heads.append(this_label_high_mi_heads)
    
    return label_high_mi_heads
        
# print_listlike([(label_name, len(each)) for label_name, each in zip(unmasked_label_names, label_high_mi_heads)])
# threshold:
# ('prep', 26) ('det', 20) ('nn', 6) ('num', 5) ('pobj', 61) ('punct', 5) ('poss', 5) ('possessive', 67) ('amod', 26) ('nsubj', 9)
# ('appos', 5) ('dobj', 39) ('dep', 5) ('cc', 18) ('conj', 5) ('nsubjpass', 5) ('partmod', 5) ('auxpass', 36) ('advmod', 5) ('root', 968)
# ('ccomp', 5) ('aux', 46) ('cop', 5) ('xcomp', 5) ('quantmod', 5) ('tmod', 5) ('neg', 10) ('infmod', 5) ('rcmod', 5) ('pcomp', 5)
# ('mark', 5) ('advcl', 5) ('predet', 5) ('csubj', 5) ('mwe', 5) ('parataxis', 5) ('npadvmod', 5) ('number', 17) ('acomp', 5) ('prt', 43)
# ('iobj', 22) ('preconj', 5) ('expl', 5) ('discourse', 5) ('csubjpass', 5)
# threshold = 0.46 if TRANSPOSE else 0.2
# threshold_high_mi_heads = get_high_mi_heads_threshold(binary_label_mi, label_binary_entropies, threshold)
# t1_threshold = (Tensor(binary_label_mi) / Tensor(label_binary_entropies).unsqueeze(-1)).flatten().sort(descending=True).values[1700].item()
# t2_threshold = (Tensor(binary_label_mi_pos) / Tensor(label_binary_entropies_pos).unsqueeze(-1)).flatten().sort(descending=True).values[1700].item()
# for alpha in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
# t1_threshold = 0.133 if TRANSPOSE else 0.2
# t2_threshold = 0.133 if TRANSPOSE else 0.2
# alpha = 0.5
# mixed_threshold = alpha * t1_threshold + (1 - alpha) * t2_threshold
# # mixed_threshold = 0.35 if TRANSPOSE else 0.16
# t2_threshold = 0.6 #if TRANSPOSE else 0.16
_SEARCH_MODE = 'posneg'# if '_SEARCH_MODE' not in os.environ else os.environ['SEARCH_MODE']# 'mixed' or 'posneg'

alpha = 0.5
threshold_l, threshold_r = 0.1, 0.2
threshold_m = (threshold_l + threshold_r) / 2
target_num_heads = 2000
total_num_high_mi_heads = -1
while not abs(total_num_high_mi_heads - target_num_heads) < 5:
    if _SEARCH_MODE == 'mixed':
        threshold_high_mi_heads = get_high_mi_heads_threshold_mix(binary_label_mi, label_binary_entropies, binary_label_mi_pos, label_binary_entropies_pos, alpha=alpha, threshold=threshold_m, last_threshold=threshold_m)
    elif _SEARCH_MODE == 'posneg':
        threshold_high_mi_heads = get_high_mi_heads_threshold(binary_label_mi, label_binary_entropies, threshold_m)
    # threshold_high_mi_heads = get_high_mi_heads_threshold(binary_label_mi, threshold)
    print(f"trying on threshold {threshold_m}, alpha {alpha}, search mode {_SEARCH_MODE}")
    # print_listlike([(label_name, len(each)) for label_name, each in zip(unmasked_label_names, threshold_high_mi_heads)])
    total_num_high_mi_heads = sum([len(each) for each in threshold_high_mi_heads])
    no_arc_high_mi_heads = len(threshold_high_mi_heads[-1])
    # print(f'heads responsible for no_arc: {(no_arc_high_mi_heads := len(threshold_high_mi_heads[-1]))}')
    print(f"total {total_num_high_mi_heads} ({total_num_high_mi_heads - no_arc_high_mi_heads} + {no_arc_high_mi_heads}) heads")
    if total_num_high_mi_heads < target_num_heads:
        threshold_r = threshold_m
    else:
        threshold_l = threshold_m
    
    threshold_m = (threshold_l + threshold_r) / 2

    # mass = 0.05
    # mass_high_mi_heads = get_high_mi_heads_mass(binary_label_mi, mass)
    # print(f"high MI heads for each label type with mass {mass}")
    # print_listlike([(label_name, len(each)) for label_name, each in zip(unmasked_label_names, mass_high_mi_heads)])
    # print(f'heads responsible for no_arc: {(no_arc_high_mi_heads := len(mass_high_mi_heads[-1]))}')
    # print(f"total {(total_num_high_mi_heads := sum([len(each) for each in mass_high_mi_heads]))} heads (including {total_num_high_mi_heads - no_arc_high_mi_heads} label heads)")
if BREAK_AFTER_MI:
    raise KeyboardInterrupt

high MI heads for each label type with threshold 0.4
('prep', 582) ('det', 636) ('nn', 395) ('num', 7) ('pobj', 575) ('punct', 690) ('poss', 5) ('possessive', 61) ('amod', 322) ('nsubj', 302)
('appos', 6) ('dobj', 216) ('dep', 5) ('cc', 23) ('conj', 92) ('nsubjpass', 5) ('partmod', 5) ('auxpass', 20) ('advmod', 6) ('root', 883)
('ccomp', 30) ('aux', 227) ('cop', 8) ('xcomp', 29) ('quantmod', 8) ('tmod', 5) ('neg', 6) ('infmod', 9) ('rcmod', 13) ('pcomp', 7)
('mark', 8) ('advcl', 5) ('predet', 5) ('csubj', 5) ('mwe', 5) ('parataxis', 5) ('npadvmod', 6) ('number', 5) ('acomp', 5) ('prt', 22)
('iobj', 5) ('preconj', 6) ('expl', 16) ('discourse', 5) ('csubjpass', 5)
heads responsible for no_arc: 1024
total 6310 heads (including 5286 label heads)


KeyboardInterrupt: 

In [59]:
_PLOT = False
if _PLOT:
    from matplotlib import pyplot as plt
    available_thresholds = torch.arange(0.4, 0.7, 0.01)
    n_heads = []
    for threshold in (pbar := tqdm(available_thresholds)):
        n_heads.append(sum([len(each) for each in get_high_mi_heads_threshold(binary_label_mi, threshold.item())]))
        pbar.set_description(f"threshold: {threshold.item()}, last appended value: {n_heads[-1]}")
    plt.plot(available_thresholds, n_heads)

threshold: 0.6899999976158142, last appended value: 313: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:00<00:00, 52.16it/s]


In [69]:
# torch.cat([torch.arange(0.4, 0.46, 0.01), torch.arange(0.46, 0.5, 0.02), torch.arange(0.5, 0.7001, 0.04)]).tolist()

tensor([0.4000, 0.4100, 0.4200, 0.4300, 0.4400, 0.4500, 0.4600, 0.4800, 0.5000,
        0.5400, 0.5800, 0.6200, 0.6600, 0.7000])

In [None]:
#papermill_description=LOAD_BASELINES
from explaination_baselines import SimpleMLP, IndependentMLP
def mask_baseline_matrix(baseline_matrix: Tensor, masked_label_indices: List[int]):
    baseline_num_labels = baseline_matrix.shape[0]
    baseline_unmasked_label_index = torch.ones(baseline_num_labels).bool()
    for idx in range(baseline_num_labels):
        if idx in masked_label_indices:
            baseline_unmasked_label_index[idx] = False
    return baseline_matrix[baseline_unmasked_label_index]

load_baselines = False
if LOAD_BASELINES:
    mean_values_matrix = torch.load(osp.join(kde_save_pth, 'baselines', 'total_mean_values.pt'))
    pl_matrix = torch.load(osp.join(kde_save_pth, 'baselines', 'probless_matrix.pt'))
    iou_matrix = torch.load(osp.join(kde_save_pth, 'baselines', 'iou_matrix_dynamic_0.995.pt'))
    v_information_matrix = torch.load(osp.join(kde_save_pth, 'baselines', 'v_information_h_2_4_leaky_relu_0.0_0.0_balanced.pt')).transpose(-2, -1)
    variational_family_state_dict = torch.load(osp.join(kde_save_pth, 'baselines', 'variational_family_mlp_h_2_4_leaky_relu_0.0_0.0_balanced.pt'))
    variational_family = IndependentMLP(pos_features.shape[-1], [2, 4], variational_family_state_dict['final_layer'].shape[-1], 'leaky_relu')
    variational_family.load_state_dict(variational_family_state_dict)
    v_information_matrix /= 47054848 
    v_information_matrix = torch.cat([v_information_matrix[2:], v_information_matrix[0].unsqueeze(0)], dim=0)
    bounded_v_information_matrix = (torch.ones_like(v_information_matrix) * v_information_matrix[v_information_matrix < 1e12].max() - v_information_matrix).to(pl_matrix.device)
    binary_entropies_matrix = torch.Tensor(label_binary_entropies).to(v_information_matrix.device).unsqueeze(-1).expand_as(v_information_matrix)
    # v_information_matrix = binary_entropies_matrix - v_information_matrix
    elasticnet_state_dict = torch.load(osp.join(kde_save_pth, 'baselines', 'mlp_1e-05_1e-05', 'best_model.pt'), map_location='cpu')
    elasticnet_matrix = elasticnet_state_dict['mlp.weight']
    # vinf_state_dict = torch.load(osp.join(kde_save_pth, 'baselines'))
    deepmlp_state_dict = torch.load(osp.join(kde_save_pth, 'baselines', 'mlp_h_512_128_0.0_0.0.pt'), map_location='cpu')
    deepmlp = SimpleMLP(pos_features.shape[-1], [512, 128], deepmlp_state_dict['mlp.2.weight'].shape[0])
    deepmlp.load_state_dict(deepmlp_state_dict)
    # elasticnet = SimpleMLP(pos_features.shape[-1], [], elasticnet_matrix.shape[0],).to(torch.bfloat16)
    # elasticnet.load_state_dict(elasticnet_state_dict)
    # vinf = IndependentMLP(pos_features.shape[-1], [], len(unmasked_label_names) + 1, 'leaky_relu')

    pl_matrix = mask_baseline_matrix(pl_matrix, masked_label_indices)
    iou_matrix = mask_baseline_matrix(iou_matrix, masked_label_indices)
    elasticnet_matrix = torch.cat([elasticnet_matrix[2:], elasticnet_matrix[0].unsqueeze(0)], dim=0)

In [None]:
#papermill_description=GET_BASELINE_HEADS
# print(Tensor(binary_label_mi).shape)
# print(pl_matrix.shape)
def get_num_heads(high_corr_head_list):
    return sum([len(each) for each in high_corr_head_list])

def select_baseline_heads_by_headnum(baseline_matrix: Tensor, high_mi_heads: List[List[int]]):
    baseline_high_mi_heads = []
    for label_idx, each in enumerate(high_mi_heads):
        this_num_heads = len(each)
        index_sorted_label_baseline_matrix = sorted([*enumerate(baseline_matrix[label_idx].tolist())], reverse=True, key=lambda x: x[1])
        baseline_high_mi_heads.append([each[0] for each in index_sorted_label_baseline_matrix[:this_num_heads]])
    return baseline_high_mi_heads

def select_baseline_heads_by_threshold(baseline_matrix: Tensor, threshold: float, force_positive: bool = True):
    # total_headnum = get_num_heads(high_mi_heads)
    # current_threshold = baseline_matrix.flatten().sort(descending=True)[total_headnum].item()
    # for _ in range(50): # maximum 50 iters
    target_heads = []
    for label_idx, this_baseline_corrs in enumerate(baseline_matrix):
        # label_high_mi_heads.append([])
        sorted_heads_and_corrs = sorted(enumerate(this_baseline_corrs), key=lambda x: x[1], reverse=True)
        # print(f"label [{label_idx}]{unmasked_label_names[label_idx]}: {[(idx, round(each, 4), round(each / label_binary_entropies[label_idx], 4)) for idx, each in  high_mi_heads_and_mis]}")
        if force_positive:
            topk_mi_heads = set([idx for idx, corr in sorted_heads_and_corrs[:5] if corr > 0])
            high_mi_proportion_heads = set([idx for idx, corr in sorted_heads_and_corrs if corr > threshold and corr > 0])
        else:
            topk_mi_heads = set([idx for idx, corr in sorted_heads_and_corrs[:5]])
            high_mi_proportion_heads = set([idx for idx, corr in sorted_heads_and_corrs if corr > threshold])
        
        target_heads.append(topk_mi_heads.union(high_mi_proportion_heads))
        # if get_num_heads(target_heads) < total_headnum:
    return target_heads

def get_baseline_grid(baseline_matrix: Tensor, head_proportion: float, num_grid_sets: int):
    assert num_grid_sets % 2 == 1, f'only odd number of grid sets are supported, got {num_grid_sets}'
    sorted_corr_values = sorted(baseline_matrix.flatten().tolist(), reverse=True)
    total_num_heads = baseline_matrix.numel()
    middle_threshold = sorted_corr_values[int(head_proportion * total_num_heads)]
    loose_threshold = sorted_corr_values[int(3 * head_proportion * total_num_heads)]
    max_value = max(sorted_corr_values)
    num_grids_per_side = (num_grid_sets - 1) // 2

    thresholds = torch.linspace(loose_threshold, middle_threshold, num_grids_per_side + 1).tolist()[:-1] \
        + [middle_threshold] + torch.linspace(middle_threshold, max_value, num_grids_per_side + 2).tolist()[1:-1]
    head_sets, num_heads = [], []
    for threshold in thresholds:
        head_sets.append(select_baseline_heads_by_threshold(baseline_matrix, threshold))
        num_heads.append(get_num_heads(head_sets[-1]))

    return thresholds, head_sets, num_heads
    # sets, thresholds, numheads = [initial_set], [start_threshold], [get_num_heads(initial_set)]
    # return [*zip(sets, thresholds, numheads)]
# print(get_num_heads(pl_high_mi_heads))
# get baselines grid test grids
matrix_to_try = bounded_v_information_matrix
_GET_THRESHOLD = False
if LOAD_BASELINES and _GET_THRESHOLD:
    threshold_l, threshold_r = -100, 100
    threshold_m = (threshold_l + threshold_r) / 2
    target_num_heads = 2000
    total_num_high_mi_heads = -1
    while not abs(total_num_high_mi_heads - target_num_heads) < 5:
        baseline_threshold_high_mi_heads = select_baseline_heads_by_threshold(matrix_to_try, threshold_m, force_positive=False)
        total_num_high_mi_heads = get_num_heads(baseline_threshold_high_mi_heads)
        print(f"trying on threshold {threshold_m}")
        print(f"total {total_num_high_mi_heads} heads (including {len(baseline_threshold_high_mi_heads[-1])} neg heads)")
        if total_num_high_mi_heads < target_num_heads:
            threshold_r = threshold_m
        else:
            threshold_l = threshold_m
        
        threshold_m = (threshold_l + threshold_r) / 2

    # num_high_mi_heads = get_num_heads(threshold_high_mi_heads)
    # start_head_proportion = num_high_mi_heads / pl_matrix.numel()
    # print(num_high_mi_heads, start_head_proportion)
    # pl_thresholds, pl_head_sets, pl_num_heads = get_baseline_grid(pl_matrix, start_head_proportion, 5)
    # iou_thresholds, iou_head_sets, iou_num_heads = get_baseline_grid(iou_matrix, start_head_proportion, 5)
    # elasticnet_thresholds, elasticnet_head_sets, elasticnet_num_heads = get_baseline_grid(elasticnet_matrix, start_head_proportion, 5)

# pl_high_mi_heads = select_baseline_heads_by_mi_heads(pl_matrix, threshold_high_mi_heads)
# iou_high_mi_heads = select_baseline_heads_by_mi_heads(iou_matrix, threshold_high_mi_heads)
# elasticnet_high_mi_heads = select_baseline_heads_by_mi_heads(elasticnet_matrix, threshold_high_mi_heads)
# random_high_mi_heads = select_baseline_heads_by_mi_heads(torch.rand_like(pl_matrix), threshold_high_mi_heads)
# pl_high_mi_heads = []
# for label_idx, each in enumerate(threshold_high_mi_heads):
#     this_num_heads = len(each)
#     index_sorted_label_pl_matrix = sorted([*enumerate(pl_matrix[label_idx].tolist())], reverse=True, key=lambda x: x[1])
#     pl_high_mi_heads.append([each[0] for each in index_sorted_label_pl_matrix[:this_num_heads]])

In [None]:
#papermill_description=INFERENCE_BY_ATTN_FEATURES
%autoreload 2
import pdb
from explaination import batched_inference_by_func
from explaination import move_to_device
# MODEL_DEVICE = model_llama.device
def get_least_used_gpu():
    result = GPUStatCollection.new_query()
    spare = []
    for idx, each in enumerate(result):
        spare += [each.memory_total - each.memory_used]
    
    return spare.index(max(spare))

def move_to_device(device: Union[int, str], *tensors):
    return [(each.to(device) if each is not None else each) for each in tensors]


MODEL_DEVICE = f'cuda:{get_least_used_gpu()}'

val_data_llama = dep_parser_llama.load_data('./data/dev.conllu')
val_data_llama = sorted(val_data_llama, key=lambda instance: len(instance.head), reverse=True)

dl_val_llama = DataLoader(val_data_llama, collate_fn=lambda x: dep_parser_llama.feature2input(MODEL_DEVICE, dep_parser_llama.convert_examples_to_features(x)), batch_size=1)

@torch.no_grad
def infer_by_attn_features(
    tok: Union[LlamaTokenizer, BertTokenizer], model: Union[LlamaModel, BertModel], dataloader: DataLoader, labelmap_convert: Dict[int, int],
    x: Tensor,
    conditional_probabs: Union[List[Tuple[Tensor, Tensor]], Tensor], 
    mutual_informations: List[float],
    label_conditional_probabs: Union[List[List[Tensor]], Tensor] = None,
    label_mutual_informations: List[List[float]] = None,
    num_samples: int = 100, with_labels: bool = False, silent: bool = True, use_weighted_mi: bool = True, high_mi_heads: List[List[int]] = None, 
    no_pbar: bool = False, transpose: bool = False,
    include_neg_possibilities: bool = False,
    infer_method: Literal['kde', 'score', 'mlp', 'variational_family'] = 'kde',
    infer_model: torch.nn.Module = None,
    
) -> Tuple[
        Tuple[List[Tensor], List[Tensor]], Dict[str, List[Tensor]]
    ]:
    """
    Args:
        tok: tokenizer
        model: model
        dataloader: DataLoader
        labelmap_convert (Dict[int, int]): labelmap converting the labelmap of `label_conditional_probabs` to the labelmap of the dataset
        x (Tensor[n_x]): the domain of defination of conditional_probabs
        conditional_probabs: `List[n_heads, 2] of Tensor[n_x]` or `Tensor[n_heads, 2, n_x]`:
            [head1: [pos_conditional_probab_on_head_1, neg_conditional_probab_on_head_1],
            ...
            [headh: [pos_conditional_probab_on_head_h, neg_conditional_probab_on_head_h],
        mutual_informations: `List[n_heads]` [mi_1, mi_2, ...]
        label_conditional_probabs: `List[n_heads, n_labels] of Tensor[n_x]` or `Tensor[n_heads, n_labels, n_x]`: 
        [
            head1: [label1_conditional_probab_on_head_1, label2_conditional_probab_on_head_1, ...], 
            ...,
            headh: [label1_conditional_probab_on_head_h, label2_conditional_probab_on_head_h, ...]
        ]
        label_mutual_informations: `List[n_labels, n_heads]`: [
        label1_mi_for_all_heads, 
        label2_mi_for_all_heads, 
        ...]
        num_samples: number of samples to extract
        with_labels: whether to extract label features
        silent: whether to print debug info
        mi_threshold: threshold for mutual information to consider, used when inferring by arcs
        high_mi_heads: `List[n_labels]` of List[(int)]
            heads with high mutual information to consider, used when inferring by labels (`with_labels=True`)
    returns 
    """
    tok.pad_token = tok.eos_token
    torch.backends.cuda.enable_flash_sdp(True)
    torch.backends.cuda.enable_mem_efficient_sdp(False)
    num_heads = len(mutual_informations)
    # assert `conditional_probabs` have shape [n_heads, 2, n_x]
    if isinstance(conditional_probabs, Tensor):
        assert len(conditional_probabs.shape) == 3 and conditional_probabs.shape[1] == 2, f"conditional_probabs must be in shape [n_heads, 2, n_x]"
    else:
        assert len(conditional_probabs) == num_heads, f"conditional_probabs must have the same n_heads (=len(mutual_informations({num_heads})))"
        assert all([len(each) == 2 for each in conditional_probabs]), "conditional_probabs must be in shape [n_heads, 2, n_x]"
        for head_idx, (pos_conditional_probab, neg_conditional_probab) in enumerate(conditional_probabs):
            assert len(pos_conditional_probab.shape) == 1 and len(neg_conditional_probab.shape) == 1, f"pos_conditional_probab and neg_conditional_probab must be in shape [n_x]"
        # stack conditional_probabs to Tensor[n_heads, 2, n_x]
        conditional_probabs = torch.stack([torch.stack([pos_conditional_probab, neg_conditional_probab]) for pos_conditional_probab, neg_conditional_probab in conditional_probabs])
    # assert `label_conditional_probabs` has shape [n_heads, n_labels, n_x]
    if with_labels:
        assert label_mutual_informations is not None and label_conditional_probabs is not None, \
            "label_mutual_informations and label_conditional_probabs must be provided if with_labels is True"
        num_labels = len(label_mutual_informations)
        if isinstance(label_conditional_probabs, Tensor):
            assert len(label_conditional_probabs.shape) == 3 and label_conditional_probabs.shape[1] == num_labels, \
                f"label_conditional_probabs must be in shape [n_heads, {num_labels}(=len(label_mutual_informations), n_x]"
        else:
            assert len(label_conditional_probabs) == len(mutual_informations), "label_conditional_probabs must have the same length as mutual_informations"
            for head_idx, label_conditional_probab in enumerate(label_conditional_probabs):
                assert len(label_conditional_probab) == len(label_mutual_informations), f"each entry of label_conditional_probabs must have the same length as num_labels"
                for label_idx, label_probab in enumerate(label_conditional_probab):
                    assert len(label_probab.shape) == 1, f"label_probab must be in shape [n_x]"
            # stack label_conditional_probabs to Tensor[n_heads, num_labels, n_x]
            label_conditional_probabs = torch.stack([torch.stack(label_conditional_probab) for label_conditional_probab in label_conditional_probabs])
    
    if high_mi_heads is not None:
        if with_labels:
            num_labels = len(label_mutual_informations)
            assert len(high_mi_heads) == num_labels, "high_mi_heads must have the same length as label_mutual_informations"
            # print(f'high_mi_heads: {high_mi_heads} will supress mi_threshold ({mi_threshold})')
    
    # if infer_method == 'model' and not include_neg_possibilities:
    #     raise AssertionError("inference method of `model` requires include_neg_possibilities to be True")
    
    labelmap_reverse_convert = {v: k for k, v in labelmap_convert.items()} # convert the labelmap of the dataset to the labelmap of `label_conditional_probabs`

    if num_samples == -1:
        num_samples = len(dataloader)
    n_corrects, correct_labels, gt_correct_labels, totals = 0, 0, 0, 0
    label_corrects, label_totals = defaultdict(int), defaultdict(int)
    matrix_corrects, matrix_totals = 0, 0

    high_mi_heads_flatten = []
    label_indicator = []
    selected_conditional_probabs = []
    weights = [0] * (num_labels if include_neg_possibilities else (num_labels - 1))
    weights_flatten = []
    label_conditional_probabs_T = label_conditional_probabs.transpose(0, 1)
    for label_idx, label_high_mi_heads in enumerate(high_mi_heads if include_neg_possibilities else high_mi_heads[:-1]):
        high_mi_heads_flatten.extend(list(label_high_mi_heads))
        label_indicator.extend([label_idx] * len(label_high_mi_heads))
        weights[label_idx] += sum([label_mutual_informations[label_idx][head_idx] for head_idx in label_high_mi_heads])
        weights_flatten.extend([label_mutual_informations[label_idx][head_idx] for head_idx in label_high_mi_heads])
        selected_conditional_probabs.append(label_conditional_probabs_T[label_idx][torch.tensor(list(label_high_mi_heads))])

    high_mi_heads_flatten = torch.LongTensor(high_mi_heads_flatten) # [sum(label1_hi_mi_head, label2_hi_mi_head, ...)]
    label_indicator = torch.LongTensor(label_indicator) # [0, 0, 0, 0, 0, ... 1, 1, 1, 1, 1, ...]
    weights_flatten = torch.Tensor(weights_flatten) # [sum(label1_hi_mi_head, label2_hi_mi_head, ...)]
    
    selected_conditional_probabs = torch.cat(selected_conditional_probabs, dim=0).to(MODEL_DEVICE) # Tensor[sum(label1_hi_mi_head, label2_hi_mi_head, ...), n_x]

    num_high_mi_heads = selected_conditional_probabs.shape[0]
    expanded_x = x.unsqueeze(0).expand(num_high_mi_heads, -1).to(MODEL_DEVICE) # 


    model_cpu = deepcopy(model.cpu())
    model_cuda = model.to(MODEL_DEVICE)
    arcs_to_save, labels_to_save, results_to_save, probab_labels_to_save, attn_scores_to_save  = [], [], [], [], []

    for data_idx, (input_ids, input_attention_mask, label_mask, eval_mask, \
        arcs, rels, word_ids, pos_ids, ngram_ids, \
        ngram_positions, segment_ids, valid_ids) in enumerate(pbar := (dataloader if no_pbar else tqdm(dataloader, total=num_samples, desc='extracting attentions', mininterval=2))):

        echo = not silent and data_idx == 0
        if echo:
            print(input_ids.shape, label_mask.shape, valid_ids.shape, arcs.shape, rels.shape)
        if input_ids.shape[-1] > (128 if infer_method != 'variational_family' else 100): # move long samples to cpu to save GPU memory
            model = model_cpu
            input_ids, input_attention_mask, label_mask, eval_mask, arcs, rels, word_ids, pos_ids, ngram_ids, ngram_positions, segment_ids, valid_ids = \
                move_to_device('cpu', input_ids, input_attention_mask, label_mask, eval_mask, arcs, rels, word_ids, pos_ids, ngram_ids, ngram_positions, segment_ids, valid_ids)
        elif model.device != MODEL_DEVICE:
            model = model_cuda

        batch_size, S = input_ids.shape
        sequence_lengths = input_attention_mask.sum(1).tolist()
        text_lengths = [valid_ids[i, :sequence_lengths[i]].sum().item() for i in range(batch_size)]
        w2s = [] # the idx at position i is whole-word i's last subword's idx
        for sample_idx, valid_id in enumerate(valid_ids):
            w2s.append([])
            for subword_idx, each in enumerate(valid_id.tolist()):
                if each == 1:
                    w2s[-1].append(subword_idx)

        arc_adj_matrix = torch.zeros(batch_size, S, S).to(model.device) # 1 for `have arc`, 0 for `no arc` (marking at whole-word's last subword)
        label_adj_matrix = torch.zeros(batch_size, S, S).to(model.device) # arc[i][j]'s relation type (marking at whole-word's last subword)
        for sample_idx in range(batch_size):
            for word_idx, head_idx in enumerate(arcs[sample_idx]):
                if head_idx != -1:
                    arc_adj_matrix[sample_idx][w2s[sample_idx][word_idx]][w2s[sample_idx][head_idx]] = 1
                    label_adj_matrix[sample_idx][w2s[sample_idx][word_idx]][w2s[sample_idx][head_idx]] = rels[sample_idx][word_idx]

        if isinstance(model, LlamaModel):
            res = model.forward(
                input_ids=input_ids, attention_mask=input_attention_mask,
                output_hidden_states=True,
                output_attention_queries=True
            )
            key_values, queries = res.past_key_values, res.queries
            # kv: [num_layers, 2(k and v), batch_size, num_heads, sequence_length, head_dim], q: [num_layers, batch_size, num_heads, sequence_length, head_dim]
            if echo:
                print('past_key_values shape:', get_shape(res.past_key_values))
                print('queries shape:', get_shape(res.queries))

            attn_scores = () # [num_layers, 1(batch_size), num_heads, seq_len, seq_len]
            for layer_idx in range(len(model.layers)):
                k = key_values[layer_idx][0] # [batch_size, num_heads, sequence_length, head_dim]
                q = queries[layer_idx] # [batch_size, num_heads, sequence_length, head_dim]
                if q.shape[-3] != k.shape[-3]: # num_q_heads == num_k_heads * n_groups
                    n_groups = q.shape[-3] / k.shape[-3]
                    assert n_groups == int(n_groups)
                    n_groups = int(n_groups)
                else:
                    n_groups = 1
                
                k = torch.repeat_interleave(k, n_groups, -3) 
                # [batch_size, num_heads, sequence_length, head_dim] -> [batch_size, num_heads * num_groups, sequence_length, head_dim]
                this_attn_score = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(res.last_hidden_state.shape[-1]) # scaled dot-product attention
                # [batch_size, num_heads, sequence_length, sequence_length]
                if infer_method == 'score':
                    this_attn_score = this_attn_score.softmax(dim=-2 if transpose else -1)
                attn_scores += (this_attn_score.transpose(-2, -1) if transpose else this_attn_score,)
                # if infer_method == 'score':
                #     attn_scores[-1] = torch.nn.functional.softmax(attn_scores[-1], dim=-1)
                # breakpoint()
            attn_scores = torch.cat(attn_scores, dim=1) # [batch_size, num_heads * num_layers, sequence_length, sequence_length]

        else:
            res = model.forward(
                input_ids=input_ids, attention_mask=input_attention_mask,
                output_raw_attentions=True
            )
            attn_scores = torch.stack(res.attentions) # [num_layers, batch_size, num_heads, seq_len, seq_len]
            attn_scores = attn_scores.transpose(0, 1) # [batch_size, num_layers, num_heads, seq_len, seq_len]
            B, L, H, S, _ = attn_scores.shape
            attn_scores = attn_scores.reshape(B, L * H, S, S)
            if transpose:
                attn_scores = attn_scores.transpose(-2, -1)
        
        if echo:
            print('attn_scores shape:', get_shape(attn_scores))
        # mask attn scores
        for sample_idx in range(batch_size):
            attn_scores[sample_idx, :, sequence_lengths[sample_idx]:, sequence_lengths[sample_idx]:] = -torch.inf
        # if echo:
        #     print('attn_score after squeezing:', get_shape(attn_scores))
        # attn_features = attn_scores.permute(1, 2, 0) # [sequence_length, sequence_length, num_heads * num_layers]
        # attn_features = attn_scores
        # probab = torch.zeros_like(attn_scores).to(model.device)

        # if with_labels:
        #     probab_labels = torch.zeros(batch_size, S, S, num_labels - 1).to(model.device)

        B, H, S, _ = attn_scores.shape
        if infer_method in ['kde', 'score', 'variational_family']:
            if infer_method == 'variational_family':
                infer_model = infer_model.to(model.device)
                assert B == 1, f'only support batch_size == 1, got `{B}`'
                v_probab = infer_model(attn_scores.view(B * H, S * S).transpose(-2, -1).float()).squeeze(-2).permute(1, 0, 2).view(H, S, S, -1).sigmoid() # H, S, S, n_labels
                v_probab = torch.cat([v_probab[:, :, :, 2:], v_probab[:, :, :, 0].unsqueeze(-1)], dim=-1)
                # breakpoint()
                selected_v_all_probab = torch.index_select(v_probab, 0, high_mi_heads_flatten.to(model.device)).view(num_high_mi_heads, S * S, -1)
                selected_v_probab = torch.gather(selected_v_all_probab, -1, label_indicator.to(model.device).unsqueeze(-1).unsqueeze(-1).expand(num_high_mi_heads, S * S, 1)).squeeze(-1).unsqueeze(0)
            else:
                label_high_mi_head_attn_scores = torch.index_select(attn_scores, 1, high_mi_heads_flatten.to(model.device),).view(batch_size, num_high_mi_heads, S * S)

            if infer_method == 'kde':
                batched_conditional_probabs = torch.log2(batched_inference_by_func(
                    label_high_mi_head_attn_scores.to(model.device), 
                    expanded_x.to(model.device).unsqueeze(0).expand(batch_size, *expanded_x.shape), 
                    selected_conditional_probabs.to(model.device).unsqueeze(0).expand(batch_size, *selected_conditional_probabs.shape)))
            elif infer_method == 'score': # replace the original conditional probab with 
                batched_conditional_probabs = torch.log2(label_high_mi_head_attn_scores.float())
            elif infer_method == 'variational_family':
                batched_conditional_probabs = torch.log2(selected_v_probab.float())

        
            batched_conditional_probabs *= weights_flatten.to(model.device).unsqueeze(0).unsqueeze(-1).expand(batch_size, -1, S * S)

            probab_labels = torch.zeros(batch_size, num_labels if include_neg_possibilities else (num_labels - 1), S * S).to(model.device)

            # print(batched_conditional_probabs.shape, probab_labels.shape, max(label_indicator))
            probab_labels.scatter_reduce_(1, label_indicator.to(model.device).unsqueeze(0).unsqueeze(-1).expand(batch_size, -1, S * S), batched_conditional_probabs.to(model.device), 'sum')
            probab_labels = probab_labels.transpose(-1, -2).reshape(batch_size, S, S, num_labels if include_neg_possibilities else (num_labels - 1))
        elif infer_method == 'mlp':
            assert batch_size == 1, f'only support batch_size == 1, got `{batch_size}`'
            infer_model = infer_model.to(model.device).to(model.dtype)
            probab_labels = infer_model(attn_scores.view(H, S * S).transpose(-2, -1)).view(1, S, S, -1)
            probab_labels = torch.cat([probab_labels[:, :, :, 2:], probab_labels[:, :, :, 0].unsqueeze(-1)], dim=-1).softmax(dim=-1)
        else:
            raise NotImplementedError(f'`infer_method` {infer_method} is not implemented')

        # elif infer_method == 'score':
        #     probab_labels = torch.zeros(batch_size, num_labels if include_neg_possibilities else (num_labels - 1), S * S).to(model.device)
        #     label_high_mi_head_attn_scores = torch.index_select(attn_scores, 1, high_mi_heads_flatten.to(model.device),).view(batch_size, num_high_mi_heads, S * S)
        #     probab_labels.scatter_reduce_(1, label_indicator.to(model.device).unsqueeze(0).unsqueeze(-1).expand(batch_size, -1, S * S), label_high_mi_head_attn_scores.to(model.device), 'sum')
        #     probab_labels = probab_labels.transpose(-1, -2).reshape(batch_size, S, S, num_labels if include_neg_possibilities else (num_labels - 1))
        
        # elif infer_method == 'model':
        #     if next(infer_model.parameters()).device != attn_scores.device:
        #         infer_model = infer_model.to(attn_scores.device)
        #     probab_labels = infer_model(attn_scores.view(batch_size, S * S, -1)).view(batch_size, S, S, num_labels + 1)
        #     probab_labels = torch.cat([probab_labels[:, :, :, 2:], probab_labels[:, :, :, 0].unsqueeze(-1)], dim=-1)
        #     probab_labels = torch.nn.functional.softmax(probab_labels, dim=-1)

        # for feature_idx in range(attn_scores.shape[1]):
        #     if with_labels:
        #         for label_idx, label_probab in enumerate(label_conditional_probabs[feature_idx]):
        #             if label_idx == num_labels - 1: # the last label stands for no_arc
        #                 continue
        #             if feature_idx not in high_mi_heads[label_idx]:
        #                 continue
        #             weight = label_mutual_informations[label_idx][feature_idx] if use_weighted_mi else 1.0  # Weight by MI if specified
        #             for sample_idx in range(batch_size):
        #                 probab_labels[sample_idx, :, :, label_idx] += weight * torch.log2(
        #                     inference_by_func(
        #                         attn_scores[sample_idx, feature_idx].flatten().to(model.device), 
        #                         x.to(model.device), label_probab.flatten().to(model.device)
        #                     ).view(S, S), 
        #                 )
        #             weights[label_idx] += weight
        #             # neg_probab_labels[:, :, label_idx] += weight * torch.log2(
        #             #     inference_by_func(
        #             #         attn_scores[feature_idx].flatten().to(MODEL_DEVICE), x.to(MODEL_DEVICE), label_probab.flatten().to(MODEL_DEVICE)
        #             #     ).view(S, S)
        #             # )
        #         # normalize probab_labels
        #     else: # only condition on arcs
        #         raise NotImplementedError
        #         if mutual_informations[feature_idx] < mi_threshold:
        #             continue
        #         probab += mutual_informations[feature_idx] * torch.log2(
        #             inference_by_func(
        #                 attn_scores[feature_idx].flatten().to(model.device), x.to(model.device), conditional_probabs[feature_idx][0].flatten().to(model.device)
        #             ).view(*attn_scores.shape[1:])
        #         )
        # end for feature_idx
        label_mask[:, 0] = 0
        # print(valid_ids.device)
        if with_labels:
            for label_idx in range(num_labels if include_neg_possibilities else (num_labels - 1)):
                if infer_method != 'mlp':
                    probab_labels[:, :, :, label_idx] /= weights[label_idx]
            # probab_labels.exp_()
            # BEGIN: multiply with other negative probabilities
            if include_neg_possibilities and infer_method != 'mlp':
                probab_labels_neg = torch.log2(1 - probab_labels.exp())
                probab_labels_neg_summation = probab_labels_neg.sum(dim=-1, keepdim=True).expand_as(probab_labels_neg).clone()
                probab_labels_neg_summation -= probab_labels_neg
                probab_labels += probab_labels_neg_summation
            # END: multiply with other negative probabilities
            
            probab_labels_pooled, labels_with_max_probab = (probab_labels[:, :, :, :-1] if include_neg_possibilities else probab_labels).max(dim=-1, )
            # probab_labels_pooled.exp_() # [batch_size, n_tokens, n_tokens]
            for sample_idx in range(batch_size):
                length_mask = (torch.arange(S, device=input_ids.device) < sequence_lengths[sample_idx])
                this_probab_labels = probab_labels[sample_idx][(valid_ids[sample_idx] == 1) & length_mask][:, (valid_ids[sample_idx] == 1) & length_mask] # [n_words, n_words, n_labels - 1]
                this_probab_labels_pooled = probab_labels_pooled[sample_idx][(valid_ids[sample_idx] == 1) & length_mask][:, (valid_ids[sample_idx] == 1) & length_mask] # [n_words, n_words]
                # [n_tokens, n_tokens] -> [n_words, n_words]
                this_labels_with_max_probab = labels_with_max_probab[sample_idx][(valid_ids[sample_idx] == 1) & length_mask][:, (valid_ids[sample_idx] == 1) & length_mask]

                result = eisner(this_probab_labels_pooled.unsqueeze(0), label_mask[sample_idx].unsqueeze(0)) # [1, n_words]
                this_arcs = arcs[sample_idx][:text_lengths[sample_idx]]
                arcs_to_save.append(this_arcs.cpu())
                results_to_save.append(result[0].cpu())
                correct_mask = (result[0] == this_arcs)[eval_mask[sample_idx][:text_lengths[sample_idx]]]
                n_corrects += correct_mask.sum().item()
                totals += correct_mask.shape[0]
                # totals += eval_mask[sample_idx][:text_lengths[sample_idx]].sum()
                result_labels = Tensor([labelmap_convert[this_labels_with_max_probab[token_idx][head_idx].item()] \
                    for token_idx, head_idx in enumerate(result[0].tolist())]).to(result.device,) # [S]
                gt_labels = rels[sample_idx, :text_lengths[sample_idx]] # [S]
                correct_labels += (result_labels == gt_labels)[eval_mask[sample_idx][:text_lengths[sample_idx]]][correct_mask].sum().item()
                gt_labels_converted = [(labelmap_reverse_convert[label] if label in labelmap_reverse_convert else -1) for label in gt_labels.tolist()] # gt_labels' corresponding unmasked label indices
                labels_to_save.append(gt_labels_converted)
                gt_arcs_predicted_labels = Tensor([labelmap_convert[this_labels_with_max_probab[token_idx][head_idx].item()] \
                    for token_idx, head_idx in enumerate(this_arcs.tolist())]).to(result.device,) # [S]
                gt_arcs_probabilities = torch.stack([this_probab_labels[token_idx][head_idx] \
                    for token_idx, head_idx in enumerate(this_arcs.tolist())]).to(result.device,)
                probab_labels_to_save.append(this_probab_labels.cpu())
                gt_correct_label_mask = (gt_arcs_predicted_labels == gt_labels)
                gt_correct_labels += gt_correct_label_mask[eval_mask[sample_idx, :text_lengths[sample_idx]]].sum().item()
                for label_idx in range(gt_labels.max().item() + 1):
                    label_corrects[label_idx] += gt_correct_label_mask[gt_labels == label_idx].sum().item()
                    label_totals[label_idx] += (gt_labels == label_idx).sum().item()
                
                # matrix level (including negative samples) ! WARNING, those code are not guaranteed to run when `batch_size > 1`
                gt_matrix = torch.ones_like(this_labels_with_max_probab) * (label_conditional_probabs.shape[1] - 1)

                for idx in range(len(gt_labels_converted)):
                    if this_arcs[idx] != -1:
                        gt_matrix[idx][this_arcs[idx]] = gt_labels_converted[idx]
                gt_matrix_mask = (gt_matrix != -1)
                matrix_correct_mask = (gt_matrix == this_labels_with_max_probab)[gt_matrix_mask]
                matrix_corrects += matrix_correct_mask.sum().item()
                matrix_totals += matrix_correct_mask.numel()
                # print(result)
        else:
            raise NotImplementedError
            result = eisner(probab[valid_ids[0] == 1][:, valid_ids[0] == 1].unsqueeze(0), label_mask)
        # print(result, result.shape)
        # print(arcs, arcs.shape)
        # print(eval_mask, eval_mask.shape)
        if data_idx >= num_samples - 1:
            break

        current_result = {"metrics": {"UAS": round(n_corrects / totals * 100, 2), "LAS": round(correct_labels / totals * 100, 2), "GTLAS": round(gt_correct_labels / totals * 100, 2), "MLAS": round(matrix_corrects / matrix_totals * 100, 2)}, 
                          "label_corrects": label_corrects, "label_totals": label_totals, "arcs": arcs_to_save, "labels": labels_to_save, "results": results_to_save, "probab_labels": probab_labels_to_save}
        if not no_pbar:
            pbar.set_postfix({**current_result['metrics'], 'len': input_ids.shape[-1]})

    return current_result

# if "CUDA_LAUNCH_BLOCKING" in os.environ:
#     os.environ.pop("CUDA_LAUNCH_BLOCKING")
torch.cuda.empty_cache()
# results = infer_by_attn_features(tok_llama, model_llama.to(MODEL_DEVICE), dl_val_llama, {unmasked_label_idx: dep_parser_llama.labelmap[label_name] for unmasked_label_idx, label_name in enumerate(unmasked_label_names)}, x, [*zip(conditional_probab_pos, conditional_probab_neg)], arc_mi, conditional_probab_label, binary_label_mi, -1, with_labels=True, high_mi_heads=threshold_high_mi_heads, no_pbar=False)
# %lprun -f infer_by_attn_features 
_INFER_MODE = 'control'
_BASELINE_NAME = 'v_information'
print(f'{_INFER_MODE =}')
os.makedirs('./inference_results', exist_ok=True)
if _INFER_MODE == 'mixed_grid':
    # for mixed_threshold in [0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5] if TRANSPOSE else [0.1, 0.13, 0.16, ]:
    # mixed_threshold = 0.133 if TRANSPOSE else 0.15
    # last_threshold = 0.133 if TRANSPOSE else 0.15
    threshold = 0.133 if TRANSPOSE else 0.1432
    # alpha = 0.5
    results = []
    for alpha in [0.5]:#[0.8, 0.7]:#, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.0]:
        print(f'doing experiment on {alpha=}, {threshold=}')
        threshold_high_mi_heads = get_high_mi_heads_threshold_mix(binary_label_mi, label_binary_entropies, binary_label_mi_pos, label_binary_entropies_pos, alpha=alpha, threshold=threshold, last_threshold=threshold)
        print(f'total {sum([len(each) for each in threshold_high_mi_heads])} heads (including {len(threshold_high_mi_heads[-1])} neg heads)')
        inference_result = infer_by_attn_features(tok_llama, model_llama.to(MODEL_DEVICE), dl_val_llama, {unmasked_label_idx: dep_parser_llama.labelmap[label_name] for unmasked_label_idx, label_name in enumerate(unmasked_label_names)}, x, [*zip(conditional_probab_pos, conditional_probab_neg)], arc_mi, conditional_probab_label, (torch.Tensor(binary_label_mi[:-1]) * alpha + torch.Tensor(binary_label_mi_pos) * (1 - alpha)).tolist() + [binary_label_mi[-1]], -1, with_labels=True, high_mi_heads=threshold_high_mi_heads, no_pbar=False, transpose=TRANSPOSE, include_neg_possibilities=True)
        with open(f"./inference_results/mixed_alpha_{alpha}_t1_{threshold}_t2_{threshold}{'_transpose' if TRANSPOSE else ''}.pkl", 'wb') as f:
            pkl.dump({k: v for k, v in inference_result.items() if k in ['metrics', 'label_corrects', 'label_totals']}, f)
        labelmap_pth = f"./inference_results/sample_labelmap.json"
        if not osp.exists(labelmap_pth):
            with open(labelmap_pth, 'w') as f:
                json.dump(dep_parser_llama.labelmap, f)
        results.append(inference_result)
elif _INFER_MODE == 'control':
    n_labels_from_mi, n_heads_from_mi = len(binary_label_mi), len(binary_label_mi[0])
    control_heads = [rd.choices(range(n_heads_from_mi), k=2000 // n_labels_from_mi)]
    inference_result = infer_by_attn_features(tok_llama, model_llama.to(MODEL_DEVICE), dl_val_llama, {unmasked_label_idx: dep_parser_llama.labelmap[label_name] for unmasked_label_idx, label_name in enumerate(unmasked_label_names)}, x, [*zip(conditional_probab_pos, conditional_probab_neg)], arc_mi, conditional_probab_label, torch.ones_like(torch.Tensor(binary_label_mi)), -1, with_labels=True, high_mi_heads=control_heads, no_pbar=False, transpose=TRANSPOSE, include_neg_possibilities=True)
elif _INFER_MODE == 'posneg_single':
    threshold_high_mi_heads = get_high_mi_heads_threshold(binary_label_mi, label_binary_entropies, 0.15 if TRANSPOSE else 0.1734)
    print(f'total {sum([len(each) for each in threshold_high_mi_heads])} heads (including {len(threshold_high_mi_heads[-1])} neg heads)')
    inference_result = infer_by_attn_features(tok_llama, model_llama.to(MODEL_DEVICE), dl_val_llama, {unmasked_label_idx: dep_parser_llama.labelmap[label_name] for unmasked_label_idx, label_name in enumerate(unmasked_label_names)}, x, [*zip(conditional_probab_pos, conditional_probab_neg)], arc_mi, conditional_probab_label, binary_label_mi, -1, with_labels=True, high_mi_heads=threshold_high_mi_heads, no_pbar=False, transpose=TRANSPOSE, include_neg_possibilities=True)
# elif _INFER_MODE == 'posneg_threshold_grid':
#     metrics = []
#     threshold_grid = torch.cat([torch.arange(0.4, 0.46, 0.01), torch.arange(0.46, 0.5, 0.02), torch.arange(0.5, 0.7001, 0.04)]).tolist() if TRANSPOSE else [0.2]
#     for tidx, threshold in enumerate(threshold_grid):
#         for include_neg_possibilities in [False, True]:
#             print(f'testing on threshold: {threshold}, transpose: {TRANSPOSE}, include neg: {include_neg_possibilities} ({tidx * 2 + int(include_neg_possibilities) + 1} / {len(threshold_grid)})')
#             label_high_mi_heads = get_high_mi_heads_threshold(binary_label_mi, threshold)
#             # print(f'testing on mass: {0.05}, transpose: {TRANSPOSE}')
#             # label_high_mi_heads = get_high_mi_heads_mass(binary_label_mi, 0.05)
#             result_save_pth = osp.join(kde_save_pth, 'inference_results', f"{'transpose' if TRANSPOSE else 'original'}_{threshold}.pt")
#             print(f"total {sum([len(each) for each in label_high_mi_heads])} heads")
#             print(f"inference results will be saved to {result_save_pth}")
#             inference_result = infer_by_attn_features(tok_llama, model_llama.to(MODEL_DEVICE), dl_val_llama, {unmasked_label_idx: dep_parser_llama.labelmap[label_name] for unmasked_label_idx, label_name in enumerate(unmasked_label_names)}, x, [*zip(conditional_probab_pos, conditional_probab_neg)], arc_mi, conditional_probab_label, binary_label_mi, -1, with_labels=True, high_mi_heads=label_high_mi_heads, no_pbar=False, transpose=TRANSPOSE, include_neg_possibilities=include_neg_possibilities)
#             print('result:', inference_result['metrics'])
#             metrics.append(inference_result['metrics'])
    # with open(result_save_pth, 'wb') as f:
    #     torch.save(inference_result, f)
# elif _INFER_MODE == 'posneg_mass_grid':
#     for mass in [0.01, 0.02, 0.03, 0.05, 0.08, 0.1, 0.2][::-1]:
#         print('testing on method `mass`, mass:', mass)
#         label_high_mi_heads = get_high_mi_heads_mass(binary_label_mi, mass)
#         print('result:', infer_by_attn_features(tok_llama, model_llama.to(MODEL_DEVICE), dl_val_llama, {unmasked_label_idx: dep_parser_llama.labelmap[label_name] for unmasked_label_idx, label_name in enumerate(unmasked_label_names)}, x, [*zip(conditional_probab_pos, conditional_probab_neg)], arc_mi, conditional_probab_label, binary_label_mi, -1, with_labels=True, high_mi_heads=label_high_mi_heads, no_pbar=False))
elif _INFER_MODE == 'head_selection_baseline':
    baseline_name2matrix = {'probeless': pl_matrix, 'iou': iou_matrix, 'elasticnet': elasticnet_matrix, 'v_information': bounded_v_information_matrix}
    baseline_name2threshold = {'probeless': 20.7031, 'iou': 0.00488758, 'elasticnet': 0.062377844005823135, 'v_information': 4.406166076660156}
    print(f'doing baseline inference on {_BASELINE_NAME}, threshold: {baseline_name2threshold[_BASELINE_NAME]}')
    baseline_high_mi_heads = select_baseline_heads_by_threshold(baseline_name2matrix[_BASELINE_NAME], baseline_name2threshold[_BASELINE_NAME])
    print(f'total {sum([len(each) for each in baseline_high_mi_heads])} heads (including {len(baseline_high_mi_heads[-1])} neg heads)')
    inference_result = infer_by_attn_features(tok_llama, model_llama.to(MODEL_DEVICE), dl_val_llama, {unmasked_label_idx: dep_parser_llama.labelmap[label_name] for unmasked_label_idx, label_name in enumerate(unmasked_label_names)}, x, [*zip(conditional_probab_pos, conditional_probab_neg)], arc_mi, conditional_probab_label, baseline_name2matrix[_BASELINE_NAME], -1, with_labels=True, high_mi_heads=baseline_high_mi_heads, no_pbar=False, include_neg_possibilities=True)

elif _INFER_MODE == 'baseline_infer_by_mlp':
    infer_by_attn_features(tok_llama, model_llama.to(MODEL_DEVICE), dl_val_llama, {unmasked_label_idx: dep_parser_llama.labelmap[label_name] for unmasked_label_idx, label_name in enumerate(unmasked_label_names)}, x, [*zip(conditional_probab_pos, conditional_probab_neg)], arc_mi, conditional_probab_label, binary_label_mi, -1, with_labels=True, high_mi_heads=threshold_high_mi_heads, no_pbar=False, include_neg_possibilities=True, infer_method='mlp', infer_model=deepmlp)
elif _INFER_MODE == 'baseline_infer_by_variational_family':
    infer_by_attn_features(tok_llama, model_llama.to(MODEL_DEVICE), dl_val_llama, {unmasked_label_idx: dep_parser_llama.labelmap[label_name] for unmasked_label_idx, label_name in enumerate(unmasked_label_names)}, x, [*zip(conditional_probab_pos, conditional_probab_neg)], arc_mi, conditional_probab_label, binary_label_mi, -1, with_labels=True, high_mi_heads=threshold_high_mi_heads, no_pbar=False, include_neg_possibilities=True, infer_method='variational_family', infer_model=variational_family)
    
elif _INFER_MODE == 'baseline_infer_by_scores':
    for topk in [6, 7, 8]:
        threshold_high_mi_heads = get_high_mi_heads_threshold(binary_label_mi, label_binary_entropies, 114514, topk=topk)
        print(f'topk: {topk}, total {sum([len(each) for each in threshold_high_mi_heads])} heads (including {len(threshold_high_mi_heads[-1])} neg heads)')
        infer_by_attn_features(tok_llama, model_llama.to(MODEL_DEVICE), dl_val_llama, {unmasked_label_idx: dep_parser_llama.labelmap[label_name] for unmasked_label_idx, label_name in enumerate(unmasked_label_names)}, x, [*zip(conditional_probab_pos, conditional_probab_neg)], arc_mi, conditional_probab_label, binary_label_mi, -1, with_labels=True, high_mi_heads=threshold_high_mi_heads, no_pbar=False, include_neg_possibilities=True, infer_method='score')
    # print('result (pl):', infer_by_attn_features(tok_llama, model_llama.to(MODEL_DEVICE), dl_val_llama, {unmasked_label_idx: dep_parser_llama.labelmap[label_name] for unmasked_label_idx, label_name in enumerate(unmasked_label_names)}, x, [*zip(conditional_probab_pos, conditional_probab_neg)], arc_mi, conditional_probab_label, pl_matrix, -1, with_labels=True, high_mi_heads=pl_high_mi_heads, no_pbar=False))
    # print('result (iou):', infer_by_attn_features(tok_llama, model_llama.to(MODEL_DEVICE), dl_val_llama, {unmasked_label_idx: dep_parser_llama.labelmap[label_name] for unmasked_label_idx, label_name in enumerate(unmasked_label_names)}, x, [*zip(conditional_probab_pos, conditional_probab_neg)], arc_mi, conditional_probab_label, iou_matrix, -1, with_labels=True, high_mi_heads=iou_high_mi_heads, no_pbar=False))
    # print('result (elasticnet, equal_contribution):', infer_by_attn_features(tok_llama, model_llama.to(MODEL_DEVICE), dl_val_llama, {unmasked_label_idx: dep_parser_llama.labelmap[label_name] for unmasked_label_idx, label_name in enumerate(unmasked_label_names)}, x, [*zip(conditional_probab_pos, conditional_probab_neg)], arc_mi, conditional_probab_label, torch.ones_like(elasticnet_matrix), -1, with_labels=True, high_mi_heads=elasticnet_high_mi_heads, no_pbar=False))
    # # elasticnet: UAS: 16.3, GTLAS=13.5, elasticnet (euqal): UAS=33.5, GTLAS=49.8
    # print('result (random):', infer_by_attn_features(tok_llama, model_llama.to(MODEL_DEVICE), dl_val_llama, {unmasked_label_idx: dep_parser_llama.labelmap[label_name] for unmasked_label_idx, label_name in enumerate(unmasked_label_names)}, x, [*zip(conditional_probab_pos, conditional_probab_neg)], arc_mi, conditional_probab_label, torch.ones_like(elasticnet_matrix), -1, with_labels=True, high_mi_heads=random_high_mi_heads, no_pbar=False))
elif _INFER_MODE == 'baseline_grid':
    def baseline_grid_test(baseline_name: str, baseline_matrix: Tensor, baseline_thresholds: List[float], baseline_head_sets: List[List[int]], baseline_num_heads: List[int]):
        for threshold, head_set, num_heads in zip(baseline_thresholds, baseline_head_sets, baseline_num_heads):
            print(f'testing on method `{baseline_name}`, threshold: {threshold} ({num_heads} heads)')
            print('result:', infer_by_attn_features(tok_llama, model_llama.to(MODEL_DEVICE), dl_val_llama, {unmasked_label_idx: dep_parser_llama.labelmap[label_name] for unmasked_label_idx, label_name in enumerate(unmasked_label_names)}, x, [*zip(conditional_probab_pos, conditional_probab_neg)], arc_mi, conditional_probab_label, baseline_matrix, -1, with_labels=True, high_mi_heads=head_set, no_pbar=False))
    # threshold 0.01 UAS 36.1
    # threshold 0.02 UAS 39.4
    # threshold 0.025 UAS 36.5
    baseline_grid_test('pl', pl_matrix, pl_thresholds, pl_head_sets, pl_num_heads)
    baseline_grid_test('iou', iou_matrix, iou_thresholds, iou_head_sets, iou_num_heads)
    baseline_grid_test('elasticnet', elasticnet_matrix, elasticnet_thresholds, elasticnet_head_sets, elasticnet_num_heads)

# 3bv2 labels (baseline): UAS: 26.6, GTLAS: 59.5
# 3bv2 labels (conditional_probab_label includes no_arc, binary_label_mi includes no_arc): UAS 47.8 GTLAS 53.1
# 7b: UAS=48.5, GTLAS=52.1
# 13b: UAS=48.6, GTLAS=59.3

In [None]:
# print([len(each) for each in conditional_probab_label])
# binary_label_mi.append(
#     estimate_mi(
#         x, [label_probabilities[label_idx], 1 - label_probabilities[label_idx]], marginal_probab_attn_from_label, 
#         torch.cat([
#             joint_probab_label_stacked[:, label_idx, :].view(n_attn_features, 1, n_x), 
#             joint_probab_label_stacked[:, torch.arange(n_labels) != label_idx, :].sum(1).view(n_attn_features, 1, n_x)
#         ], dim=1), # Tensor[n_heads, 2, n_x]
#         KDE_DEVICE
#     )
# )
# print(len(label_probabilities), joint_probab_label_stacked.shape, marginal_probab_attn_from_label.shape)

# def estimate_mi_at_labelset(selected_abel_probabilities, ):
#     estimate_mi(
#         x, [label_probabilities[label_idx], 1 - label_probabilities[label_idx]], marginal_probab_attn_from_label, 
#         torch.cat([
#             joint_probab_label_stacked[:, label_idx, :].view(n_attn_features, 1, n_x), 
#             joint_probab_label_stacked[:, torch.arange(n_labels) != label_idx, :].sum(1).view(n_attn_features, 1, n_x)
#         ], dim=1), # Tensor[n_heads, 2, n_x]
#         KDE_DEVICE
#     )

In [None]:
# a = torch.arange(1, 10).view(3, 3)
# a[a != 2]

#### 1.6 result and mi analysis

In [None]:
#papermill_description=SET_DATASET_PRIORS
with open('./data/dataset_priors.pkl', 'rb') as f:
    dataset_priors = pkl.load(f)

label2avgdist, label2direction, label2direction_binary = dataset_priors['label2avgdist'], dataset_priors['label2direction'], dataset_priors['label2direction_binary']
populations = []
for idx, label_samples in enumerate(label_features):
    n_samples = label_samples.shape[0]
    if n_samples != 0:
        populations.append(n_samples)

sorted_avgdist = sorted(label2avgdist.values(), reverse=True)
label2distrank = {label_name: sorted_avgdist.index(label2avgdist[label_name]) for label_name in label2avgdist}

print(populations)

In [None]:
#papermill_description=POPULATIONS_AND_DIRECTIONS
for label_idx, population in sorted([*enumerate(populations)], key=lambda x: x[1], reverse=True):
    label_name = unmasked_label_names[label_idx]
    print(f"{label_name}({label2direction_binary[label_name]}): {population}")

[39832, 0]

##### 1.6.1 Correlation between MI and Acc, accopanied with number of samples

In [None]:
#papermill_description=ANALYZE_MI_RESULT_CORR
dataset_labelid_to_kde_labelid = {dep_parser_llama.labelmap[unmasked_label_name]: unmasked_label_idx for unmasked_label_idx, unmasked_label_name in enumerate(unmasked_label_names)}
kde_labelid_to_dataset_labelid = {unmasked_label_idx: dep_parser_llama.labelmap[unmasked_label_name]  for unmasked_label_idx, unmasked_label_name in enumerate(unmasked_label_names)}

def analyze_mi_result_correlationship(result_entry: Dict[str, Any], mi: List[List[int]], entropies: List[int]):
    accs, avg_mis, n_samples = [], [], []
    for dataset_labelid, kde_labelid in dataset_labelid_to_kde_labelid.items():
        accs.append(result_entry['label_corrects'][dataset_labelid] / result_entry['label_totals'][dataset_labelid] * 100)
        avg_mis.append(sum(mi[kde_labelid]) / len(mi[kde_labelid]) / entropies[kde_labelid])
        n_samples.append(result_entry['label_totals'][dataset_labelid])
    
    # print(unmasked_label_names[avg_mis.index(max(avg_mis))])
    accs, avg_mis, n_samples, label_names = [*zip(*sorted(zip(accs, avg_mis, n_samples, unmasked_label_names), key=lambda x: x[1]))]

    return accs, avg_mis, n_samples, label_names

accs_by_mi, avg_mis_by_mi, n_samples_by_mi, label_names_by_mi = analyze_mi_result_correlationship(results[1], binary_label_mi, label_binary_entropies)
plt.figure(figsize=(32, 32))
plt.scatter(avg_mis_by_mi[:-1], accs_by_mi[:-1])
for acc, avg_mi, n_sample, label_name in zip(accs_by_mi, avg_mis_by_mi, n_samples_by_mi, label_names_by_mi):
    plt.annotate(f'{label_name}\n({n_sample})', (avg_mi, acc))
plt.savefig('./plot.png', dpi=400)

In [None]:
#papermill_description=VISUALIZE_MI_STACKPLOT
accs_by_acc, label_names_by_acc = [*zip(*sorted(zip(accs_by_mi, label_names_by_mi), key=lambda x: x[0], reverse=True))]
high_acc_label_names = label_names_by_acc[:16]
for label_rank, label_name in enumerate(tqdm(high_acc_label_names, desc=f'drawing & saving stackplots')):
    direction = label2direction_binary[label_name]
    distrank = label2distrank[label_name]
    label_idx = unmasked_label_names.index(label_name)
    mi = binary_label_mi[label_idx]
    mi_pos = binary_label_mi_pos[label_idx]
    L, H = model_llama.config.num_hidden_layers, model_llama.config.num_attention_heads
    # fig = plt.figure(figsize=(8, 8))
    # plt.stackplot(range(L), [mi[i::H] for i in range(H)])
    # plt.title(f'mi stack plot for label {label_name} (accrank: {label_rank}, direction: {direction}, distrank: {distrank})')
    # plt.savefig(f"./visualization_results/mi_stackplot{'_transpose' if TRANSPOSE else ''}/{label_rank:02d}_{label_name}_{direction}_d{distrank:02d}.png", dpi=400)
    # plt.close(fig)

    fig = plt.figure(figsize=(8, 8))
    plt.stackplot(range(L), [mi_pos[i::H] for i in range(H)])
    plt.title(f'mi (pos) stack plot for label {label_name} (accrank: {label_rank}, direction: {direction}, distrank: {distrank})')
    plt.savefig(f"./visualization_results/mi_pos_stackplot{'_transpose' if TRANSPOSE else ''}/{label_rank:02d}_{label_name}_{direction}_d{distrank:02d}.png", dpi=400)
    plt.close(fig)

In [None]:
#papermill_description=VISUALIZE_PROBAB_DISTRIBUTIONS
for label_rank, label_name in enumerate(tqdm(high_acc_label_names, desc='drawing & saving probab distributions')):
    direction = label2direction_binary[label_name]
    label_idx = unmasked_label_names.index(label_name)
    mi = binary_label_mi[label_idx]
    high_mi_heads = [each[0] for each in sorted([*enumerate(mi)], key=lambda x: x[1], reverse=True)[:5]]
    for head_rank, head_idx in enumerate(high_mi_heads):
        fig = plt.figure(figsize=(8, 8))
        plt.subplot(2, 1, 1).plot(x.cpu(), joint_probab_label_stacked[head_idx][label_idx].cpu(), label='label')
        plt.subplot(2, 1, 2).plot(x.cpu(), marginal_probab_attn_from_label[head_idx].cpu(), label='marginal')
        plt.title(f'probab distribution for label {label_name} and head {head_idx} (No. {head_rank})')
        plt.legend()
        plt.savefig(f"./visualization_results/probab_distribution{'_transpose' if TRANSPOSE else ''}/{label_rank:02d}_{label_name}_{direction}_{head_rank:02d}_head{head_idx}.png", dpi=400)

In [None]:
with open(osp.join('inference_results', 'original_0.2.json'), 'r') as f_original, open(osp.join('inference_results', 'transpose_0.5.json'), 'r') as f_transpose:
    result_original = json.load(f_original)
    result_transpose = json.load(f_transpose)
with open(osp.join(kde_save_pth, 'mi_1.json')) as f_original, open(osp.join(kde_save_pth, 'mi_1_transpose.json')) as f_transpose:
    mi_original = json.load(f_original)
    mi_transpose = json.load(f_transpose)

In [None]:
def get_macro_acc(result):
    label_corrects, label_totals = [*result['label_corrects'].values()], [*result['label_totals'].values()]
    # label_totals is possibly zero
    individual_results = [label_correct / label_total for (label_correct, label_total) in zip(label_corrects, label_totals) if label_total > 0]
    return sum(individual_results) / len(individual_results)

print(get_macro_acc(result_original))
print(get_macro_acc(result_transpose))

In [None]:
print_listlike([*zip([max(each) for each in mi_original['binary_label_mi']], [max(each) for each in mi_transpose['binary_label_mi']])])

In [None]:
print_listlike([*zip([max(each) for each in mi_original['binary_label_mi']], [max(each) for each in mi_transpose['binary_label_mi']])])
print_listlike([*zip(mi_original['arc_mi'], mi_transpose['arc_mi'])])

(0.008509306237101555, 0.011922112666070461) (0.005480445921421051, 0.01151051465421915) (0.0038365633226931095, 0.00990423932671547) (0.0012449548812583089, 0.00282877404242754) (0.006173050031065941, 0.012894314713776112) (0.004695800598710775, 0.013605937361717224) (0.001123006222769618, 0.002053096890449524) (0.0014259052695706487, 0.0018166368827223778) (0.005726659670472145, 0.008281094953417778) (0.00371841574087739, 0.00848476029932499)
(0.00032986191217787564, 0.0008982069557532668) (0.0035700106527656317, 0.004983129911124706) (0.0003002947196364403, 0.0015047428896650672) (0.00217073573730886, 0.002882377477362752) (0.00163302943110466, 0.0029146289452910423) (0.00040211749728769064, 0.0009975847788155079) (0.00038957095239311457, 0.0007821473991498351) (0.0008883021655492485, 0.001717525301501155) (0.0012408897746354342, 0.0030396226793527603) (0.005120383109897375, 0.007830372080206871)
(0.0008684846106916666, 0.0017866557464003563) (0.003090638667345047, 0.004927610047161

: 

: 

: 

In [None]:
def get_weighted_average_mi(num_samples_for_each_label, binary_mi):
    total_weighted = 0
    for num_samples_for_label, binary_mi_for_label in zip(num_samples_for_each_label, binary_mi):
        total_weighted += (num_samples_for_label / sum(num_samples_for_each_label)) * (sum(binary_mi_for_label) / len(binary_mi_for_label))
    
    return total_weighted

label_num_samples = [v for k, v in result_original['label_totals'].items() if v != 0]
# print(len(mi_original['binary_label_mi']), len(label_num_samples))
# print(label_num_samples)
num_samples_for_each_label = [each.shape[0] for each in label_features if each.shape[0]] + [n_neg_samples]

print(get_weighted_average_mi(num_samples_for_each_label, mi_original['binary_label_mi']))
print(get_weighted_average_mi(num_samples_for_each_label, mi_transpose['binary_label_mi']))

0.008965459026784438
0.08830101225334051


: 

: 

: 

In [None]:
inverse_labelmap = {v: k for k, v in dep_parser_llama.labelmap.items()}
label_entries = []
for label_idx in results['label_corrects']:
    if label_idx in inverse_labelmap:
        if results['label_totals'][label_idx] == 0:
            label_entries.append((inverse_labelmap[label_idx], results['label_corrects'][label_idx], results['label_totals'][label_idx], 0))
        else:
            label_entries.append((inverse_labelmap[label_idx], results['label_corrects'][label_idx], results['label_totals'][label_idx], round(results['label_corrects'][label_idx] / results['label_totals'][label_idx] * 100, 2)))

large_quantity_label_entries = sorted(label_entries, key=lambda x: x[-2], reverse=True)[:len(label_entries) // 2]
print(*sorted(large_quantity_label_entries, key=lambda x: x[-1]), sep='\n')

('<s>', 0, 1700, 0.0)
('num', 1, 1087, 0.09)
('dep', 3, 756, 0.4)
('cop', 2, 359, 0.56)
('dobj', 14, 1620, 0.86)
('mark', 4, 422, 0.95)
('number', 8, 493, 1.62)
('poss', 29, 706, 4.11)
('advmod', 94, 1257, 7.48)
('amod', 316, 2477, 12.76)
('aux', 240, 1243, 19.31)
('cc', 314, 1001, 31.37)
('possessive', 140, 434, 32.26)
('xcomp', 142, 434, 32.72)
('conj', 572, 1004, 56.97)
('ccomp', 380, 560, 67.86)
('det', 2283, 3341, 68.33)
('prep', 3129, 3783, 82.71)
('nn', 2726, 3240, 84.14)
('nsubj', 2438, 2836, 85.97)
('punct', 4397, 4731, 92.94)
('pobj', 3563, 3744, 95.17)
('root', 1700, 1700, 100.0)


: 

: 

: 

In [None]:
with open('./label_entries.json', 'w') as f:
    json.dump(label_entries, f, indent=2)

: 

: 

: 

In [None]:
len(unmasked_label_names), len(dep_parser_llama.labelmap)

(45, 47)

: 

: 

: 

In [None]:
print(get_shape([*zip(conditional_probab_pos, conditional_probab_neg)]))
print(get_shape(conditional_probab_label))
1 / 0

[1024(A), 2(A), [168](T)]
[1024(A), 46(A), [168](T)]


ZeroDivisionError: division by zero

: 

: 

: 

In [None]:
fig = plt.figure()
plt.plot(torch.arange(-3.8, 2.9, 0.05), pos_kde_estims[1].cpu())
plt.plot(torch.arange(-3.8, 2.9, 0.05), neg_kde_estims[1].cpu())
plt.show()
plt.close(fig)

: 

: 

: 

### 2[Optional]. Calculate Corr Matrix

In [None]:
attn_feature_cov = torch.cov(arc_attn_features.T)

AttributeError: 'tuple' object has no attribute 'T'

: 

: 

: 

In [None]:
# print(arc_attn_features.shape)
attn_feature_cov_naiive = torch.matmul(arc_attn_features.T, arc_attn_features) / arc_attn_features.shape[0]

: 

: 

: 

In [None]:
from matplotlib import pyplot as plt

fig = plt.figure(figsize=(32, 32))
plt.imshow(attn_feature_cov_naiive.float())

<matplotlib.image.AxesImage at 0x7fc8940eebe0>

: 

: 

: 

### 3[Optional]. Estimate spearsman-corr of each dimension

: 

: 

: 

In [None]:
import pickle as pkl

for i in trange(len(arc_feature_with_y)):
    pkl.dump(arc_feature_with_y, open(f'/tmp/pickles/arc_feature_with_y_dim{i}.pkl', 'wb'))

: 

: 

: 

In [None]:
from multiprocessing import Pool

def estimate_mi_for_dimension(features2y: List[Tensor]):
    pass

def estimate_spearmanr_for_dimension(features2y: List[Tensor], num_features: int = -1):
    """
    Args:
        features2y: List[Tensor[total_num_features, num_rel_categories] ...(total `featurn_dim_size` Tensors)]
        num_features: the first n features to process
    returns:
        corrs: List[float], list of spearman relations
        pvalues: List[float], list of p-values on null hypothesis of `feature_i and y has no relationship`
    """
    if num_features == -1:
        num_features = len(features2y)
    corrs, pvalues = [], []
    for feature_idx, (feature_values, ys) in enumerate(tqdm(features2y, desc='calculating for each dimension...')):
        # print(f"feature_idx: {feature_idx}, spearman_r: {spearmanr(feature_values.numpy(), ys.numpy())}")
        corr, pvalue = spearmanr(feature_values.numpy(), ys.numpy())
        corrs.append(corr), pvalues.append(pvalue)

    # results = Pool(processes=20).imap(spearmanr, features2y)
    # corrs, pvalues = zip(*results)
    
    return corrs, pvalues

corrs, pvalues = estimate_spearmanr_for_dimension(arc_feature_with_y, 100)

: 

: 

: 

In [None]:
max_idx = np.argmax(np.abs(corrs))
def plot_single_head(arc_feature_with_y):
    fig = plt.figure(figsize=(6, 3))
    plt.scatter(*arc_feature_with_y, s=0.2)
    plt.show()
    plt.close(fig)

plot_single_head(arc_feature_with_y[max_idx])

: 

: 

: 

In [None]:
print(spearmanr(arc))

: 

: 

: 

In [None]:
fig = plt.figure(figsize=(10, 10))
plt.bar([*range(len(corrs))], sorted(corrs))
plt.show()
plt.close(fig)

: 

: 

: 

In [None]:
print(sorted(pvalues[:10], reverse=True))

[1.6917530078326845e-25, 4.324582566083494e-256, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]


: 

: 

: 

In [None]:
fig = plt.figure(figsize=(10, 10))
plt.bar([*range(len(pvalues))], sorted(pvalues))
plt.show()

: 

: 

: 

In [None]:
import matplotlib.pyplot as plt

fig = plt.figure()
plt.plot([*range(len(corrs))], corrs)
# print(corrs, pvalues)

[<matplotlib.lines.Line2D at 0x7fcbc41c6c40>]

: 

: 

: 

In [None]:
from scipy.integrate import dblquad