In [1]:
import os
import pandas as pd
import networkx as nx
import numpy as np
from networkx.utils import reverse_cuthill_mckee_ordering

In [2]:
NUMBER_NODES = 9

def load_data():
    test_df = pd.read_csv(os.path.join('datasets', f'dataset_{NUMBER_NODES}_test.csv'))

    featuresNumber = (NUMBER_NODES * NUMBER_NODES - NUMBER_NODES) // 2 
    def get_tuple_tensor_dataset(row):
        X = row[0 : featuresNumber].astype('int32')
        Y = row[featuresNumber + 1: ].astype('int32') # Pula a banda otima na posicao 0
        return X, Y

    test_dataset = list(map(get_tuple_tensor_dataset, test_df.to_numpy()))

    X = []
    Y = []
    for x, y in test_dataset:
        X.append(x)
        Y.append(y)
    x_test = np.array(X)
    y_test = np.array(Y)

    return x_test, y_test

In [3]:
x_test, y_test = load_data()

In [4]:
def get_optimal_bandwidth_nodelist_adjacency_rcm(Graph):
    rcm = list(reverse_cuthill_mckee_ordering(Graph))
    A = nx.adjacency_matrix(Graph, nodelist=rcm)
    L = nx.laplacian_matrix(nx.Graph(A))
    x, y = np.nonzero(L)
    return (x-y).max()

In [5]:
def count_repeats(output):
    counts = np.unique(np.round(output))
    repeated = NUMBER_NODES - counts.shape[0]
    return repeated

def get_valid_pred(pred):
    valid = np.ones(NUMBER_NODES)
    labels = np.arange(0, NUMBER_NODES)
    for i in labels:
        min_value = np.amin(pred)
        min_idx, = np.where(pred == min_value)
        min_idx = min_idx[0]
        pred[min_idx] = 100
        valid[min_idx] = i
    return valid
    
def getGraph(upperTriangleAdjMatrix):
    dense_adj = np.zeros((NUMBER_NODES, NUMBER_NODES))
    dense_adj = np.zeros((NUMBER_NODES, NUMBER_NODES))
    k = 0
    for i in range(NUMBER_NODES):
        for j in range(NUMBER_NODES):
            if i == j:
                continue
            elif i < j:
                dense_adj[i][j] = upperTriangleAdjMatrix[k]
                k += 1
            else:
                dense_adj[i][j] = dense_adj[j][i]
    return dense_adj

def get_bandwidth(Graph, nodelist=None):
    if not Graph.edges:
        return 0
    if nodelist.all() != None:
        L = nx.laplacian_matrix(Graph, nodelist=nodelist)
    else:
        L = nx.laplacian_matrix(Graph)
    x, y = np.nonzero(L)
    return (x-y).max()

In [6]:
sumTest_original = 0
sumTest_pred = 0
sumTest_true = 0

count = 0
cases_with_repetition = 0
for i in range(len(x_test)):
    graph = getGraph(x_test[i])
    Graph = nx.Graph(graph)
    output = np.array(list(reverse_cuthill_mckee_ordering(Graph)))

    quantity_repeated = count_repeats(np.round(output))

    print('Pred: ', output)
    print('True: ', y_test[i])

    if quantity_repeated != 0:
        cases_with_repetition += 1

    output = get_valid_pred(output)
    print('Pred valid: ', output)
    count += quantity_repeated

    original_band = get_bandwidth(Graph, np.array(None))
    sumTest_original += original_band
    pred_band = get_bandwidth(Graph, output)
    sumTest_pred += pred_band
    true_band = get_bandwidth(Graph, y_test[i])
    sumTest_true += true_band
    print("Bandwidth")
    print(original_band)
    print(pred_band)
    print(true_band)
print('Quantidade de rótulos repetidos, exemplo [1, 1, 1, 1, 1, 1, 1] conta como 6 - ', count)
print('Quantidade de saídas com repetição, exemplo [1, 1, 1, 1, 1, 1, 1] conta como 1 - ', cases_with_repetition)
test_length = x_test.shape[0]
print('Test length - ', test_length)
print("Bandwidth mean")
print(sumTest_original / test_length)
print("Pred bandwidth mean")
print(sumTest_pred / test_length)
print("True bandwidth mean")
print(sumTest_true / test_length)

Pred:  [6 5 4 3 2 8 1 0 7]
True:  [1 8 4 5 6 7 0 2 3]
Pred valid:  [6. 5. 4. 3. 2. 8. 1. 0. 7.]
Bandwidth
7
1
1
Pred:  [6 5 4 3 2 7 0 8 1]
True:  [1 8 0 7 4 6 5 2 3]
Pred valid:  [6. 5. 4. 3. 2. 7. 0. 8. 1.]
Bandwidth
8
2
2
Pred:  [6 5 4 3 7 0 2 8 1]
True:  [1 0 8 7 2 5 3 6 4]
Pred valid:  [6. 5. 4. 3. 7. 0. 2. 8. 1.]
Bandwidth
8
3
2
Pred:  [6 5 4 3 2 1 8 0 7]
True:  [6 5 7 4 3 0 8 1 2]
Pred valid:  [6. 5. 4. 3. 2. 1. 8. 0. 7.]
Bandwidth
8
4
3
Pred:  [6 5 4 7 1 0 3 8 2]
True:  [7 1 8 0 2 3 4 6 5]
Pred valid:  [6. 5. 4. 7. 1. 0. 3. 8. 2.]
Bandwidth
8
4
3
Pred:  [6 5 1 7 0 4 3 8 2]
True:  [7 1 0 8 3 4 2 6 5]
Pred valid:  [6. 5. 1. 7. 0. 4. 3. 8. 2.]
Bandwidth
8
4
3
Pred:  [7 1 0 6 5 4 3 8 2]
True:  [3 6 2 5 8 0 4 1 7]
Pred valid:  [7. 1. 0. 6. 5. 4. 3. 8. 2.]
Bandwidth
8
7
4
Pred:  [6 2 1 0 7 5 4 8 3]
True:  [0 1 7 4 2 8 3 5 6]
Pred valid:  [6. 2. 1. 0. 7. 5. 4. 8. 3.]
Bandwidth
7
3
3
Pred:  [6 7 2 1 0 5 4 8 3]
True:  [0 7 4 1 8 2 3 5 6]
Pred valid:  [6. 7. 2. 1. 0. 5. 4. 8. 3.]
Bandwidt