In [1]:
import os
import h5py
import time
import copy
import json

import numpy as np
import torch
import scipy
from sklearn.metrics import f1_score, matthews_corrcoef

from span import Span

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Args():
    def __init__(self, data_file, output_file, dic):
        #self.data_file = data_file
        for k,v in dic.items():
            setattr(self, k, v)
        self.data_file = data_file
        self.output_file = output_file

with open('config.json', 'r') as openfile:
    dic = json.load(openfile)
    
args = Args('real_data_demo/dlpfc_151569_r.h5', 'results.csv', dic)

In [3]:
h5 = h5py.File(args.data_file, 'r')
Y = np.array(h5['genes'])  #[n]
label = np.array(h5['group']) if 'group' in h5.keys() else None
rho = np.array(h5['rho'])
Z_neighbor_idx = np.array(h5['Z_neighbor_idx'])
batch_matrix = np.array(h5['batch_matrix']) if 'batch_matrix' in h5.keys() else None
h5.close()
    
if batch_matrix is not None:
    n_cov = batch_matrix.shape[1]
else:
    n_cov = 0
        
S = Y.sum(axis = 1, keepdims=True)
S = S/np.mean(S)    
        
model = Span(Y.shape[0], rho.shape[0], rho.shape[1], rho, Y, S, Z_neighbor_idx,
                    n_cov  = n_cov, cov_matrix = batch_matrix, 
                    batch_size = args.batch_size, B = 10).to(args.device)
        
model.pre_train_model(max_iter_em = args.max_iter_em_pretrain, min_iter_adam = args.min_iter_adam_pretrain, 
                          max_iter_adam = args.max_iter_adam_pretrain, lr_adam = args.lr_adam,
                            rel_tol_adam = args.rel_tol_adam, 
                          y_true = label)
    
model.train_model(max_iter_em = args.max_iter_em_train, min_iter_adam = args.min_iter_adam_train, 
                  max_iter_adam = args.max_iter_adam_train, lr_adam = args.lr_adam, 
                  y_true = label)

epoch 0 loss: 47.5450983253711
pretrain acc,  0.4818355640535373

epoch 0 acc,  0.6866976236001092
epoch 1 acc,  0.7508877355913685
epoch 2 acc,  0.7787489756897022
epoch 3 acc,  0.784758262769735
epoch 4 acc,  0.7913138486752254
epoch 5 acc,  0.7921332969134116
epoch 6 acc,  0.7962305381043431
epoch 7 acc,  0.7959573886916144
epoch 8 acc,  0.8000546298825457
epoch 9 acc,  0.8000546298825457
epoch 10 acc,  0.801966675771647
epoch 11 acc,  0.8027861240098334
epoch 12 acc,  0.805517618137121
epoch 13 acc,  0.805517618137121
epoch 14 acc,  0.8074296640262223
epoch 15 acc,  0.8074296640262223
epoch 16 acc,  0.8098880087407813
epoch 17 acc,  0.81016115815351
epoch 18 acc,  0.8123463534553401
epoch 19 acc,  0.8123463534553401
epoch 20 acc,  0.8134389511062551
epoch 21 acc,  0.8123463534553401
epoch 22 acc,  0.8134389511062551
epoch 23 acc,  0.8123463534553401
epoch 24 acc,  0.8137121005189839
epoch 25 acc,  0.8131658016935264
epoch 26 acc,  0.8150778475826277
epoch 27 acc,  0.814531548757170

In [4]:
mask = (label != np.max(label))

In [5]:
print('Accuracy: ', np.round(np.mean((label==model.Z_current)[mask]),3) )

Accuracy:  0.821


In [6]:
f1 = f1_score(label[mask], model.Z_current[mask],  average = 'macro')
print('F1: ', np.round(f1, 3))

F1:  0.739


In [7]:
mcc = matthews_corrcoef(label[mask], model.Z_current[mask])
print('MCC: ', np.round(mcc, 3))

MCC:  0.721


In [None]:
np.savetxt(args.output_file, model.Z_current, delimiter=",")