In [7]:
import gzip
import pickle
import time

import optuna
from optuna.samplers import TPESampler

import torch
from torch_geometric.utils import scatter
from torch_geometric.data import DataLoader
import torch.optim as optim

import import_ipynb
from dataset import get_data
from constants import *
from utils import *

In [8]:
# train GNN

def train(train_dataset_name, test_dataset_name, hyper_params={}, pre_ckpt_name=None, saving_name=None, gpu_num='cpu', print_epoch=0):
    if saving_name and '.pt' not in saving_name : saving_name+='.pt'
    record = []

    # initialize gpu
    torch.cuda.empty_cache()
    device = set_gpu(gpu_num)
    
    # set hyper_params
    max_epoch = hyper_params.get('max_epoch', 2000)
    lr = hyper_params.get('lr', 0.002)
    lr_gamma = hyper_params.get('lr_gamma', 0.999)
    gnn_latent_dim = hyper_params.get('gnn_latent_dim', [128,128,128,128,128,128])
    
    # load data
    data = get_data(train_dataset_name)
    train_num = int(len(data)*0.8)
    train_data = data[:train_num]
    val_data = data[train_num:]
    test_data = get_data(test_dataset_name)
    train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    val_dataloader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False)
    test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)
    
    # load model
    model = load_model(pre_ckpt_name, device, gnn_latent_dim=gnn_latent_dim)

    # training setup
    loss_fn = torch.nn.MSELoss(reduction='sum')
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lambda epoch: lr_gamma)

    best_val_loss_graph = float('inf')
    best_model_epoch = 0
    best_model_state_dict = None

    stime = time.time()
    if print_epoch: print('epoch \t train_loss_n \t train_loss_g \t val_loss_n \t val_loss_g')
    for epoch in range(1,max_epoch+1):
        # train
        model.train()
        train_loss_node = 0
        train_loss_graph = 0
        train_prediction_errp = 0
        for batch in train_dataloader:
            batch = batch.to(device)
            
            preds = model(batch, node_regression=True)
            
            loss = loss_fn(preds,batch.y)
            train_loss_node += loss.item()
            
            preds = scatter(preds, batch.batch, dim=0, reduce='sum')
            labels = scatter(batch.y, batch.batch, dim=0, reduce='sum')
            fracs = preds/labels
            ones = torch.ones(fracs.size(), device=device)
            train_loss_graph += loss_fn(preds,labels).item()
            train_prediction_errp += loss_fn(ones,fracs).item()
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        train_loss_node /= len(train_dataloader.dataset)
        train_loss_graph /= len(train_dataloader.dataset)
        train_prediction_errp /= len(train_dataloader.dataset)
        train_loss_node = train_loss_node**0.5
        train_loss_graph = train_loss_graph**0.5
        train_prediction_errp = train_prediction_errp**0.5
    
        # validation
        model.eval()
        with torch.no_grad():
            val_loss_node = 0
            val_loss_graph = 0
            val_prediction_errp = 0
            for batch in val_dataloader:
                batch = batch.to(device)
                
                preds = model(batch, node_regression=True)
                
                loss = loss_fn(preds,batch.y)
                val_loss_node += loss.item()

                preds = scatter(preds, batch.batch, dim=0, reduce='sum')
                labels = scatter(batch.y, batch.batch, dim=0, reduce='sum')
                fracs = preds/labels
                ones = torch.ones(fracs.size(),device=device)
                val_loss_graph += loss_fn(preds,labels).item()
                val_prediction_errp += loss_fn(ones,fracs).item()
            val_loss_node /= len(val_dataloader.dataset)
            val_loss_graph /= len(val_dataloader.dataset)
            val_prediction_errp /= len(val_dataloader.dataset)
            val_loss_node = val_loss_node**0.5
            val_loss_graph = val_loss_graph**0.5
            val_prediction_errp = val_prediction_errp**0.5
    
        # record best weight
        if best_val_loss_graph > val_loss_graph:
            best_val_loss_graph = val_loss_graph
            best_model_epoch = epoch
            best_model_state_dict = model.state_dict()
        
        record.append((train_loss_node, train_loss_graph, train_prediction_errp, val_loss_node, val_loss_graph, val_prediction_errp))
        if print_epoch and epoch%print_epoch==0 : print(f'{epoch} \t {train_loss_node:.6f} \t {train_loss_graph:.6f} \t {val_loss_node:.6f} \t {val_loss_graph:.6f}')

        scheduler.step()
    training_time = time.time()-stime
    if print_epoch: print('training end. time:', training_time)
    
    model.load_state_dict(best_model_state_dict)  # load best weight
    if saving_name : torch.save(best_model_state_dict, MODEL_DIR+saving_name)  # save best weight
            
    # test
    model.eval()
    with torch.no_grad():
        test_loss_node = 0
        test_loss_graph = 0
        for batch in test_dataloader:
            batch = batch.to(device)
            
            preds = model(batch, node_regression=True)
            
            loss = loss_fn(preds,batch.y)
            test_loss_node += loss.item()
            
            preds = scatter(preds, batch.batch, dim=0, reduce='sum')
            labels = scatter(batch.y, batch.batch, dim=0, reduce='sum')
            test_loss_graph += loss_fn(preds,labels).item()
        test_loss_node /= len(test_dataloader.dataset)
        test_loss_graph /= len(test_dataloader.dataset)
        test_loss_node = test_loss_node**0.5
        test_loss_graph = test_loss_graph**0.5
    if print_epoch: 
        print('test loss_n:', test_loss_node)
        print('test loss_g:', test_loss_graph)
        print()

    if saving_name:
        with open(RESULT_DIR+'model_train_log/'+saving_name[:-3]+'.txt','w') as f:
            f.write(f'time: {training_time}, test_loss_graph: {test_loss_graph}, test_loss_node: {test_loss_node}, best_model_epoch: {best_model_epoch}\n')
            f.write('epoch \t train_loss_n \t train_loss_g \t train_errp \t val_loss_n \t val_loss_g \t val_errp\n')
            for i, (train_loss_node, train_loss_graph, train_prediction_errp, val_loss_node, val_loss_graph, val_prediction_errp) in enumerate(record):
                f.write(f'{i+1} {train_loss_node} {train_loss_graph} {train_prediction_errp} {val_loss_node} {val_loss_graph} {val_prediction_errp}\n')

    return best_val_loss_graph, record, test_loss_graph, test_loss_node, best_model_epoch, best_model_state_dict, training_time

In [9]:
# same as test part of train()

def test(model_name, dataset_name, verbose=False, gpu_num='cpu'):
    device = set_gpu(gpu_num)
    model = load_model(model_name, device)  
    
    loss_fn = torch.nn.MSELoss(reduction='sum')
    
    data = get_data(dataset_name)
    dataloader = DataLoader(data, batch_size=20, shuffle=False)

    model.eval()
    with torch.no_grad():
        test_loss_node = 0
        test_loss_graph = 0
        for batch in dataloader:
            batch = batch.to(device)
            
            preds = model(batch, node_regression=True)
            
            loss = loss_fn(preds,batch.y)
            test_loss_node += loss.item()

            preds = scatter(preds, batch.batch, dim=0, reduce='sum')
            labels = scatter(batch.y, batch.batch, dim=0, reduce='sum')
            seeds = scatter(batch.x, batch.batch, dim=0, reduce='sum')
            
            test_loss_graph += loss_fn(preds,labels).item()
            
            if verbose:
                for i in range(len(preds)): print(seeds[i].item(), labels[i].item(), preds[i].item())
        test_loss_node /= len(dataloader.dataset)
        test_loss_graph /= len(dataloader.dataset)
        test_loss_node = test_loss_node**0.5
        test_loss_graph = test_loss_graph**0.5
    print('test loss(node level):', test_loss_node)
    print('test loss(graph level):', test_loss_graph)

In [10]:
# hyperparameter tuning by Optuna

def hparam_tuning(*args, **kwargs):
    if 'saving_name' in kwargs:
        saving_name = kwargs['saving_name']
        if saving_name and '.pt' not in saving_name : saving_name+='.pt'
        del kwargs['saving_name']
    if 'hyper_params' in kwargs:
        hyper_params = kwargs['hyper_params']
        del kwargs['hyper_params']
    else: hyper_params = {}
    
    def objective(trial):
        lr = trial.suggest_float('lr',1e-3, 1e-1, log=True)
        lr_gamma = trial.suggest_float('lr_gamma',0.99, 0.999, log=True)
        new_hyper_params = {'lr': lr, 'lr_gamma': lr_gamma}
        new_hyper_params.update(hyper_params)
        
        ret = train(*args, **kwargs, hyper_params=new_hyper_params)
        trial.set_user_attr("ret", ret)
        
        return ret[0]	
    
    study = optuna.create_study(sampler=TPESampler(n_startup_trials=5))
    study.optimize(objective, n_trials=15)

    print(f'best_trial value : {study.best_trial.value}') 
    print(f'best_params : {study.best_params}')
    
    if saving_name:
        best_val_loss_graph, record, test_loss_graph, test_loss_node, best_model_epoch, best_model_state_dict, training_time = study.best_trial.user_attrs["ret"]
        lr, lr_gamma = study.best_params['lr'], study.best_params['lr_gamma']
        torch.save(best_model_state_dict, MODEL_DIR+saving_name)
        with open(RESULT_DIR+'model_train_log/'+saving_name[:-3]+'.txt','w') as f:
            f.write(f'time: {training_time}, lr: {lr}, lr_gamma: {lr_gamma}\n')
            f.write(f'best_model_epoch: {best_model_epoch}, best_val_loss_graph: {best_val_loss_graph}\n')
            f.write(f'test_loss_graph: {test_loss_graph}, test_loss_node: {test_loss_node}\n')
            f.write('epoch \t train_loss_n \t train_loss_g \t train_errp \t val_loss_n \t val_loss_g \t val_errp\n')
            for i, (train_loss_node, train_loss_graph, train_prediction_errp, val_loss_node, val_loss_graph, val_prediction_errp) in enumerate(record): f.write(f'{i+1} {train_loss_node} {train_loss_graph} {train_prediction_errp} {val_loss_node} {val_loss_graph} {val_prediction_errp}\n')
        with gzip.open(RESULT_DIR+'study_obj/'+saving_name[:-3]+'.pkl.gz','wb') as f: pickle.dump(study, f, protocol=4)

In [None]:
# influence estimation quality

def evaluate_quality(model_name, gnn_latent_dim, train_dataset_name, gpu_num):
    device = set_gpu(gpu_num)

    # load data
    data = get_data(train_dataset_name)
    train_num = int(len(data)*0.8)
    val_data = data[train_num:]

    # load model
    model = load_model(model_name, device, gnn_latent_dim=gnn_latent_dim)

    preds = []
    labels = []

    # validation
    model.eval()
    with torch.no_grad():
        for data in val_data:
            pred = model(data)[0][0].item()
            label = np.sum(data.y.numpy())
            preds.append(pred)
            labels.append(label)
    preds = np.array(preds)
    labels = np.array(labels)

    sorted_indices = np.argsort(labels)
    labels = labels[sorted_indices]
    preds = preds[sorted_indices]

    print('r =', np.corrcoef(preds, labels)[0,1])
    print('MC GNN')
    for i in range(len(preds)): print(labels[i], preds[i])