In [1]:
import os
import h5py
import time
import copy
import json

import numpy as np
import torch
import scipy
from sklearn.metrics import f1_score, matthews_corrcoef

from span import Span

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Args():
    def __init__(self, data_file, output_file, dic):
        #self.data_file = data_file
        for k,v in dic.items():
            setattr(self, k, v)
        self.data_file = data_file
        self.output_file = output_file

with open('config.json', 'r') as openfile:
    dic = json.load(openfile)
    
args = Args('real_data_demo/amb.h5', 'results.csv', dic)

In [3]:
h5 = h5py.File(args.data_file, 'r')
Y = np.array(h5['genes'])  #[n]
label = np.array(h5['group']) if 'group' in h5.keys() else None
rho = np.array(h5['rho'])
Z_neighbor_idx = np.array(h5['Z_neighbor_idx'])
batch_matrix = np.array(h5['batch_matrix']) if 'batch_matrix' in h5.keys() else None
h5.close()
    
if batch_matrix is not None:
    n_cov = batch_matrix.shape[1]
else:
    n_cov = 0
        
S = Y.sum(axis = 1, keepdims=True)
S = S/np.mean(S)    
        
model = Span(Y.shape[0], rho.shape[0], rho.shape[1], rho, Y, S, Z_neighbor_idx,
                    n_cov  = n_cov, cov_matrix = batch_matrix, 
                    batch_size = args.batch_size, B = 10).to(args.device)
        
model.pre_train_model(max_iter_em = args.max_iter_em_pretrain, min_iter_adam = args.min_iter_adam_pretrain, 
                          max_iter_adam = args.max_iter_adam_pretrain, lr_adam = args.lr_adam,
                            rel_tol_adam = args.rel_tol_adam, 
                          y_true = label)
    
model.train_model(max_iter_em = args.max_iter_em_train, min_iter_adam = args.min_iter_adam_train, 
                  max_iter_adam = args.max_iter_adam_train, lr_adam = args.lr_adam, 
                  y_true = label)


epoch 0 loss: 399.2122543601135
pretrain acc,  0.8710247349823321

epoch 0 acc,  0.8847173144876325
epoch 1 acc,  0.8860424028268551
epoch 2 acc,  0.8878091872791519
epoch 3 acc,  0.8878091872791519
epoch 4 acc,  0.8882508833922261
epoch 5 acc,  0.8878091872791519
epoch 6 acc,  0.8895759717314488
epoch 7 acc,  0.8891342756183745
epoch 8 acc,  0.890017667844523


In [6]:
f1 = f1_score(label, model.Z_current,  average = 'macro')
print('F1: ', np.round(f1, 3))

F1:  0.891


In [7]:
mcc = matthews_corrcoef(label, model.Z_current)
print('MCC: ', np.round(mcc, 3))

MCC:  0.875
