In [None]:
from __future__ import division
from __future__ import print_function

import os
import glob
import time
import random
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

import scipy.sparse as sp
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import pickle

from sklearn.metrics import f1_score
from collections import defaultdict

from cadae_model import Base_CADAE
from utils import roc_auc

from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")

# Training settings
class Args():
    epochs = 400
    lr = float(0.004)
    weight_decay = float(8e-5)
    save_model = 'Yes'
    gin_layer = 2
    mlp_layer = 2
    learn_eps = True
    
    feature_type = 'imcap_minmax_fea'
    neighbor_pooling_type = 'sum'
    input_dim = 2816
    hidden_dim = 256
    str_dec_dim = 128
    emb_dim = 128
    dropout = 0.
    
    cuda = True
    alpha = 0.6
    device = 0
    
args = Args()

###options of print
torch.set_printoptions(
    precision=2,    
    threshold=1000,
    edgeitems=3,
    linewidth=150,  
    profile=None,
    sci_mode=False  
)
np.set_printoptions(suppress=True)

In [None]:
class CADAE(nn.Module):

    def __init__(self, args):
        
        super(CADAE, self).__init__()
        self.base_model = Base_CADAE(3,3,2, args.input_dim, args.hidden_dim, args.emb_dim,
                                    args.str_dec_dim, args.learn_eps, args.neighbor_pooling_type, args.dropout,
                                    args.device)
    
    def cadae_forward(self, features, adj_ori):
        
        att_rec, str_rec, embs = self.base_model(features, adj_ori)
        return att_rec, str_rec, embs       

    def loss_function(self, features, att_rec, adj_label, str_rec, alpha = args.alpha):
        mse = nn.MSELoss(reduction='mean')
        
        att_error = mse(features, att_rec)
        att_rec_error = torch.sum(torch.square(features-att_rec),1)
        
        str_error = mse(adj_label, str_rec)
        str_rec_error = torch.sum(torch.square(adj_label-str_rec),1)
        
        total_reconstruction_error = torch.mul(att_rec_error, alpha)+torch.mul(str_rec_error, 1-alpha)
        cost = alpha*att_error + (1-alpha)*str_error
        
        return cost, total_reconstruction_error

    
    def loss(self, features, adj_ori, adj_labels):
         
        att_recons, str_recons, embeds = self.cadae_forward(features, adj_ori)
        cost, total_reconstruction_error = self.loss_function(features, att_recons, adj_labels, str_recons)
        
        return cost, total_reconstruction_error, embeds
    

def draw_original(fea, labels):
    input_fea = fea.cpu().detach().numpy()
    labels = labels.reshape([-1])
    index1 = np.where(labels==1)[0]
    index0 = np.where(labels==0)[0]

    fig = plt.figure()
    plt.title('Testing Reconstruction Errors') 
    plt.scatter(range(index0.shape[0]), input_fea[index0], c = 'b', label = 'normal')
    plt.scatter(range(index0.shape[0], index0.shape[0]+index1.shape[0]), input_fea[index1], c = 'r', 
                label = 'abnormal')

    plt.xlabel('Num of samples')       
    plt.ylabel('Reconstruction errors')
    plt.legend()
    plt.show()
    

In [None]:
'''
training
'''
###trianing data dir
traindata_list = '/home/suzukilab/zhangkang/graph_experiments/data/hatae_toy_data/hatae_toy_train_pkl/*.pkl'

feature_type = args.feature_type
feature_name = feature_type.rstrip('_fea')
time_begin = time.time()
loss_values = []

assert torch.cuda.is_available()
device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")

times = []
model = CADAE(args).cuda()
model.train()
optimizer = optim.Adam(filter(lambda p : p.requires_grad, model.parameters()),
                       lr=args.lr, 
                       weight_decay=args.weight_decay)

for epoch in tqdm(range(args.epochs)):
    
    ###each 10 nodes in one graph as a batch
    batch_list = glob.glob(traindata_list)
    random.shuffle(batch_list)
    start_time = time.time()
    
    loss = []
    for batch_pickle in batch_list:
        with open(batch_pickle,'rb') as file:
            batch_data = pickle.load(file)
            
        batch_feature = batch_data[feature_type]        
        batch_feature_processed = torch.FloatTensor(batch_data[feature_type]).cuda()
        batch_adj_ori = batch_data['adj_mx'].todense().astype(np.float32)
        batch_adj_ori_t = torch.tensor(batch_adj_ori).cuda()
        batch_adj_selfloop_t = torch.cuda.FloatTensor(batch_adj_ori_t+torch.eye(batch_adj_ori_t.shape[0]).cuda())      
        
        optimizer.zero_grad()
        batch_loss, batch_reconstruction,_ = model.loss(batch_feature_processed, 
                                                                    batch_adj_ori_t,
                                                                    batch_adj_selfloop_t)     
        batch_loss.backward()
        optimizer.step()

        loss.append(batch_loss.data)
    end_time = time.time()
    times.append(end_time-start_time)
    
    batch_loss = [i.cpu()for i in loss]
    epoch_loss = np.mean(batch_loss)
    
    print ('Epoch:{}, epoch_loss:{}'.format(epoch, epoch_loss))
    
if args.save_model=='Yes': 
    torch.save(model.state_dict(), './trained_model/hatae_toy_data/cadae_adjust_eps={}_{}_epoch_{}.pkl'.format(str(args.learn_eps), feature_type, epoch+1))
print("Optimization Finished!")
print("Total time elapsed: {:.4f}s".format(time.time() - time_begin))

In [None]:
'''
testing
'''

testdata_list = '/home/suzukilab/zhangkang/graph_experiments/data/hatae_toy_data/hatae_toy_test_pkl/*.pkl'

feature_type = args.feature_type
args.cuda = False


model = CADAE(args)

model_load_dir = '/home/suzukilab/zhangkang/graph_experiments/model/giae_newcode/trained_model/\
hatae_toy_data/cadae_adjust_eps=True_imcap_minmax_fea_epoch_400.pkl'

model.load_state_dict(torch.load(model_load_dir))
model.eval()
print('Loading ' + model_load_dir)

###get testing data
test_batch_list = glob.glob(testdata_list)

check_image = []
check_rec_error = []
check_label= []
check_embs = []
check_test_cap = []

for test_idx, test_batch_pickle in enumerate(test_batch_list):
    with open(test_batch_pickle,'rb') as file:
        test_batch_data = pickle.load(file)
            
    test_batch_feature = test_batch_data[feature_type]
    test_batch_feature_processed = torch.FloatTensor(test_batch_data[feature_type])  
    test_batch_adj_ori_t = torch.tensor(test_batch_data['adj_mx'].todense().astype(np.float32))
    test_batch_adj_selfloop_t = torch.FloatTensor(test_batch_adj_ori_t+torch.eye(test_batch_adj_ori_t.shape[0]))
    test_batch_label = test_batch_data['label']
    test_batch_captions = test_batch_data['captions']

    with torch.no_grad():
        _, test_batch_rec_error, test_batch_embeds = model.loss(test_batch_feature_processed, 
                                                                test_batch_adj_ori_t,
                                                                test_batch_adj_selfloop_t)

    check_test_cap.append(test_batch_captions)
    
    check_rec_error.append(test_batch_rec_error.detach())
    check_label.append(test_batch_label)
    check_embs.append(test_batch_embeds.detach())
    
    for num_node in range(test_batch_feature_processed.shape[0]):
        check_image.append(test_batch_pickle.split('/')[-1])

test_rec_error = np.stack(check_rec_error).reshape(-1)
test_label = np.stack(check_label).reshape(-1)
test_embs = np.stack(check_embs).reshape(test_rec_error.shape[0],-1)
test_cap = np.stack(check_test_cap).reshape(test_rec_error.shape[0],-1)

input_rec_error = torch.tensor(test_rec_error)
test_auc = roc_auc(input_rec_error, test_label)

print('test_AUC:',test_auc)
draw_original(input_rec_error, test_label)