In [23]:
import h5py
import numpy as np
from itertools import combinations

In [2]:
def load_h5_file(file, num_cost_matrices):
    """Loads the train and test embeddings"""
    with h5py.File(file, 'r') as f:
        cost_matrices = {i:np.array(f[f'cost_matrix_{i}']) for i in range(num_cost_matrices)}
    return cost_matrices


In [4]:
len_dataset = 188
num_cost_matrices = len(list(combinations(range(len_dataset), r=2)))

print(f'Number of cost matrices: {num_cost_matrices}')

cost_matrices = load_h5_file('../data/cost_matrices.h5', num_cost_matrices)

Number of cost matrices: 17578


In [12]:
cost_matrices[0]

array([[0.13845885, 0.13855484, 0.70526678, 1.01150449, 1.54351854,
        1.62207349, 1.62207708, 1.67934353, 1.84178903, 1.97717433,
        4.26150579, 4.33924013, 4.58241239, 4.76704108, 4.76704108,
        8.12694368, 8.33910386],
       [0.53724399, 0.71430031, 0.73315077, 0.73316307, 1.13961124,
        1.34748111, 1.36842527, 1.36843859, 2.10670704, 2.37059793,
        3.48328601, 3.57798544, 4.63460039, 5.01832703, 5.01832703,
        7.67941813, 7.90360381],
       [2.04664173, 2.10656624, 2.10657052, 2.11382245, 2.27469461,
        2.38405153, 2.45242473, 2.45243216, 2.81480159, 2.93104686,
        3.00860427, 3.24260571, 5.16500119, 5.23503958, 5.23503958,
        7.35370117, 7.58660316],
       [0.5387259 , 0.71541556, 0.73423739, 0.73424967, 1.1403106 ,
        1.34807263, 1.36098197, 1.36099536, 2.09980534, 2.36722034,
        3.48351488, 3.57820825, 4.6336976 , 5.01794937, 5.01794937,
        7.67952195, 7.90370468],
       [0.27700247, 0.27703624, 0.54436064, 1.127589

In [45]:
left, right = np.hsplit(cost_matrices[0], [round(len(cost_matrices[0][0])/ 2 )])

In [46]:
top_left, bottom_left = np.vsplit(left, [round(len(left)/ 2 )])

In [47]:
top_left

array([[0.13845885, 0.13855484, 0.70526678, 1.01150449, 1.54351854,
        1.62207349, 1.62207708, 1.67934353],
       [0.53724399, 0.71430031, 0.73315077, 0.73316307, 1.13961124,
        1.34748111, 1.36842527, 1.36843859],
       [2.04664173, 2.10656624, 2.10657052, 2.11382245, 2.27469461,
        2.38405153, 2.45242473, 2.45243216],
       [0.5387259 , 0.71541556, 0.73423739, 0.73424967, 1.1403106 ,
        1.34807263, 1.36098197, 1.36099536],
       [0.27700247, 0.27703624, 0.54436064, 1.12758955, 1.60346112,
        1.6925447 , 1.76448533, 1.76448758],
       [1.12847232, 1.26254331, 1.26254947, 1.29389919, 1.81451678,
        2.23287191, 2.28788746, 2.28788888]])

In [48]:
bottom_left

array([[2.5126827 , 3.80540993, 3.8571529 , 3.85728103, 3.91078517,
        4.06436983, 4.06436983, 4.06897398],
       [0.13842783, 1.66347551, 1.68685188, 1.68685811, 2.00529141,
        2.4354887 , 2.48604837, 2.48605054],
       [0.41038485, 1.60244598, 1.7414084 , 1.74142389, 2.17878516,
        2.59084472, 2.63838057, 2.63838522],
       [0.40010909, 1.04892411, 1.0489629 , 1.80267871, 1.91938257,
        2.40199692, 2.45317496, 2.4531826 ],
       [0.08263569, 3.74822884, 3.74822884, 4.59811514, 4.60078965,
        4.60090267, 4.63830696, 4.64374796],
       [0.04097388, 0.04097388, 3.7592952 , 4.80001732, 4.80277771,
        4.80288164, 4.87904504, 4.89554476],
       [0.04097388, 0.04097388, 3.7592952 , 4.80001732, 4.80277771,
        4.80288164, 4.87904504, 4.89554476]])

In [50]:
top_right, bottom_right = np.vsplit(right, [round(len(right)/ 2 )])

In [51]:
top_right

array([[1.84178903, 1.97717433, 4.26150579, 4.33924013, 4.58241239,
        4.76704108, 4.76704108, 8.12694368, 8.33910386],
       [2.10670704, 2.37059793, 3.48328601, 3.57798544, 4.63460039,
        5.01832703, 5.01832703, 7.67941813, 7.90360381],
       [2.81480159, 2.93104686, 3.00860427, 3.24260571, 5.16500119,
        5.23503958, 5.23503958, 7.35370117, 7.58660316],
       [2.09980534, 2.36722034, 3.48351488, 3.57820825, 4.6336976 ,
        5.01794937, 5.01794937, 7.67952195, 7.90370468],
       [1.96837194, 2.09559409, 4.42680932, 4.50169187, 4.61384684,
        4.82365019, 4.82365019, 8.22705119, 8.43669469],
       [2.44857401, 2.55196771, 4.50919282, 4.5602035 , 4.5602035 ,
        4.70209051, 4.77265652, 8.38405579, 8.58986845]])

In [52]:
bottom_right

array([[4.25690318, 4.26175755, 4.26181226, 4.27279025, 4.31855516,
        5.96715666, 5.99395358, 8.86589131, 9.05244965],
       [2.63469225, 2.7310446 , 4.89454013, 4.9204718 , 4.96237753,
        5.04459557, 5.04459557, 8.50421691, 8.70719146],
       [2.77886018, 2.87038546, 4.94067311, 4.96448022, 4.96649263,
        4.96649263, 5.03135994, 8.54484146, 8.74686992],
       [2.60365637, 2.70112804, 4.74711006, 4.75523324, 4.75523324,
        4.86736664, 4.9355479 , 8.48722194, 8.69059557],
       [4.65445803, 4.65489949, 4.65495218, 4.71905684, 4.9471019 ,
        6.03570288, 6.05458009, 8.65349668, 8.83470472],
       [4.93200601, 4.98461426, 4.9846615 , 4.98596714, 5.07339182,
        5.94933366, 5.95108043, 8.19741178, 8.37226051],
       [4.93200601, 4.98461426, 4.9846615 , 4.98596714, 5.07339182,
        5.94933366, 5.95108043, 8.19741178, 8.37226051]])