In [None]:
import torch
import numpy as np

def prim_algo_simplified(adjacency_matrix):
    n = len(adjacency_matrix)

    infty = torch.max(adjacency_matrix).item() + 10
    dst = torch.ones(n, device=adjacency_matrix.device) * infty
    ancestors = -torch.ones(n, dtype=int, device=adjacency_matrix.device)
    visited = torch.zeros(n, dtype=bool, device=adjacency_matrix.device)

    s, v = 0, 0
    for i in range(n - 1):
        visited[v] = 1

        ancestors[dst > adjacency_matrix[v]] = v
        dst = torch.minimum(dst, adjacency_matrix[v])
        dst[visited] = infty
        v = torch.argmin(dst)

        s += adjacency_matrix[v][ancestors[v]]

    return s

class RTD_simplified_summ_only(nn.Module):
    def __init__(self):
        super().__init__()

    def __call__(self, a1, a2):
        r1 = (a1 / torch.quantile(a1, 0.9))
        r2 = (a2 / torch.quantile(a2, 0.9))
        rmin = torch.minimum(r1, r2)

        rmin_sum = prim_algo_simplified(rmin)
        r1_sum = prim_algo_simplified(r1)
        r2_sum = prim_algo_simplified(r2)

        return (r1_sum - rmin_sum + r2_sum - rmin_sum)

In [None]:
# head idx is fixed, layer idx varies from 0 to 12
attention_12_head = np.load("/content/sample_data/attention_train_12_head.npy")
# layer idx is fixed, head idx varies from 0 to 12
attention_12_layer = np.load("/content/sample_data/attention_train_12_layer.npy")

In [None]:
rtd_approx = RTD_simplified_summ_only()

In [None]:
N = attention_12_head.shape[0]
N_stats = 144 # 12 heads * 12 layers in attention "cross"
device = "cuda"

# for each pair of attention matrices count rtd scores using fast approximation
cross_barcodes_stats = np.zeros((N, N_stats))
for i in range(N):
  print("Sample number:", i)
  for j in range(12):
    for k in range(12):
      a1 = (attention_12_layer[i, j, :, :]).astype(float)
      a2 = (attention_12_head[i, k, :, :]).astype(float)
      a1 = torch.tensor(a1).to(device)
      a2 = torch.tensor(a2).to(device)
      rtd_score = rtd_approx(a1, a2)
      res[i, 12 * j + k] = rtd_score