In [None]:
import argparse
import numpy as np
import torch
from sklearn.metrics import f1_score, roc_auc_score,accuracy_score,confusion_matrix
from lib.Elliptic_data import Get_data, get_Data, split_idx
from lib.util import load_best_result, save_best_checkpoint
from model.loss import FocalLoss
from model.model import Net, GCNNet, GATNet, SimpleNet,EMA
def mixup_data(x, alpha=1.0):
    '''Compute the mixup data. Return mixed inputs, mixed target, and lambda'''
    if alpha > 0.:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.
    batch_size = x.size()[0]
    index = np.random.permutation(batch_size)
    mixed_x = lam * x + (1 - lam) * x[index,:]
    return mixed_x.to(torch.float32)

parser = argparse.ArgumentParser(description='Elliptic dataloader')
#set
parser.add_argument('--device', default='cpu', type=str, help='Choose the device to train, cpu | cuda:0')
parser.add_argument('--mode', default='train', type=str, help='Choose the mode to use the model, train | test')
parser.add_argument('--epoch', default=100, type=int)
parser.add_argument('--step', default=None, type=int, help='if train all data, set None, else 1~49')
#dataset
parser.add_argument('--test_ratio', default=0.8, type=float)
parser.add_argument('--seed', default=42, type=int, help='Set the spliting seeds')
#model
parser.add_argument('--mode_name',default='Net', type=str)
parser.add_argument('--loss',default='BCE',type=str)
parser.add_argument('--hid_dim',default=128, type=int)
parser.add_argument('--f_att',default=True, type=bool,help='Feature Booster')
parser.add_argument('--slices',default=3, type=int)
parser.add_argument('--num_layer',default=1, type=int)
args = parser.parse_args(args=[])

In [None]:
data, edge_index, classified_idx, unclassified_idx, un_edge_index = Get_data(step=args.step, un=True, aug=True)
if args.step is not None:
    print('Setp:',args.step)
input_Data = get_Data(data, edge_index)
mixup_unx1= mixup_data(x=data[unclassified_idx], alpha=0.4)
mixup_unx2= mixup_data(x=data[unclassified_idx], alpha=2)
unlabel_Data1 = get_Data(mixup_unx1, un_edge_index)
unlabel_Data2 = get_Data(mixup_unx2, un_edge_index)
train_idx, valid_idx = split_idx(input_Data, classified_idx, args.test_ratio, args.seed)
print('Get Data Ready, train shape {} | val shape {}'.format(len(train_idx), len(valid_idx)))
device = torch.device(args.device)
if args.loss == 'Focal':
    criterion = FocalLoss(alpha=0.25)
elif args.loss == 'BCE':
    criterion = torch.nn.BCELoss()
pretrain =  Net(dim_in=input_Data.x.shape[1], dim_hidden=args.hid_dim, slices=args.slices, num_layer=args.num_layer, f_att=args.f_att).to(device)
pretrain.float()
pretrain, _ = load_best_result(pretrain, args.mode_name)
model = SimpleNet(dim_in=input_Data.x.shape[1], dim_hidden=args.hid_dim).to(device)
model.float()

ema_model = EMA(model,0.999)
ema_model.register()
best_loss = np.Inf
input_Data, unlabel_Data1, unlabel_Data2 = input_Data.to(device), unlabel_Data1.to(device), unlabel_Data2.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')


In [None]:
if args.mode == 'train':
    tea_loss_list = []
    stu_loss_list = []
    for epoch in range(1, args.epoch+1):
        optimizer.zero_grad()
        if epoch ==1:
            un_out2 = pretrain(unlabel_Data2)
            l_out = pretrain(input_Data)
        else:
            l_out = model(input_Data)
        un_out1 = model(unlabel_Data1)
        l_loss = criterion(l_out[train_idx].squeeze(), input_Data.y[train_idx])
        un_loss = criterion(un_out1, un_out2)
        loss = l_loss + un_loss
        auc = roc_auc_score(input_Data.y.detach().cpu().numpy()[train_idx], l_out[train_idx].detach().cpu().numpy())
        loss.backward()
        optimizer.step()
        ema_model.update()
        if epoch%5 == 0:
            print("epoch: {:2d} - Student loss: {:.6f} - roc: {:.6f}".format(epoch, loss.item(), auc))
        model.eval()
        with torch.no_grad():
            ema_model.apply_shadow()
            un_out2 = model(unlabel_Data2)
            l_out = model(input_Data)
            un_loss = criterion(un_out1, un_out2)
            auc = roc_auc_score(input_Data.y.detach().cpu().numpy()[valid_idx], l_out[valid_idx].detach().cpu().numpy())
            if epoch%5 == 0:
                print("epoch: {:2d} - Teacher loss: {:.6f} - roc: {:.6f}".format(epoch, un_loss.item(), auc))
            if un_loss < best_loss:
                best_loss = un_loss
                best_epoch = epoch
                save_best_checkpoint(model, best_epoch, 'Teacher')
            


In [None]:
try:
    best_model, best_epoch = load_best_result(model, 'Teacher')
    print('Best Model Load At {}'.format(best_epoch))
except:
    print('Load Model Fail!')

preds = best_model(input_Data)
preds = preds.detach().cpu().numpy()

out_labels = preds > 0.5
train_acc = accuracy_score(input_Data.y.detach().cpu().numpy()[train_idx], out_labels[train_idx])
train_auc = roc_auc_score(input_Data.y.detach().cpu().numpy()[train_idx], preds[train_idx])
f1_train = f1_score(input_Data.y.detach().cpu().numpy()[train_idx], out_labels[train_idx])
print("Train accuracy: {:.6f}".format(train_acc))
print("train AUC     : {:.6f}".format(train_auc))
print("F1 score      : {:.6f}".format(f1_train))
print('--------------------------')
valid_auc = roc_auc_score(input_Data.y.detach().cpu().numpy()[valid_idx], preds[valid_idx])
valid_acc = accuracy_score(input_Data.y.detach().cpu().numpy()[valid_idx], out_labels[valid_idx])
f1_valid = f1_score(input_Data.y.detach().cpu().numpy()[valid_idx], out_labels[valid_idx])
print("Valid accuracy: {:.6f}".format(valid_acc))
print("Valid AUC     : {:.6f}".format(valid_auc))
print("F1 score      : {:.6f}".format(f1_valid))
print('--------------------------')
f1_total = f1_score(input_Data.y.detach().cpu().numpy()[classified_idx], out_labels[classified_idx])
print('Total F1 score: {:.6f}'.format(f1_total))
print(confusion_matrix(input_Data.y.detach().cpu().numpy()[valid_idx], out_labels[valid_idx]))

In [None]:
preds = pretrain(input_Data)
preds = preds.detach().cpu().numpy()

out_labels = preds > 0.5
train_acc = accuracy_score(input_Data.y.detach().cpu().numpy()[train_idx], out_labels[train_idx])
train_auc = roc_auc_score(input_Data.y.detach().cpu().numpy()[train_idx], preds[train_idx])
f1_train = f1_score(input_Data.y.detach().cpu().numpy()[train_idx], out_labels[train_idx])
print("Train accuracy: {:.6f}".format(train_acc))
print("train AUC     : {:.6f}".format(train_auc))
print("F1 score      : {:.6f}".format(f1_train))
print('--------------------------')
valid_auc = roc_auc_score(input_Data.y.detach().cpu().numpy()[valid_idx], preds[valid_idx])
valid_acc = accuracy_score(input_Data.y.detach().cpu().numpy()[valid_idx], out_labels[valid_idx])
f1_valid = f1_score(input_Data.y.detach().cpu().numpy()[valid_idx], out_labels[valid_idx])
print("Valid accuracy: {:.6f}".format(valid_acc))
print("Valid AUC     : {:.6f}".format(valid_auc))
print("F1 score      : {:.6f}".format(f1_valid))
print('--------------------------')
f1_total = f1_score(input_Data.y.detach().cpu().numpy()[classified_idx], out_labels[classified_idx])
print('Total F1 score: {:.6f}'.format(f1_total))
print(confusion_matrix(input_Data.y.detach().cpu().numpy()[valid_idx], out_labels[valid_idx]))