In [3]:
import datareader, sharedutils, os
import torch
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils import data as Data
from model_lif_fc_no_val import model_lif_fc
dataname = 'cora'
conf, cnf = sharedutils.read_config(), {}
cnf.update(conf['shared_conf'])
cnf.update(conf['snn_al'][dataname])

rd = datareader.ReadData("~/datasets/datafromgg")
# get fixed dataset，fixed split
data_fixed = rd.get_fixed_splited_data(dataname)
data = data_fixed
mat, adj_mat, features, all_x, tag = rd.conv_subgraph(data_fixed)

#initial the active learning
random_seed = cnf['random_seed']
res_list = [{'vid':random_seed-1, 'test_acc':np.NaN}]
n_sample_acquired = 1
n_sample_budget = 50
tr_mask = np.array([False]*features.shape[0])
tr_mask[random_seed-1] = True
#Get the laplacian of graph 
laplacian = np.diag(np.sum(adj_mat, 1)) - adj_mat
predCovCQ = np.zeros((len(laplacian), len(laplacian)))

while n_sample_acquired < n_sample_budget:
    tr_ind = all_x[tr_mask].flatten()
    tr_mat = mat[tr_ind]
    ts_ind = all_x[~tr_mask].flatten()
    ts_mat = mat[ts_ind]
    tr_tag = tag[tr_ind]
    ts_tag = tag[ts_ind]

    k = pd.DataFrame(mat)
    u = k.describe()
    print("tr_mat.shape()",u)


    # may be some new params
    cnf['log_dir'] = conf['snn']['log_dir']
    if cnf['v_reset'] == -100: cnf['v_reset'] = None
    train_data_loader, test_data_loader = rd.tr_ts_numpy2dataloader(tr_mat, ts_mat,
      tr_tag, ts_tag, batch_size=cnf["batch_size"])

    print("train, valiadation,test's batch num:", len(train_data_loader), len(test_data_loader))
    print("train, valiadation,test's shape:", tr_mat.shape, ts_mat.shape)
    n_nodes, n_feat, n_flat = mat.shape[0], mat.shape[1], 1
    print("data: %s, num_node_classes: %d" % (dataname, data.graph.num_classes))
    print(cnf)
    ret = model_lif_fc(device=cnf["device"], dataset_dir=cnf["dataset_dir"],
                      dataname=dataname, batch_size=cnf["batch_size"], 
                      learning_rate=cnf["learning_rate"], T=cnf["T"], tau=cnf["tau"], 
                      v_reset=cnf["v_reset"], v_threshold=cnf["v_threshold"],
                      train_epoch=cnf["train_epoch"], log_dir=cnf["log_dir"], n_labels=data.graph.num_classes,
                      n_dim0=n_nodes, n_dim1=n_flat, n_dim2=n_feat, train_data_loader=train_data_loader,
                      test_data_loader=test_data_loader)
    accuracy = ret
    #SOPT
    masks = np.reshape(~tr_mask, (-1,1)) & np.reshape(~tr_mask, (1,-1))
    predCovCQ[masks] = np.linalg.inv(laplacian[~tr_mask][:,~tr_mask]).flatten()
    acq_scores = np.sum(predCovCQ[~tr_mask][:, ~tr_mask], 1)/np.sqrt(np.diag(predCovCQ)[~tr_mask])
    to_label = all_x[~tr_mask][np.argmax(acq_scores)]
    print(to_label)

    #Adding the res_list
    if n_sample_acquired == n_sample_budget:
        res_list[-1]['test_acc'] = accuracy
    else:
        res_list.append({'vid':-100, 'test_acc':np.NaN})
        res_list[-2]['test_acc'], res_list[-1]['vid'] = accuracy, to_label[0] 
        assert np.sum(tr_mask) == n_sample_acquired, 'Num. of sample in tr_mask != n_sample_acquired'
        assert tr_mask[res_list[-1]['vid']]==False, 'Node {0} alrdy acq.'.format(res_list[-1]['vid'])
        tr_mask[res_list[-1]['vid']] = True
    n_sample_acquired += 1
    print(n_sample_acquired, " nodes", "accuracy:", accuracy)

In [4]:
import os, pickle
result_fp = os.path.join(os.getenv('PWD'), 'al_result_files')
result_fp = os.path.join(result_fp, 'SNN_AL-{0}-rs_{1}.p'.format("cora", 1))
i = 0
accuracy_sum = 0
res_list = pickle.load(open(result_fp, 'rb'))
print(res_list)
while i < len(res_list):
    accuracy_sum += res_list[i]['test_acc']
    i+=1
area = (accuracy_sum*2-res_list[0]['test_acc']-res_list[-1]['test_acc'])/2
curve_rate = area/(49)
print(curve_rate)

[{'vid': 0, 'test_acc': 0.2918679549114332}, {'vid': 2048, 'test_acc': 0.15465163109142166}, {'vid': 558, 'test_acc': 0.23811442385173248}, {'vid': 1088, 'test_acc': 0.22249093107617895}, {'vid': 927, 'test_acc': 0.3125}, {'vid': 316, 'test_acc': 0.4126663977410246}, {'vid': 1810, 'test_acc': 0.43946731234866826}, {'vid': 608, 'test_acc': 0.44327815906338314}, {'vid': 1785, 'test_acc': 0.41478190630048467}, {'vid': 2186, 'test_acc': 0.39434343434343433}, {'vid': 607, 'test_acc': 0.40040404040404043}, {'vid': 1478, 'test_acc': 0.4183508488278092}, {'vid': 310, 'test_acc': 0.4892842701172665}, {'vid': 236, 'test_acc': 0.46561488673139156}, {'vid': 224, 'test_acc': 0.46863617968433835}, {'vid': 1519, 'test_acc': 0.5145748987854251}, {'vid': 2259, 'test_acc': 0.5143782908059943}, {'vid': 1559, 'test_acc': 0.5162074554294975}, {'vid': 1245, 'test_acc': 0.49371706526145115}, {'vid': 1732, 'test_acc': 0.5332522303325223}, {'vid': 1212, 'test_acc': 0.5574036511156186}, {'vid': 1607, 'test_acc'

In [14]:
import os, pickle
import numpy as np
result_fp = os.path.join('/home/zlzhu/snn/bsgcn/handcode/active_snn', 'al_result_files')
result_fp = os.path.join(result_fp, 'SNN_AL-{0}-rs_{1}.p'.format("acm", 123))
i = 0
accuracy_sum = 0
res_list = pickle.load(open(result_fp, 'rb'))
accuracy = np.zeros(len(res_list))
print(len(res_list))
while i < len(res_list):
    accuracy[i] = res_list[i]['test_acc']
    if (i+1)%5==0:
        print(accuracy[i])
    else: 
        print(accuracy[i], end=" ")
    accuracy_sum += res_list[i]['test_acc']
    i+=1
area = (accuracy_sum*2-res_list[0]['test_acc']-res_list[-1]['test_acc'])/2
curve_rate = area/(49)
print(accuracy)
print(curve_rate)
#1 random 2 sopt 3 predictive 4 combination

50
0.3605985037406484 0.3602794411177645 0.6060908637044433 0.8686313686313686 0.8140929535232384
0.898 0.9014507253626813 0.9039039039039038 0.9068602904356535 0.9083166332665331
0.9087719298245615 0.9092276830491475 0.9131961866532865 0.9116465863453815 0.9126067302862882
0.9125628140703518 0.9170437405731523 0.9134808853118712 0.914443885254152 0.9169184290030211
0.8952141057934508 0.9017137096774194 0.8961169944528492 0.8521695257315842 0.8818778394750126
0.8848484848484849 0.8736735725113693 0.8857431749241659 0.8700050581689428 0.8785425101214575
0.8764556962025316 0.8733535967578521 0.8773441459706032 0.8889452332657201 0.8904109589041096
0.8888324873096447 0.9040121889283901 0.899390243902439 0.8998474834773768 0.9038657171922686
0.9027989821882951 0.9022403258655805 0.9057564951604686 0.908256880733945 0.9077001529831719
0.9076530612244897 0.9122001020929046 0.9152196118488254 0.9095554420030659 0.9120654396728016
[0.3605985  0.36027944 0.60609086 0.86863137 0.81409295 0.898
 