In [1]:
import h5py
import numpy as np
from itertools import combinations

In [2]:
def load_h5_file(file, num_cost_matrices):
    """Loads the train and test embeddings"""
    with h5py.File(file, 'r') as f:
        cost_matrices = {i:np.array(f[f'cost_matrix_{i}']) for i in range(num_cost_matrices)}
    return cost_matrices


In [3]:
len_dataset = 188
num_cost_matrices = len(list(combinations(range(len_dataset), r=2)))

print(f'Number of cost matrices: {num_cost_matrices}')

cost_matrices = load_h5_file('../data/cost_matrices.h5', num_cost_matrices)

Number of cost matrices: 17578


In [4]:
cost_matrices[0]

array([[0.13855484, 0.13845885, 1.84178903, 8.33910386, 4.33924013,
        1.54351854, 1.62207349, 1.62207708, 4.26150579, 8.12694368,
        1.97717433, 0.70526678, 1.67934353, 1.01150449, 4.58241239,
        4.76704108, 4.76704108],
       [1.36842527, 1.36843859, 1.13961124, 7.90360381, 3.57798544,
        0.53724399, 0.73316307, 0.73315077, 3.48328601, 7.67941813,
        1.34748111, 2.10670704, 2.37059793, 0.71430031, 4.63460039,
        5.01832703, 5.01832703],
       [2.45242473, 2.45243216, 2.27469461, 7.58660316, 2.93104686,
        2.04664173, 2.10657052, 2.10656624, 2.81480159, 7.35370117,
        2.38405153, 3.00860427, 3.24260571, 2.11382245, 5.16500119,
        5.23503958, 5.23503958],
       [1.36098197, 1.36099536, 1.1403106 , 7.90370468, 3.57820825,
        0.5387259 , 0.73424967, 0.73423739, 3.48351488, 7.67952195,
        1.34807263, 2.09980534, 2.36722034, 0.71541556, 4.6336976 ,
        5.01794937, 5.01794937],
       [0.27703624, 0.27700247, 1.96837194, 8.436694

In [16]:
top, middle, bottom = np.vsplit(cost_matrices[0], [round(len(cost_matrices[0])/3), round(2 * len(cost_matrices[0])/3)])
top_left, top_middle, top_right = np.hsplit(top, [round(len(top[0])/3), round(2 * len(top[0])/3)])
middle_left, middle_middle, middle_right = np.hsplit(middle, [round(len(middle[0])/3), round(2 * len(middle[0])/3)])
bottom_left, bottom_middle, bottom_right = np.hsplit(bottom, [round(len(bottom[0])/3), round(2 * len(bottom[0])/3)])

In [17]:
top_left

array([[0.13855484, 0.13845885, 1.84178903, 8.33910386, 4.33924013,
        1.54351854],
       [1.36842527, 1.36843859, 1.13961124, 7.90360381, 3.57798544,
        0.53724399],
       [2.45242473, 2.45243216, 2.27469461, 7.58660316, 2.93104686,
        2.04664173],
       [1.36098197, 1.36099536, 1.1403106 , 7.90370468, 3.57820825,
        0.5387259 ]])

In [18]:
top_middle

array([[1.62207349, 1.62207708, 4.26150579, 8.12694368, 1.97717433],
       [0.73316307, 0.73315077, 3.48328601, 7.67941813, 1.34748111],
       [2.10657052, 2.10656624, 2.81480159, 7.35370117, 2.38405153],
       [0.73424967, 0.73423739, 3.48351488, 7.67952195, 1.34807263]])

In [19]:
top_right

array([[0.70526678, 1.67934353, 1.01150449, 4.58241239, 4.76704108,
        4.76704108],
       [2.10670704, 2.37059793, 0.71430031, 4.63460039, 5.01832703,
        5.01832703],
       [3.00860427, 3.24260571, 2.11382245, 5.16500119, 5.23503958,
        5.23503958],
       [2.09980534, 2.36722034, 0.71541556, 4.6336976 , 5.01794937,
        5.01794937]])

In [20]:
middle_left

array([[0.27703624, 0.27700247, 1.96837194, 8.43669469, 4.50169187,
        1.6925447 ],
       [1.26254947, 1.26254331, 2.44857401, 8.58986845, 4.77265652,
        2.23287191],
       [3.8571529 , 3.85728103, 4.27279025, 9.05244965, 5.99395358,
        4.25690318],
       [1.68685188, 1.68685811, 2.63469225, 8.70719146, 4.96237753,
        2.4354887 ],
       [1.74142389, 1.7414084 , 2.77886018, 8.74686992, 5.03135994,
        2.59084472]])

In [21]:
middle_middle

array([[1.76448533, 1.76448758, 4.42680932, 8.22705119, 2.09559409],
       [2.28788746, 2.28788888, 4.70209051, 8.38405579, 2.55196771],
       [4.26181226, 4.26175755, 5.96715666, 8.86589131, 4.31855516],
       [2.48605054, 2.48604837, 4.89454013, 8.50421691, 2.7310446 ],
       [2.63838057, 2.63838522, 4.96448022, 8.54484146, 2.87038546]])

In [22]:
middle_right

array([[0.54436064, 1.60346112, 1.12758955, 4.61384684, 4.82365019,
        4.82365019],
       [1.12847232, 1.29389919, 1.81451678, 4.50919282, 4.5602035 ,
        4.5602035 ],
       [3.80540993, 3.91078517, 4.06897398, 2.5126827 , 4.06436983,
        4.06436983],
       [1.66347551, 0.13842783, 2.00529141, 4.9204718 , 5.04459557,
        5.04459557],
       [1.60244598, 0.41038485, 2.17878516, 4.94067311, 4.96649263,
        4.96649263]])

In [23]:
bottom_left

array([[1.0489629 , 1.04892411, 2.60365637, 8.69059557, 4.9355479 ,
        2.40199692],
       [4.60078965, 4.60090267, 4.63830696, 8.83470472, 6.05458009,
        4.64374796],
       [4.80277771, 4.80288164, 4.89554476, 8.37226051, 5.94933366,
        4.98596714],
       [4.80277771, 4.80288164, 4.89554476, 8.37226051, 5.94933366,
        4.98596714]])

In [24]:
bottom_middle

array([[2.45317496, 2.4531826 , 4.86736664, 8.48722194, 2.70112804],
       [4.65495218, 4.65489949, 6.03570288, 8.65349668, 4.65445803],
       [4.9846615 , 4.98461426, 5.95108043, 8.19741178, 4.87904504],
       [4.9846615 , 4.98461426, 5.95108043, 8.19741178, 4.87904504]])

In [25]:
bottom_right

array([[0.40010909, 1.80267871, 1.91938257, 4.74711006, 4.75523324,
        4.75523324],
       [4.71905684, 4.9471019 , 4.59811514, 0.08263569, 3.74822884,
        3.74822884],
       [4.80001732, 5.07339182, 4.93200601, 3.7592952 , 0.04097388,
        0.04097388],
       [4.80001732, 5.07339182, 4.93200601, 3.7592952 , 0.04097388,
        0.04097388]])