## TODO
 - Anáise de dados
 - Continuar testes (Dot Product Loss e Neg Sampling, Modelo de Classificação padrão)
 
## DONE
 - Fixar permutação (https://gitlab.com/cristophersfr/fisher-networks/-/blob/corrected/Corrections_on_Scientific_Reports.pdf) 
 - n = 5 
 - features como binário da permutação 
 - testar outra métrica (pytorch metric learning) (As listas de adjacência possuem tamanhos diferentes entre os grafos, então não conseguimos concatenar os tensores para fazer treinamento em batch, dai usar o pytorch metric learning por agora não é possível (podemos fazer um padding nas listas, mas isso aumentaria significativamente o gasto com espaço). Vou continuar a investigar outras alternativas para conseguir usar a biblioteca)
 
## Situação atual
 - O modelo continua sem conseguir otimizar. Atualmente estamos utilizando as features binárias e testando variações do GAT e do GCN, mas o modelo não consegue otimizar (embeddings vão todos para vetor de zeros). Ainda existe muita experimentação possível com as arquiteturas, então acredito que com mais testes essa situação mude.
 

## Imports

In [None]:
import pickle as pkl
import torch as th
from torch import Tensor
from torch_geometric.nn import GCNConv,SAGEConv,aggr,GATConv
from torch_geometric.nn.pool import global_max_pool
import ordpy
import numpy as np
from scipy.cluster import hierarchy
from scipy import sparse
from Data import *
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

import itertools

from pytorch_metric_learning import distances, losses, miners, reducers, testers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator

## Creating Data
Creates a data.pkl file with the data.

In [None]:
!python create_data.py

## Pre-Processing: Creating Ordinal Networks

### Classes and Functions

In [None]:
def create_ord_net(time_series:list,ord_net_dim:int,label_encoder=None):
    # Creates the ordinal network for a given time_series
    # Returns nodes as a list of strings (e.g. ['12345','12354',...])
    # Returns label-encoded edges and float weights
    nodes,edges,weights = ordpy.ordinal_network(time_series,ord_net_dim)
    
    if label_encoder != None:
        label_encoder.fit(nodes)
        edges = th.Tensor(np.array([label_encoder.transform(edges[:,0]),label_encoder.transform(edges[:,1])])).long()
        nodes = [node.replace("|","") for node in nodes]
        
    return nodes,edges,weights
    
def create_graph_list(series_list:list,ord_net_dim:int):
    # Creates the list of ordinal networks from a list of time series
    from sklearn.preprocessing import LabelEncoder
    
    le = LabelEncoder()
    graphs = []
    for series in tqdm(series_list):
        nodes,edges,weights = create_ord_net(series,ord_net_dim,le)
        graphs.append([th.Tensor(edges),th.Tensor(weights)])
    pkl.dump(graphs,open(save_path,"wb"))
    
    return graphs

def create_bin_features(ord_net_dim:int,num_bits:int):
    # Create a binary representation of the permutations as a feature matrix
    permutation_list = list(itertools.permutations([str(i) for i in range(1,ord_net_dim+1)]))
    permutation_list = [[np.binary_repr(int(i),width=num_bits) for i in curr_list] for curr_list in permutation_list]
    node_feats = np.stack([np.array(list(''.join(bit_list)),dtype=int) for bit_list in permutation_list],axis=0)
    return th.Tensor(node_feats)

In [None]:
class GraphDataset(th.utils.data.Dataset):
    # Pytorch Wrapper
    def __init__(self,graphs,labels):
        self.graphs = graphs
        self.labels = labels
    def __len__(self):
        return len(self.labels)
    def __getitem__(self,idx):
        return self.graphs[idx],self.labels[idx]

In [None]:
def EdgeList2AdjMatrix(edges,weights):
    adj_matrix = np.zeros([num_nodes,num_nodes],dtype=float)

    for edge_idx in range(edges.shape[1]):
        adj_matrix[edges[0,edge_idx].item(),edges[1,edge_idx].item()] = weights[edge_idx]
    return adj_matrix


def optimal_ordering(edges,weights,node_feats,return_adj:bool=False):
    adj_matrix = EdgeList2AdjMatrix(edges,weights)
    
    Z = hierarchy.ward(adj_matrix)
    opt_ordering = hierarchy.leaves_list(hierarchy.optimal_leaf_ordering(Z,adj_matrix))
    
    ordered_feats = node_feats[opt_ordering,:]
    if return_adj:
        identifier_arrays = []
        for idx in opt_ordering:
            curr_array = np.zeros(num_nodes)
            curr_array[idx] = 1
            identifier_arrays.append(curr_array)

        permutation_matrix = np.stack(arrays)
        ordered_adj = permutation_matrix @ adj_matrix @ permutation_matrix.T
    
        return ordered_adj,ordered_feats
    else:
        permutation_dict = {opt_ordering[i]:i for i in range(len(opt_ordering))}
        src_edges = th.Tensor([permutation_dict[edge.item()] for edge in edges[0,:]]).int()
        tgt_edges = th.Tensor([permutation_dict[edge.item()] for edge in edges[1,:]]).int()

        ordered_edges = th.stack([src_edges,tgt_edges])
        ordered_weights = weights[opt_ordering]
        
        return ordered_edges,ordered_weights,ordered_feats

### ...

In [None]:
# Notebook Parameters
LOAD_GRAPHS = True
USE_BINARY_FEATURES = True # If True, will use a binary representation of a given permutation as the node feature
ORDER_EDGES = False # If True, will order edges according to the optimal leaf ordering

# Pre-Processing Parameters
ord_net_dim = 5 # Controls the sliding window size in the creation of the ordinal network

In [None]:
# Computing graph parameters and creating node features

num_nodes = np.math.factorial(ord_net_dim)
num_bits = int(np.ceil(ord_net_dim / 2))
feat_dim = num_bits * ord_net_dim
save_path = "graphs"

if USE_BINARY_FEATURES:
    save_path += "_binary_feats"
    node_feats = create_bin_features(ord_net_dim,num_bits)
else:
    node_feats = th.ones([num_nodes,feat_dim])
        
save_path += ".pkl"


In [None]:
series_data,labels,true_pos = pkl.load(open("data.pkl", "rb"))

if LOAD_GRAPHS:
    graphs = pkl.load(open(save_path, "rb"))
else:
    graphs = create_graph_list(series_data,ord_net_dim) 

if ORDER_EDGES:
    for idx,(edges,weights) in tqdm(enumerate(graphs)):
        ordered_edges,ordered_weights,ordered_feats = optimal_ordering(edges,weights,node_feats)
        graphs[idx] = [ordered_edges,ordered_weights,ordered_feats]

## Data Analysis

In [None]:
plt.rcParams["figure.figsize"] = (10,5)

In [None]:
labels= np.array(labels)

idx = 1
label_list = [0,30]

for i in tqdm(label_list):
    idxs = np.where(labels == i)[0]
    tgt_weights = [graphs[idx][1] for idx in idxs]

    for weights in tgt_weights:
        plt.hist(weights.numpy(),bins=20,label=f"graph({i})_{idx}")
        idx += 1
    plt.legend()
    plt.show()

## Training GNNs

### Classes and functions

In [None]:
class GCN(th.nn.Module):
    def __init__(self,layers:list,out_dim:int=None,skip_connect:bool=False,**layer_kwargs):
        super().__init__()
        self.num_layers = len(layers)
        self.convs = th.nn.ModuleList()
        for layer_idx in range(self.num_layers-1):
            self.convs.append(GCNConv(layers[layer_idx],layers[layer_idx+1],**layer_kwargs))
        self.out_dim = out_dim if out_dim != None else layers[-1]
        self.linear = th.nn.Linear(layers[-1],self.out_dim)
        self.skip_connect = skip_connect
            

    def forward(self, x: Tensor, edge_index: Tensor,edge_weights: Tensor,agg_func,**agg_kwargs) -> Tensor:
        # x: Node feature matrix of shape [num_nodes, in_channels]
        # edge_index: Graph connectivity matrix of shape [2, num_edges]
        embeddings = []
        for conv_layer in self.convs:
            x = conv_layer(x, edge_index,edge_weights).relu()
            embeddings.append(x)
        
        if self.skip_connect:
            x = th.stack([agg_func(embedding,**agg_kwargs).flatten() for embedding in embeddings])

        return self.linear(agg_func(x,**agg_kwargs).flatten())

class GAT(th.nn.Module):
    def __init__(self,layers:list,out_dim:int=None,skip_connect:bool=False,**layer_kwargs):
        super().__init__()
        self.num_layers = len(layers)
        self.convs = th.nn.ModuleList()
        for layer_idx in range(self.num_layers-1):
            self.convs.append(GATConv(layers[layer_idx],layers[layer_idx+1],**layer_kwargs))
        self.out_dim = out_dim if out_dim != None else layers[-1]
        self.linear = th.nn.Linear(layers[-1],self.out_dim)
        self.skip_connect = skip_connect
            

    def forward(self, x: Tensor, edge_index: Tensor,edge_weights: Tensor,agg_func,**agg_kwargs) -> Tensor:
        # x: Node feature matrix of shape [num_nodes, in_channels]
        # edge_index: Graph connectivity matrix of shape [2, num_edges]
        embeddings = []
        for conv_layer in self.convs:
            x = conv_layer(x, edge_index,edge_attr = edge_weights).relu()
            embeddings.append(x)
        
        if self.skip_connect:
            x = th.stack([agg_func(embedding,**agg_kwargs).flatten() for embedding in embeddings])

        return self.linear(agg_func(x,**agg_kwargs).flatten())    
    
        
class tripletLoss(th.nn.Module):
    def __init__(self, margin):
        super(tripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, pos, neg):
        distance_pos = (anchor - pos).pow(2).sum()
        distance_neg = (anchor - neg).pow(2).sum()
        loss = th.nn.functional.relu(distance_pos - distance_neg + self.margin)
        return loss.mean(), self.triplet_correct(distance_pos, distance_neg)

    def triplet_correct(self, d_pos, d_neg):
        return (d_pos < d_neg).sum()

### ...

In [None]:
# Training parameters
USE_GPU = False
num_epochs = 100
batch_size = 1
gnn_layers_dim = [feat_dim,32,32]
skip_connect = False # If True will add skip-connections to GCN (make sure all layer dimensions are the same!)


In [None]:
device = th.device("cuda") if (USE_GPU and th.cuda.is_available()) else th.device("cpu")

gnn_aggr_func = aggr.SoftmaxAggregation(learn=True)

# model = GAT(gnn_layers_dim,heads=2,edge_dim=1).to(device)
model = GCN(gnn_layers_dim,project=True).to(device)
print(model)

criterion = tripletLoss(margin = 1).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr = 0.0001, momentum = 0.9, weight_decay = 0.2)
test_indexes = cross_validation_sample(50, 10)

node_feats = node_feats.to(device)

In [None]:
for fold, test_index in enumerate(test_indexes):
    # split training & testing
    print("Test indexes: ", test_index)
    train_x, train_y, test_x, test_y = split_colocation_train(graphs, labels, test_index, 'room')
    train_x = gen_colocation_triplet(train_x, train_y)
    test_x = gen_colocation_triplet(test_x, test_y)

    total_triplets = len(train_x)
    logging_step = total_triplets//10
    print("Total training triplets: %d\n" % (total_triplets))\
    
    # Training Loop
    for epoch in tqdm(range(num_epochs)):
        total_triplet_correct = 0
        np.random.shuffle(train_x)
        model.train()
        for step, batch_x in enumerate(train_x):
            anchor_edges,anchor_weights = batch_x[0]
            anchor_pred = model(node_feats,anchor_edges.to(device),anchor_weights.to(device),gnn_aggr_func,dim=0)
            
            pos_edges,pos_weights = batch_x[1]
            pos_pred = model(node_feats,pos_edges.to(device),pos_weights.to(device),gnn_aggr_func,dim=0)
            
            
            neg_edges,neg_weights = batch_x[2]
            neg_pred = model(node_feats,neg_edges.to(device),neg_weights.to(device),gnn_aggr_func,dim=0)
            
            loss, triplet_correct = criterion(anchor_pred, pos_pred, neg_pred)
            total_triplet_correct += triplet_correct.item()
            
            optimizer.zero_grad()           
            loss.backward()                 
            optimizer.step()
            
            if step % logging_step == 0 and step != 0:
                print("loss "+str(loss.item())+"\n")
                print("triplet_acc " + str(triplet_correct.item()/batch_size) + "\n")
            
        print(f"Epoch={epoch+1}")
        print("Triplet accuracy: %f"%(total_triplet_correct/total_triplets))    

    # Testing Loop
    model.eval()
    with th.no_grad():
        total_triplet_correct = 0
        for step, batch_x in tqdm(enumerate(test_x)):
            anchor_edges,anchor_weights = batch_x[0]
            anchor_pred = model(node_feats,anchor_edges.to(device),anchor_weights.to(device),gnn_aggr_func,dim=0)
            
            pos_edges,pos_weights = batch_x[1]
            pos_pred = model(node_feats,pos_edges.to(device),pos_weights.to(device),gnn_aggr_func,dim=0)
            
            
            neg_edges,neg_weights = batch_x[2]
            neg_pred = model(node_feats,neg_edges.to(device),neg_weights.to(device),gnn_aggr_func,dim=0)
            
            loss, triplet_correct = criterion(anchor_pred, pos_pred, neg_pred)
            total_triplet_correct += triplet_correct.item()
            
        print("Test Triplet accuracy: %f"%(total_triplet_correct/total_triplets))
        