In [1]:
import os
os.chdir("..")

from scripts import ripser_count, stats_count

In [30]:
import numpy as np
import json
from pathlib import Path

from multiprocessing import Pool
from tqdm import trange, tqdm

np.random.seed(42)

In [3]:
attn_mx_filename = lambda i: f"assets/attention_maps/qa/pt_{i}/attn_matrices.npz"
ntokens_filename = lambda i: f"assets/attention_maps/qa/pt_{i}/tokens_count.json"

In [31]:
save_path = Path("assets/tda_features")
save_path.mkdir(exist_ok=True)

### Statistical features calculation

In [4]:
stats_name = "s_e_v_c_b0b1"
stats_cap = 500

thresholds = [0.025, 0.05, 0.1, 0.25, 0.5, 0.75] 

In [5]:
def split_data(adj_matrices: list, ntokens_array: list, num_of_workers: int = 20):
    split_adj_matricies = np.array_split(adj_matrices, num_of_workers)
    split_ntokens = np.array_split(ntokens_array, num_of_workers)
    assert all([len(m)==len(n) for m, n in zip(split_adj_matricies, split_ntokens)]), "Split is not valid!"
    return zip(split_adj_matricies, split_ntokens)

In [6]:
num_of_workers = 20
pool = Pool(num_of_workers)

In [7]:
stats_features, keys = [], []
for i in trange(1):
    attn_matrices = np.load(attn_mx_filename(i))

    with open(ntokens_filename(i), "r") as f:
        ntokens = json.load(f)
    
    mx_list, ntokens_list = [], []
    for key in attn_matrices.keys():
        mx_list.append(attn_matrices[key])
        ntokens_list.append(ntokens[key])
        keys.append(key)

    split = split_data(np.asarray(mx_list), np.asarray(ntokens_list), num_of_workers=num_of_workers)
    args = [(mxs, thresholds, tokens, stats_name.split("_"), stats_cap) for mxs, tokens in split]
    stats_features_ = pool.starmap(
        stats_count.count_top_stats, args
    )
    stats_features.append(np.concatenate([_ for _ in stats_features_], axis=3))

100%|██████████| 1/1 [11:33<00:00, 693.28s/it]


100%|██████████| 32/32 [09:34<00:00, 17.94s/it]
100%|██████████| 32/32 [10:00<00:00, 18.77s/it]
100%|██████████| 32/32 [10:01<00:00, 18.81s/it]
100%|██████████| 32/32 [10:29<00:00, 19.67s/it]
100%|██████████| 32/32 [10:29<00:00, 19.69s/it]
100%|██████████| 32/32 [10:38<00:00, 19.96s/it]
100%|██████████| 32/32 [10:45<00:00, 20.16s/it]
100%|██████████| 32/32 [10:51<00:00, 20.37s/it]
100%|██████████| 32/32 [10:54<00:00, 20.47s/it]
100%|██████████| 32/32 [10:54<00:00, 20.45s/it]
100%|██████████| 32/32 [10:55<00:00, 20.49s/it]
100%|██████████| 32/32 [10:50<00:00, 20.31s/it]
100%|██████████| 32/32 [10:57<00:00, 20.55s/it]
100%|██████████| 32/32 [10:56<00:00, 20.52s/it]
100%|██████████| 32/32 [10:54<00:00, 20.45s/it]
100%|██████████| 32/32 [10:56<00:00, 20.52s/it]
100%|██████████| 32/32 [10:56<00:00, 20.52s/it]
100%|██████████| 32/32 [10:59<00:00, 20.60s/it]
100%|██████████| 32/32 [10:59<00:00, 20.60s/it]
100%|██████████| 32/32 [11:00<00:00, 20.64s/it]


In [18]:
stats_features = np.concatenate(stats_features, axis=3)
stats_features_dict = dict(zip(keys, stats_features.transpose(3, 0, 1, 2, 4)))

NameError: name 'stats_features' is not defined

In [None]:
np.savez_compressed(f"{save_path}/stats_features", **stats_features_dict)

### Barcodes calculation

In [23]:
dim = 1
lower_bound = 1e-3

In [33]:
from multiprocessing import Process, Queue

def subprocess_wrap(queue, function, args):
    queue.put(function(*args))
    queue.close()
    exit()

In [34]:
from itertools import product
from collections import defaultdict

def get_only_barcodes(adj_matricies, ntokens_array, dim, lower_bound):
    """Get barcodes from adj matricies for each layer, head"""
    barcodes = {}
    layers, heads = range(adj_matricies.shape[1]), range(adj_matricies.shape[2])
    for (layer, head) in product(layers, heads):
        matricies = adj_matricies[:, layer, head, :, :]
        barcodes[(layer, head)] = ripser_count.get_barcodes(matricies, ntokens_array, dim, lower_bound, (layer, head))
    return barcodes

def format_barcodes(barcodes):
    """Reformat barcodes to json-compatible format"""
    return [{d: b[d].tolist() for d in b} for b in barcodes]

def save_barcodes(barcodes, filename):
    """Save barcodes to file"""
    formatted_barcodes = defaultdict(dict)
    for layer, head in barcodes:
        formatted_barcodes[layer][head] = format_barcodes(barcodes[(layer, head)])

    with open(filename, 'w') as f:
        json.dump(formatted_barcodes, f)
    
def unite_barcodes(barcodes, barcodes_part):
    """Unite 2 barcodes"""
    for (layer, head) in barcodes_part:
        barcodes[(layer, head)].extend(barcodes_part[(layer, head)])
    return barcodes

In [35]:
queue = Queue()
number_of_splits = 2
keys = []
for i in trange(1):
    attn_matrices = np.load(attn_mx_filename(i))

    with open(ntokens_filename(i), "r") as f:
        ntokens = json.load(f)

    mx_list, ntokens_list = [], []
    for key in attn_matrices.keys():
        mx_list.append(attn_matrices[key])
        ntokens_list.append(ntokens[key])
        keys.append(key)
    
    barcodes = defaultdict(list)

    split = split_data(mx_list, ntokens_list, number_of_splits)
    for matrices, ntokens in tqdm(split, leave=False):
        p = Process(
            target=subprocess_wrap,
            args=(
                queue,
                get_only_barcodes,
                (matrices, ntokens, dim, lower_bound)
            )
        )
        p.start()
        barcodes_part = queue.get()
        p.join()
        p.close()
        
        barcodes = unite_barcodes(barcodes, barcodes_part)
        
    save_barcodes(barcodes, save_path / f"barcodes_{i}.json")

  0%|          | 0/1 [00:00<?, ?it/s]Process Process-21:
Traceback (most recent call last):
  File "/home/llm-factuality/miniconda/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/llm-factuality/miniconda/lib/python3.9/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/tmp/ipykernel_102683/2182458175.py", line 4, in subprocess_wrap
    queue.put(function(*args))
  File "/tmp/ipykernel_102683/2137997521.py", line 10, in get_only_barcodes
    barcodes[(layer, head)] = ripser_count.get_barcodes(matricies, ntokens_array, dim, lower_bound, (layer, head))
  File "/app/scripts/ripser_count.py", line 145, in get_barcodes
    matrix = matrix_to_ripser(matrix, ntokens_array[i], lower_bound)
  File "/app/scripts/ripser_count.py", line 127, in matrix_to_ripser
    matrix = (matrix > lower_bound).astype(np.int) * matrix
  File "/home/llm-factuality/miniconda/lib/python3.9/site-packages/numpy/__init__.py

KeyboardInterrupt: 