In [1]:
import os
import h5py
import time
import copy
import json

import numpy as np
import torch
import scipy
from sklearn.metrics import f1_score, matthews_corrcoef

from span import Span

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Args():
    def __init__(self, data_file, output_file, dic):
        #self.data_file = data_file
        for k,v in dic.items():
            setattr(self, k, v)
        self.data_file = data_file
        self.output_file = output_file

with open('config.json', 'r') as openfile:
    dic = json.load(openfile)
    
args = Args('real_data_demo/dlpfc_151673.h5', 'results.csv', dic)

In [3]:
h5 = h5py.File(args.data_file, 'r')
Y = np.array(h5['genes'])  #[n]
label = np.array(h5['group']) if 'group' in h5.keys() else None
rho = np.array(h5['rho'])
Z_neighbor_idx = np.array(h5['Z_neighbor_idx'])
batch_matrix = np.array(h5['batch_matrix']) if 'batch_matrix' in h5.keys() else None
h5.close()
    
if batch_matrix is not None:
    n_cov = batch_matrix.shape[1]
else:
    n_cov = 0
        
S = Y.sum(axis = 1, keepdims=True)
S = S/np.mean(S)    
        
model = Span(Y.shape[0], rho.shape[0], rho.shape[1], rho, Y, S, Z_neighbor_idx,
                    n_cov  = n_cov, cov_matrix = batch_matrix, 
                    batch_size = args.batch_size, B = 10).to(args.device)
        
model.pre_train_model(max_iter_em = args.max_iter_em_pretrain, min_iter_adam = args.min_iter_adam_pretrain, 
                          max_iter_adam = args.max_iter_adam_pretrain, lr_adam = args.lr_adam,
                            rel_tol_adam = args.rel_tol_adam, 
                          y_true = label)
    
model.train_model(max_iter_em = args.max_iter_em_train, min_iter_adam = args.min_iter_adam_train, 
                  max_iter_adam = args.max_iter_adam_train, lr_adam = args.lr_adam, 
                  y_true = label)
    


epoch 0 loss: 375.59732894395313
pretrain acc,  0.5954932673811487

epoch 0 acc,  0.6985435559219566
epoch 1 acc,  0.7175048090134653
epoch 2 acc,  0.7230008244023083
epoch 3 acc,  0.7282220390217092
epoch 4 acc,  0.731244847485573
epoch 5 acc,  0.7326188513327837
epoch 6 acc,  0.7334432536411102
epoch 7 acc,  0.7339928551799945
epoch 8 acc,  0.7339928551799945
epoch 9 acc,  0.7339928551799945
epoch 10 acc,  0.7339928551799945
epoch 11 acc,  0.7339928551799945
epoch 12 acc,  0.7337180544105524
epoch 13 acc,  0.7337180544105524
epoch 14 acc,  0.7337180544105524
epoch 15 acc,  0.7337180544105524
epoch 16 acc,  0.7337180544105524
epoch 17 acc,  0.7337180544105524
epoch 18 acc,  0.7337180544105524
epoch 19 acc,  0.7337180544105524
epoch 20 acc,  0.7337180544105524
epoch 21 acc,  0.7337180544105524
epoch 22 acc,  0.7337180544105524
epoch 23 acc,  0.7337180544105524
epoch 24 acc,  0.7337180544105524
epoch 25 acc,  0.7337180544105524
epoch 26 acc,  0.7337180544105524
epoch 27 acc,  0.73371805

In [6]:
mask = label!=7

In [9]:
print('Accuracy: ', np.round(np.mean((label==model.Z_current)[mask]),3) )

Accuracy:  0.739


In [10]:
f1 = f1_score(label[mask], model.Z_current[mask],  average = 'macro')
print('F1: ', np.round(f1, 3))

F1:  0.734


In [11]:
mcc = matthews_corrcoef(label[mask], model.Z_current[mask])
print('MCC: ', np.round(mcc, 3))

MCC:  0.704


In [None]:
np.savetxt(args.output_file, model.Z_current, delimiter=",")