# Importations

In [30]:
%load_ext autoreload
%autoreload 2

import torch
from torch.functional import norm
import wandb
import os

import graph
from graph.submission import save_models, compute_save_submission
from graph.training import ultimate_unsupervised_training, ultimate_supervised_training, ultimate_unsupervised_training2, ultimate_supervised_training2
from graph.gcn import GCN, GAT, GIN, GIN_shared_weights, GIN_GRU
from graph.load_data import ultimate_dataloader, init_nodes_embedding 
from graph.classifier import Classifier_Dense, Classifier_RNN, style_Dense
from graph.config import Config
from graph.auc_loss import ROC_LOSS, ROC_STAR_LOSS

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Config

In [31]:
# Wandb
log = False
# Interactive display
display= False
# Save models and predictions
save = False
compute_predictions = False

In [43]:
config = {

    # Structure
    "gnn_type" : GIN,
    "classifier_type" : style_Dense,
    "depth_dense_block": 1,

    # Embedding dims and GCN
    "init_embedding_method" : 'eigen2014',
    "gnn_hidden_dims" : [24,32,32],

    # Classifier
    "classifier_hidden_dims" : [None, 32, 16, 2],
    "classifier_activation" : "Relu",

    "unsupervised_training_years" : [y for y in range(2011-5+1, 2011+1)],

    # Config Unsupervised Training loader
    "minimal_degs_unsupervised" : 5,

    # Generator unsupervised
    "iter_number_unsupervised" : 250,
    "max_neighbors_unsupervised" : 5,
    "k_init" : 1,
    "starting_size" : 32,

    # Unsupervised training
    "epochs_supervised": 1,
    "batch_size_unsupervised" : 1,

    # Unsupervised optimizer
    "optimizer_algo_unsupervised" : "Adam",
    "lr_unsupervised": 5e-4,

    # Config Supervised Training loader
    "delta_years": 1,
    "min_year": 2011-5,
    "batch_size_supervised" : 1,
    "pairs_batch_size": 2**15,
    "proportion_dataset": 2,
    "minimal_degs_supervised": 5,
    "test_rate": 1,
    "max_size_dataset":1_000_000//2,

    # Supervised training
    "add_diag": False,
    "normalize_adj_matrix": True,
    "loss_fn" : "CrossEntropy",
    "iter_number_supervised": 300,

    # Supervised optimizer
    "optimizer_algo_supervised_gnn" : "Adam",
    "optimizer_algo_supervised_classifier" : "Adam",
    "lr_supervised_gnn": 1e-3,
    "lr_supervised_classifier": 1e-3,
    
    "drop": 0.3,
    "batch_norm": False,


    "early_stop":False

}


config = Config(config)

depth = len(config.gnn_hidden_dims) - 1

In [33]:

if log:
    wandb.init(project="HyperparametersTesting", entity="argocs", config=config)
    wandb.define_metric("unsupervised step")
    wandb.define_metric("supervised step")
    wandb.define_metric("loss GNN", step_metric="unsupervised step")
    wandb.define_metric("train AUC", step_metric="supervised step")
    wandb.define_metric("test AUC", step_metric="supervised step")
    wandb.define_metric("train loss classifier", step_metric="supervised step")
    wandb.define_metric("test loss classifier", step_metric="supervised step")


# Initialisation

## Loaders

In [34]:
loader = ultimate_dataloader(   min_year=config.min_year,
                                proportion_dataset=config.proportion_dataset,
                                minimal_degs_unsupervised=config.minimal_degs_unsupervised,
                                minimal_degs_supervised=config.minimal_degs_supervised,
                                add_diag=config.add_diag,
                                normalize_adj_matrix=config.normalize_adj_matrix,
                                max_size=config.max_size_dataset
)


Graph for year 2006 has 106391 edges
Graph for year 2007 has 141614 edges
Graph for year 2008 has 192874 edges
Graph for year 2009 has 278820 edges
Graph for year 2010 has 427201 edges
Graph for year 2011 has 651997 edges
Graph for year 2012 has 1042310 edges
Graph for year 2013 has 1582114 edges
Graph for year 2014 has 2278611 edges
Graph for year 2015 has 3342061 edges
Graph for year 2016 has 5018339 edges
Graph for year 2017 has 7652945 edges
Ratio links:  0.33457075758220584
Ratio links:  0.33464769715749354
Ratio links:  0.3346030271752818
Ratio links:  0.33473731811560564
Ratio links:  0.33469842938362576
Ratio links:  0.3347151565440057
Ratio links:  0.33479405376269483
Ratio links:  0.33492510787508106
Ratio links:  0.33516523108332613


## Initial embeddings

In [35]:

graph_sparse, _, _, _ = graph.data_utils.load_data('CompetitionSet2017_3.pkl')
# -- extract every graph
graphs_years = [graph.data_utils.extract_graph(graph_sparse, year) for year in range(config.min_year, 2014 + 1)]

initial_embedding = init_nodes_embedding(config.init_embedding_method, config.gnn_hidden_dims[0], graphs_years, config.minimal_degs_unsupervised, normalize=True).to(graph.device)

del graph_sparse
del graphs_years
del _


Graph for year 2006 has 106391 edges
Graph for year 2007 has 141614 edges
Graph for year 2008 has 192874 edges
Graph for year 2009 has 278820 edges
Graph for year 2010 has 427201 edges
Graph for year 2011 has 651997 edges
Graph for year 2012 has 1042310 edges
Graph for year 2013 has 1582114 edges
Graph for year 2014 has 2278611 edges
Graph for year 2014 has 2278611 edges


## Networks

In [44]:
gnn = config.gnn_type(config.gnn_hidden_dims, config.depth_dense_block).to(graph.device)

# Optimizer GNN
unsupervised_optimizer = torch.optim.Adam(gnn.parameters(), lr=config.lr_unsupervised)


input_dim_classifier = config.gnn_hidden_dims[-1] * config.delta_years * 2
config.classifier_hidden_dims[0] = input_dim_classifier

if config.classifier_activation == "Relu":
    activation_type = torch.nn.ReLU
elif config.classifier_activation == "PRelu":
    activation_type = torch.nn.PReLU
elif config.classifier_activation == "Gelu":
    activation_type = torch.nn.GELU
else:
    print(f"The following activation is not implemented: {config.classifier_activation}")
    exit()

if config.classifier_type == Classifier_Dense:

    classifier = config.classifier_type(config.classifier_hidden_dims, activation_type).to(graph.device)
    
if config.classifier_type == Classifier_RNN:
    classifier = config.classifier_type(config.classifier_hidden_dims, activation_type, config.gnn_hidden_dims[-1]).to(graph.device)
    
if config.classifier_type == style_Dense:
    classifier = config.classifier_type(config.classifier_hidden_dims, activation_type, config.drop, config.batch_norm).to(graph.device)

# Loss function
if config.loss_fn == "CrossEntropy":
    loss_fn = torch.nn.CrossEntropyLoss()
elif config.loss_fn == "ROC_loss":
    loss_fn = ROC_LOSS(2048, 2048)
elif config.loss_fn == "ROC_star_loss":
    loss_fn = ROC_STAR_LOSS(2048, 2048, 0.4)

# Optimizer classification
optimizer_gnn = torch.optim.Adam(gnn.parameters(), lr=config.lr_supervised_gnn)
optimizer_classifier = torch.optim.Adam(classifier.parameters(), lr=config.lr_supervised_classifier)


# Training

In [39]:
# Loss function
if config.loss_fn == "CrossEntropy":
    loss_fn = torch.nn.CrossEntropyLoss()
elif config.loss_fn == "ROC_loss":
    loss_fn = ROC_LOSS(1024, 1024)
elif config.loss_fn == "ROC_star_loss":
    loss_fn = ROC_STAR_LOSS(2048, 2048, 0.4)

## Unsupervised

In [12]:

ultimate_unsupervised_training2( epochs=1,
                                loader=loader,
                                initial_embedding=initial_embedding,
                                model=gnn,
                                optimizer=optimizer_gnn,
                                batch_size=config.batch_size_unsupervised,
                                iter_number=config.iter_number_unsupervised,
                                max_neighbors=config.max_neighbors_unsupervised,
                                k_init=config.k_init,
                                k_max=depth,
                                starting_size=config.starting_size,
                                display=display,
                                log=log,
                                years=config.unsupervised_training_years)

NameError: name 'ultimate_unsupervised_training2' is not defined

## Supervised

In [None]:
ultimate_supervised_training2(   epochs=10,
                                loader=loader,
                                initial_embedding=initial_embedding,
                                gnn=gnn,
                                classifier=classifier,
                                optimizer_gnn=optimizer_gnn,
                                optimizer_classifier=optimizer_classifier,
                                loss_fn=loss_fn,
                                pairs_batch_size=config.pairs_batch_size,
                                batch_size=config.batch_size_supervised,
                                delta_years=config.delta_years,
                                test_rate=config.test_rate,
                                log=log,
                                early_stop=config.early_stop)


######## Epoch 0 ########

    [ 1  ]  Loss: 0.69410658   AUC: 0.51229151   from 2011 to 2014
Loss: 0.68349308   AUC: 0.81308704   from 2014 to 2017
    [ 2  ]  Loss: 0.68877310   AUC: 0.53289972   from 2011 to 2014
Loss: 0.67321855   AUC: 0.85914926   from 2014 to 2017
    [ 3  ]  Loss: 0.68438685   AUC: 0.54892944   from 2011 to 2014
Loss: 0.66377193   AUC: 0.86531341   from 2014 to 2017
    [ 4  ]  Loss: 0.68021560   AUC: 0.56665866   from 2011 to 2014
Loss: 0.65433103   AUC: 0.86649889   from 2014 to 2017
    [ 5  ]  Loss: 0.67619848   AUC: 0.58017877   from 2011 to 2014
Loss: 0.64486492   AUC: 0.86625695   from 2014 to 2017
    [ 6  ]  Loss: 0.67241079   AUC: 0.60284591   from 2011 to 2014
Loss: 0.63670367   AUC: 0.86564395   from 2014 to 2017
    [ 7  ]  Loss: 0.66966504   AUC: 0.60935301   from 2011 to 2014
Loss: 0.63013721   AUC: 0.86544890   from 2014 to 2017
    [ 8  ]  Loss: 0.66746700   AUC: 0.61629646   from 2011 to 2014
Loss: 0.62410414   AUC: 0.86466252   from 2014 to 2

In [14]:
######## Epoch 9 ########

    [406 ]  Loss: 0.38940576   AUC: 0.87545571   from 2011 to 2014
Loss: 0.28202456   AUC: 0.86608686   from 2014 to 2017
    [407 ]  Loss: 0.38752237   AUC: 0.87595749   from 2011 to 2014
Loss: 0.23246711   AUC: 0.86577067   from 2014 to 2017
    [408 ]  Loss: 0.38071540   AUC: 0.87857586   from 2011 to 2014
Loss: 0.22305581   AUC: 0.86580216   from 2014 to 2017
    [409 ]  Loss: 0.38408196   AUC: 0.87695855   from 2011 to 2014
Loss: 0.26694751   AUC: 0.86610463   from 2014 to 2017
    [410 ]  Loss: 0.38202500   AUC: 0.87561480   from 2011 to 2014
Loss: 0.28722411   AUC: 0.86627575   from 2014 to 2017
    [411 ]  Loss: 0.38520452   AUC: 0.87501928   from 2011 to 2014
Loss: 0.25047794   AUC: 0.86600841   from 2014 to 2017
    [412 ]  Loss: 0.37661877   AUC: 0.88143097   from 2011 to 2014
Loss: 0.22454691   AUC: 0.86572004   from 2014 to 2017
    [413 ]  Loss: 0.37827235   AUC: 0.88111980   from 2011 to 2014
Loss: 0.24380083   AUC: 0.86577940   from 2014 to 2017
    [414 ]  Loss: 0.38138583   AUC: 0.87787096   from 2011 to 2014
Loss: 0.27157584   AUC: 0.86598785   from 2014 to 2017
    [415 ]  Loss: 0.37969592   AUC: 0.87884873   from 2011 to 2014
Loss: 0.25880554   AUC: 0.86581884   from 2014 to 2017
    [416 ]  Loss: 0.37899423   AUC: 0.87980500   from 2011 to 2014
Loss: 0.23837371   AUC: 0.86575702   from 2014 to 2017
    [417 ]  Loss: 0.38393492   AUC: 0.87721140   from 2011 to 2014
Loss: 0.23906098   AUC: 0.86578760   from 2014 to 2017
    [418 ]  Loss: 0.37971771   AUC: 0.87968952   from 2011 to 2014
Loss: 0.26338875   AUC: 0.86584039   from 2014 to 2017
    [419 ]  Loss: 0.37687910   AUC: 0.88092934   from 2011 to 2014
Loss: 0.26886997   AUC: 0.86584609   from 2014 to 2017
    [420 ]  Loss: 0.37981567   AUC: 0.87996631   from 2011 to 2014
Loss: 0.24891324   AUC: 0.86584995   from 2014 to 2017
    [421 ]  Loss: 0.37758258   AUC: 0.87985499   from 2011 to 2014
Loss: 0.24172667   AUC: 0.86576143   from 2014 to 2017
    [422 ]  Loss: 0.37794462   AUC: 0.87984030   from 2011 to 2014
Loss: 0.25602397   AUC: 0.86576759   from 2014 to 2017
    [423 ]  Loss: 0.37453911   AUC: 0.88156152   from 2011 to 2014
Loss: 0.25556752   AUC: 0.86565790   from 2014 to 2017
    [424 ]  Loss: 0.37681270   AUC: 0.88018798   from 2011 to 2014
Loss: 0.24654341   AUC: 0.86561310   from 2014 to 2017
    [425 ]  Loss: 0.38100138   AUC: 0.87853309   from 2011 to 2014
Loss: 0.25414518   AUC: 0.86561591   from 2014 to 2017
    [426 ]  Loss: 0.38046166   AUC: 0.87708491   from 2011 to 2014
Loss: 0.25612000   AUC: 0.86560556   from 2014 to 2017
    [427 ]  Loss: 0.38164866   AUC: 0.87641780   from 2011 to 2014
Loss: 0.25474119   AUC: 0.86560742   from 2014 to 2017
    [428 ]  Loss: 0.37588879   AUC: 0.88129129   from 2011 to 2014
Loss: 0.25717217   AUC: 0.86555495   from 2014 to 2017
    [429 ]  Loss: 0.37848443   AUC: 0.87969525   from 2011 to 2014
Loss: 0.25201192   AUC: 0.86550903   from 2014 to 2017
    [430 ]  Loss: 0.37927601   AUC: 0.87830079   from 2011 to 2014
Loss: 0.25405392   AUC: 0.86549656   from 2014 to 2017
    [431 ]  Loss: 0.37859219   AUC: 0.87976378   from 2011 to 2014
Loss: 0.25582281   AUC: 0.86554405   from 2014 to 2017
    [432 ]  Loss: 0.37450203   AUC: 0.88224845   from 2011 to 2014
Loss: 0.25676119   AUC: 0.86555555   from 2014 to 2017
    [433 ]  Loss: 0.37878054   AUC: 0.87922063   from 2011 to 2014
Loss: 0.25017780   AUC: 0.86548981   from 2014 to 2017
    [434 ]  Loss: 0.37575665   AUC: 0.88288118   from 2011 to 2014
Loss: 0.26532903   AUC: 0.86560644   from 2014 to 2017
    [435 ]  Loss: 0.37301591   AUC: 0.88343982   from 2011 to 2014
Loss: 0.24702546   AUC: 0.86549374   from 2014 to 2017
    [436 ]  Loss: 0.37733170   AUC: 0.88078200   from 2011 to 2014
Loss: 0.24853857   AUC: 0.86546214   from 2014 to 2017
    [437 ]  Loss: 0.37963885   AUC: 0.88015505   from 2011 to 2014
Loss: 0.27275607   AUC: 0.86547809   from 2014 to 2017
    [438 ]  Loss: 0.38045532   AUC: 0.87830018   from 2011 to 2014
Loss: 0.26119304   AUC: 0.86548477   from 2014 to 2017
    [439 ]  Loss: 0.37723398   AUC: 0.87921944   from 2011 to 2014
Loss: 0.23649520   AUC: 0.86546606   from 2014 to 2017
    [440 ]  Loss: 0.37882456   AUC: 0.88185787   from 2011 to 2014
Loss: 0.28253663   AUC: 0.86556429   from 2014 to 2017
    [441 ]  Loss: 0.37908894   AUC: 0.87951339   from 2011 to 2014
Loss: 0.25570810   AUC: 0.86559075   from 2014 to 2017
    [442 ]  Loss: 0.37679118   AUC: 0.88092705   from 2011 to 2014
Loss: 0.24622680   AUC: 0.86553669   from 2014 to 2017
    [443 ]  Loss: 0.37788200   AUC: 0.87887339   from 2011 to 2014
Loss: 0.25996459   AUC: 0.86550344   from 2014 to 2017
    [444 ]  Loss: 0.37553257   AUC: 0.88115480   from 2011 to 2014
Loss: 0.26508048   AUC: 0.86552361   from 2014 to 2017
    [445 ]  Loss: 0.37949449   AUC: 0.87935699   from 2011 to 2014
Loss: 0.25547317   AUC: 0.86547976   from 2014 to 2017
    [446 ]  Loss: 0.37540254   AUC: 0.88181100   from 2011 to 2014
Loss: 0.25042370   AUC: 0.86549248   from 2014 to 2017
    [447 ]  Loss: 0.37682334   AUC: 0.88031912   from 2011 to 2014
Loss: 0.26584384   AUC: 0.86547431   from 2014 to 2017
    [448 ]  Loss: 0.37305185   AUC: 0.88203347   from 2011 to 2014
Loss: 0.26846296   AUC: 0.86562339   from 2014 to 2017
    [449 ]  Loss: 0.37082076   AUC: 0.88348329   from 2011 to 2014
Loss: 0.24273875   AUC: 0.86560457   from 2014 to 2017
    [450 ]  Loss: 0.37719601   AUC: 0.88256487   from 2011 to 2014
Loss: 0.26250067   AUC: 0.86572425   from 2014 to 2017

SyntaxError: invalid syntax (<ipython-input-14-6d87f4e48e3a>, line 1)

# Save models

In [28]:
import numpy as np
rng_state = np.random.get_state()
loader.test_dataset[0] = np.random.permutation(loader.test_dataset[0])
np.random.set_state(rng_state)
loader.test_dataset[1] = np.random.permutation(loader.test_dataset[1])

In [None]:

c = len(os.listdir("Models"))

if save:
    save_models(initial_embedding, gnn, classifier, name=c)

if compute_predictions:
    compute_save_submission(initial_embedding, gnn, classifier, config.delta_years, 2, filename=f"Models/{c}/preds.pt")
