In [1]:
import scanpy as sc
import numpy as np
import torch

from utils import add_annotations
from model import EnrichClassifier
from train import train_classifier

In [2]:
# download dataset - https://drive.google.com/open?id=1-N7wPpYUf_QcG5566WVZlaxVC90M7NNE
adata = sc.read('kang_count.h5ad')

In [3]:
# add reactome annotations to adata.varm['I']
# filter out pathways with the number of genes less than 13
add_annotations(adata, files='c2.cp.reactome.v4.0.symbols.gmt', min_genes=13)

In [4]:
# remove genes which are not in any pathway
select_genes = adata.varm['I'].sum(1)>0
adata._inplace_subset_var(select_genes)

In [5]:
LR = 0.005
BATCH_SIZE = 64
N_EPOCHS = 30

In [6]:
pathways_mask = torch.Tensor(adata.varm['I'])

In [7]:
n_labels = len(np.unique(adata.obs['cell_type']))

classifier = EnrichClassifier(pathways_mask, n_labels)

In [None]:
# train
train_classifier(adata, 'cell_type', classifier, LR, BATCH_SIZE, N_EPOCHS)

In [None]:
# get relevance (enrichment score) for each pathway
# returns tensor of size num of classes X num of pathways
classifier.get_relevance()