In [1]:
import h5py
import numpy as np
from itertools import combinations
from sklearn.cluster import KMeans

In [2]:
def load_h5_file(file, num_cost_matrices):
    """Loads the train and test embeddings"""
    with h5py.File(file, 'r') as f:
        cost_matrices = {i:np.array(f[f'cost_matrix_{i}']) for i in range(num_cost_matrices)}
    return cost_matrices


In [3]:
len_dataset = 188
num_cost_matrices = len(list(combinations(range(len_dataset), r=2)))

print(f'Number of cost matrices: {num_cost_matrices}')

cost_matrices = load_h5_file('../data/cost_matrices.h5', num_cost_matrices)

Number of cost matrices: 17578


In [4]:
cost_matrices[1000]

array([[1.41663042e+00, 1.41663042e+00, 4.97740474e+00, 1.71907095e+00,
        1.30047794e+00, 1.71907096e+00, 4.97740478e+00, 1.78666707e+01,
        1.78666708e+01, 4.97740521e+00, 1.71907095e+00, 1.30047794e+00,
        1.71907095e+00, 4.97740502e+00, 1.41663042e+00, 1.41663042e+00,
        4.53574837e+00, 4.75825614e+00, 4.75825614e+00, 4.53574836e+00,
        4.75825614e+00, 4.75825614e+00, 4.53574836e+00, 4.75825614e+00,
        4.75825614e+00, 4.53574837e+00, 4.75825613e+00, 4.75825613e+00],
       [1.41659105e+00, 1.41659105e+00, 4.97741619e+00, 1.71903850e+00,
        1.30054248e+00, 1.71903851e+00, 4.97741623e+00, 1.78666856e+01,
        1.78666857e+01, 4.97741666e+00, 1.71903851e+00, 1.30054248e+00,
        1.71903851e+00, 4.97741647e+00, 1.41659105e+00, 1.41659105e+00,
        4.53545617e+00, 4.75798866e+00, 4.75798866e+00, 4.53545616e+00,
        4.75798866e+00, 4.75798866e+00, 4.53545616e+00, 4.75798866e+00,
        4.75798866e+00, 4.53545617e+00, 4.75798865e+00, 4.75798

In [5]:
def sub_matrice_generator_k_means(matrix, row_num_clusters, col_num_clusters):
    cluster_matrices = []
    kmeans_row = KMeans(n_clusters=row_num_clusters)
    kmeans_col = KMeans(n_clusters=col_num_clusters)
    kmeans_row.fit(matrix)
    kmeans_col.fit(matrix.T)
    row_clusters = kmeans_row.labels_
    col_clusters = kmeans_col.labels_
    row_ids = np.arange(row_num_clusters)
    col_ids = np.arange(col_num_clusters)
    for row_id in row_ids:
        for col_id in col_ids:
            row_indices = np.where(row_clusters == row_id)[0]
            col_indices = np.where(col_clusters == col_id)[0]

            sub_matrix = matrix[row_indices, :][:, col_indices]
            cluster_matrices.append(sub_matrix)
    return cluster_matrices

In [17]:
sub_matices = []
for i in range(num_cost_matrices):
    sub_matices.append(sub_matrice_generator_k_means(cost_matrices[i], 4, 4))

In [18]:
len(sub_matices[0])

16

In [19]:
sub_matices[1000][0]

array([[22.43221923, 22.43221923, 22.76183495, 22.81412635, 22.76183495,
        22.76183495, 22.81412634, 22.76183495, 22.43221923, 22.43221923],
       [21.58119661, 21.58119661, 21.96430453, 22.01900331, 21.96430453,
        21.96430454, 22.0190033 , 21.96430453, 21.58119661, 21.58119661]])

In [20]:
sub_matices[1000][1]

array([[5.85331358, 5.85331356],
       [4.82604689, 4.82604683]])

In [21]:
sub_matices[1000][2]

array([[22.51536044, 22.04611664, 22.04611664, 22.51536053, 22.0461166 ,
        22.0461166 , 22.51536053, 22.0461166 , 22.0461166 , 22.51536044,
        22.04611666, 22.04611666],
       [21.70758281, 21.22031491, 21.22031491, 21.7075829 , 21.22031487,
        21.22031487, 21.7075829 , 21.22031487, 21.22031487, 21.70758281,
        21.22031493, 21.22031493]])

In [22]:
sub_matices[1000][3]

array([[18.61581513, 18.61581508, 18.61581446, 18.61581474],
       [17.71834033, 17.71834028, 17.71833967, 17.71833995]])

In [23]:
sub_matices[1000][4]

array([[4.6380783 , 4.6380783 , 4.93669821, 4.88141967, 4.93669821,
        4.93669823, 4.88141967, 4.93669821, 4.6380783 , 4.6380783 ],
       [4.99516528, 4.99516528, 5.24802268, 4.7953176 , 5.24802269,
        5.2480227 , 4.7953176 , 5.24802269, 4.99516528, 4.99516528],
       [4.99516528, 4.99516528, 5.24802268, 4.7953176 , 5.24802269,
        5.2480227 , 4.7953176 , 5.24802269, 4.99516528, 4.99516528],
       [4.58679805, 4.58679805, 4.88899149, 4.84541088, 4.8889915 ,
        4.88899151, 4.84541088, 4.88899149, 4.58679805, 4.58679805],
       [4.96643519, 4.96643519, 5.22068428, 4.76538279, 5.22068429,
        5.2206843 , 4.76538279, 5.22068428, 4.96643519, 4.96643519],
       [4.96643519, 4.96643519, 5.22068428, 4.76538279, 5.22068429,
        5.2206843 , 4.76538279, 5.22068428, 4.96643519, 4.96643519],
       [4.56993177, 4.56993177, 4.87325471, 4.83185067, 4.87325471,
        4.87325473, 4.83185067, 4.87325471, 4.56993177, 4.56993177],
       [4.95201236, 4.95201236, 5.2069657

In [24]:
sub_matices[1000][5]

array([[17.7004127 , 17.70041285],
       [17.18518757, 17.18518771],
       [17.18518757, 17.18518771],
       [17.80129295, 17.8012931 ],
       [17.26798621, 17.26798635],
       [17.26798621, 17.26798635],
       [17.83057301, 17.83057316],
       [17.31278647, 17.31278661],
       [17.31278647, 17.31278661]])

In [25]:
sub_matices[1000][6]

array([[2.80026997e-01, 3.73138452e+00, 3.73138452e+00, 2.80027117e-01,
        3.73138454e+00, 3.73138454e+00, 2.80027117e-01, 3.73138454e+00,
        3.73138454e+00, 2.80026997e-01, 3.73138454e+00, 3.73138454e+00],
       [3.77832758e+00, 1.78414230e-01, 1.78414230e-01, 3.77832762e+00,
        1.78414170e-01, 1.78414170e-01, 3.77832762e+00, 1.78414170e-01,
        1.78414170e-01, 3.77832760e+00, 1.78414257e-01, 1.78414257e-01],
       [3.77832758e+00, 1.78414230e-01, 1.78414230e-01, 3.77832762e+00,
        1.78414170e-01, 1.78414170e-01, 3.77832762e+00, 1.78414170e-01,
        1.78414170e-01, 3.77832760e+00, 1.78414257e-01, 1.78414257e-01],
       [7.22297560e-02, 3.74508739e+00, 3.74508739e+00, 7.22298235e-02,
        3.74508742e+00, 3.74508742e+00, 7.22298235e-02, 3.74508742e+00,
        3.74508742e+00, 7.22297560e-02, 3.74508741e+00, 3.74508741e+00],
       [3.75671382e+00, 6.24950831e-02, 6.24950831e-02, 3.75671386e+00,
        6.24950246e-02, 6.24950246e-02, 3.75671386e+00, 6.24

In [26]:
sub_matices[1000][7]

array([[6.27178017, 6.27178021, 6.27178047, 6.27178034],
       [6.03217119, 6.03217125, 6.03217145, 6.03217135],
       [6.03217119, 6.03217125, 6.03217145, 6.03217135],
       [6.28818549, 6.28818553, 6.2881858 , 6.28818567],
       [6.05912658, 6.05912664, 6.05912685, 6.05912675],
       [6.05912658, 6.05912664, 6.05912685, 6.05912675],
       [6.29059772, 6.29059776, 6.29059803, 6.2905979 ],
       [6.07456587, 6.07456593, 6.07456615, 6.07456605],
       [6.07456587, 6.07456593, 6.07456615, 6.07456605]])

In [27]:
sub_matices[1000][8]

array([[1.41663042, 1.41663042, 1.71907095, 1.30047794, 1.71907096,
        1.71907095, 1.30047794, 1.71907095, 1.41663042, 1.41663042],
       [1.41659105, 1.41659105, 1.7190385 , 1.30054248, 1.71903851,
        1.71903851, 1.30054248, 1.71903851, 1.41659105, 1.41659105],
       [0.96787845, 0.96787845, 2.47133409, 2.87613993, 2.47133409,
        2.4713341 , 2.87613993, 2.47133409, 0.96787845, 0.96787845],
       [2.36625744, 2.36625744, 3.28174919, 3.59498525, 3.28174919,
        3.2817492 , 3.59498524, 3.28174919, 2.36625744, 2.36625744],
       [4.89034803, 4.89034803, 5.38779528, 5.58263766, 5.38779528,
        5.38779529, 5.58263766, 5.38779528, 4.89034803, 4.89034803],
       [2.10395921, 2.10395921, 1.84610251, 0.70065837, 1.84610252,
        1.84610252, 0.70065837, 1.84610252, 2.10395921, 2.10395921],
       [2.18763647, 2.18763647, 0.0062373 , 2.32058934, 0.00623733,
        0.0062373 , 2.32058934, 0.00623733, 2.18763647, 2.18763647],
       [4.48399643, 4.48399643, 5.4719995

In [28]:
sub_matices[1000][9]

array([[17.86667065, 17.86667078],
       [17.86668557, 17.8666857 ],
       [16.72449001, 16.72449013],
       [15.39832726, 15.39832738],
       [13.2470704 , 13.24707054],
       [18.01752854, 18.01752867],
       [18.02867267, 18.02867279],
       [13.77471154, 13.77471165],
       [17.18935174, 17.18935187],
       [18.06315461, 18.06315474],
       [17.76072644, 17.76072657]])

In [29]:
sub_matices[1000][10]

array([[4.53574837, 4.75825614, 4.75825614, 4.53574836, 4.75825614,
        4.75825614, 4.53574836, 4.75825614, 4.75825614, 4.53574837,
        4.75825613, 4.75825613],
       [4.53545617, 4.75798866, 4.75798866, 4.53545616, 4.75798866,
        4.75798866, 4.53545616, 4.75798866, 4.75798866, 4.53545617,
        4.75798865, 4.75798865],
       [4.57830632, 4.86496271, 4.86496271, 4.57830632, 4.86496271,
        4.86496271, 4.57830632, 4.86496271, 4.86496271, 4.57830631,
        4.86496271, 4.86496271],
       [4.92721613, 5.08704458, 5.08704458, 4.92721615, 5.08704457,
        5.08704457, 4.92721615, 5.08704457, 5.08704457, 4.92721612,
        5.08704459, 5.08704459],
       [6.34864403, 6.32695104, 6.32695104, 6.34864409, 6.32695102,
        6.32695102, 6.34864409, 6.32695102, 6.32695102, 6.34864403,
        6.32695105, 6.32695105],
       [4.65640954, 4.77637995, 4.77637995, 4.65640954, 4.77637996,
        4.77637996, 4.65640954, 4.77637996, 4.77637996, 4.65640955,
        4.77637995,

In [30]:
sub_matices[1000][11]

array([[4.97740474, 4.97740478, 4.97740521, 4.97740502],
       [4.97741619, 4.97741623, 4.97741666, 4.97741647],
       [3.65277318, 3.65277321, 3.65277368, 3.65277346],
       [2.7253479 , 2.72534791, 2.7253482 , 2.72534808],
       [2.86064278, 2.86064273, 2.86064247, 2.86064261],
       [5.35508563, 5.35508567, 5.35508606, 5.35508589],
       [5.41656094, 5.41656097, 5.41656137, 5.41656119],
       [0.63769017, 0.63769012, 0.63769012, 0.63769023],
       [4.8555335 , 4.85553352, 4.85553384, 4.85553371],
       [5.50195983, 5.50195987, 5.50196026, 5.50196008],
       [4.7230165 , 4.72301653, 4.72301699, 4.72301679]])

In [31]:
sub_matices[1000][12]

array([[13.95036671, 13.95036671, 14.45948671, 14.54284645, 14.45948671,
        14.45948672, 14.54284644, 14.45948671, 13.95036671, 13.95036671],
       [17.49802161, 17.49802161, 17.9122841 , 17.97925788, 17.9122841 ,
        17.9122841 , 17.97925788, 17.9122841 , 17.49802161, 17.49802161],
       [16.95534871, 16.95534871, 17.33025175, 17.39435385, 17.33025175,
        17.33025175, 17.39435385, 17.33025175, 16.95534871, 16.95534871]])

In [32]:
sub_matices[1000][13]

array([[4.47982227, 4.47982234],
       [2.61423563, 2.61423568],
       [2.98162991, 2.98163005]])

In [33]:
sub_matices[1000][14]

array([[14.39727805, 13.93980805, 13.93980805, 14.39727813, 13.93980801,
        13.93980801, 14.39727813, 13.93980801, 13.93980801, 14.39727805,
        13.93980807, 13.93980807],
       [17.74810509, 17.2813133 , 17.2813133 , 17.74810517, 17.28131327,
        17.28131327, 17.74810517, 17.28131327, 17.28131327, 17.74810509,
        17.28131333, 17.28131333],
       [17.20411095, 16.753876  , 16.753876  , 17.20411103, 16.75387596,
        16.75387596, 17.20411103, 16.75387596, 16.75387596, 17.20411095,
        16.75387602, 16.75387602]])

In [34]:
sub_matices[1000][15]

array([[ 9.99204276,  9.9920427 ,  9.99204211,  9.99204238],
       [13.61188346, 13.6118834 , 13.6118828 , 13.61188307],
       [13.11376467, 13.11376462, 13.113764  , 13.11376428]])