In [2]:
import numpy as np
from scipy.linalg import sqrtm 
from scipy.special import softmax
import networkx as nx
from networkx.algorithms.community.modularity_max import greedy_modularity_communities
import matplotlib.pyplot as plt
from matplotlib import animation
%matplotlib inline
from IPython.display import HTML

from gcn_code.GCNLayer import GCNLayer
from gcn_code.GCNNetwork import GCNNetwork
import pickle as pk

In [3]:
city = 'Perth'

In [6]:
file = open('data_process/' + city + '/local', 'rb')
local_adj = pk.load(file)
file = open('data_process/' + city + '/global', 'rb')
global_adj = pk.load(file)

In [7]:
global_adj.shape

(25, 25)

In [8]:
def create_A_HAT(A):
    A_mod = A + np.eye(A.shape[0]) # add self-connections

    D_mod = np.zeros_like(A_mod)
    np.fill_diagonal(D_mod, np.asarray(A_mod.sum(axis=1)).flatten())


    D_mod_invroot = np.linalg.inv(sqrtm(D_mod))

    A_hat = D_mod_invroot @ A_mod @ D_mod_invroot
    
    return A_hat

In [9]:
global_adj = create_A_HAT(global_adj)

In [10]:
for item in local_adj:
    local_adj[item] = create_A_HAT(local_adj[item])

In [11]:
def glorot_init(nin, nout):
    sd = np.sqrt(6.0 / (nin + nout))
    return np.random.uniform(-sd, sd, size=(nin, nout))

In [12]:
gcn_model = GCNNetwork(
    n_inputs=global_adj.shape[0], 
    n_outputs=global_adj.shape[0], 
    n_layers=3,
    hidden_sizes=[global_adj.shape[0], global_adj.shape[0], global_adj.shape[0]], 
    activation=np.tanh,
    seed=100,
)
global_adj = gcn_model.embedding(global_adj, np.eye(global_adj.shape[0]))

In [13]:
for item in local_adj:
    gcn_model = GCNNetwork(
        n_inputs=local_adj[item].shape[0], 
        n_outputs=local_adj[item].shape[0], 
        n_layers=3,
        hidden_sizes=[4, 4, 4], 
        activation=np.tanh,
        seed=100,
    )
    local_adj[item] = gcn_model.embedding(local_adj[item], np.eye(local_adj[item].shape[0]))

In [14]:
for item in local_adj:
    local_adj[item] = global_adj

In [15]:
with open("data_process/" + city + "/adjacency","wb") as file:
    pk.dump(local_adj, file)