# Dataset

In [None]:
import torch
import pandas as pd
from whale_dataset import WhaleDataset
from tslearn.clustering import TimeSeriesKMeans
import networkx as nx
from torch_geometric_temporal.nn import MTGNN
from config import CONFIG
from pathlib import Path
import numpy as np
import contextily as cx
import matplotlib.pyplot as plt
from pyproj import Proj

## Data & weights

In [None]:
dataset = WhaleDataset(CONFIG.model_config.seq_length - 1)

In [None]:
PATH = Path("../") / "artifacts" / "model.pt"
model_config = CONFIG.model_config
model_config.num_nodes = len(dataset.dataframe.columns)
model = MTGNN(**model_config.dict())
model.load_state_dict(torch.load(PATH))
model = model.eval()

## Time Series

In [None]:
# km = TimeSeriesKMeans(n_clusters=3, metric="dtw")
# labels = km.fit_predict(dataset.dataframe.values.T)
# cmap = ['red', 'blue', 'green']
# dataset.dataframe.plot(color=[cmap[i] for i in labels])

## Graph

In [None]:
A_tilde = model._graph_constructor(model._idx, FE=None)
# A_tilde = dataset.A_pathway
graph = nx.from_numpy_array(A_tilde.cpu().detach().numpy())

In [None]:
# Attributes
edge_weights = nx.get_edge_attributes(graph, 'weight')
labels = {x: c for x, c in  zip(graph.nodes, dataset.places)}
p = Proj('EPSG:4326')
cmap = plt.get_cmap('viridis')
pos = [(pos[1], pos[0]) for _, _, pos, _ in dataset.places.values()]
colors = [cmap(x) for x in edge_weights.values()]

In [None]:
fig, ax = plt.subplots(figsize=(20,6))
nx.draw(graph, pos, labels=labels, node_size=0.1, with_labels=True, width=np.array(list(edge_weights.values())), ax=ax, edge_color=colors, connectionstyle="arc3,rad=0.1")
cx.add_basemap(ax, crs="EPSG:4326")