In [49]:
from explaination_baselines import get_possible_num_chunks, formalize_pickle_file_name
import torch
from torch import nn, Tensor, tensor
from tqdm import trange
import os.path as osp
from typing import Tuple, List

In [50]:
KDE_SAVE_PTH = '/home/user/pretrained-models/kde/open_llama_7b/'

In [51]:
def probeless_step1(
    train_data_dir: str,
):
    total_num_chunks = get_possible_num_chunks(train_data_dir)
    mean_values, num_samples = [], []
    for current_num_save in trange(total_num_chunks, desc='loading attention samples and calculating mean values...'):
        this_neg_features: Tensor = torch.load(osp.join(train_data_dir, formalize_pickle_file_name('neg_attn_samples.pkl', 1, False, current_num_save, total_num_chunks)), map_location='cpu')
        this_label_features: Tuple[Tensor] = torch.load(osp.join(train_data_dir, formalize_pickle_file_name('label_attn_samples.pkl', 1, False, current_num_save, total_num_chunks)), map_location='cpu')
        mean_values.append([]) # [label1_batch1_mean_values(Tensor[n_features]), label2_batch1_mean_values(Tensor[n_features]), ... neg_batch1_mean_values(Tensor[n_features])]
        num_samples.append([]) # [label1_batch1_n_samples(int), label2_batch1_n_samples(int), ... neg_batch1_n_samples(int)]
        for label_features in this_label_features:
            if label_features.shape[0] == 0:
                mean_values[-1].append(torch.zeros(label_features.shape[-1]))
            else:
                mean_values[-1].append(label_features.mean(dim=0))
            num_samples[-1].append(label_features.shape[0])
        mean_values[-1].append(this_neg_features.mean(dim=0))
        num_samples[-1].append(this_neg_features.shape[0])
    
    # assert set([len(each) for each in mean_values]) == 1, f'error, found different number of labels in different chunks: {set([len(each) for each in mean_values])}'
    
    return mean_values, num_samples

    
mean_values, num_samples = probeless_step1(KDE_SAVE_PTH)

loading attention samples and calculating mean values...:   0%|                                                                                                                          | 0/39 [00:13<?, ?it/s]


KeyboardInterrupt: 

In [None]:
def probeless_step2(
    mean_values: List[List[Tensor]], num_samples: List[List[int]]
):
    assert len(possible_num_labels := set([len(each) for each in mean_values])) == 1, f'error, different number of labels in different chunks: {possible_num_labels}'
    assert len(possible_num_features := set([each.shape[-1] for _ in mean_values for each in _])) == 1, f'error, different number of features in different chunks: {possible_num_features}'
    num_labels = list(possible_num_labels)[0]
    num_features = list(possible_num_features)[0]
    num_chunks = len(mean_values)
    total_mean_values = torch.zeros(num_labels, num_features)
    total_num_samples = torch.Tensor(num_samples).sum(dim=0).tolist()
    for this_mean_values, this_num_samples in zip(mean_values, num_samples): # iterate over chunks
        for label_idx in range(num_labels):
            if total_num_samples[label_idx] != 0:
                total_mean_values[label_idx] += this_mean_values[label_idx] * (this_num_samples[label_idx] / total_num_samples[label_idx])
            # total_mean_values.append(this_mean_values)

    return total_mean_values

total_mean_values = probeless_step2(mean_values, num_samples) # total_mean_values = Tensor[n_labels, n_features] 

In [None]:
torch.save(total_mean_values, osp.join(KDE_SAVE_PTH, 'baselines', 'total_mean_values.pt'))

In [None]:
num_labels = total_mean_values.shape[0]
pl_matrix = torch.zeros_like(total_mean_values)

for z in range(num_labels):
    for z2 in range(num_labels):
        pl_matrix[z] += (total_mean_values[z2] - total_mean_values[z]).abs()

In [None]:
print(pl_matrix)
torch.save(pl_matrix, osp.join(KDE_SAVE_PTH, 'baselines', 'total_mean_values.pt'))

tensor([[12.0614, 12.9103, 13.9199,  ..., 48.1243,  4.3099, 40.9697],
        [ 6.2106,  3.9450,  4.4875,  ...,  9.6447,  3.3624,  6.7478],
        [ 3.0707,  3.4532,  4.1131,  ..., 14.7476,  5.0805,  9.8415],
        ...,
        [ 3.1218,  3.3797,  4.0568,  ..., 15.3775,  5.7540, 12.2089],
        [12.0614, 12.9103, 13.9199,  ..., 48.1243,  4.3099, 40.9697],
        [ 3.7964,  3.2735,  4.2227,  ..., 20.1973,  6.1718, 16.6401]])


In [None]:
a = torch.arange(10).reshape(2, 5).float()
a.quantile(0.5, dim=-1, keepdim=True)

tensor([[2.],
        [7.]])

In [None]:
def iou(train_data_dir: str, threshold: float = 0.5, dynamic_threshold: bool = False, dynamic_threshold_percentile: float = 0.995):
    total_num_chunks = get_possible_num_chunks(train_data_dir)
    num_labels = -1
    num_features = -1
    iou_intersection, iou_union = None, None
    for current_num_save in trange(total_num_chunks, desc='loading attention samples and calculating iou_matrix...'):
        # load neg and label features and dynamically calculate the number of labels and features, while assuring the number of features is consistent across all chunks
        neg_features: Tensor = torch.load(osp.join(train_data_dir, formalize_pickle_file_name('neg_attn_samples.pkl', 1, False, current_num_save, total_num_chunks)), map_location='cpu')
        label_features: Tuple[Tensor] = torch.load(osp.join(train_data_dir, formalize_pickle_file_name('label_attn_samples.pkl', 1, False, current_num_save, total_num_chunks)), map_location='cpu')
        if num_labels == -1:
            num_labels = len(label_features) + 1
        else:
            assert num_labels == len(label_features) + 1, f'error, different number of labels in different chunks: {num_labels} vs {len(label_features) + 1}'
        assert len(set([each.shape[-1] for each in label_features + (neg_features,)])) == 1, f'error, different number of features in different chunks: {set([each.shape[-1] for each in label_features + (neg_features,)])}'
        if num_features == -1:
            num_features = neg_features.shape[-1]
        else:
            assert num_features == neg_features.shape[-1], f'error, different number of features in different chunks: {num_features} vs {neg_features.shape[-1]}'
        if iou_intersection is None: iou_intersection = torch.zeros(num_labels, num_features)
        if iou_union is None: iou_union = torch.zeros(num_labels, num_features)
        label_features = label_features + (neg_features,) # concatenate neg features at the tail of label features
        if dynamic_threshold:
            threshold = torch.cat(label_features, dim=0).float().quantile(dynamic_threshold_percentile, dim=0, keepdim=True) # [num_features]
        for label_idx in range(num_labels):
            this_label_features = label_features[label_idx]
            if this_label_features.shape[0] != 0:
                high_value_mask = (this_label_features > threshold).long() # [num_samples, num_features]
                feature_high_values = high_value_mask.sum(dim=0) # [num_features]
                iou_intersection[label_idx] += feature_high_values
                feature_num_samples = torch.ones(num_features) * this_label_features.shape[0] # [num_features]
                for other_label_idx in range(num_labels):
                    if other_label_idx == label_idx:
                        iou_union[other_label_idx] += feature_num_samples # for this label, add the number of samples to each feature's entry in the matrix (since the `feature mask` is all 1)
                    else:
                        iou_union[other_label_idx] += feature_high_values # for other labels, add the number of high values to each feature's entry in the matrix
    
    return iou_intersection / iou_union

iou_matrix = iou(KDE_SAVE_PTH, dynamic_threshold=True)

loading attention samples and calculating iou_matrix...:  59%|███████████████████████████████████████████████████████████████████▏                                              | 23/39 [35:41<25:21, 95.11s/it]

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.8358e-03, 8.9633e-03, 0.0000e+00,  ..., 4.0604e-05, 1.5527e-05,
         9.0804e-03],
        [0.0000e+00, 0.0000e+00, 6.8045e-05,  ..., 5.5433e-03, 5.8366e-05,
         2.2208e-02],
        ...,
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [4.8371e-03, 4.8854e-03, 4.7513e-03,  ..., 4.6057e-03, 4.9403e-03,
         4.2393e-03]])


In [None]:
torch.save(iou_matrix, osp.join(KDE_SAVE_PTH, 'baselines', 'iou_matrix_dynamic_0.995.pt'))

In [None]:
iou1 = torch.load(osp.join(KDE_SAVE_PTH, 'baselines', 'iou_matrix_dynamic_0.995.pt'))
iou2 = torch.load(osp.join(KDE_SAVE_PTH, 'baselines', 'iou_matrix_static_0.5.pt'))
print(iou1)
print(iou2)