In [1]:
import argparse
import math
import time
import torch
import torch.nn as nn
from models import normal_learning_model, normal_learning_model_last
from models import generator_learning_model, classifier_learning_model
import numpy as np;
from models import utils_multidatasource, Optim
import scipy
import sklearn
from sklearn import metrics
import matplotlib.pyplot as plt
import scipy.io as sio
from scipy.sparse.linalg import svds
from sklearn.preprocessing import normalize
from numpy import linalg as LA
import pickle
from torch.autograd import Variable
from scipy.io import loadmat
import torch.nn.functional as F
from scipy import spatial
from termcolor import colored
from scipy.stats import sem
import warnings
from sklearn.metrics import classification_report
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import f1_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import confusion_matrix
from imblearn.metrics import geometric_mean_score

# def extract(v):
#     return v.data.storage().tolist()


def train_normal(loader, data, model, criterion, optim, batch_size):
    model.train();
    total_loss = 0;
    n_samples = 0;
    for inputs in loader.get_batches(data, batch_size, True):        
        X, Y, Labels = inputs[0], inputs[1], inputs[2]
        model.zero_grad();
        output = model(X);        
        batch_loss = criterion(output, Y); 
        batch_loss.backward()
        total_loss += batch_loss.data.item();
        optim.step();
        if len(output.shape)<4:
            n_samples += (output.size(0) * output.size(1) * output.size(2));
        else:
            n_samples += (output.size(0) * output.size(1) * output.size(2) * output.size(3));
    return total_loss / n_samples

def test_normal(loader, data, model, criterion, optim, batch_size):
    total_loss = 0;
    n_samples = 0;
    for inputs in loader.get_batches(data, batch_size, True):        
        X, Y, Labels = inputs[0], inputs[1], inputs[2]
        output = model(X);        
        batch_loss = criterion(output, Y); 
        total_loss += batch_loss.data.item();
        if len(output.shape)<4:
            n_samples += (output.size(0) * output.size(1) * output.size(2));
        else:
            n_samples += (output.size(0) * output.size(1) * output.size(2) * output.size(3));
    return total_loss / n_samples


def train_classifier(loader, data, model, criterion, optim, batch_size, model_normal):
    model.train();
    total_loss = 0;
    n_samples = 0;
    Label_truth = torch.zeros((0,data[2].shape[1])) 
    Label_predict = torch.zeros((0,data[2].shape[1])) 
    for inputs in loader.get_batches(data, batch_size, True):        
        X, Y, Labels = inputs[0], inputs[1], inputs[2]
        model.zero_grad();
        output = model(X, model_normal);             
        Y_duplicate = Y.clone()
        for class_idx in range(Labels.shape[1]-1):
            Y_duplicate = torch.cat((Y_duplicate, Y),dim=0)
        loss_all = criterion(output, Y_duplicate);     
        Label_weighting = Labels[:,0]
        for class_idx in range(Labels.shape[1]-1):
            if torch.sum(Labels[:,class_idx+1])>0:
                ratio = torch.sum(Labels[:,0])/torch.sum(Labels[:,class_idx+1])
            else:
                ratio = 1
            Label_weighting = torch.cat((Label_weighting, ratio*Labels[:,class_idx+1]),dim=0) 
        Label_weighting.unsqueeze_(-1)
        Label_weighting.unsqueeze_(-1)
        Label_weighting.unsqueeze_(-1)
        Label_weighting_new = Label_weighting.expand(len(Label_weighting),output.shape[1],output.shape[2],output.shape[3])     
        batch_loss = torch.sum(torch.mul(loss_all, Label_weighting_new)) 
        graph_reg = 0
        G_predict_cls, G_predict_cls_org = model_classifier.predict_relationship(model_normal)     
        num_graph = len(G_predict_cls)
        for graph_i in range(num_graph-1):
            for graph_j in range(graph_i+1, num_graph):
                graph_A = G_predict_cls[graph_i].reshape(Data.m*Data.m); 
                graph_B = G_predict_cls[graph_j].reshape(Data.m*Data.m);
                graph_reg = graph_reg + 1/LA.norm(graph_A-graph_B)      
        batch_loss = batch_loss + 0.01*graph_reg 
        batch_loss.backward()
        total_loss += batch_loss.data.item();
        optim.step();
        if len(output.shape)<4:
            n_samples += (output.size(0) * output.size(1) * output.size(2));
        else:
            n_samples += (output.size(0) * output.size(1) * output.size(2) * output.size(3));

        #################### classification evaluation ##############
        predict_label_tmp = torch.sum(loss_all, dim=1)
        predict_label_tmp = torch.sum(predict_label_tmp, dim=1)
        predict_label_tmp = torch.sum(predict_label_tmp, dim=1)
        predict_label = predict_label_tmp.view(Labels.shape[1], Y.shape[0]).transpose(0,1)
        predict_label = F.softmin(predict_label, dim =1)  
        Label_predict = torch.cat((Label_predict, predict_label),dim=0)
        Label_truth = torch.cat((Label_truth, Labels),dim=0)
    
    Label_predict = Label_predict.detach().numpy()
    classification_report_trn = classification_report(np.argmax(Label_truth, axis=1), np.argmax(Label_predict, axis=1), output_dict=True)    
    
    return total_loss / n_samples, classification_report_trn



def evaluate_classifier(loader, data, model, criterion, optim, batch_size):
    total_loss = 0;
    n_samples = 0;
    Label_truth = torch.zeros((0,data[2].shape[1])) 
    Label_predict = torch.zeros((0,model.num_class)) 
    for inputs in loader.get_batches(data, batch_size, True):        
        X, Y, Labels = inputs[0], inputs[1], inputs[2]
        output = model(X, model_normal);        
        Y_duplicate = Y.clone()
        for class_idx in range(model.num_class-1):
            Y_duplicate = torch.cat((Y_duplicate, Y),dim=0)
        loss_all = criterion(output, Y_duplicate); 
        total_loss += torch.sum(loss_all).data.item();
        if len(output.shape)<4:
            n_samples += (output.size(0) * output.size(1) * output.size(2));
        else:
            n_samples += (output.size(0) * output.size(1) * output.size(2) * output.size(3));
                    
        #################### classification evaluation ##############
        predict_label_tmp = torch.sum(loss_all, dim=1)
        predict_label_tmp = torch.sum(predict_label_tmp, dim=1)
        predict_label_tmp = torch.sum(predict_label_tmp, dim=1)
        predict_label = predict_label_tmp.view(model.num_class, X[0].shape[0]).transpose(0,1)
        predict_label = F.softmin(predict_label, dim =1)  
        Label_predict = torch.cat((Label_predict, predict_label),dim=0)
        Label_truth = torch.cat((Label_truth, Labels),dim=0)
    
    Label_predict = Label_predict.detach().numpy()
    Label_predict_sum = np.zeros(Label_truth.shape)
    Label_predict_sum[:,0] = Label_predict[:,0]
    #pdb.set_trace()
    for class_i in range(Label_truth.shape[1]-1):
        Label_predict_sum[:,class_i+1] = Label_predict[:,class_i+1] + Label_predict[:,class_i+1+Label_truth.shape[1]-1] 
        
    Label_truth = np.argmax(Label_truth, axis=1)
    Label_predict_sum = np.argmax(Label_predict_sum, axis=1)
    
    
    
    classification_report_tst = classification_report(Label_truth, Label_predict_sum, output_dict=True) 
    
    MF = (classification_report_tst['macro avg']['f1-score'])
    GM = (geometric_mean_score(Label_truth, Label_predict_sum, average='macro'))
    PC = (classification_report_tst['macro avg']['precision'])
    AR = (accuracy_score(Label_truth, Label_predict_sum))
    return total_loss / n_samples, classification_report_tst, MF, GM, PC, AR

Using TensorFlow backend.


In [2]:
class args:
     
    train = 0.9 ## 
    valid = 0.05##
    
    model_normal = 'normal_learning_model'
    model_normal_last = 'normal_learning_model_last'
    model_generator = 'generator_learning_model'
    model_classifier = 'classifier_learning_model'
       
    window = 18           
    pre_win = 3 
    normal_win = 18
    normal_prewin = 3
    generator_win = 18    
    generator_prewin = 1  
    classifier_win = 15
    classifier_prewin = 3

    bootstrap_num = 3
    epochs_simple = 50
    epochs_last = 300
    epochs_GANs = 10
    rounds_LSTM = 10
    syn_rounds = 1
    y_dim = 65   

    RUC_layers = 1
    hidden_dim = 80 
    reduce_dim = 60 
    lowrank_minor = 3
       
    clip = 1.
    lr_N = 0.01
    lr_G = 0.01
    lr_C = 0.01
    lr_lstm = 0.01
    
    batch_size = 100
    dropout = 0.001
    gpu = None
    cuda = False
    optim = 'adam'#'adam'

    weight_decay = 0
    horizon = 1
    output_fun = None
    mask = False

In [3]:
## initialize models
warnings.filterwarnings('ignore')
Data = utils_multidatasource.Data_utility(args)

data_paths = ['10-20-30-edges/filter_norm_expression0.mat', '10-20-30-edges/filter_norm_expression1.mat', 
              '10-20-30-edges/filter_norm_expression2.mat', '10-20-30-edges/filter_norm_expression3.mat']
graph_paths = ['10-20-30-edges/A0.mat', '10-20-30-edges/A1.mat', 
               '10-20-30-edges/A2.mat', '10-20-30-edges/A3.mat']
traning_samples = [1000, 60, 60, 60]
testing_samples = [1000, 2*traning_samples[1], 2*traning_samples[2], 2*traning_samples[3]]
data_all = []
G_groudtruth = []

In [None]:
for idx_data in range(len(data_paths)):
    data_tmp = sio.loadmat(data_paths[idx_data])['expression']
    graph_tmp = sio.loadmat(graph_paths[idx_data])['A']
    data_all.append(data_tmp)
    G_groudtruth.append(graph_tmp)
    
criterion_1 = nn.MSELoss(size_average=False)
criterion_2 = nn.MSELoss(size_average=False, reduce=False)
criterion_3 = nn.CrossEntropyLoss()

In [None]:
start_point = 0   
Data.m = data_all[0].shape[1]
print('buliding model')                
###############################
###### preparing dataset ######
gap_point = 100
X_trn_org = torch.zeros((0, args.window, Data.m)) ## input
Y_trn_full_org = torch.zeros((0, args.window, args.pre_win, Data.m)) ### long target
Label_trn_org = torch.zeros((np.sum(traning_samples), len(data_paths)))   ### label
start_trn_idx = 0
X_tst_org = torch.zeros((0, args.window, Data.m))
Y_tst_full_org = torch.zeros((0, args.window, args.pre_win, Data.m))
Label_tst_org = torch.zeros((np.sum(testing_samples), len(data_paths)))
start_tst_idx = 0
### concatenating data class by class 
for idx_data in range(len(data_paths)):
            
    data_tmp_train, data_tmp_valid, data_tmp_test =  Data._split(data_all[idx_data], args)
            
    X_trn_tmp      = data_tmp_train[0][start_point:(start_point+traning_samples[idx_data])]    
    X_trn_org = torch.cat((X_trn_org, X_trn_tmp), dim = 0)
    Y_trn_full_tmp = data_tmp_train[3][start_point:(start_point+traning_samples[idx_data])]  
    Y_trn_full_org = torch.cat((Y_trn_full_org, Y_trn_full_tmp), dim = 0)
    Label_trn_org[start_trn_idx:(start_trn_idx+traning_samples[idx_data]), idx_data] = 1   
    start_trn_idx = start_trn_idx + traning_samples[idx_data]        
    print(X_trn_org.shape)
            
    X_tst_tmp      = data_tmp_train[0][(start_point+traning_samples[idx_data]+gap_point):(start_point+traning_samples[idx_data]+gap_point+testing_samples[idx_data])] 
    Y_tst_full_tmp = data_tmp_train[3][(start_point+traning_samples[idx_data]+gap_point):(start_point+traning_samples[idx_data]+gap_point+testing_samples[idx_data])]     
    X_tst_org = torch.cat((X_tst_org, X_tst_tmp), dim = 0)
    Y_tst_full_org = torch.cat((Y_tst_full_org, Y_tst_full_tmp), dim = 0)
    Label_tst_org[start_tst_idx:(start_tst_idx+testing_samples[idx_data]), idx_data] = 1   
    start_tst_idx = start_tst_idx + testing_samples[idx_data]   
    print(X_tst_org.shape)     
        
## copy of normal data (major data)
trn_indices_normal = np.arange(traning_samples[0])
np.random.shuffle(trn_indices_normal)
X_trn_normal = X_trn_org[0:traning_samples[0]][trn_indices_normal] 
Y_trn_full_normal = Y_trn_full_org[0:traning_samples[0]][trn_indices_normal]
Data_trn_full_normal = [X_trn_normal, Y_trn_full_normal, Label_trn_org[0:traning_samples[0]][trn_indices_normal]]

## copy of event data (minor data)
trn_indices_minor = np.arange(np.sum(traning_samples[1:]))
np.random.shuffle(trn_indices_minor)
X_trn_minor = X_trn_org[traning_samples[0]:][trn_indices_minor] 
Y_trn_full_minor = Y_trn_full_org[traning_samples[0]:][trn_indices_minor]
Data_trn_full_minor = [X_trn_minor, Y_trn_full_minor, Label_trn_org[traning_samples[0]:][trn_indices_minor]]
        
#### prepraing training set (normal and event) for LSTM  ######
Data_trn_full = [X_trn_org[:,0:args.classifier_win,:], Y_trn_full_org[:,0:args.classifier_win,:,:], Label_trn_org]
#### preparing testing set #################
#Data_tst_full is in the shape of (1300, 15, 65)
Data_tst_full = [X_tst_org[:,0:args.classifier_win,:], Y_tst_full_org[:,0:args.classifier_win,:,:], Label_tst_org]
        

print("~~~~~~~~~~~~~~~~~~~ begin training/validating ~~~~~~~~~~~~~~~~")          
G_groundtruth_normal = G_groudtruth[0].reshape(Data.m*Data.m); 


buliding model
torch.Size([1000, 18, 65])
torch.Size([1000, 18, 65])
torch.Size([1060, 18, 65])
torch.Size([1120, 18, 65])
torch.Size([1120, 18, 65])
torch.Size([1240, 18, 65])
torch.Size([1180, 18, 65])
torch.Size([1360, 18, 65])
~~~~~~~~~~~~~~~~~~~ begin training/validating ~~~~~~~~~~~~~~~~


In [None]:
G_predict_aggregate = np.zeros((Data.m, Data.m))
for bootstrap_idx in range(args.bootstrap_num):
    ## initializing models
    args.num_class = 1
    model_normal = eval(args.model_normal).Model(args, Data);
    optim_normal = Optim.Optim(
        model_normal.parameters(), args.optim, args.lr_N, args.clip, weight_decay = args.weight_decay,
    )

    trn_indices_normal = np.arange(900)
    np.random.shuffle(trn_indices_normal)
    trn_indices_normal = trn_indices_normal[0:800]
    X_trn_normal_spl = X_trn_org[0:traning_samples[0]][trn_indices_normal] 
    Y_trn_full_normal_spl = Y_trn_full_org[0:traning_samples[0]][trn_indices_normal]
    Data_trn_full_normal_spl = [X_trn_normal_spl, Y_trn_full_normal_spl, Label_trn_org[0:traning_samples[0]][trn_indices_normal]]

     
    for epoch in range(0, args.epochs_simple): 
        round_start_time = time.time()
        train_normal_loss = train_normal(Data, Data_trn_full_normal_spl, model_normal, criterion_1, optim_normal, args.batch_size)    
        G_predict, G_predict_org = model_normal.predict_relationship()

        G_predict_normal = G_predict[0].reshape(Data.m*Data.m);
        precision, recall, thresholds = metrics.precision_recall_curve(G_groundtruth_normal, G_predict_normal)
        aupr = metrics.auc(recall, precision)
        print('bootstrap_idx{:3d}|epoch {:3d}|time:{:5.2f}s|tn_ls {:5.6f}|aupr {:5.6f}|'.format(bootstrap_idx, epoch, (time.time() - round_start_time), train_normal_loss, aupr)) 

    G_predict_aggregate = G_predict_aggregate + G_predict[0]

G_predict_aggregate_normal = (G_predict_aggregate/args.bootstrap_num);
precision, recall, thresholds = metrics.precision_recall_curve(G_groundtruth_normal, G_predict_aggregate_normal.reshape(Data.m*Data.m))
aupr = metrics.auc(recall, precision)
print('aggregate aupr:', aupr) 

pickle.dump( G_predict_aggregate_normal, open( "G_predict_aggregate_normal_A60_0121.pkl", "wb" ) )

bootstrap_idx  0|epoch   0|time:20.75s|tn_ls 1.314089|aupr 0.026139|
bootstrap_idx  0|epoch   1|time:22.14s|tn_ls 1.001514|aupr 0.026744|
bootstrap_idx  0|epoch   2|time:21.88s|tn_ls 0.985632|aupr 0.039676|
bootstrap_idx  0|epoch   3|time:22.26s|tn_ls 0.964525|aupr 0.082864|
bootstrap_idx  0|epoch   4|time:23.30s|tn_ls 0.932782|aupr 0.136940|
bootstrap_idx  0|epoch   5|time:21.34s|tn_ls 0.895609|aupr 0.165290|
bootstrap_idx  0|epoch   6|time:21.45s|tn_ls 0.856999|aupr 0.188987|
bootstrap_idx  0|epoch   7|time:24.28s|tn_ls 0.826074|aupr 0.210035|
bootstrap_idx  0|epoch   8|time:8165.67s|tn_ls 0.803406|aupr 0.234614|
bootstrap_idx  0|epoch   9|time:23.97s|tn_ls 0.785142|aupr 0.255033|
bootstrap_idx  0|epoch  10|time:21.58s|tn_ls 0.769267|aupr 0.276013|
bootstrap_idx  0|epoch  11|time:26.46s|tn_ls 0.755983|aupr 0.289760|
bootstrap_idx  0|epoch  12|time:22.10s|tn_ls 0.744721|aupr 0.300089|
bootstrap_idx  0|epoch  13|time:22.47s|tn_ls 0.734312|aupr 0.312327|
bootstrap_idx  0|epoch  14|time:

bootstrap_idx  2|epoch  19|time:18.64s|tn_ls 0.605345|aupr 0.571974|
bootstrap_idx  2|epoch  20|time:19.56s|tn_ls 0.587689|aupr 0.598240|
bootstrap_idx  2|epoch  21|time:18.82s|tn_ls 0.571220|aupr 0.620854|
bootstrap_idx  2|epoch  22|time:20.07s|tn_ls 0.556033|aupr 0.638390|
bootstrap_idx  2|epoch  23|time:19.59s|tn_ls 0.542808|aupr 0.654588|
bootstrap_idx  2|epoch  24|time:18.77s|tn_ls 0.530461|aupr 0.667749|
bootstrap_idx  2|epoch  25|time:21.63s|tn_ls 0.519061|aupr 0.680130|
bootstrap_idx  2|epoch  26|time:18.91s|tn_ls 0.508932|aupr 0.688445|
bootstrap_idx  2|epoch  27|time:20.02s|tn_ls 0.499049|aupr 0.693584|
bootstrap_idx  2|epoch  28|time:18.52s|tn_ls 0.491118|aupr 0.701521|
bootstrap_idx  2|epoch  29|time:19.31s|tn_ls 0.483271|aupr 0.710105|
bootstrap_idx  2|epoch  30|time:19.80s|tn_ls 0.475743|aupr 0.716441|
bootstrap_idx  2|epoch  31|time:19.48s|tn_ls 0.468763|aupr 0.722135|
bootstrap_idx  2|epoch  32|time:18.95s|tn_ls 0.462189|aupr 0.729323|
bootstrap_idx  2|epoch  33|time:20

In [None]:
major_graph = torch.from_numpy(G_predict_aggregate_normal).float()
args.num_class = 1
model_normal = eval(args.model_normal_last).Model(args, Data, major_graph);
optim_normal = Optim.Optim(
    model_normal.parameters(), args.optim, args.lr_N, args.clip, weight_decay = args.weight_decay,
)

trn_num = 900
indices_normal = np.arange(traning_samples[0])
trn_indices_normal = indices_normal[0:trn_num]
val_indices_normal = indices_normal[trn_num:]

X_trn_normal_spl = X_trn_org[0:traning_samples[0]][trn_indices_normal] 
Y_trn_full_normal_spl = Y_trn_full_org[0:traning_samples[0]][trn_indices_normal]
Data_trn_full_normal_spl = [X_trn_normal_spl, Y_trn_full_normal_spl, Label_trn_org[0:traning_samples[0]][trn_indices_normal]]

X_val_normal_spl = X_trn_org[0:traning_samples[0]][val_indices_normal] 
Y_val_full_normal_spl = Y_trn_full_org[0:traning_samples[0]][val_indices_normal]
Data_val_full_normal_spl = [X_val_normal_spl, Y_val_full_normal_spl, Label_trn_org[0:traning_samples[0]][val_indices_normal]]


for epoch in range(0, args.epochs_last): 
    epoch_start_time = time.time() 
    train_normal_loss = train_normal(Data, Data_trn_full_normal_spl, model_normal, criterion_1, optim_normal, args.batch_size)       
    val_normal_loss = test_normal(Data, Data_val_full_normal_spl, model_normal, criterion_1, optim_normal, args.batch_size)    
    print('Normal last|epoch{:3d}|time:{:5.2f}s|tn_ls {:5.6f}|vl_ls {:5.6f}'.format(epoch, (time.time() - epoch_start_time), train_normal_loss, val_normal_loss)) 

pickle.dump( model_normal, open( "model_normal_A60_0121.pkl", "wb" ) )

Normal last|epoch  0|time:21.77s|tn_ls 1.221932|vl_ls 0.990601
Normal last|epoch  1|time:21.01s|tn_ls 0.991298|vl_ls 0.976201
Normal last|epoch  2|time:20.78s|tn_ls 0.968938|vl_ls 0.942024
Normal last|epoch  3|time:21.61s|tn_ls 0.931368|vl_ls 0.918543
Normal last|epoch  4|time:21.97s|tn_ls 0.907393|vl_ls 0.897918
Normal last|epoch  5|time:24.60s|tn_ls 0.871677|vl_ls 0.859395
Normal last|epoch  6|time:23.73s|tn_ls 0.841219|vl_ls 0.833444
Normal last|epoch  7|time:24.88s|tn_ls 0.815676|vl_ls 0.804554
Normal last|epoch  8|time:23.19s|tn_ls 0.790033|vl_ls 0.781495
Normal last|epoch  9|time:23.10s|tn_ls 0.769550|vl_ls 0.766734
Normal last|epoch 10|time:23.92s|tn_ls 0.751053|vl_ls 0.750105
Normal last|epoch 11|time:24.26s|tn_ls 0.735547|vl_ls 0.731780
Normal last|epoch 12|time:22.97s|tn_ls 0.717362|vl_ls 0.715323
Normal last|epoch 13|time:23.80s|tn_ls 0.699317|vl_ls 0.701053
Normal last|epoch 14|time:24.39s|tn_ls 0.681273|vl_ls 0.685043
Normal last|epoch 15|time:25.19s|tn_ls 0.667670|vl_ls 0

In [None]:
gen_num = traning_samples[1]
########################################################
########## initializing generator/classifier ###########
########################################################
args.num_class = len(data_paths)-1    ###### generator of GAN ######   
model_generator = eval(args.model_generator).Model(args, Data, model_normal);
optim_generator = Optim.Optim(
    model_generator.parameters(), args.optim, args.lr_G, args.clip, weight_decay = args.weight_decay,
)  
        
args.num_class = len(data_paths) + len(data_paths)-1  ###### classifier of GAN ######  
model_classifier = eval(args.model_classifier).Model(args, Data, model_normal);         
optim_classifier = Optim.Optim(
    model_classifier.parameters(), args.optim, args.lr_C, args.clip, weight_decay = args.weight_decay,
)
        

############################################
####### generator/classifier training ######
############################################
classifier_trn_loss_old = 1000000
tst_report_old = None

for epoch in range(args.epochs_GANs):
    epoch_start_time = time.time() 
    
    for c_index in range(1):
        ########## preparing input for generator ##########
        trn_noise_gen = torch.randn((gen_num, args.generator_win, Data.m))
        Data_trn_full_normal_gen = [trn_noise_gen, None, None]

        ################ generator generating fake data for classifier ######################         
        fake_data_all = model_generator(Data_trn_full_normal_gen, model_normal); 
        fake_data_X = fake_data_all[:, 0:args.classifier_win, :]  ### cut data
        fake_data_Y = torch.zeros((fake_data_all.shape[0], args.classifier_win, args.classifier_prewin, Data.m))
        for pre_win_i in range(args.classifier_prewin):
            fake_data_Y[:, :, pre_win_i, :] = fake_data_all[:, (pre_win_i+1):(pre_win_i+1+args.classifier_win), :]
        fake_data_label = torch.zeros((len(fake_data_X), len(data_paths)-1))
        for fake_class_i in range(len(data_paths)-1):
            fake_sampleidx_start = fake_class_i*gen_num
            fake_sampleidx_end   = (fake_class_i+1)*gen_num
            fake_data_label[fake_sampleidx_start:fake_sampleidx_end, fake_class_i] = 1

        
        ########### prepraing training input for classifier by merging real and fake data ############        
        X_trn_merged      = torch.cat((X_trn_org[:, 0:args.classifier_win,:], fake_data_X), dim = 0)
        Y_trn_full_merged = torch.cat((Y_trn_full_org[:, 0:args.classifier_win, :, :], fake_data_Y), dim = 0)
        Label_merged = torch.zeros((len(X_trn_merged), 2*len(data_paths)-1))   ### label
        Label_merged[0:len(Label_trn_org), 0:len(data_paths)] = Label_trn_org  
        Label_merged[len(Label_trn_org):, len(data_paths):] = fake_data_label  
        Data_trn_merge = [X_trn_merged, Y_trn_full_merged, Label_merged]   
        ################## train and update classifier ###################
        classifier_trn_loss, trn_report = train_classifier(Data, Data_trn_merge, model_classifier, criterion_2, optim_classifier, args.batch_size, model_normal)
        trn_macro_f1 = trn_report['macro avg']['f1-score']   
            
    ###################### update generator #############   
    for g_index in range(1):
        
        model_generator.train();
        model_generator.zero_grad()
        
        trn_noise_gen = torch.randn((gen_num, args.generator_win, Data.m))
        Data_trn_full_normal_gen = [trn_noise_gen, None, None]
        fake_data_all = model_generator(Data_trn_full_normal_gen, model_normal); 
        fake_data_X = fake_data_all[:, 0:args.classifier_win, :]  ### cut data
        fake_data_Y = torch.zeros((fake_data_all.shape[0], args.classifier_win, args.classifier_prewin, Data.m))
        for pre_win_i in range(args.classifier_prewin):
            fake_data_Y[:, :, pre_win_i, :] = fake_data_all[:, (pre_win_i+1):(pre_win_i+1+args.classifier_win), :]

                
        ################## get regression feedback for generator from classifier ##########
        prediction_from_classifier = model_classifier([fake_data_X], model_normal)    
        ################## get label loss of generator output and update generator ##########
        fake_data_Y_duplicate = fake_data_Y.clone()
        for class_idx in range(model_classifier.num_class-1):
            fake_data_Y_duplicate = torch.cat((fake_data_Y_duplicate, fake_data_Y),dim=0)
        fake_loss_all = criterion_2(prediction_from_classifier, fake_data_Y_duplicate); 
        
        feedback_label_tmp = torch.sum(fake_loss_all, dim=1)
        feedback_label_tmp = torch.sum(feedback_label_tmp, dim=1)
        feedback_label_tmp = torch.sum(feedback_label_tmp, dim=1)
        predict_label_soft = feedback_label_tmp.view(2*len(data_paths)-1, fake_data_X.shape[0]).transpose(0,1)            
        predict_label_soft_invert = 1/(predict_label_soft)#torch.exp(1/(predict_label_soft+1))
        predict_label_soft_norm = F.normalize(predict_label_soft_invert, p=1, dim=1)           
        fake_labels_groundtruth = torch.zeros((predict_label_soft_norm.shape))
        for class_idx in range(model_generator.num_class):
            start_idx_label = (class_idx+1)*gen_num
            end_idx_label = (class_idx+2)*gen_num
            fake_labels_groundtruth[start_idx_label:end_idx_label, class_idx+1] = 1
        fake_labels_groundtruth = torch.tensor(torch.argmax(fake_labels_groundtruth, 1), dtype=torch.long)
        generator_loss = criterion_3(predict_label_soft_norm, fake_labels_groundtruth); 
        
        graph_reg = 0
        G_predict_gen, G_predict_gen_org = model_generator.predict_relationship(model_normal)     
        num_graph = len(G_predict_gen)
        for graph_i in range(num_graph-1):
            for graph_j in range(graph_i+1, num_graph):
                graph_A = G_predict_gen[graph_i].reshape(Data.m*Data.m); 
                graph_B = G_predict_gen[graph_j].reshape(Data.m*Data.m);
                graph_reg = graph_reg + 1/LA.norm(graph_A-graph_B)      
        generator_loss = generator_loss + 0.01*graph_reg 
        
        generator_loss.backward(retain_graph=True)
        optim_generator.step()  
        generator_loss_sum = torch.sum(generator_loss)
               
    ##################### testing classifier ######################
    classifier_tst_loss, tst_report, MF, GM, PC, AR = evaluate_classifier(Data, Data_tst_full, model_classifier, criterion_2, optim_classifier, args.batch_size)
    tst_macro_f1 = tst_report['macro avg']['f1-score']       

    classifier_trn_loss_old = classifier_trn_loss
    tst_report_old = tst_report
    print('GAN Rround {:3d}|T:{:5.2f}s|classifier trn loss {:5.4f}|generator loss {:5.4f}|testset MF1 {:5.4f}|'.format(epoch, (time.time() - epoch_start_time), classifier_trn_loss, generator_loss_sum, tst_macro_f1))
    print('MF1: ',MF, ' GM: ',GM, ' PC: ',PC, ' AR: ',AR)
  
