In [None]:
import networkx as nx
from config import CONFIG, Config
from data.metro_dataset import MetroDataset
from config import Line
from pytorch_lightning import Trainer
from model.metro_model import MetroModel
from networkx.drawing.layout import *
import numpy as np
import torch
torch.manual_seed(0)

# Dataset

In [None]:
dataset = MetroDataset(CONFIG.lines, init_nb=20)
num_nodes = dataset.cg.num_nodes()
dataset.cg.altair_graph(dataset.cg.graph, 10)

In [None]:
dataset.dataframes[0].plot()

In [None]:
from collections import defaultdict
df = dataset.dataframes[0]
node2station = nx.get_node_attributes(dataset.cg.graph, "station")
station2series = defaultdict(list)
for key, value in node2station.items():
    station2series[str(value) + '_station'].append(key)

import pandas as pd
import matplotlib.pyplot as plt
sumdf = pd.DataFrame()
for i, (key, value) in enumerate(station2series.items()):
    sumdf["station " + str(i+1)] = df[value].sum(axis=1)

plt.style.use('fivethirtyeight')

sumdf.plot()
plt.title('Synthetic Metro Traffic')

# Model

In [None]:
A_init = torch.tensor(dataset.cg.adjacency_matrix(), requires_grad=True).float()

In [None]:
model = MetroModel(embedding_size=8, num_nodes=num_nodes, neighbor_nb=2, input_size=1, gsl_mode="embedding", lr=1e-3) # , 

# Training

In [None]:
train_loader = torch.utils.data.DataLoader(dataset, batch_size=8)
trainer = Trainer(max_epochs=100)
trainer.fit(model, train_loader)

# Evaluation

In [None]:
A = model.graph_learning()

dim=1
values, indices = A.topk(k=2, dim=dim)
mask = torch.zeros_like(A)
mask.scatter_(dim, indices, values.fill_(1))
A = A*mask

A = A.detach().numpy()
A = np.array(A)

In [None]:
B = A * (A > 0.1)

In [None]:
learned_graph = nx.from_numpy_array(A, nx.DiGraph)
dataset.cg.altair_graph(learned_graph, 10)

In [None]:
# TODO: Essayer avec les embeddings
# TODO: Essayer avec un graph bruité en entrée
# TODO: Essayer avec des data bruitées
# TODO: Essayer avec différents réseaux (de métro)
# TODO: Essayer avec le graphe contracté 