In [1]:
import os
import h5py
import time
import copy
import json

import numpy as np
import torch
import scipy
from sklearn.metrics import f1_score, matthews_corrcoef

from span import Span

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class Args():
    def __init__(self, data_file, output_file, dic):
        #self.data_file = data_file
        for k,v in dic.items():
            setattr(self, k, v)
        self.data_file = data_file
        self.output_file = output_file

with open('config.json', 'r') as openfile:
    dic = json.load(openfile)
    
args = Args('simdata/simdata.h5', 'test.csv', dic)

In [4]:
h5 = h5py.File(args.data_file, 'r')
Y = np.array(h5['genes'])  #[n]
label = np.array(h5['group']) if 'group' in h5.keys() else None
rho = np.array(h5['rho'])
Z_neighbor_idx = np.array(h5['Z_neighbor_idx'])
batch_matrix = np.array(h5['batch_matrix']) if 'batch_matrix' in h5.keys() else None
h5.close()
    
if batch_matrix is not None:
    n_cov = batch_matrix.shape[1]
else:
    n_cov = 0
        
S = Y.sum(axis = 1, keepdims=True)
S = S/np.mean(S)    
        
model = Span(Y.shape[0], rho.shape[0], rho.shape[1], rho, Y, S, Z_neighbor_idx,
                    n_cov  = n_cov, cov_matrix = batch_matrix, 
                    batch_size = args.batch_size, B = 10).to(args.device)
        
model.pre_train_model(max_iter_em = args.max_iter_em_pretrain, min_iter_adam = args.min_iter_adam_pretrain, 
                          max_iter_adam = args.max_iter_adam_pretrain, lr_adam = args.lr_adam,
                            rel_tol_adam = args.rel_tol_adam, 
                          y_true = label)
    
model.train_model(max_iter_em = args.max_iter_em_train, min_iter_adam = args.min_iter_adam_train, 
                  max_iter_adam = args.max_iter_adam_train, lr_adam = args.lr_adam, 
                  y_true = label)
    


epoch 0 loss: 434.37563184737627
pretrain acc,  0.8111326502353918

epoch 0 acc,  0.8518415951260039
epoch 1 acc,  0.8723345333702576
epoch 2 acc,  0.8836887288839657
epoch 3 acc,  0.8958737191913597
epoch 4 acc,  0.9041816671282193
epoch 5 acc,  0.9135973414566602
epoch 6 acc,  0.9210744945998338
epoch 7 acc,  0.9313209637219607
epoch 8 acc,  0.9382442536693437
epoch 9 acc,  0.9448906120188314
epoch 10 acc,  0.9490445859872612
epoch 11 acc,  0.9515369703683191
epoch 12 acc,  0.9540293547493769
epoch 13 acc,  0.9565217391304348
epoch 14 acc,  0.9576294655220161
epoch 15 acc,  0.9598449183051786
epoch 16 acc,  0.9623373026862365
epoch 17 acc,  0.9651066186651897
epoch 18 acc,  0.9670451398504569
epoch 19 acc,  0.971476045416782
epoch 20 acc,  0.973137635004154
epoch 21 acc,  0.9753530877873166
epoch 22 acc,  0.9761838825810025
epoch 23 acc,  0.979230130157851
epoch 24 acc,  0.9808917197452229
epoch 25 acc,  0.9822763777346996
epoch 26 acc,  0.9833841041262809
epoch 27 acc,  0.9842148989

In [9]:
f1 = f1_score(label, model.Z_current,  average = 'macro')
print('F1: ', np.round(f1, 3))

F1:  0.985


In [10]:
mcc = matthews_corrcoef(label, model.Z_current)
print('MCC: ', np.round(mcc, 3))

MCC:  0.984


In [None]:
np.savetxt(args.output_file, model.Z_current, delimiter=",")