In [1]:
import torch
import utils
import numpy as n
import scanpy as sc
import matplotlib.pyplot as plt
import pandas as pd
import anndata as ad
from torch import nn
from trainer import concrete_trainer, MLP_trainer
from network import MLP
%matplotlib inline

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
h5ad_path = "../../data/filtered_hg19.h5ad"
device = utils.get_device()
#device = torch.device("cpu")
device

device(type='cuda', index=1)

In [3]:
def anndata_load(file_path):
    """Load anndata, with file_path containing mtx file"""
    adata = sc.read_10x_mtx(file_path, var_names='gene_symbols')
    adata.var_names_make_unique()
    return adata

In [41]:
adata = sc.read_h5ad(h5ad_path )
adata

AnnData object with n_obs × n_vars = 2638 × 3000
    obs: 'n_genes', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt'
    var: 'gene_ids', 'n_cells', 'mt', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm'
    uns: 'hvg', 'log1p'

In [169]:
n_top = 256

In [170]:
input_dim = n_top
hidden_dim = [128, 256]
out_dim = adata.shape[1]
lr  = 1e-3
epochs = 256
weight_decay = 0
loss = nn.MSELoss()

In [171]:

sc.pp.highly_variable_genes(adata, n_top_genes=n_top, flavor='seurat_v3')
adata_train, adata_test = utils.train_test_split(adata, test_size=0.2)
cr_train = adata_train[:, adata.var['highly_variable']]
cr_test = adata_test[:, adata.var['highly_variable']]
cr_train.shape, cr_test.shape
cr_traindata = utils.data_loader(cr_train, batch_size=128, shuffle=False)
input_train = utils.data_loader(adata_train, batch_size=128, shuffle=False)
cr_train.shape



(2111, 256)

In [172]:
net = MLP(input_dim=input_dim, hidden_dim=hidden_dim, out_dim=out_dim,device=device)
MLP_trainer(net, loss,cr_traindata,input_train , epochs,lr, weight_decay, device)

Epoch: 1 ; Loss: 0.34364; Time: 0.05 s
Epoch: 2 ; Loss: 0.20232; Time: 0.08 s
Epoch: 3 ; Loss: 0.17731; Time: 0.12 s
Epoch: 4 ; Loss: 0.16520; Time: 0.16 s
Epoch: 5 ; Loss: 0.16017; Time: 0.20 s
Epoch: 6 ; Loss: 0.15790; Time: 0.24 s
Epoch: 7 ; Loss: 0.15658; Time: 0.28 s
Epoch: 8 ; Loss: 0.15540; Time: 0.32 s
Epoch: 9 ; Loss: 0.15412; Time: 0.36 s
Epoch: 10 ; Loss: 0.15288; Time: 0.40 s
Epoch: 11 ; Loss: 0.15186; Time: 0.43 s
Epoch: 12 ; Loss: 0.15097; Time: 0.47 s
Epoch: 13 ; Loss: 0.15015; Time: 0.51 s
Epoch: 14 ; Loss: 0.14934; Time: 0.55 s
Epoch: 15 ; Loss: 0.14856; Time: 0.59 s
Epoch: 16 ; Loss: 0.14779; Time: 0.63 s
Epoch: 17 ; Loss: 0.14707; Time: 0.66 s
Epoch: 18 ; Loss: 0.14639; Time: 0.70 s
Epoch: 19 ; Loss: 0.14571; Time: 0.74 s
Epoch: 20 ; Loss: 0.14505; Time: 0.78 s
Epoch: 21 ; Loss: 0.14443; Time: 0.81 s
Epoch: 22 ; Loss: 0.14381; Time: 0.85 s
Epoch: 23 ; Loss: 0.14317; Time: 0.89 s
Epoch: 24 ; Loss: 0.14257; Time: 0.93 s
Epoch: 25 ; Loss: 0.14195; Time: 0.97 s
Epoch: 26

In [173]:
x_val = torch.tensor(cr_test.X.A).to(device)
y_val = torch.tensor(adata_test.X.A).to(device)
net.validate(x_val,y_val)

0.15457934141159058

In [174]:
subset = sc.read_h5ad('./output/concrete_autoencoder.h5ad')
subset

AnnData object with n_obs × n_vars = 2638 × 32

In [175]:
input_dim = subset.shape[1]
sub_train, sub_test = utils.train_test_split(subset, test_size=0.2)
sub_traindata = utils.data_loader(sub_train, batch_size=128, shuffle=False)

In [176]:
net = MLP(input_dim=input_dim, hidden_dim=hidden_dim, out_dim=out_dim,device=device)
MLP_trainer(net, loss,sub_traindata,input_train,epochs,lr, weight_decay, device)

Epoch: 1 ; Loss: 0.39356; Time: 0.04 s
Epoch: 2 ; Loss: 0.25454; Time: 0.08 s
Epoch: 3 ; Loss: 0.21202; Time: 0.11 s
Epoch: 4 ; Loss: 0.20018; Time: 0.15 s
Epoch: 5 ; Loss: 0.19748; Time: 0.19 s
Epoch: 6 ; Loss: 0.19684; Time: 0.22 s
Epoch: 7 ; Loss: 0.19649; Time: 0.26 s
Epoch: 8 ; Loss: 0.19624; Time: 0.30 s
Epoch: 9 ; Loss: 0.19603; Time: 0.34 s
Epoch: 10 ; Loss: 0.19583; Time: 0.37 s
Epoch: 11 ; Loss: 0.19566; Time: 0.41 s
Epoch: 12 ; Loss: 0.19549; Time: 0.45 s
Epoch: 13 ; Loss: 0.19533; Time: 0.48 s
Epoch: 14 ; Loss: 0.19518; Time: 0.52 s
Epoch: 15 ; Loss: 0.19503; Time: 0.56 s
Epoch: 16 ; Loss: 0.19487; Time: 0.59 s
Epoch: 17 ; Loss: 0.19472; Time: 0.63 s
Epoch: 18 ; Loss: 0.19457; Time: 0.66 s
Epoch: 19 ; Loss: 0.19442; Time: 0.70 s
Epoch: 20 ; Loss: 0.19427; Time: 0.74 s
Epoch: 21 ; Loss: 0.19411; Time: 0.78 s
Epoch: 22 ; Loss: 0.19396; Time: 0.81 s
Epoch: 23 ; Loss: 0.19380; Time: 0.85 s
Epoch: 24 ; Loss: 0.19364; Time: 0.89 s
Epoch: 25 ; Loss: 0.19348; Time: 0.93 s
Epoch: 26

In [150]:
x_val = torch.tensor(sub_test.X).to(device)
net.validate(x_val,y_val)

0.2084445059299469

In [None]:
subset = pd.read_csv('../result/filtered_set384_1.csv')
subsubset = subset['gene_idx'].values[:32]

In [None]:
subset = adata[:,subsubset]

In [None]:
subset

In [None]:
sub_test.X