# Adjacency Matrix visualization

In [None]:
import torch
import pandas as pd
from country_dataset import Dataset
import networkx as nx
from torch_geometric_temporal.nn import MTGNN
from config import CONFIG
from pathlib import Path
import numpy as np
import contextily as cx
import matplotlib.pyplot as plt
from pyproj import Proj

## Data & Weights

In [None]:
dataset = Dataset(CONFIG.model_config.seq_length - 1, countries=CONFIG.countries)

In [None]:
country_df = pd.read_csv('../country.csv', delimiter='\t')
pos_dict = country_df.set_index('name')[['latitude ', 'longitude ']].to_dict()

In [None]:
PATH = Path("../") / "artifacts" / "model.pt"
model_config = CONFIG.model_config
model_config.num_nodes = len(dataset.dataframe.columns)
model = MTGNN(**model_config.dict())
model.load_state_dict(torch.load(PATH))
model = model.eval()

## Graph

In [None]:
node_labels = dataset.dataframe.columns

In [None]:
A_tilde = model._graph_constructor(model._idx, FE=None)
graph = nx.from_numpy_array(A_tilde.cpu().detach().numpy()) #TODO: utiliser un graph directed.

In [None]:
# Attributes
edge_weights = nx.get_edge_attributes(graph, 'weight')
labels = {x: c for x, c in  zip(graph.nodes, node_labels)}
p = Proj('EPSG:4326')
cmap = plt.get_cmap('viridis')
colors = [cmap(x) for x in edge_weights.values()]
pos = {n: p(pos_dict['longitude '][c], pos_dict['latitude '][c]) for n, c in zip(graph.nodes, node_labels)}

In [None]:
fig, ax = plt.subplots(figsize=(20,6))
nx.draw(graph, pos, labels=labels,  node_size=0.1, with_labels=True, width=np.array(list(edge_weights.values())), ax=ax, edge_color=colors, connectionstyle="arc3,rad=0.1")
cx.add_basemap(ax, crs="EPSG:4326")