In [None]:
import os
import torch
import argparse
import pyhocon
import random
import numpy as np
import pandas as pd
from src.TAXDataCenter import TAXDataCenter
from src.utils import evaluate, train_classification, apply_model, get_gnn_TAXembeddings, test
from src.models import GraphSage, Classification, UnsupervisedLoss


In [None]:
parser = argparse.ArgumentParser()
parser.add_argument('--dataSet', type=str, default='TaxH', choices=['TaxH', 'TaxZ', 'TaxS'])
parser.add_argument('--agg_func', type=str, default='MEAN', choices=['MEAN', 'MAX'])
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--b_sz', type=int, default=20)
parser.add_argument('--seed', type=int, default=824)
parser.add_argument('--cuda', action='store_true',help='use CUDA')
parser.add_argument('--gcn', action='store_true')
parser.add_argument('--learn_method', type=str, default='unsup', choices=['sup', 'unsup', 'plus_unsup'], help='sup： supervised learning, unsup: unsupervised learing, plus_unsup: supervised learning plus unsupervised loss')
parser.add_argument('--unsup_loss', type=str, default='normal', choices=['normal', 'margin'])
parser.add_argument('--max_vali_f1', type=float, default=0)
parser.add_argument('--name', type=str, default='debug')
parser.add_argument('--config', type=str, default='./src/taxConfig.conf')
# args = parser.parse_args()
args = parser.parse_args(args=[])

In [None]:
if torch.cuda.is_available():
	if not args.cuda:
		print("WARNING: You have a CUDA device, so you should probably run with --cuda")
	else:
		device_id = torch.cuda.current_device()
		print('using device', device_id, torch.cuda.get_device_name(device_id))

device = torch.device("cuda:1" if args.cuda else "cpu")
print('DEVICE:', device)

In [None]:
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

# # load config file
config = pyhocon.ConfigFactory.parse_file(args.config)

# load data
ds = args.dataSet

basic_feature = '/home/dada/TAX/TYC/ValidDATA/BasicFeature/{}_normalize.pickle'.format(ds)
label_file = '/home/dada/TAX/TYC/ValidDATA/Label/{}_label.pickle'.format(ds)
edges_file = '/home/dada/TAX/TYC/ValidDATA/NetworkEdges/{}_edges.pickle'.format(ds)

dataCenter = TAXDataCenter(basic_feature, edges_file, label_file)
dataCenter.load_dataSet(dataSet=ds)
features = torch.FloatTensor(getattr(dataCenter, ds+'_feats')).to(device)

graphSage = GraphSage(config['setting.num_layers'], features.size(1), config['setting.hidden_emb_size'], features, getattr(dataCenter, ds+'_adj_lists'), device, gcn=args.gcn, agg_func=args.agg_func)
graphSage.to(device)

num_labels = len(set(getattr(dataCenter, ds+'_labels')))
classification = Classification(config['setting.hidden_emb_size'], num_labels)
classification.to(device)

unsupervised_loss = UnsupervisedLoss(getattr(dataCenter, ds+'_adj_lists'), getattr(dataCenter, ds+'_train'), device)

if args.learn_method == 'sup':
    print('GraphSage with Supervised Learning')
elif args.learn_method == 'plus_unsup':
    print('GraphSage with Supervised Learning plus Net Unsupervised Learning')
else:
    print('GraphSage with Net Unsupervised Learning')

In [None]:
if args.learn_method == 'unsup':   ## unsupervised 训练
    for epoch in range(args.epochs):
        graphSage, _ = apply_model(dataCenter, ds, graphSage, classification, unsupervised_loss, args.b_sz, args.unsup_loss, device, args.learn_method)

In [None]:
embeddings, nsrdzdah = get_gnn_TAXembeddings(graphSage, dataCenter, ds)
embeds = pd.DataFrame(embeddings)
embeds.index = nsrdzdah
embeddings_fin = embeds.loc[getattr(dataCenter, ds+'_node_names'), :]
embeddings_fin.to_pickle('Embeddings{}_unsup_embeddings.pickle'.format(ds))