In [None]:
import os

import torch
import matplotlib.pyplot as plt
import networkx as nx
import collections

from copy import deepcopy

# local
import optimization_utils as ou

import sys
if '..' not in sys.path:
    sys.path.insert(0, '..')


from datasets.import_dataset import import_dataset
from trainer import Trainer
from utils.plotting import *
from utils import utils
from utils import utils_pyg as up
import datasets.simulations as sim
import utils.link_prediction as lp




device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'device = {device}')


%load_ext autoreload
%autoreload 2

Training goes as follows: 
1. omit both validation and test sets.
2. collect scores from different test sets where the validation set is sampled randomly at every iteration.
3. take the maximal validation parameters and train the model on the combined training and validation set.
4. report the test score 

In [None]:
#draft
from ogb.linkproppred import Evaluator

evaluator = Evaluator(name='ogbl-ddi')

In [None]:
#DDI crossval
#DECIDE ON A MODEL AND DATASET
model_name = 'ieclam'
ds_name = 'ogbl-ddi'

# SET RANGE PARAMETERS
use_global_config_base = True
densify = False
attr_opt = False
test_p = 0.1
val_p = 0.0
n_reps = 10

'''to do cross validation, give a list of parameters for every one of the parameters as given in the following'''
# range_triplets = [
#     ['clamiter_init','s_reg', s_regs],
#     ['clamiter_init','l1_reg', l1_regs],
#     ['clamiter_init', 'dim_feat', dim_feats],
#     ['feat_opt','n_iter', n_iters_feats],
#     ['feat_opt','lr', lr_feats],
# ]

# if model_name in ['pclam', 'pieclam']:
#     range_triplets += [
#         ['back_forth', 'first_func_in_fit', first_funcs_in_fit],
#         ['prior_opt','n_iter', n_iters_prior],
#         ['prior_opt','lr', lr_prior],
#         ['prior_opt','noise_amp', noise_amps],
#         ['back_forth','n_back_forth', n_back_forth]
#     ]

'''another method is to give a list of perturbations to the base config'''
#! when i do the deltas the parameters arent saved
deltas ={
# 'clamiter_init': {'s_reg': 0.001,
#                   'l1_reg': 0.001,
#                   'dim_feat': 2},
# 'feat_opt': {'n_iter': 50,
#             'lr': 0.0000001},
# 'prior_opt': {'n_iter': 50,
#               'lr': 0.0000001,
#               'noise_amp': 0.005},
'back_forth': {'scheduler_gamma' : 0.1}
}

range_triplets = ou.perturb_config('anomaly_unsupervised', 
                                   model_name, 
                                   deltas, 
                                   use_global_config=use_global_config_base)



#RUN CROSS VALIDATION ON THE DATASET WITH THE MODEL
ou.cross_val_link(
    ds_name=ds_name,
    model_name=model_name,
    range_triplets=range_triplets,
    use_global_config_base=use_global_config_base,
    densify=densify,
    attr_opt=attr_opt,
    test_p=test_p,
    val_p=val_p,
    n_reps=n_reps,
    device=device
    #! add calculation type and finish the hAk function
    )
#test that texas still works


In [None]:
#TEXAS crossval
#DECIDE ON A MODEL AND DATASET


# First find parameters on the validation set and use these parameters to train on the training U validation set

model_name = 'pieclam'
ds_name = 'texas'


# SET RANGE PARAMETERS
use_global_config_base = True
densify = False
attr_opt = False
test_p = 0.1
val_p = 0.05
n_reps = 10

'''to do cross validation, give a list of parameters for every one of the parameters as given in the following. There is also an option to use the global parameter configuration as shown below. each field is a list of values for that variable'''
# range_triplets = [
#     ['clamiter_init','s_reg', s_regs],
#     ['clamiter_init','l1_reg', l1_regs],
#     ['clamiter_init', 'dim_feat', dim_feats],
#     ['feat_opt','n_iter', n_iters_feats],
#     ['feat_opt','lr', lr_feats],
# ]

# if model_name in ['pclam', 'pieclam']:
#     range_triplets += [
#         ['back_forth', 'first_func_in_fit', first_funcs_in_fit],
#         ['prior_opt','n_iter', n_iters_prior],
#         ['prior_opt','lr', lr_prior],
#         ['prior_opt','noise_amp', noise_amps],
#         ['back_forth','n_back_forth', n_back_forth]
#     ]

'''another method is to give a list of perturbations to the base config from which you get the range triplets for the cross val'''

deltas ={
'clamiter_init': {'s_reg': 0.001,
                  'l1_reg': 0.001,
                  'dim_feat': 2},
'feat_opt': {'n_iter': 50,
            'lr': 0.0000001},
'prior_opt': {'n_iter': 50,
              'lr': 0.0000001,
              'noise_amp': 0.005},
'back_forth': {'scheduler_gamma' : 0.1}
}

range_triplets = ou.perturb_config('anomaly_unsupervised', 
                                   model_name, 
                                   deltas, 
                                   use_global_config=use_global_config_base)




#RUN CROSS VALIDATION ON THE DATASET WITH THE MODEL
#todo: if it's a validation task save to validation folder if it's a test task save to test folder
#todo: add the metric folder split. /results/task/data/model/metric/test or val
#todo: how does the accuracy saving work? add verbose

#todo: make a different accuracy collecting scheme that tells
ou.cross_val_link(
    ds_name=ds_name,
    model_name=model_name,
    range_triplets=range_triplets,
    use_global_config_base=use_global_config_base,
    densify=densify,
    attr_opt=attr_opt,
    test_p=test_p,
    val_p=val_p,
    n_reps=n_reps,
    device=device,
    verbose=True,
    verbose_in_funcs=False
    )


In [None]:
# Then perform 10 runs single run on the test set with parameters from the ds config
#todo: make cross val link with global config 

#todo: change the auc dictionary to have the h@k from ogb as an accuracy
#what are the small tasks here? also need to change to test scores and val scores folders. 

#how to handle two metrics on the same task? should i have different folders for each?

#todo: folder structure: results/task/dataset/model/metric/validation and test 



# Single Run

In [None]:


try:
    ds_name = 'ogbl-ddi'
    model_name = 'ieclam'
    val_dyads_to_omit = None
    test_dyads_to_omit = None

    ds = import_dataset(ds_name)
    
    if hasattr(ds, 'val_dyads_to_omit'):
        val_dyads_to_omit = ds.val_dyads_to_omit
    if hasattr(ds, 'test_dyads_to_omit'):
        test_dyads_to_omit = ds.test_dyads_to_omit

    # OMIT TEST
    ds_omitted = ds.clone()
    if test_dyads_to_omit is None: 
        ds_omitted.omitted_dyads_test, ds_omitted.edge_index, ds_omitted.edge_attr = lp.get_dyads_to_omit(
                                            ds.edge_index, 
                                            ds.edge_attr, 
                                            test_p)
        
        
    else:
        assert type(test_dyads_to_omit) == tuple
        assert utils.is_undirected(test_dyads_to_omit[0]) and utils.is_undirected(test_dyads_to_omit[1])
        
        ds_omitted.omitted_dyads_test, ds_omitted.edge_index, ds_omitted.edge_attr = lp.omit_dyads(ds_omitted.edge_index,
                                    ds_omitted.edge_attr,
                                    test_dyads_to_omit)
        

    if val_dyads_to_omit is None:
        ds_test_val_omitted.omitted_dyads_val, ds_test_val_omitted.edge_index, ds_test_val_omitted.edge_attr = lp.get_dyads_to_omit(
                                        ds_test_omitted.edge_index, 
                                        ds_test_omitted.edge_attr, 
                                        ((val_p)/(1-test_p)))# the amount to extract from the remaining edges to get the initial extraction we wanted for val (size changes after removal).

    else:
        #todo: if this condition holds also dont do the sampling at every iteration
        assert type(val_dyads_to_omit) == tuple
        assert utils.is_undirected(val_dyads_to_omit[0]) and utils.is_undirected(val_dyads_to_omit[1])

        ds_test_omitted.omitted_dyads_val = val_dyads_to_omit
        ds_test_omitted.omitted_dyads_val, ds_test_omitted.edge_index, ds_test_omitted.edge_attr = lp.omit_dyads(
                        ds_test_omitted.edge_index, 
                        ds_test_omitted.edge_attr,
                        val_dyads_to_omit)

    for values in itertools.product(*[triplet[2] for triplet in range_triplets]):
        for _ in range(n_reps): 
    
            ds_test_val_omitted = ds_test_omitted.clone()
            
            # OMIT VALIDATION DYADS
            '''edge attr signifies if the edge is omitted or not. if the edge_attr is 0 then the edge is an omitted dyad.'''

            
            # ============ OMIT VALIDATION =============

            if densify:
                ds_test_val_omitted.edge_index, ds_test_val_omitted.edge_attr = up.two_hop_link(ds_test_val_omitted)
                                        
            outers = []
            inners = []
            for i in range(len(range_triplets)):
                outers.append(range_triplets[i][0])
                inners.append(range_triplets[i][1])
                
            config_triplets = [
                [outers[i], inners[i], values[i]] for i in range(len(range_triplets))
                        ]
            # if model_name in {'ieclam', 'pieclam'}:
            #     if 's_reg' in inners:
            #         ind_s = inners.index('s_reg')
            #         config_triplets.append([outers[ind_s], inners[ind_s], values[ind_s]])





            trainer = Trainer(
                        dataset=ds_test_val_omitted,
                        model_name=model_name,
                        task='link_prediction',
                        config_triplets_to_change=config_triplets,
                        use_global_config_base=use_global_config_base,
                        attr_opt=False,
                        device=device,
            )

            losses, acc_test, acc_val = trainer.train(
                        init_type='small_gaus',
                        init_feats=True,
                        acc_every=20,
                        plot_every=plot_every,
                        verbose=False,
                        verbose_in_funcs=False
                    )
            
            
            if acc_test['auc']:
                last_acc_test = acc_test['auc'][-1]
            else:
                last_acc_test = None
            
            if acc_val['auc']:
                last_acc_val = acc_val['auc'][-1]                    
            else:
                last_acc_val = None
            run_saver.update_file((last_acc_test, last_acc_val), config_triplets)
        
            del ds_test_val_omitted
            ds_test_val_omitted = None
            torch.cuda.empty_cache()
except Exception as e:
    raise e
finally:
    if ds is not None:
        del ds
    if ds_test_omitted is not None:
        del ds_test_omitted
    if ds_test_val_omitted is not None:
        del ds_test_val_omitted
    torch.cuda.empty_cache()
    printd('\n\nFinished CrossVal!\n\n')    

