In [1]:
import h5py
import numpy as np
from itertools import combinations
from sklearn.cluster import KMeans

In [2]:
def load_h5_file(file, num_cost_matrices):
    """Loads the train and test embeddings"""
    with h5py.File(file, 'r') as f:
        cost_matrices = {i:np.array(f[f'cost_matrix_{i}']) for i in range(num_cost_matrices)}
    return cost_matrices


In [3]:
len_dataset = 188
num_cost_matrices = len(list(combinations(range(len_dataset), r=2)))

print(f'Number of cost matrices: {num_cost_matrices}')

cost_matrices = load_h5_file('../data/cost_matrices.h5', num_cost_matrices)

Number of cost matrices: 17578


In [17]:
cost_matrices[1000]

array([[1.41663042e+00, 1.41663042e+00, 4.97740474e+00, 1.71907095e+00,
        1.30047794e+00, 1.71907096e+00, 4.97740478e+00, 1.78666707e+01,
        1.78666708e+01, 4.97740521e+00, 1.71907095e+00, 1.30047794e+00,
        1.71907095e+00, 4.97740502e+00, 1.41663042e+00, 1.41663042e+00,
        4.53574837e+00, 4.75825614e+00, 4.75825614e+00, 4.53574836e+00,
        4.75825614e+00, 4.75825614e+00, 4.53574836e+00, 4.75825614e+00,
        4.75825614e+00, 4.53574837e+00, 4.75825613e+00, 4.75825613e+00],
       [1.41659105e+00, 1.41659105e+00, 4.97741619e+00, 1.71903850e+00,
        1.30054248e+00, 1.71903851e+00, 4.97741623e+00, 1.78666856e+01,
        1.78666857e+01, 4.97741666e+00, 1.71903851e+00, 1.30054248e+00,
        1.71903851e+00, 4.97741647e+00, 1.41659105e+00, 1.41659105e+00,
        4.53545617e+00, 4.75798866e+00, 4.75798866e+00, 4.53545616e+00,
        4.75798866e+00, 4.75798866e+00, 4.53545616e+00, 4.75798866e+00,
        4.75798866e+00, 4.53545617e+00, 4.75798865e+00, 4.75798

In [5]:
def sub_matrice_generator_k_means(matrix, row_num_clusters, col_num_clusters):
    cluster_matrices = []
    kmeans_row = KMeans(n_clusters=row_num_clusters)
    kmeans_col = KMeans(n_clusters=col_num_clusters)
    kmeans_row.fit(matrix)
    kmeans_col.fit(matrix.T)
    row_clusters = kmeans_row.labels_
    col_clusters = kmeans_col.labels_
    row_ids = np.arange(row_num_clusters)
    col_ids = np.arange(col_num_clusters)
    for row_id in row_ids:
        for col_id in col_ids:
            row_indices = np.where(row_clusters == row_id)[0]
            col_indices = np.where(col_clusters == col_id)[0]

            sub_matrix = matrix[row_indices, :][:, col_indices]
            cluster_matrices.append(sub_matrix)
    return cluster_matrices

In [6]:
sub_matices = []
for i in range(num_cost_matrices):
    sub_matices.append(sub_matrice_generator_k_means(cost_matrices[i], 2, 2))

In [7]:
len(sub_matices[0])

4

In [18]:
sub_matices[1000][0]

array([[13.95036671, 13.95036671,  9.99204276, 14.45948671, 14.54284645,
        14.45948671,  9.9920427 ,  9.99204211, 14.45948672, 14.54284644,
        14.45948671,  9.99204238, 13.95036671, 13.95036671, 14.39727805,
        13.93980805, 13.93980805, 14.39727813, 13.93980801, 13.93980801,
        14.39727813, 13.93980801, 13.93980801, 14.39727805, 13.93980807,
        13.93980807],
       [17.49802161, 17.49802161, 13.61188346, 17.9122841 , 17.97925788,
        17.9122841 , 13.6118834 , 13.6118828 , 17.9122841 , 17.97925788,
        17.9122841 , 13.61188307, 17.49802161, 17.49802161, 17.74810509,
        17.2813133 , 17.2813133 , 17.74810517, 17.28131327, 17.28131327,
        17.74810517, 17.28131327, 17.28131327, 17.74810509, 17.28131333,
        17.28131333],
       [22.43221923, 22.43221923, 18.61581513, 22.76183495, 22.81412635,
        22.76183495, 18.61581508, 18.61581446, 22.76183495, 22.81412634,
        22.76183495, 18.61581474, 22.43221923, 22.43221923, 22.51536044,
       

In [19]:
sub_matices[1000][1]

array([[4.47982227, 4.47982234],
       [2.61423563, 2.61423568],
       [5.85331358, 5.85331356],
       [4.82604689, 4.82604683],
       [2.98162991, 2.98163005]])

In [20]:
sub_matices[1000][2]

array([[1.41663042e+00, 1.41663042e+00, 4.97740474e+00, 1.71907095e+00,
        1.30047794e+00, 1.71907096e+00, 4.97740478e+00, 4.97740521e+00,
        1.71907095e+00, 1.30047794e+00, 1.71907095e+00, 4.97740502e+00,
        1.41663042e+00, 1.41663042e+00, 4.53574837e+00, 4.75825614e+00,
        4.75825614e+00, 4.53574836e+00, 4.75825614e+00, 4.75825614e+00,
        4.53574836e+00, 4.75825614e+00, 4.75825614e+00, 4.53574837e+00,
        4.75825613e+00, 4.75825613e+00],
       [1.41659105e+00, 1.41659105e+00, 4.97741619e+00, 1.71903850e+00,
        1.30054248e+00, 1.71903851e+00, 4.97741623e+00, 4.97741666e+00,
        1.71903851e+00, 1.30054248e+00, 1.71903851e+00, 4.97741647e+00,
        1.41659105e+00, 1.41659105e+00, 4.53545617e+00, 4.75798866e+00,
        4.75798866e+00, 4.53545616e+00, 4.75798866e+00, 4.75798866e+00,
        4.53545616e+00, 4.75798866e+00, 4.75798866e+00, 4.53545617e+00,
        4.75798865e+00, 4.75798865e+00],
       [9.67878446e-01, 9.67878446e-01, 3.65277318e+00

In [21]:
sub_matices[1000][3]

array([[17.86667065, 17.86667078],
       [17.86668557, 17.8666857 ],
       [16.72449001, 16.72449013],
       [15.39832726, 15.39832738],
       [13.2470704 , 13.24707054],
       [18.01752854, 18.01752867],
       [18.02867267, 18.02867279],
       [13.77471154, 13.77471165],
       [17.18935174, 17.18935187],
       [18.06315461, 18.06315474],
       [17.76072644, 17.76072657],
       [17.7004127 , 17.70041285],
       [17.18518757, 17.18518771],
       [17.18518757, 17.18518771],
       [17.80129295, 17.8012931 ],
       [17.26798621, 17.26798635],
       [17.26798621, 17.26798635],
       [17.83057301, 17.83057316],
       [17.31278647, 17.31278661],
       [17.31278647, 17.31278661]])