In [1]:
import os
import pdb
import random
import datetime
import argparse
import json
from collections import namedtuple
import scipy.sparse
from sklearn.preprocessing import StandardScaler
import dgl
import numpy as np
import torch
from sklearn.metrics import f1_score

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.utils as u

from mh_aug.utils import n_f1_score
from mh_aug.model import GCN, MLP, GAT, SAGE, AGG_NET
from mh_aug.preprocess_nc import load_nodes
from sklearn.metrics import accuracy_score
from scipy.stats import truncnorm
from scipy.special import betaln

import torch.autograd.profiler as profiler

import pickle
import wandb
import time

# Model/Train loop from MH-Aug

main.py > main()
-> epi.py > episode()
-> train()
-> augment()

In [2]:
def logging_time(original_fn):
    def wrapper_fn(*args, **kwargs):
        start_time = time.time()
        result = original_fn(*args, **kwargs)
        end_time = time.time()
        print("WorkingTime[{}]: {} sec".format(original_fn.__name__, (end_time-start_time)*100))
        return result
    return wrapper_fn


class HLoss(nn.Module):
    def __init__(self):
        super(HLoss, self).__init__()

    def forward(self, x, full=False):
        num_data = x.shape[0]
        b = F.softmax(x, dim=1) * F.log_softmax(x, dim=1)
        if full: 
            return -1.0 * b.sum(1)
        b = -1.0 * b.sum()
        b = b/num_data
        return b

class XeLoss(nn.Module):
    def __init__(self):
        super(XeLoss, self).__init__()
        
    def forward(self, y, x):
        num_data = x.shape[0]
        b = F.softmax(y, dim=1)*F.log_softmax(x, dim=1) - F.softmax(y, dim=1)*F.log_softmax(y, dim=1)
        b = -1.0 * b.sum()
        b = b/num_data
        return b
    
class Jensen_Shannon(nn.Module):
    def __init__(self):
        super(Jensen_Shannon, self).__init__()
        
    def forward(self, y, x):
        num_data = x.shape[0]
        b = F.softmax(y, dim=1)*F.log_softmax(x, dim=1) - F.softmax(y, dim=1)*F.log_softmax(y, dim=1)
        b += F.softmax(x, dim=1)*F.log_softmax(y, dim=1) - F.softmax(x, dim=1)*F.log_softmax(x, dim=1)
        b = -0.5 * b.sum()
        b = b/num_data
        return b

def our_truncnorm(a, b, mu, sigma, x=None, mode='pdf'):
    a, b = (a - mu) / sigma, (b - mu) / sigma
    if mode=='pdf':
        return truncnorm.pdf(x, a, b, loc = mu, scale = sigma)
    elif mode=='rvs':
        return truncnorm.rvs(a, b, loc = mu, scale = sigma)
    
def aggregate(features, edge_index, agg_model, num_hop):
    n = features.shape[0]
    edge_index_w_sl = u.add_self_loops(edge_index, num_nodes = n)[0]
    s_vec = agg_model(features, edge_index_w_sl)
    return s_vec

def log_normal(a, b, sigma):
    return -1 * torch.pow(a - b, 2) / (2 * torch.pow(sigma, 2)) #/root2pi / sigma

def augment(args, org_edge_index, org_feature, delta_G_e, delta_G_v):
    m = org_edge_index.shape[1]
    num_edge_drop = int(m*delta_G_e)
    #######  flip_edge (A=1)  #######
    idx = torch.randperm(m, device='cuda')[:m-num_edge_drop]
    aug_edge_index = org_edge_index[:, idx]
    #################################    
    
    n = org_feature.shape[0]
    num_node_drop = int(n*delta_G_v)
    
    aug_feature = org_feature.clone()
    node_list = torch.ones(n, 1, device = device)
    ##########  flip_feat  ##########
    idx = torch.randperm(n, device='cuda')[:num_node_drop]
    aug_feature[idx] = 0
    node_list[idx] = 0

    if num_node_drop:
        aug_feature *= n / (n-num_node_drop)
    #################################
    return aug_edge_index, aug_feature, node_list

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hard_xe_loss_op = nn.CrossEntropyLoss()   
soft_xe_loss_op = XeLoss()
h_loss_op = HLoss()
js_loss_op = Jensen_Shannon()

In [9]:
a = {'seed': 1234, 'use_seed': True, 'dataset': 'CORA', 'num_epochs': 7000, 'max_epochs': 2000, 'num': 2, 'print': 1, 'wandb': 0, 'wandb_name': 'ours_manyparam', 'model_name': 'GCN', 'emb_dim': 64, 'num_layers': 2, 'lr': 0.005, 
'decay': 0.0005, 'dropout': 0.5, 'att_dropout': 0.5, 'num_heads': 8, 'a_e': 100, 'b_e': 1, 'a_v': 100, 'b_v': 1, 'kl': 2.0, 'h': 0.2, 'sigma_delta_e': 0.03, 'sigma_delta_v': 0.03, 'mu_e': 0.6, 'mu_v': 0.2, 'lam1_e': 1, 'lam1_v': 1, 'lam2_e': 0.0, 'lam2_v': 0.0, 'option_loss': 0}

from collections import namedtuple
A = namedtuple('a', a)
args = A(**a)

In [10]:
args.emb_dim

64

In [5]:
# features, edge_index, train_index, train_label, valid_index, valid_label, test_index, test_label, num_classes = load_nodes(args)

# Data from graphsaint

In [6]:
args = None
multilabel = True

if not os.path.exists('graphsaintdata') and not os.path.exists('data'):
    raise ValueError("The directory graphsaintdata does not exist!")
elif os.path.exists('graphsaintdata') and not os.path.exists('data'):
    os.rename('graphsaintdata', 'data')
# prefix = "data/{}".format(args.dataset)#################################
prefix = "data/ppi"#################################
DataType = namedtuple('Dataset', ['num_classes', 'train_nid', 'g'])

adj_full = scipy.sparse.load_npz('./{}/adj_full.npz'.format(prefix)).astype(bool)#################
g = dgl.from_scipy(adj_full)
num_nodes = g.num_nodes()

adj_train = scipy.sparse.load_npz('./{}/adj_train.npz'.format(prefix)).astype(bool)#############
train_nid = np.array(list(set(adj_train.nonzero()[0])))

role = json.load(open('./{}/role.json'.format(prefix)))
mask = np.zeros((num_nodes,), dtype=bool)
train_mask = mask.copy()
train_mask[role['tr']] = True
val_mask = mask.copy()
val_mask[role['va']] = True
test_mask = mask.copy()
test_mask[role['te']] = True

feats = np.load('./{}/feats.npy'.format(prefix))
scaler = StandardScaler()
scaler.fit(feats[train_nid])
feats = scaler.transform(feats)

class_map = json.load(open('./{}/class_map.json'.format(prefix)))
class_map = {int(k): v for k, v in class_map.items()}

if multilabel:
    # Multi-label binary classification
    num_classes = len(list(class_map.values())[0])
    class_arr = np.zeros((num_nodes, num_classes))
    for k, v in class_map.items():
        class_arr[k] = v
else:
    num_classes = max(class_map.values()) - min(class_map.values()) + 1
    class_arr = np.zeros((num_nodes,))
    for k, v in class_map.items():
        class_arr[k] = v


In [7]:
features = torch.tensor(feats, dtype=torch.float)
labels = torch.tensor(class_arr, dtype=torch.float if multilabel else torch.long)
num_classes = num_classes
bn = False

edge_index = torch.Tensor(adj_full.nonzero()).to(torch.int)
train_index = torch.tensor(train_mask, dtype=torch.bool).tolist()
train_label = labels[train_index]
valid_index = torch.tensor(val_mask, dtype=torch.bool).tolist()
valid_label = labels[valid_index]
test_index = torch.tensor(test_mask, dtype=torch.bool).tolist()
test_label = labels[test_index]

  edge_index = torch.Tensor(adj_full.nonzero()).to(torch.int)


In [11]:
features = features.to(device)
edge_index = edge_index.to(device)

train_label = train_label.to(device)
valid_label = valid_label.to(device)
test_label = test_label.to(device)

in_channels = features.shape[1]
hidden_channels = args.emb_dim
num_layers = args.num_layers
dropout = args.dropout
learning_rate = args.lr
weight_decay = args.decay

In [12]:
agg_model = AGG_NET(num_hop = num_layers).cuda()
agg_model.eval()
with torch.no_grad():
    org_ego = aggregate(torch.ones(features.shape[0],1, device = device), edge_index, agg_model,args.num_layers)
if args.model_name == 'GCN':
    model = GCN(in_channels, hidden_channels, num_classes, num_layers, dropout).to(device)
elif args.model_name == 'GAT':
    model = GAT(in_channels, hidden_channels, num_classes, num_layers, dropout, args.num_heads, args.att_dropout).to(device)
elif args.model_name == 'SAGE':
    model = SAGE(in_channels, hidden_channels, num_classes, num_layers, dropout).to(device)
elif args.model_name == 'MLP':
    model = MLP(in_channels, hidden_channels, num_classes, num_layers, dropout).to(device)
opt_model = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)    

best_valid_acc, best_valid_f1, best_test_acc, best_test_f1, best_epoch = -1, -1, -1, -1, -1

In [13]:
aug_feature, aug_edge_index, aug_node_list = features, edge_index, torch.ones(features.shape[0], 1, device = device)

In [14]:
# num_node = feature.shape[0]    #############
num_node = features.shape[0]    #############
num_edge = edge_index.shape[1]

delta_G_e = 1 - aug_edge_index.shape[1]/num_edge
delta_G_e_aug = our_truncnorm(0, 1, delta_G_e, args.sigma_delta_e, mode='rvs')

# delta_G_v = 1 - node_list.sum().item()/num_node ########
delta_G_v = 1 - aug_node_list.sum().item()/num_node ########
delta_G_v_aug = our_truncnorm(0, 1, delta_G_v, args.sigma_delta_v, mode='rvs')

In [15]:
# aug_edge_index2, aug_feature2, node_list2 = augment(args, edge_index, feature, delta_G_e_aug, delta_G_v_aug)    ####
aug_edge_index2, aug_feature2, node_list2 = augment(args, edge_index, features, delta_G_e_aug, delta_G_v_aug)    ####

In [16]:
model.train()
output = model(features, edge_index)     ###########
aug_edge_index = aug_edge_index.to(device)      #############

feat_ones = torch.ones(num_node, 1, device = device)
with torch.no_grad():
    delta_g_e     = 1 - (aggregate(feat_ones,  aug_edge_index,agg_model,  args.num_layers) / org_ego).squeeze(1) 
    delta_g_aug_e = 1 - (aggregate(feat_ones,  aug_edge_index2,agg_model, args.num_layers) / org_ego).squeeze(1)
    delta_g_v     = 1 - (aggregate(aug_node_list,  edge_index,agg_model,      args.num_layers) / org_ego).squeeze(1) ####
    delta_g_aug_v = 1 - (aggregate(node_list2, edge_index,agg_model,      args.num_layers) / org_ego).squeeze(1)


max_ent = h_loss_op(torch.full((1, output.shape[1]), 1 / output.shape[1])).item()
ent = h_loss_op(output.detach(), True) / max_ent


p     = args.lam1_e * log_normal(delta_g_e,     args.mu_e, args.a_e * ent + args.b_e) + \
        args.lam1_v * log_normal(delta_g_v,     args.mu_v, args.a_v * ent + args.b_v)
p_aug = args.lam1_e * log_normal(delta_g_aug_e, args.mu_e, args.a_e * ent + args.b_e) + \
        args.lam1_v * log_normal(delta_g_aug_v, args.mu_v, args.a_v * ent + args.b_v)

q     = np.log(our_truncnorm(0, 1, delta_G_e_aug, args.sigma_delta_e, x=delta_G_e, mode='pdf')) + \
        args.lam2_e * betaln(num_edge - num_edge * delta_G_e + 1, num_edge * delta_G_e + 1) + \
        np.log(our_truncnorm(0, 1, delta_G_v_aug, args.sigma_delta_v, x=delta_G_v, mode='pdf')) + \
        args.lam2_v * betaln(num_node - num_node * delta_G_v + 1, num_node * delta_G_v + 1)
q_aug = np.log(our_truncnorm(0, 1, delta_G_e, args.sigma_delta_e, x=delta_G_e_aug, mode='pdf')) + \
        args.lam2_e * betaln(num_edge - num_edge * delta_G_e_aug + 1, num_edge * delta_G_e_aug + 1) + \
        np.log(our_truncnorm(0, 1, delta_G_v, args.sigma_delta_v, x=delta_G_v_aug, mode='pdf')) + \
        args.lam2_v * betaln(num_node - num_node * delta_G_v_aug + 1, num_node * delta_G_v_aug + 1)


acceptance = ( (torch.sum(p_aug) - torch.sum(p))  - (q_aug - q) )

f1 = 0
acc = 0

In [17]:
acceptance

tensor(-1.0884, device='cuda:0')