In [None]:
import numpy as np
import pandas as pd
import json
import os
import datetime

from tqdm import tqdm

np.random.seed(314159) # set random seed

import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, ModelSummary
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger

from torch_geometric.data import Data
import torch_geometric.loader

import wandb

import models
from data_utils import read_data, create_data
import model_utils

In [None]:
base_path = '../data'

model_creator_dict = {'SGConv': models.create_SGConv_GNN,
                      'GraphSAGE': models.create_GraphSAGE_GNN,
                      'TAG': models.create_TAG_GNN,
                      'ClusterGCN': models.create_clusterGCN_GNN,
                      'MLP': models.create_MLP}

In [None]:
node_dataset, edge_list, labels = read_data(base_path=base_path, graph_used='know', feats_type='nodeonly', label_thres='0,02')

In [None]:
num_classes = 2
num_features = len(node_dataset.columns)

In [None]:
import os
notebook_name = 'train_gnn_model.ipynb'
os.environ['WANDB_NOTEBOOK_NAME'] = notebook_name

In [None]:
import gc
from sklearn import metrics

def run_trials(create_model, start_trial=0, end_trial=100, n_epochs=500, log=False, log_project=None):

    if log:
        # dt_string = str(datetime.datetime.today()).replace(' ', '_')
        if log_project is None:
            print('Enter the name of the log project: ')
            log_project = input()

    # model info
    model = create_model()
    model_summary = pl.utilities.model_summary.summarize(model, max_depth=4)
    model_summary_str = str(model_summary)
    num_trainable_params = model_summary.trainable_parameters

    print(model_summary_str)

    train_reports = []
    test_reports = []
    roc_data = []

    for trial in tqdm(range(start_trial, end_trial + 1)):

        print(f'running trial {str(trial)}')
        data = create_data(node_dataset, edge_list, labels, f'label_{trial}', test_size=0.2, val_size=0.1)


        model = create_model()

        if log:
            n_zfills = int(np.ceil(np.log10(100)))
            log_name = f'{log_project}_trial{str(trial).zfill(n_zfills)}'

            logger = WandbLogger(name=log_name, project=log_project, log_model="all", save_dir='wandb_projects')

            logger.log_metrics({'model_summary_str': model_summary_str,
                                'num_trainable_params': num_trainable_params})
            
            # log random train-val-test split
            logger.log_metrics({'train_mask': data.train_mask, 'val_mask': data.val_mask, 'test_mask': data.test_mask})

        else:
            logger = False

        AVAIL_GPUS = min(1, torch.cuda.device_count())

        data_loader = torch_geometric.loader.DataLoader([data], batch_size=1, num_workers=os.cpu_count())

        trainer = pl.Trainer(
                    callbacks=[ModelCheckpoint(save_weights_only=False, mode="max", monitor="val_acc")],
                    gpus=AVAIL_GPUS,
                    max_epochs=n_epochs,
                    logger=logger,
                    enable_model_summary=False
                    # progress_bar_refresh_rate=0,
                    )

        trainer.fit(model, data_loader, data_loader)

        model = models.LitGNN.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

        train_report, test_report = model_utils.evaluate_model(model, data, logger=logger)

        train_reports.append(train_report)
        test_reports.append(test_report)

        model.to(device='cuda')
        logits, _, _ = model.forward(data.to(device='cuda'))

        preds = logits[data.test_mask][:, 1].cpu().detach().numpy()
        y = data.y[data.test_mask].cpu().detach().numpy()

        fpr, tpr, thresholds = metrics.roc_curve(y, preds)
        auc_score = metrics.roc_auc_score(y, preds)

        roc_data.append({'fpr': fpr, 'tpr': tpr, 'threholds': thresholds, 'auc': auc_score})

        if log:
            logger.log_metrics({'auc_test': auc_score, 'fpr_test': fpr, 
                                'tpr_test': tpr, 'roc_thres': thresholds})

        if log:
            wandb.save('modeling_gnn.ipynb')
            wandb.finish(quiet=True)

        del model, data_loader, trainer, data
        gc.collect()

        print('memory allocated: ', torch.cuda.memory_allocated())
        print('memory reserved: ', torch.cuda.memory_reserved())
        torch.cuda.empty_cache()
        print('\nafter empty_cache:')
        print('memory allocated: ', torch.cuda.memory_allocated())
        print('memory reserved: ', torch.cuda.memory_reserved())



    return train_reports, test_reports, roc_data

In [None]:
## TRAIN AND EVALUATE MODEL

model_name = 'SGConv'
create_model = model_creator_dict[model_name]

log_project_name = f'{model_name}'

# run multiple trials
train_reports, test_reports, roc_data = run_trials(lambda: create_model(model_name, num_features, num_classes), start_trial=0, end_trial=0,
                                         n_epochs=250, log=False, log_project=log_project_name)

# save reports from trials to json
model_utils.save_reports(f'project_reports/{log_project_name}_reports', train_reports, test_reports)
np.save(f'project_reports/{model_name}_roc', roc_data)