# Analysis of the gradient conflict of Llama model and its K-neuron variants

 - gradient conflict
 - dominating gradient

目前发现使用topk neuron可能会让模型的gradient conflict增加（无论如何调整gradeint 的norm以及cos相似度大小阈值）

**Further analysis**
 - [ ] Analyze layer by layer, starting from the last layer first
 - [ ] Analyze neuron by neuron, starting from the neuron that receives max activation values (这个和考虑梯度norm的应该是很类似的，因为接收到更大激活值的neuron，按理来说应该有更大的梯度)
 - [ ] Analyze weight by weight, 目前是以neuron为单位进行分析的，去看一下 gradient surgery 那篇文章是以什么为单位进行讨论的
 - [ ] Analyze batch by batch, 目前是以sample为单位进行分析的，不过感觉batch_size=1也是一种情况，不应该被特殊处理，即batch_size为多少都应该
 - [ ] **第一优先级** 有一种可能是topk的选取方式让neuron在不同sample上获取了更加平均的梯度，即梯度不会被一个sample所dominate。因此即使梯度有冲突也可以很好的average之后进行很好的优化。

In [1]:

import os
import torch
import random 
import torch.nn as nn
from transformers import LlamaForCausalLM, AutoTokenizer
from topk_models import TopKLlamaForCausalLM, TopKLlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np

# Configurations of this notobook
SEED=1
MODEL_PATH = 'models/Llama-2-7b-hf'
MODEL_TYPE = 'original'
DATA_PATH = './data/alpaca/'
N_SAMPLES = 50
CACHE_PATH = 'cache/original_llama/'
USE_GATE, USE_UP, USE_DOWN = True, True, True
if MODEL_TYPE == 'topk':
    from datetime import datetime
    currentTime = datetime.now().strftime("%H-%M-%S")
    CACHE_PATH = f'cache/{currentTime}'
    
    os.system(f"mkdir {CACHE_PATH} & cp topk_7b_configs.json {CACHE_PATH}")
    
# load the model
if MODEL_TYPE == 'original':
    model = LlamaForCausalLM.from_pretrained("./models/Llama-2-7b-hf/", device_map="auto")
else:
    # load the topk model
    config = TopKLlamaConfig.from_json_file('topk_7b_configs.json')
    model = TopKLlamaForCausalLM.from_pretrained(
        "./models/Llama-2-7b-hf/", 
        # device_map="auto",
        config=config
    )
tokenizer = AutoTokenizer.from_pretrained("./models/Llama-2-7b-hf/")
model.eval()

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 2/2 [00:06<00:00,  3.07s/it]


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNo

Sample some data points from the instruction-tuning data, e.g., Alpaca52K

In [2]:
import random
from datasets import load_dataset
random.seed(10)
data = load_dataset(DATA_PATH)['train']
data_idx_sample = random.sample([_ for _ in range(len(data))], N_SAMPLES)
data_samples = [data[i] for i in data_idx_sample]

Register gradients of the model

In [3]:
IGNORE_INDEX = -100
from copy import deepcopy
from tqdm import tqdm
from utils import PROMPT_TEMPLATE_SINGLE as PROMPT_DICT
from torch import optim
prompt = PROMPT_DICT["prompt_input"]


n_layer = 32
grads = {
    "gate_proj": [[] for _ in range(n_layer)], "up_proj": [[] for _ in range(n_layer)], "down_proj": [[] for _ in range(n_layer)]
}
optimizer = optim.AdamW(model.parameters())
for d in tqdm(data_samples):
    inp = prompt.format_map(d)
    out = f"{d['output']}{tokenizer.eos_token}"
    text = inp + out
    # tokenize
    input_ids = tokenizer(text, return_tensors='pt').input_ids# .to(device)
    src_len = tokenizer(inp, return_tensors='pt').input_ids.ne(tokenizer.pad_token_id).sum().item()
    labels = deepcopy(input_ids)
    labels[:, : src_len] = IGNORE_INDEX
    labels = labels[:, 1:]
    
    # forward
    y = model(input_ids[:,:-1])
    loss = F.cross_entropy(input=y.logits.squeeze(), target=labels.squeeze())
    
    # backward
    loss.backward()
    
    # record gradients
    for n, p in model.named_parameters():
        if len(n.split(".")) < 5:
            continue
        layer_id = int(n.split(".")[2])
        name = n.split(".")[4]
        if name in grads:
            grads[name][layer_id].append(p.grad.data.detach().cpu())
            
    optimizer.zero_grad()


100%|██████████| 50/50 [07:12<00:00,  8.66s/it]


Collect the gradient and analyze their gradient conflicting, for now let's just focus on one neuron in the FFN layer.

The representation of neuron $i$ in Llama FFN layer could be defined as $[g_i, up_i, down_i]$, where $g_i$, $up_i$, $down_i$ represent the 

In [4]:
n_layers = model.config.num_hidden_layers
# for the ffn of each layer
# (neuron_num, sample, d_dim) @ (neuron_num, d_dim, sample)
norm_dominate = []
conflicts = []
epsilon = 0
gamma = 0
device='cpu'
from tqdm import trange
for layer_id in trange(n_layers):
    gate_grads = torch.stack(grads["gate_proj"][layer_id]).transpose(0, 1)
    up_grads = torch.stack(grads["up_proj"][layer_id]).transpose(0, 1)
    down_grads = torch.stack(grads["down_proj"][layer_id]).permute(2, 0, 1)
    neuron_grads = torch.cat([
        gate_grads if USE_GATE else torch.zeros_like(gate_grads), 
        up_grads if USE_UP else torch.zeros_like(up_grads), 
        down_grads if USE_DOWN else torch.zeros_like(down_grads)], dim=-1).to(device)
    
    grad_norm = torch.norm(neuron_grads, p=2, dim=-1, keepdim=True)
    grad_norm_ = torch.cat([grad_norm for _ in range(N_SAMPLES)], dim=-1)
    grad_norm_2 = grad_norm_ * grad_norm_
    tmp = torch.bmm(grad_norm, grad_norm.transpose(1, 2))
    gm_sim = 2 * torch.bmm(grad_norm, grad_norm.transpose(1, 2)) / (grad_norm_2 + grad_norm_2.transpose(1,2))
    norm_dominate.append(torch.mean(gm_sim).cpu().data)
    
    norm_mask = grad_norm > epsilon
    # print(torch.sum(norm_mask) / N_SAMPLES / 11008)
    norm_mask = norm_mask.float()
    batch_norm_mask = torch.bmm(norm_mask, norm_mask.transpose(1, 2)).cpu()
    
    _a = F.normalize(neuron_grads, p=2, dim=-1)
    _b = _a.transpose(1, 2)
    conflict = torch.bmm(_a, _b).cpu()
    conflict.masked_fill_(~batch_norm_mask.bool(), 0)
    conflict_rate = torch.sum(conflict < gamma) / torch.sum(batch_norm_mask)
    conflicts.append(conflict_rate.data)
    print(torch.mean(gm_sim))
    print(conflicts[-1])
    
# plt.bar([ _ for _ in range(2)], conflicts)

  3%|▎         | 1/32 [00:06<03:12,  6.21s/it]

tensor([[[1.0000, 0.9938, 0.9327,  ..., 0.9882, 1.0000, 0.9925],
         [0.9938, 1.0000, 0.9657,  ..., 0.9991, 0.9939, 0.9731],
         [0.9327, 0.9657, 1.0000,  ..., 0.9755, 0.9332, 0.8866],
         ...,
         [0.9882, 0.9991, 0.9755,  ..., 1.0000, 0.9884, 0.9628],
         [1.0000, 0.9939, 0.9332,  ..., 0.9884, 1.0000, 0.9923],
         [0.9925, 0.9731, 0.8866,  ..., 0.9628, 0.9923, 1.0000]],

        [[1.0000, 0.9030, 0.9982,  ..., 0.7508, 0.9946, 0.9643],
         [0.9030, 1.0000, 0.9253,  ..., 0.9466, 0.9401, 0.9826],
         [0.9982, 0.9253, 1.0000,  ..., 0.7804, 0.9990, 0.9781],
         ...,
         [0.7508, 0.9466, 0.7804,  ..., 1.0000, 0.8017, 0.8775],
         [0.9946, 0.9401, 0.9990,  ..., 0.8017, 1.0000, 0.9862],
         [0.9643, 0.9826, 0.9781,  ..., 0.8775, 0.9862, 1.0000]],

        [[1.0000, 0.9869, 0.9624,  ..., 0.9969, 0.9963, 0.9967],
         [0.9869, 1.0000, 0.9933,  ..., 0.9965, 0.9971, 0.9967],
         [0.9624, 0.9933, 1.0000,  ..., 0.9804, 0.9819, 0.

  6%|▋         | 2/32 [00:14<03:49,  7.66s/it]

tensor([[[1.0000, 0.9873, 0.9061,  ..., 0.7618, 0.9665, 0.9755],
         [0.9873, 1.0000, 0.9591,  ..., 0.8384, 0.9169, 0.9306],
         [0.9061, 0.9591, 1.0000,  ..., 0.9508, 0.7900, 0.8087],
         ...,
         [0.7618, 0.8384, 0.9508,  ..., 1.0000, 0.6313, 0.6505],
         [0.9665, 0.9169, 0.7900,  ..., 0.6313, 1.0000, 0.9992],
         [0.9755, 0.9306, 0.8087,  ..., 0.6505, 0.9992, 1.0000]],

        [[1.0000, 0.9711, 0.9999,  ..., 0.9994, 0.9883, 0.9744],
         [0.9711, 1.0000, 0.9736,  ..., 0.9622, 0.9260, 0.9999],
         [0.9999, 0.9736, 1.0000,  ..., 0.9989, 0.9866, 0.9768],
         ...,
         [0.9994, 0.9622, 0.9989,  ..., 1.0000, 0.9931, 0.9660],
         [0.9883, 0.9260, 0.9866,  ..., 0.9931, 1.0000, 0.9311],
         [0.9744, 0.9999, 0.9768,  ..., 0.9660, 0.9311, 1.0000]],

        [[1.0000, 0.9588, 0.9084,  ..., 0.9983, 0.9663, 0.9990],
         [0.9588, 1.0000, 0.9884,  ..., 0.9735, 0.9996, 0.9702],
         [0.9084, 0.9884, 1.0000,  ..., 0.9298, 0.9836, 0.

  9%|▉         | 3/32 [00:23<03:56,  8.17s/it]

tensor([[[1.0000, 0.9500, 0.8125,  ..., 0.6524, 0.7809, 0.7608],
         [0.9500, 1.0000, 0.9436,  ..., 0.8118, 0.9216, 0.9064],
         [0.8125, 0.9436, 1.0000,  ..., 0.9496, 0.9978, 0.9943],
         ...,
         [0.6524, 0.8118, 0.9496,  ..., 1.0000, 0.9675, 0.9768],
         [0.7809, 0.9216, 0.9978,  ..., 0.9675, 1.0000, 0.9992],
         [0.7608, 0.9064, 0.9943,  ..., 0.9768, 0.9992, 1.0000]],

        [[1.0000, 0.9816, 0.9872,  ..., 0.9149, 0.9948, 0.9884],
         [0.9816, 1.0000, 0.9403,  ..., 0.8338, 0.9579, 0.9429],
         [0.9872, 0.9403, 1.0000,  ..., 0.9654, 0.9983, 1.0000],
         ...,
         [0.9149, 0.8338, 0.9654,  ..., 1.0000, 0.9491, 0.9633],
         [0.9948, 0.9579, 0.9983,  ..., 0.9491, 1.0000, 0.9987],
         [0.9884, 0.9429, 1.0000,  ..., 0.9633, 0.9987, 1.0000]],

        [[1.0000, 0.9538, 0.9956,  ..., 0.9823, 0.9664, 0.5019],
         [0.9538, 1.0000, 0.9235,  ..., 0.8871, 0.8556, 0.6467],
         [0.9956, 0.9235, 1.0000,  ..., 0.9955, 0.9859, 0.

 12%|█▎        | 4/32 [00:32<03:57,  8.50s/it]

tensor([[[1.0000, 0.9897, 0.9691,  ..., 0.8136, 0.9880, 0.9487],
         [0.9897, 1.0000, 0.9942,  ..., 0.7434, 0.9999, 0.9834],
         [0.9691, 0.9942, 1.0000,  ..., 0.6896, 0.9954, 0.9971],
         ...,
         [0.8136, 0.7434, 0.6896,  ..., 1.0000, 0.7377, 0.6519],
         [0.9880, 0.9999, 0.9954,  ..., 0.7377, 1.0000, 0.9854],
         [0.9487, 0.9834, 0.9971,  ..., 0.6519, 0.9854, 1.0000]],

        [[1.0000, 0.9998, 0.9394,  ..., 0.9847, 0.9853, 0.9981],
         [0.9998, 1.0000, 0.9454,  ..., 0.9877, 0.9820, 0.9991],
         [0.9394, 0.9454, 1.0000,  ..., 0.9839, 0.8746, 0.9579],
         ...,
         [0.9847, 0.9877, 0.9839,  ..., 1.0000, 0.9422, 0.9935],
         [0.9853, 0.9820, 0.8746,  ..., 0.9422, 1.0000, 0.9733],
         [0.9981, 0.9991, 0.9579,  ..., 0.9935, 0.9733, 1.0000]],

        [[1.0000, 0.9351, 0.9940,  ..., 0.8019, 0.9988, 0.9949],
         [0.9351, 1.0000, 0.9670,  ..., 0.6188, 0.9505, 0.9648],
         [0.9940, 0.9670, 1.0000,  ..., 0.7481, 0.9981, 1.

 16%|█▌        | 5/32 [00:41<03:55,  8.72s/it]

tensor([[[1.0000, 0.9881, 1.0000,  ..., 0.9684, 0.9993, 0.9931],
         [0.9881, 1.0000, 0.9886,  ..., 0.9215, 0.9815, 0.9993],
         [1.0000, 0.9886, 1.0000,  ..., 0.9676, 0.9991, 0.9935],
         ...,
         [0.9684, 0.9215, 0.9676,  ..., 1.0000, 0.9771, 0.9344],
         [0.9993, 0.9815, 0.9991,  ..., 0.9771, 1.0000, 0.9879],
         [0.9931, 0.9993, 0.9935,  ..., 0.9344, 0.9879, 1.0000]],

        [[1.0000, 0.9950, 0.9486,  ..., 0.9391, 0.9896, 0.9692],
         [0.9950, 1.0000, 0.9149,  ..., 0.9677, 0.9706, 0.9887],
         [0.9486, 0.9149, 1.0000,  ..., 0.8035, 0.9836, 0.8530],
         ...,
         [0.9391, 0.9677, 0.8035,  ..., 1.0000, 0.8855, 0.9943],
         [0.9896, 0.9706, 0.9836,  ..., 0.8855, 1.0000, 0.9263],
         [0.9692, 0.9887, 0.8530,  ..., 0.9943, 0.9263, 1.0000]],

        [[1.0000, 0.9994, 0.9915,  ..., 0.9776, 0.9561, 0.4902],
         [0.9994, 1.0000, 0.9864,  ..., 0.9842, 0.9458, 0.5054],
         [0.9915, 0.9864, 1.0000,  ..., 0.9435, 0.9855, 0.

 19%|█▉        | 6/32 [00:51<03:53,  8.97s/it]

tensor([[[1.0000, 0.9575, 0.9185,  ..., 0.9958, 0.9632, 0.9992],
         [0.9575, 1.0000, 0.9927,  ..., 0.9288, 0.9998, 0.9681],
         [0.9185, 0.9927, 1.0000,  ..., 0.8826, 0.9899, 0.9329],
         ...,
         [0.9958, 0.9288, 0.8826,  ..., 1.0000, 0.9360, 0.9912],
         [0.9632, 0.9998, 0.9899,  ..., 0.9360, 1.0000, 0.9732],
         [0.9992, 0.9681, 0.9329,  ..., 0.9912, 0.9732, 1.0000]],

        [[1.0000, 0.9971, 0.9671,  ..., 0.9115, 0.7933, 0.9915],
         [0.9971, 1.0000, 0.9834,  ..., 0.8811, 0.8295, 0.9985],
         [0.9671, 0.9834, 1.0000,  ..., 0.7979, 0.9079, 0.9917],
         ...,
         [0.9115, 0.8811, 0.7979,  ..., 1.0000, 0.5782, 0.8578],
         [0.7933, 0.8295, 0.9079,  ..., 0.5782, 1.0000, 0.8543],
         [0.9915, 0.9985, 0.9917,  ..., 0.8578, 0.8543, 1.0000]],

        [[1.0000, 0.9934, 0.9903,  ..., 0.9353, 0.9873, 0.9916],
         [0.9934, 1.0000, 0.9997,  ..., 0.8929, 0.9633, 0.9707],
         [0.9903, 0.9997, 1.0000,  ..., 0.8827, 0.9566, 0.

 22%|██▏       | 7/32 [01:00<03:47,  9.11s/it]

tensor([[[1.0000, 0.9998, 0.9991,  ..., 0.7832, 0.9875, 0.9593],
         [0.9998, 1.0000, 0.9998,  ..., 0.7934, 0.9905, 0.9648],
         [0.9991, 0.9998, 1.0000,  ..., 0.8035, 0.9932, 0.9700],
         ...,
         [0.7832, 0.7934, 0.8035,  ..., 1.0000, 0.8575, 0.9114],
         [0.9875, 0.9905, 0.9932,  ..., 0.8575, 1.0000, 0.9914],
         [0.9593, 0.9648, 0.9700,  ..., 0.9114, 0.9914, 1.0000]],

        [[1.0000, 0.9989, 0.7837,  ..., 0.9812, 1.0000, 0.9978],
         [0.9989, 1.0000, 0.7605,  ..., 0.9891, 0.9991, 0.9935],
         [0.7837, 0.7605, 1.0000,  ..., 0.6866, 0.7813, 0.8157],
         ...,
         [0.9812, 0.9891, 0.6866,  ..., 1.0000, 0.9821, 0.9666],
         [1.0000, 0.9991, 0.7813,  ..., 0.9821, 1.0000, 0.9974],
         [0.9978, 0.9935, 0.8157,  ..., 0.9666, 0.9974, 1.0000]],

        [[1.0000, 0.9902, 0.9973,  ..., 0.8586, 0.9384, 0.9996],
         [0.9902, 1.0000, 0.9775,  ..., 0.9156, 0.9763, 0.9936],
         [0.9973, 0.9775, 1.0000,  ..., 0.8251, 0.9126, 0.

 25%|██▌       | 8/32 [01:10<03:43,  9.30s/it]

tensor([[[1.0000, 0.9952, 0.9973,  ..., 0.9835, 0.9928, 0.9673],
         [0.9952, 1.0000, 0.9855,  ..., 0.9616, 0.9765, 0.9391],
         [0.9973, 0.9855, 1.0000,  ..., 0.9940, 0.9989, 0.9829],
         ...,
         [0.9835, 0.9616, 0.9940,  ..., 1.0000, 0.9980, 0.9971],
         [0.9928, 0.9765, 0.9989,  ..., 0.9980, 1.0000, 0.9904],
         [0.9673, 0.9391, 0.9829,  ..., 0.9971, 0.9904, 1.0000]],

        [[1.0000, 0.9974, 0.9994,  ..., 0.9994, 0.9076, 0.9236],
         [0.9974, 1.0000, 0.9945,  ..., 0.9943, 0.8789, 0.8967],
         [0.9994, 0.9945, 1.0000,  ..., 1.0000, 0.9202, 0.9352],
         ...,
         [0.9994, 0.9943, 1.0000,  ..., 1.0000, 0.9208, 0.9358],
         [0.9076, 0.8789, 0.9202,  ..., 0.9208, 1.0000, 0.9991],
         [0.9236, 0.8967, 0.9352,  ..., 0.9358, 0.9991, 1.0000]],

        [[1.0000, 0.9645, 0.9729,  ..., 1.0000, 0.9812, 0.9962],
         [0.9645, 1.0000, 0.9994,  ..., 0.9630, 0.9972, 0.9835],
         [0.9729, 0.9994, 1.0000,  ..., 0.9715, 0.9992, 0.

 28%|██▊       | 9/32 [01:20<03:38,  9.49s/it]

tensor([[[1.0000, 0.8284, 0.9997,  ..., 0.9257, 0.9892, 1.0000],
         [0.8284, 1.0000, 0.8179,  ..., 0.6327, 0.7571, 0.8310],
         [0.9997, 0.8179, 1.0000,  ..., 0.9334, 0.9922, 0.9996],
         ...,
         [0.9257, 0.6327, 0.9334,  ..., 1.0000, 0.9695, 0.9236],
         [0.9892, 0.7571, 0.9922,  ..., 0.9695, 1.0000, 0.9883],
         [1.0000, 0.8310, 0.9996,  ..., 0.9236, 0.9883, 1.0000]],

        [[1.0000, 0.9169, 0.8901,  ..., 0.7669, 0.9253, 0.4727],
         [0.9169, 1.0000, 0.9976,  ..., 0.9453, 0.9997, 0.3207],
         [0.8901, 0.9976, 1.0000,  ..., 0.9648, 0.9957, 0.3002],
         ...,
         [0.7669, 0.9453, 0.9648,  ..., 1.0000, 0.9379, 0.2316],
         [0.9253, 0.9997, 0.9957,  ..., 0.9379, 1.0000, 0.3278],
         [0.4727, 0.3207, 0.3002,  ..., 0.2316, 0.3278, 1.0000]],

        [[1.0000, 0.9637, 0.9332,  ..., 0.6683, 0.9026, 0.9996],
         [0.9637, 1.0000, 0.9948,  ..., 0.8037, 0.9828, 0.9562],
         [0.9332, 0.9948, 1.0000,  ..., 0.8513, 0.9965, 0.

 31%|███▏      | 10/32 [01:30<03:34,  9.76s/it]

tensor([[[1.0000, 0.9775, 0.9557,  ..., 0.9938, 0.9849, 0.8559],
         [0.9775, 1.0000, 0.9960,  ..., 0.9492, 0.9287, 0.9391],
         [0.9557, 0.9960, 1.0000,  ..., 0.9198, 0.8955, 0.9649],
         ...,
         [0.9938, 0.9492, 0.9198,  ..., 1.0000, 0.9980, 0.8045],
         [0.9849, 0.9287, 0.8955,  ..., 0.9980, 1.0000, 0.7736],
         [0.8559, 0.9391, 0.9649,  ..., 0.8045, 0.7736, 1.0000]],

        [[1.0000, 0.9100, 0.9631,  ..., 0.9864, 0.9998, 1.0000],
         [0.9100, 1.0000, 0.9865,  ..., 0.8405, 0.9022, 0.9109],
         [0.9631, 0.9865, 1.0000,  ..., 0.9099, 0.9577, 0.9637],
         ...,
         [0.9864, 0.8405, 0.9099,  ..., 1.0000, 0.9895, 0.9861],
         [0.9998, 0.9022, 0.9577,  ..., 0.9895, 1.0000, 0.9997],
         [1.0000, 0.9109, 0.9637,  ..., 0.9861, 0.9997, 1.0000]],

        [[1.0000, 0.6766, 0.7211,  ..., 0.5929, 0.6025, 0.5950],
         [0.6766, 1.0000, 0.9961,  ..., 0.9855, 0.9887, 0.9863],
         [0.7211, 0.9961, 1.0000,  ..., 0.9671, 0.9719, 0.

 34%|███▍      | 11/32 [01:40<03:27,  9.90s/it]

tensor([[[1.0000, 0.9891, 0.8725,  ..., 0.9307, 0.9992, 0.9637],
         [0.9891, 1.0000, 0.9300,  ..., 0.8735, 0.9827, 0.9170],
         [0.8725, 0.9300, 1.0000,  ..., 0.6890, 0.8556, 0.7437],
         ...,
         [0.9307, 0.8735, 0.6890,  ..., 1.0000, 0.9434, 0.9940],
         [0.9992, 0.9827, 0.8556,  ..., 0.9434, 1.0000, 0.9730],
         [0.9637, 0.9170, 0.7437,  ..., 0.9940, 0.9730, 1.0000]],

        [[1.0000, 0.9091, 0.9994,  ..., 1.0000, 0.9083, 0.9611],
         [0.9091, 1.0000, 0.8958,  ..., 0.9101, 0.7032, 0.7836],
         [0.9994, 0.8958, 1.0000,  ..., 0.9993, 0.9209, 0.9697],
         ...,
         [1.0000, 0.9101, 0.9993,  ..., 1.0000, 0.9073, 0.9604],
         [0.9083, 0.7032, 0.9209,  ..., 0.9073, 1.0000, 0.9870],
         [0.9611, 0.7836, 0.9697,  ..., 0.9604, 0.9870, 1.0000]],

        [[1.0000, 0.9661, 0.8694,  ..., 0.9717, 0.4148, 0.9961],
         [0.9661, 1.0000, 0.9627,  ..., 0.9997, 0.5237, 0.9849],
         [0.8694, 0.9627, 1.0000,  ..., 0.9564, 0.6552, 0.

 38%|███▊      | 12/32 [01:51<03:22, 10.13s/it]

tensor([[[1.0000, 0.9816, 0.9719,  ..., 0.8794, 0.9567, 0.9902],
         [0.9816, 1.0000, 0.9989,  ..., 0.9495, 0.9944, 0.9468],
         [0.9719, 0.9989, 1.0000,  ..., 0.9626, 0.9982, 0.9318],
         ...,
         [0.8794, 0.9495, 0.9626,  ..., 1.0000, 0.9766, 0.8166],
         [0.9567, 0.9944, 0.9982,  ..., 0.9766, 1.0000, 0.9105],
         [0.9902, 0.9468, 0.9318,  ..., 0.8166, 0.9105, 1.0000]],

        [[1.0000, 0.9636, 0.8806,  ..., 0.9758, 0.9975, 0.9688],
         [0.9636, 1.0000, 0.9717,  ..., 0.8883, 0.9797, 0.8754],
         [0.8806, 0.9717, 1.0000,  ..., 0.7787, 0.9090, 0.7635],
         ...,
         [0.9758, 0.8883, 0.7787,  ..., 1.0000, 0.9585, 0.9995],
         [0.9975, 0.9797, 0.9090,  ..., 0.9585, 1.0000, 0.9496],
         [0.9688, 0.8754, 0.7635,  ..., 0.9995, 0.9496, 1.0000]],

        [[1.0000, 0.9848, 0.9919,  ..., 0.9819, 0.9982, 1.0000],
         [0.9848, 1.0000, 0.9988,  ..., 0.9999, 0.9933, 0.9862],
         [0.9919, 0.9988, 1.0000,  ..., 0.9979, 0.9977, 0.

 41%|████      | 13/32 [02:02<03:14, 10.25s/it]

tensor([[[1.0000, 0.9926, 0.9225,  ..., 0.9972, 0.9882, 0.9943],
         [0.9926, 1.0000, 0.9606,  ..., 0.9811, 0.9631, 0.9999],
         [0.9225, 0.9606, 1.0000,  ..., 0.8943, 0.8609, 0.9566],
         ...,
         [0.9972, 0.9811, 0.8943,  ..., 1.0000, 0.9968, 0.9838],
         [0.9882, 0.9631, 0.8609,  ..., 0.9968, 1.0000, 0.9669],
         [0.9943, 0.9999, 0.9566,  ..., 0.9838, 0.9669, 1.0000]],

        [[1.0000, 1.0000, 0.9955,  ..., 0.9827, 0.9904, 0.9722],
         [1.0000, 1.0000, 0.9957,  ..., 0.9832, 0.9900, 0.9729],
         [0.9955, 0.9957, 1.0000,  ..., 0.9958, 0.9731, 0.9898],
         ...,
         [0.9827, 0.9832, 0.9958,  ..., 1.0000, 0.9489, 0.9987],
         [0.9904, 0.9900, 0.9731,  ..., 0.9489, 1.0000, 0.9327],
         [0.9722, 0.9729, 0.9898,  ..., 0.9987, 0.9327, 1.0000]],

        [[1.0000, 0.9706, 0.8845,  ..., 0.8362, 0.9347, 0.9955],
         [0.9706, 1.0000, 0.9671,  ..., 0.7171, 0.8359, 0.9448],
         [0.8845, 0.9671, 1.0000,  ..., 0.5890, 0.7092, 0.

 44%|████▍     | 14/32 [02:12<03:05, 10.31s/it]

tensor([[[1.0000, 0.9918, 0.9617,  ..., 0.9251, 1.0000, 0.9995],
         [0.9918, 1.0000, 0.9216,  ..., 0.8751, 0.9905, 0.9953],
         [0.9617, 0.9216, 1.0000,  ..., 0.9930, 0.9642, 0.9532],
         ...,
         [0.9251, 0.8751, 0.9930,  ..., 1.0000, 0.9284, 0.9139],
         [1.0000, 0.9905, 0.9642,  ..., 0.9284, 1.0000, 0.9992],
         [0.9995, 0.9953, 0.9532,  ..., 0.9139, 0.9992, 1.0000]],

        [[1.0000, 0.8815, 0.9932,  ..., 0.9943, 0.9782, 0.9997],
         [0.8815, 1.0000, 0.9265,  ..., 0.9231, 0.9559, 0.8921],
         [0.9932, 0.9265, 1.0000,  ..., 1.0000, 0.9957, 0.9958],
         ...,
         [0.9943, 0.9231, 1.0000,  ..., 1.0000, 0.9947, 0.9967],
         [0.9782, 0.9559, 0.9957,  ..., 0.9947, 1.0000, 0.9832],
         [0.9997, 0.8921, 0.9958,  ..., 0.9967, 0.9832, 1.0000]],

        [[1.0000, 0.9999, 0.9706,  ..., 0.9989, 0.9996, 0.8898],
         [0.9999, 1.0000, 0.9737,  ..., 0.9981, 0.9999, 0.8953],
         [0.9706, 0.9737, 1.0000,  ..., 0.9586, 0.9766, 0.

 47%|████▋     | 15/32 [02:22<02:55, 10.34s/it]

tensor([[[1.0000, 0.6311, 0.6461,  ..., 0.8536, 0.9631, 0.9472],
         [0.6311, 1.0000, 0.9995,  ..., 0.3836, 0.5028, 0.7957],
         [0.6461, 0.9995, 1.0000,  ..., 0.3946, 0.5162, 0.8103],
         ...,
         [0.8536, 0.3836, 0.3946,  ..., 1.0000, 0.9562, 0.6927],
         [0.9631, 0.5028, 0.5162,  ..., 0.9562, 1.0000, 0.8397],
         [0.9472, 0.7957, 0.8103,  ..., 0.6927, 0.8397, 1.0000]],

        [[1.0000, 0.9898, 0.9741,  ..., 0.8408, 0.9190, 0.9901],
         [0.9898, 1.0000, 0.9340,  ..., 0.7725, 0.8611, 0.9606],
         [0.9741, 0.9340, 1.0000,  ..., 0.9333, 0.9828, 0.9961],
         ...,
         [0.8408, 0.7725, 0.9333,  ..., 1.0000, 0.9824, 0.9010],
         [0.9190, 0.8611, 0.9828,  ..., 0.9824, 1.0000, 0.9632],
         [0.9901, 0.9606, 0.9961,  ..., 0.9010, 0.9632, 1.0000]],

        [[1.0000, 0.9506, 0.9722,  ..., 0.9784, 0.9147, 0.9412],
         [0.9506, 1.0000, 0.9966,  ..., 0.9938, 0.9943, 0.9995],
         [0.9722, 0.9966, 1.0000,  ..., 0.9996, 0.9823, 0.

 50%|█████     | 16/32 [02:33<02:47, 10.48s/it]

tensor([[[1.0000, 0.9204, 0.9795,  ..., 0.9693, 0.9815, 0.9313],
         [0.9204, 1.0000, 0.9787,  ..., 0.9870, 0.9765, 0.9995],
         [0.9795, 0.9787, 1.0000,  ..., 0.9989, 0.9999, 0.9845],
         ...,
         [0.9693, 0.9870, 0.9989,  ..., 1.0000, 0.9984, 0.9915],
         [0.9815, 0.9765, 0.9999,  ..., 0.9984, 1.0000, 0.9826],
         [0.9313, 0.9995, 0.9845,  ..., 0.9915, 0.9826, 1.0000]],

        [[1.0000, 0.9426, 0.9958,  ..., 0.9833, 0.6794, 0.9637],
         [0.9426, 1.0000, 0.9681,  ..., 0.9868, 0.8482, 0.9973],
         [0.9958, 0.9681, 1.0000,  ..., 0.9957, 0.7250, 0.9836],
         ...,
         [0.9833, 0.9868, 0.9957,  ..., 1.0000, 0.7709, 0.9960],
         [0.6794, 0.8482, 0.7250,  ..., 0.7709, 1.0000, 0.8143],
         [0.9637, 0.9973, 0.9836,  ..., 0.9960, 0.8143, 1.0000]],

        [[1.0000, 0.4876, 0.4655,  ..., 0.9056, 0.6634, 0.8096],
         [0.4876, 1.0000, 0.9986,  ..., 0.7011, 0.9329, 0.8097],
         [0.4655, 0.9986, 1.0000,  ..., 0.6749, 0.9143, 0.

 53%|█████▎    | 17/32 [02:44<02:37, 10.53s/it]

tensor([[[1.0000, 0.9713, 0.9857,  ..., 0.8476, 0.7861, 0.7769],
         [0.9713, 1.0000, 0.9974,  ..., 0.9422, 0.8951, 0.8876],
         [0.9857, 0.9974, 1.0000,  ..., 0.9175, 0.8648, 0.8566],
         ...,
         [0.8476, 0.9422, 0.9175,  ..., 1.0000, 0.9915, 0.9889],
         [0.7861, 0.8951, 0.8648,  ..., 0.9915, 1.0000, 0.9998],
         [0.7769, 0.8876, 0.8566,  ..., 0.9889, 0.9998, 1.0000]],

        [[1.0000, 0.8996, 0.8208,  ..., 0.9972, 0.9995, 0.9857],
         [0.8996, 1.0000, 0.9838,  ..., 0.8687, 0.8863, 0.9571],
         [0.8208, 0.9838, 1.0000,  ..., 0.7850, 0.8051, 0.8951],
         ...,
         [0.9972, 0.8687, 0.7850,  ..., 1.0000, 0.9991, 0.9708],
         [0.9995, 0.8863, 0.8051,  ..., 0.9991, 1.0000, 0.9798],
         [0.9857, 0.9571, 0.8951,  ..., 0.9708, 0.9798, 1.0000]],

        [[1.0000, 0.8317, 0.8883,  ..., 0.7334, 0.9979, 0.9300],
         [0.8317, 1.0000, 0.9916,  ..., 0.9798, 0.8010, 0.9718],
         [0.8883, 0.9916, 1.0000,  ..., 0.9471, 0.8608, 0.

 56%|█████▋    | 18/32 [02:55<02:28, 10.58s/it]

tensor([[[1.0000, 0.9962, 0.9719,  ..., 0.9984, 0.9700, 0.9939],
         [0.9962, 1.0000, 0.9885,  ..., 0.9897, 0.9873, 0.9807],
         [0.9719, 0.9885, 1.0000,  ..., 0.9574, 1.0000, 0.9415],
         ...,
         [0.9984, 0.9897, 0.9574,  ..., 1.0000, 0.9552, 0.9986],
         [0.9700, 0.9873, 1.0000,  ..., 0.9552, 1.0000, 0.9390],
         [0.9939, 0.9807, 0.9415,  ..., 0.9986, 0.9390, 1.0000]],

        [[1.0000, 0.9183, 0.9844,  ..., 0.9928, 0.9393, 0.8548],
         [0.9183, 1.0000, 0.8450,  ..., 0.9572, 0.7594, 0.6512],
         [0.9844, 0.8450, 1.0000,  ..., 0.9570, 0.9841, 0.9261],
         ...,
         [0.9928, 0.9572, 0.9570,  ..., 1.0000, 0.8955, 0.7988],
         [0.9393, 0.7594, 0.9841,  ..., 0.8955, 1.0000, 0.9769],
         [0.8548, 0.6512, 0.9261,  ..., 0.7988, 0.9769, 1.0000]],

        [[1.0000, 0.9967, 0.9975,  ..., 0.8759, 0.9737, 0.9093],
         [0.9967, 1.0000, 0.9999,  ..., 0.9086, 0.9888, 0.9380],
         [0.9975, 0.9999, 1.0000,  ..., 0.9046, 0.9872, 0.

 59%|█████▉    | 19/32 [03:06<02:19, 10.72s/it]

tensor([[[1.0000, 0.9756, 1.0000,  ..., 0.7358, 0.8835, 0.8900],
         [0.9756, 1.0000, 0.9734,  ..., 0.8433, 0.7814, 0.9649],
         [1.0000, 0.9734, 1.0000,  ..., 0.7309, 0.8875, 0.8860],
         ...,
         [0.7358, 0.8433, 0.7309,  ..., 1.0000, 0.4935, 0.9474],
         [0.8835, 0.7814, 0.8875,  ..., 0.4935, 1.0000, 0.6479],
         [0.8900, 0.9649, 0.8860,  ..., 0.9474, 0.6479, 1.0000]],

        [[1.0000, 0.7908, 0.8041,  ..., 0.9411, 0.7788, 0.7944],
         [0.7908, 1.0000, 0.9996,  ..., 0.9385, 0.9997, 1.0000],
         [0.8041, 0.9996, 1.0000,  ..., 0.9472, 0.9986, 0.9998],
         ...,
         [0.9411, 0.9385, 0.9472,  ..., 1.0000, 0.9303, 0.9409],
         [0.7788, 0.9997, 0.9986,  ..., 0.9303, 1.0000, 0.9995],
         [0.7944, 1.0000, 0.9998,  ..., 0.9409, 0.9995, 1.0000]],

        [[1.0000, 0.9996, 0.9901,  ..., 0.7416, 0.9929, 0.9850],
         [0.9996, 1.0000, 0.9859,  ..., 0.7554, 0.9958, 0.9799],
         [0.9901, 0.9859, 1.0000,  ..., 0.6712, 0.9670, 0.

 62%|██████▎   | 20/32 [03:16<02:09, 10.77s/it]

tensor([[[1.0000, 0.9545, 0.9989,  ..., 0.9964, 0.9128, 0.9255],
         [0.9545, 1.0000, 0.9400,  ..., 0.9277, 0.9921, 0.7937],
         [0.9989, 0.9400, 1.0000,  ..., 0.9993, 0.8944, 0.9414],
         ...,
         [0.9964, 0.9277, 0.9993,  ..., 1.0000, 0.8792, 0.9526],
         [0.9128, 0.9921, 0.8944,  ..., 0.8792, 1.0000, 0.7316],
         [0.9255, 0.7937, 0.9414,  ..., 0.9526, 0.7316, 1.0000]],

        [[1.0000, 0.9601, 0.9938,  ..., 0.9897, 0.8928, 0.9992],
         [0.9601, 1.0000, 0.9848,  ..., 0.9137, 0.9807, 0.9699],
         [0.9938, 0.9848, 1.0000,  ..., 0.9682, 0.9341, 0.9973],
         ...,
         [0.9897, 0.9137, 0.9682,  ..., 1.0000, 0.8302, 0.9835],
         [0.8928, 0.9807, 0.9341,  ..., 0.8302, 1.0000, 0.9080],
         [0.9992, 0.9699, 0.9973,  ..., 0.9835, 0.9080, 1.0000]],

        [[1.0000, 0.9517, 0.9965,  ..., 0.9942, 0.9814, 0.9929],
         [0.9517, 1.0000, 0.9244,  ..., 0.9160, 0.9925, 0.9115],
         [0.9965, 0.9244, 1.0000,  ..., 0.9997, 0.9624, 0.

 66%|██████▌   | 21/32 [03:27<01:57, 10.69s/it]

tensor([[[1.0000, 0.9781, 0.9888,  ..., 0.9386, 0.9821, 0.8368],
         [0.9781, 1.0000, 0.9379,  ..., 0.9891, 0.9998, 0.9238],
         [0.9888, 0.9379, 1.0000,  ..., 0.8825, 0.9445, 0.7648],
         ...,
         [0.9386, 0.9891, 0.8825,  ..., 1.0000, 0.9858, 0.9684],
         [0.9821, 0.9998, 0.9445,  ..., 0.9858, 1.0000, 0.9163],
         [0.8368, 0.9238, 0.7648,  ..., 0.9684, 0.9163, 1.0000]],

        [[1.0000, 0.9841, 0.9977,  ..., 0.7993, 0.8408, 0.9944],
         [0.9841, 1.0000, 0.9703,  ..., 0.7108, 0.7549, 0.9605],
         [0.9977, 0.9703, 1.0000,  ..., 0.8310, 0.8705, 0.9992],
         ...,
         [0.7993, 0.7108, 0.8310,  ..., 1.0000, 0.9961, 0.8489],
         [0.8408, 0.7549, 0.8705,  ..., 0.9961, 1.0000, 0.8870],
         [0.9944, 0.9605, 0.9992,  ..., 0.8489, 0.8870, 1.0000]],

        [[1.0000, 0.9247, 0.9598,  ..., 0.7384, 0.9458, 0.9934],
         [0.9247, 1.0000, 0.9937,  ..., 0.5433, 0.9980, 0.8799],
         [0.9598, 0.9937, 1.0000,  ..., 0.5959, 0.9988, 0.

 69%|██████▉   | 22/32 [03:43<02:03, 12.39s/it]

tensor([[[1.0000, 0.9986, 0.9869,  ..., 0.7009, 0.9981, 0.9957],
         [0.9986, 1.0000, 0.9941,  ..., 0.6743, 0.9934, 0.9894],
         [0.9869, 0.9941, 1.0000,  ..., 0.6204, 0.9754, 0.9682],
         ...,
         [0.7009, 0.6743, 0.6204,  ..., 1.0000, 0.7316, 0.7474],
         [0.9981, 0.9934, 0.9754,  ..., 0.7316, 1.0000, 0.9995],
         [0.9957, 0.9894, 0.9682,  ..., 0.7474, 0.9995, 1.0000]],

        [[1.0000, 0.8709, 0.9505,  ..., 0.8655, 0.9896, 0.9922],
         [0.8709, 1.0000, 0.9770,  ..., 0.9999, 0.8047, 0.9204],
         [0.9505, 0.9770, 1.0000,  ..., 0.9743, 0.9002, 0.9810],
         ...,
         [0.8655, 0.9999, 0.9743,  ..., 1.0000, 0.7988, 0.9158],
         [0.9896, 0.8047, 0.9002,  ..., 0.7988, 1.0000, 0.9646],
         [0.9922, 0.9204, 0.9810,  ..., 0.9158, 0.9646, 1.0000]],

        [[1.0000, 0.9892, 0.9597,  ..., 0.6566, 0.9883, 0.9234],
         [0.9892, 1.0000, 0.9901,  ..., 0.7302, 0.9563, 0.9678],
         [0.9597, 0.9901, 1.0000,  ..., 0.7996, 0.9095, 0.

 72%|███████▏  | 23/32 [03:54<01:46, 11.82s/it]

tensor([[[1.0000, 0.9999, 1.0000,  ..., 0.9565, 0.9995, 0.9897],
         [0.9999, 1.0000, 0.9997,  ..., 0.9607, 0.9989, 0.9918],
         [1.0000, 0.9997, 1.0000,  ..., 0.9543, 0.9997, 0.9886],
         ...,
         [0.9565, 0.9607, 0.9543,  ..., 1.0000, 0.9472, 0.9879],
         [0.9995, 0.9989, 0.9997,  ..., 0.9472, 1.0000, 0.9847],
         [0.9897, 0.9918, 0.9886,  ..., 0.9879, 0.9847, 1.0000]],

        [[1.0000, 1.0000, 0.9904,  ..., 0.9247, 0.6958, 0.8161],
         [1.0000, 1.0000, 0.9901,  ..., 0.9255, 0.6970, 0.8172],
         [0.9904, 0.9901, 1.0000,  ..., 0.8702, 0.6271, 0.7487],
         ...,
         [0.9247, 0.9255, 0.8702,  ..., 1.0000, 0.8856, 0.9675],
         [0.6958, 0.6970, 0.6271,  ..., 0.8856, 1.0000, 0.9708],
         [0.8161, 0.8172, 0.7487,  ..., 0.9675, 0.9708, 1.0000]],

        [[1.0000, 0.6173, 0.5149,  ..., 0.7072, 0.9982, 0.9868],
         [0.6173, 1.0000, 0.9763,  ..., 0.9837, 0.6464, 0.5404],
         [0.5149, 0.9763, 1.0000,  ..., 0.9244, 0.5415, 0.

 75%|███████▌  | 24/32 [04:04<01:31, 11.43s/it]

tensor([[[1.0000, 0.8217, 0.9499,  ..., 0.9555, 0.9700, 0.8150],
         [0.8217, 1.0000, 0.9497,  ..., 0.6723, 0.9253, 0.5035],
         [0.9499, 0.9497, 1.0000,  ..., 0.8311, 0.9972, 0.6555],
         ...,
         [0.9555, 0.6723, 0.8311,  ..., 1.0000, 0.8649, 0.9393],
         [0.9700, 0.9253, 0.9972,  ..., 0.8649, 1.0000, 0.6930],
         [0.8150, 0.5035, 0.6555,  ..., 0.9393, 0.6930, 1.0000]],

        [[1.0000, 0.6875, 0.6892,  ..., 0.9984, 0.8788, 0.9584],
         [0.6875, 1.0000, 1.0000,  ..., 0.7157, 0.4486, 0.5458],
         [0.6892, 1.0000, 1.0000,  ..., 0.7174, 0.4500, 0.5473],
         ...,
         [0.9984, 0.7157, 0.7174,  ..., 1.0000, 0.8544, 0.9417],
         [0.8788, 0.4486, 0.4500,  ..., 0.8544, 1.0000, 0.9750],
         [0.9584, 0.5458, 0.5473,  ..., 0.9417, 0.9750, 1.0000]],

        [[1.0000, 0.2745, 0.2587,  ..., 0.7082, 0.8449, 0.5736],
         [0.2745, 1.0000, 0.9981,  ..., 0.6055, 0.4777, 0.7417],
         [0.2587, 0.9981, 1.0000,  ..., 0.5761, 0.4523, 0.

 78%|███████▊  | 25/32 [04:15<01:17, 11.10s/it]

tensor([[[1.0000, 0.9969, 0.9968,  ..., 0.9477, 1.0000, 0.9473],
         [0.9969, 1.0000, 0.9876,  ..., 0.9690, 0.9976, 0.9687],
         [0.9968, 0.9876, 1.0000,  ..., 0.9212, 0.9961, 0.9207],
         ...,
         [0.9477, 0.9690, 0.9212,  ..., 1.0000, 0.9504, 1.0000],
         [1.0000, 0.9976, 0.9961,  ..., 0.9504, 1.0000, 0.9499],
         [0.9473, 0.9687, 0.9207,  ..., 1.0000, 0.9499, 1.0000]],

        [[1.0000, 0.8062, 0.8469,  ..., 0.9539, 0.8946, 0.8127],
         [0.8062, 1.0000, 0.9962,  ..., 0.9350, 0.9804, 0.9999],
         [0.8469, 0.9962, 1.0000,  ..., 0.9612, 0.9938, 0.9973],
         ...,
         [0.9539, 0.9350, 0.9612,  ..., 1.0000, 0.9855, 0.9395],
         [0.8946, 0.9804, 0.9938,  ..., 0.9855, 1.0000, 0.9830],
         [0.8127, 0.9999, 0.9973,  ..., 0.9395, 0.9830, 1.0000]],

        [[1.0000, 0.9858, 0.9538,  ..., 0.9071, 0.8172, 0.9212],
         [0.9858, 1.0000, 0.9902,  ..., 0.9622, 0.7347, 0.8526],
         [0.9538, 0.9902, 1.0000,  ..., 0.9904, 0.6644, 0.

 81%|████████▏ | 26/32 [04:25<01:05, 10.92s/it]

tensor([[[1.0000, 0.9340, 0.9762,  ..., 0.9849, 0.7896, 0.9310],
         [0.9340, 1.0000, 0.8464,  ..., 0.8664, 0.9445, 0.7692],
         [0.9762, 0.8464, 1.0000,  ..., 0.9990, 0.6804, 0.9869],
         ...,
         [0.9849, 0.8664, 0.9990,  ..., 1.0000, 0.7031, 0.9787],
         [0.7896, 0.9445, 0.6804,  ..., 0.7031, 1.0000, 0.6005],
         [0.9310, 0.7692, 0.9869,  ..., 0.9787, 0.6005, 1.0000]],

        [[1.0000, 0.9610, 0.9983,  ..., 0.9927, 0.9994, 0.9850],
         [0.9610, 1.0000, 0.9440,  ..., 0.9232, 0.9698, 0.9940],
         [0.9983, 0.9440, 1.0000,  ..., 0.9981, 0.9956, 0.9734],
         ...,
         [0.9927, 0.9232, 0.9981,  ..., 1.0000, 0.9879, 0.9578],
         [0.9994, 0.9698, 0.9956,  ..., 0.9879, 1.0000, 0.9904],
         [0.9850, 0.9940, 0.9734,  ..., 0.9578, 0.9904, 1.0000]],

        [[1.0000, 0.8196, 0.8282,  ..., 0.9658, 0.9990, 0.8489],
         [0.8196, 1.0000, 0.9998,  ..., 0.9297, 0.7988, 0.5341],
         [0.8282, 0.9998, 1.0000,  ..., 0.9358, 0.8076, 0.

 84%|████████▍ | 27/32 [04:36<00:54, 10.89s/it]

tensor([[[1.0000, 0.8625, 0.9977,  ..., 0.8851, 0.6452, 0.9427],
         [0.8625, 1.0000, 0.8912,  ..., 0.9986, 0.9072, 0.9783],
         [0.9977, 0.8912, 1.0000,  ..., 0.9120, 0.6790, 0.9623],
         ...,
         [0.8851, 0.9986, 0.9120,  ..., 1.0000, 0.8861, 0.9878],
         [0.6452, 0.9072, 0.6790,  ..., 0.8861, 1.0000, 0.8164],
         [0.9427, 0.9783, 0.9623,  ..., 0.9878, 0.8164, 1.0000]],

        [[1.0000, 0.9960, 0.9939,  ..., 0.9784, 0.9981, 0.9993],
         [0.9960, 1.0000, 0.9803,  ..., 0.9568, 0.9887, 0.9921],
         [0.9939, 0.9803, 1.0000,  ..., 0.9951, 0.9988, 0.9972],
         ...,
         [0.9784, 0.9568, 0.9951,  ..., 1.0000, 0.9891, 0.9852],
         [0.9981, 0.9887, 0.9988,  ..., 0.9891, 1.0000, 0.9997],
         [0.9993, 0.9921, 0.9972,  ..., 0.9852, 0.9997, 1.0000]],

        [[1.0000, 0.8457, 0.7722,  ..., 0.7126, 0.7816, 0.9178],
         [0.8457, 1.0000, 0.9881,  ..., 0.9633, 0.9908, 0.6404],
         [0.7722, 0.9881, 1.0000,  ..., 0.9929, 0.9998, 0.

 88%|████████▊ | 28/32 [04:46<00:42, 10.73s/it]

tensor([[[1.0000, 0.7903, 0.7748,  ..., 0.5436, 0.9953, 0.4970],
         [0.7903, 1.0000, 0.9995,  ..., 0.8845, 0.8360, 0.8387],
         [0.7748, 0.9995, 1.0000,  ..., 0.8974, 0.8212, 0.8531],
         ...,
         [0.5436, 0.8845, 0.8974,  ..., 1.0000, 0.5887, 0.9945],
         [0.9953, 0.8360, 0.8212,  ..., 0.5887, 1.0000, 0.5399],
         [0.4970, 0.8387, 0.8531,  ..., 0.9945, 0.5399, 1.0000]],

        [[1.0000, 0.7182, 0.7398,  ..., 0.7712, 0.6964, 0.7859],
         [0.7182, 1.0000, 0.9991,  ..., 0.9943, 0.9990, 0.9907],
         [0.7398, 0.9991, 1.0000,  ..., 0.9980, 0.9962, 0.9956],
         ...,
         [0.7712, 0.9943, 0.9980,  ..., 1.0000, 0.9888, 0.9995],
         [0.6964, 0.9990, 0.9962,  ..., 0.9888, 1.0000, 0.9839],
         [0.7859, 0.9907, 0.9956,  ..., 0.9995, 0.9839, 1.0000]],

        [[1.0000, 0.9738, 0.9999,  ..., 0.9986, 0.9993, 1.0000],
         [0.9738, 1.0000, 0.9763,  ..., 0.9844, 0.9816, 0.9754],
         [0.9999, 0.9763, 1.0000,  ..., 0.9991, 0.9996, 1.

 91%|█████████ | 29/32 [04:57<00:32, 10.75s/it]

tensor([[[1.0000, 0.9437, 0.8702,  ..., 0.9899, 0.9484, 0.7240],
         [0.9437, 1.0000, 0.9811,  ..., 0.9801, 0.8100, 0.8853],
         [0.8702, 0.9811, 1.0000,  ..., 0.9260, 0.7138, 0.9544],
         ...,
         [0.9899, 0.9801, 0.9260,  ..., 1.0000, 0.8985, 0.7943],
         [0.9484, 0.8100, 0.7138,  ..., 0.8985, 1.0000, 0.5635],
         [0.7240, 0.8853, 0.9544,  ..., 0.7943, 0.5635, 1.0000]],

        [[1.0000, 0.8999, 0.8625,  ..., 0.8560, 0.9008, 0.9013],
         [0.8999, 1.0000, 0.9960,  ..., 0.6285, 1.0000, 0.6821],
         [0.8625, 0.9960, 1.0000,  ..., 0.5851, 0.9958, 0.6376],
         ...,
         [0.8560, 0.6285, 0.5851,  ..., 1.0000, 0.6297, 0.9941],
         [0.9008, 1.0000, 0.9958,  ..., 0.6297, 1.0000, 0.6833],
         [0.9013, 0.6821, 0.6376,  ..., 0.9941, 0.6833, 1.0000]],

        [[1.0000, 0.9885, 0.9984,  ..., 0.9944, 0.7397, 0.8956],
         [0.9885, 1.0000, 0.9954,  ..., 0.9989, 0.8141, 0.9492],
         [0.9984, 0.9954, 1.0000,  ..., 0.9988, 0.7677, 0.

 94%|█████████▍| 30/32 [05:08<00:21, 10.85s/it]

tensor([[[1.0000, 0.9943, 0.9986,  ..., 0.8449, 0.9644, 0.9908],
         [0.9943, 1.0000, 0.9986,  ..., 0.7950, 0.9328, 0.9712],
         [0.9986, 0.9986, 1.0000,  ..., 0.8205, 0.9498, 0.9823],
         ...,
         [0.8449, 0.7950, 0.8205,  ..., 1.0000, 0.9491, 0.9025],
         [0.9644, 0.9328, 0.9498,  ..., 0.9491, 1.0000, 0.9910],
         [0.9908, 0.9712, 0.9823,  ..., 0.9025, 0.9910, 1.0000]],

        [[1.0000, 0.9325, 0.9112,  ..., 0.8774, 0.5756, 0.9870],
         [0.9325, 1.0000, 0.9982,  ..., 0.9897, 0.7617, 0.9771],
         [0.9112, 0.9982, 1.0000,  ..., 0.9964, 0.7909, 0.9632],
         ...,
         [0.8774, 0.9897, 0.9964,  ..., 1.0000, 0.8310, 0.9385],
         [0.5756, 0.7617, 0.7909,  ..., 0.8310, 1.0000, 0.6542],
         [0.9870, 0.9771, 0.9632,  ..., 0.9385, 0.6542, 1.0000]],

        [[1.0000, 0.8877, 0.7627,  ..., 0.7209, 0.9219, 0.9127],
         [0.8877, 1.0000, 0.9642,  ..., 0.9399, 0.9960, 0.9980],
         [0.7627, 0.9642, 1.0000,  ..., 0.9965, 0.9382, 0.

 97%|█████████▋| 31/32 [05:19<00:10, 10.83s/it]

tensor([[[1.0000, 0.8635, 0.7988,  ..., 0.9767, 0.9758, 0.7713],
         [0.8635, 1.0000, 0.9902,  ..., 0.9458, 0.9471, 0.9809],
         [0.7988, 0.9902, 1.0000,  ..., 0.8960, 0.8976, 0.9984],
         ...,
         [0.9767, 0.9458, 0.8960,  ..., 1.0000, 1.0000, 0.8726],
         [0.9758, 0.9471, 0.8976,  ..., 1.0000, 1.0000, 0.8744],
         [0.7713, 0.9809, 0.9984,  ..., 0.8726, 0.8744, 1.0000]],

        [[1.0000, 0.9839, 0.9164,  ..., 0.9929, 0.8828, 0.4403],
         [0.9839, 1.0000, 0.9711,  ..., 0.9981, 0.9482, 0.3733],
         [0.9164, 0.9711, 1.0000,  ..., 0.9553, 0.9963, 0.2968],
         ...,
         [0.9929, 0.9981, 0.9553,  ..., 1.0000, 0.9283, 0.3951],
         [0.8828, 0.9482, 0.9963,  ..., 0.9283, 1.0000, 0.2734],
         [0.4403, 0.3733, 0.2968,  ..., 0.3951, 0.2734, 1.0000]],

        [[1.0000, 0.9947, 0.9873,  ..., 0.8062, 0.9226, 0.9998],
         [0.9947, 1.0000, 0.9984,  ..., 0.8537, 0.9555, 0.9967],
         [0.9873, 0.9984, 1.0000,  ..., 0.8785, 0.9703, 0.

100%|██████████| 32/32 [05:30<00:00, 10.33s/it]

tensor([[[1.0000, 0.9246, 0.7988,  ..., 0.9940, 0.9737, 0.9313],
         [0.9246, 1.0000, 0.9581,  ..., 0.8823, 0.9859, 0.9998],
         [0.7988, 0.9581, 1.0000,  ..., 0.7450, 0.9014, 0.9527],
         ...,
         [0.9940, 0.8823, 0.7450,  ..., 1.0000, 0.9443, 0.8903],
         [0.9737, 0.9859, 0.9014,  ..., 0.9443, 1.0000, 0.9889],
         [0.9313, 0.9998, 0.9527,  ..., 0.8903, 0.9889, 1.0000]],

        [[1.0000, 0.8732, 0.9466,  ..., 0.7000, 0.9838, 0.8167],
         [0.8732, 1.0000, 0.9807,  ..., 0.9376, 0.9412, 0.9922],
         [0.9466, 0.9807, 1.0000,  ..., 0.8609, 0.9883, 0.9499],
         ...,
         [0.7000, 0.9376, 0.8609,  ..., 1.0000, 0.7897, 0.9724],
         [0.9838, 0.9412, 0.9883,  ..., 0.7897, 1.0000, 0.8961],
         [0.8167, 0.9922, 0.9499,  ..., 0.9724, 0.8961, 1.0000]],

        [[1.0000, 0.9640, 0.9381,  ..., 0.9906, 0.9922, 1.0000],
         [0.9640, 1.0000, 0.8279,  ..., 0.9213, 0.9892, 0.9653],
         [0.9381, 0.8279, 1.0000,  ..., 0.9755, 0.8923, 0.




In [5]:
import os
import numpy as np
os.system(f"mkdir {CACHE_PATH}")
np.save(f'{CACHE_PATH}/conflict_{SEED}_{epsilon}_{gamma}.npy', conflicts)
np.save(f'{CACHE_PATH}/gm_sim_{SEED}_{epsilon}_{gamma}.npy', norm_dominate)

mkdir: cannot create directory ‘cache/original_llama/’: File exists


In [6]:
import numpy as np
res = np.load(f'{CACHE_PATH}/conflict_{epsilon}.npy')
res2 = np.load(f'cache/19-46-26/conflict_0.01.npy')
x = [ _ for _ in range(n_layers)]
plt.plot(x, res)
plt.plot(x, res2)

FileNotFoundError: [Errno 2] No such file or directory: 'cache/original_llama//conflict_0.npy'