In [1]:
import numpy as np
from scipy.linalg import sqrtm 
from IPython.display import HTML
from gcn_code.GCNLayer import GCNLayer
from gcn_code.GCNNetwork import GCNNetwork
import pickle as pk
#%matplotlib inline

In [2]:
city = 'London'

In [3]:
file = open('data_process/' + city + '/local', 'rb')
local_adj = pk.load(file)
file = open('data_process/' + city + '/global', 'rb')
global_adj = pk.load(file)

In [4]:
global_adj.shape

(30, 30)

In [5]:
def create_A_HAT(A):
    A_mod = A + np.eye(A.shape[0]) # add self-connections

    D_mod = np.zeros_like(A_mod)
    np.fill_diagonal(D_mod, np.asarray(A_mod.sum(axis=1)).flatten())


    D_mod_invroot = np.linalg.inv(sqrtm(D_mod))

    A_hat = D_mod_invroot @ A_mod @ D_mod_invroot
    
    return A_hat

In [6]:
global_adj = create_A_HAT(global_adj)

In [7]:
for item in local_adj:
    local_adj[item] = create_A_HAT(local_adj[item])

In [8]:
gcn_model = GCNNetwork(
    n_inputs=global_adj.shape[0], 
    n_outputs=global_adj.shape[0], 
    n_layers=3,
    hidden_sizes=[global_adj.shape[0], global_adj.shape[0], global_adj.shape[0]], 
    activation=np.tanh,
    seed=100,
)
global_adj = gcn_model.embedding(global_adj, np.eye(global_adj.shape[0]))

In [9]:
for item in local_adj:
    gcn_model = GCNNetwork(
        n_inputs=local_adj[item].shape[0], 
        n_outputs=local_adj[item].shape[0], 
        n_layers=3,
        hidden_sizes=[global_adj.shape[0], global_adj.shape[0], global_adj.shape[0]], 
        activation=np.tanh,
        seed=100,
    )
    local_adj[item] = gcn_model.embedding(local_adj[item], np.eye(local_adj[item].shape[0]))

In [10]:
np.around(global_adj[0], decimals=10, out=None)

array([ 0.03535053, -0.11362884, -0.07773147, -0.0011891 , -0.03285874,
        0.08827935,  0.00450359,  0.07151944, -0.02380898,  0.012924  ,
       -0.00748185, -0.00772258, -0.05376971, -0.04685076,  0.06786373,
       -0.05353854, -0.00685832, -0.00964043,  0.04323872,  0.01046531,
       -0.02538465, -0.06116321, -0.03242151,  0.0486479 ,  0.0457434 ,
       -0.02604534, -0.07488217, -0.03873201,  0.00222072,  0.05750129])

In [11]:
for item in local_adj:
    local_adj[item] = local_adj[item] + global_adj

In [12]:
with open("data_process/" + city + "/adjacency","wb") as file:
    pk.dump(local_adj, file)