In [None]:
import numpy as np
from MatrixVectorizer import *
from typing import Union
import torch
import networkx as nx

In [None]:
# load csvs as numpy
lr_data_path = 'data/lr_train.csv'
hr_data_path = 'data/hr_train.csv'

lr_train_data = np.loadtxt(lr_data_path, delimiter=',')
hr_train_data = np.loadtxt(hr_data_path, delimiter=',')
lr_train_data[lr_train_data < 0] = 0
np.nan_to_num(lr_train_data, copy=False)

hr_train_data[hr_train_data < 0] = 0
np.nan_to_num(hr_train_data, copy=False)

# map the anti-vectorize function to each row of the lr_train_data

lr_train_data_vectorized = np.array([MatrixVectorizer.anti_vectorize(row, 160) for row in lr_train_data])
hr_train_data_vectorized = np.array([MatrixVectorizer.anti_vectorize(row, 260) for row in hr_train_data])

In [None]:
class TopologicalMeasures:
    def __init__(self,graph:Union[np.ndarray,torch.Tensor]):
        if isinstance(graph,np.ndarray):
            self.graph = nx.Graph(graph)
        elif isinstance(graph,torch.Tensor):
            graph_numpy = graph.cpu().detach().numpy()
            self.graph = nx.Graph(graph_numpy)

    def compute_measures(self):
        self.measures = {}
        measures['degree'] = dict(self.graph.degree())
        measures['clustering'] = nx.clustering(self.graph)
        measures['closeness'] = nx.closeness_centrality(self.graph)
        measures['betweenness'] = nx.betweenness_centrality(self.graph)
        measures['pagerank'] = nx.pagerank(self.graph)
        measures['eigenvector'] = nx.eigenvector_centrality(self.graph)
        return measures

    @staticmethod
    def compute_topological_MAE_loss(graph1:Union[np.ndarray,torch.Tensor],graph2:Union[np.ndarray,torch.Tensor]):
        measures1 = TopologicalMeasures(graph1).compute_measures()
        measures2 = TopologicalMeasures(graph2).compute_measures()
        loss = 0
        # compute MAE for each measure
        
        for measure in measures1:
            loss += np.abs(measures1[measure] - measures2[measure])
        loss = loss/len(measures1)

        return loss