In [1]:
import sys
sys.path.append('/lmh_data/work/SEE/train/')

import numpy as np
import pandas as pd
import torch
import torch.utils.data as data
from torch.autograd import Variable
from captum.attr import IntegratedGradients

from dataset import Dataset
from net import define_network

In [2]:
torch.manual_seed(123)
np.random.seed(123)

In [3]:
def evaluate(eval_file, model_file, gene_name):
    eval_set = Dataset(eval_file, gene_name, is_train=True)
    data_loader = data.DataLoader(eval_set, batch_size=1, shuffle=False)

    input_size, output_size = tuple(eval_set[0][0].shape), eval_set[0][1].shape[0]
    patch_size = tuple([int(i / 8) for i in input_size])
    model = torch.nn.DataParallel(define_network(input_size, patch_size, output_size))
    model.load_state_dict(torch.load(model_file))
    model.cuda()
    model.eval()
    ig = IntegratedGradients(model)

    _datas = np.load(eval_file, allow_pickle=True)
    important_genes = dict(Astro=dict(), Endo=dict(), OPC=dict(), ODC=dict())
    for index, batch in enumerate(data_loader, 1):
        if _datas[index-1]['cell_type'] != 'OPC':
            continue
        _important_genes = important_genes[_datas[index-1]['cell_type']]
        input = Variable(batch[0]).cuda().unsqueeze(1)
        target = Variable(batch[1]).cuda().unsqueeze(1) / 10
        
        attributions = torch.zeros(input.reshape(-1).shape).cuda()
        for i in range(target.shape[-1]):
            attributions += ig.attribute(input, target=(0, i))[0, 0].reshape(-1)
        attributions = attributions.cpu().detach().numpy()
        for i in range(len(attributions)):
            if i not in _important_genes.keys():
                _important_genes[i] = 0
            _important_genes[i] += attributions[i]
        
    return important_genes

important_genes = evaluate(
    '/lmh_data/data/sclab/sclab/promoter/eval_dataset.npy',
    '/lmh_data/data/sclab/sclab/promoter/tmp/MBP/model_epoch_9.pth', 'MBP')
important_genes

{'Astro': {},
 'Endo': {},
 'OPC': {0: 0.0,
  1: 0.0,
  2: 0.0,
  3: 0.0,
  4: 0.0,
  5: 0.0,
  6: 0.0,
  7: 0.0,
  8: 0.0,
  9: 0.0,
  10: 0.0,
  11: 0.0,
  12: 0.0,
  13: 0.0,
  14: 0.0,
  15: 0.0,
  16: 0.0,
  17: 0.0,
  18: 0.0,
  19: -0.021295784506946802,
  20: 0.0,
  21: 0.0,
  22: -0.022388174198567867,
  23: 0.0,
  24: -0.04634299874305725,
  25: 0.0,
  26: 0.0,
  27: 0.0,
  28: 0.0,
  29: 0.0,
  30: 0.0,
  31: 0.0,
  32: 0.0,
  33: -0.0040367040783166885,
  34: 0.008119464386254549,
  35: 0.003597741131670773,
  36: 0.0,
  37: -0.0062095290049910545,
  38: -0.014482170343399048,
  39: 0.05799134401604533,
  40: 0.0,
  41: 0.012753108283504844,
  42: 0.0,
  43: 0.002443105448037386,
  44: 0.0,
  45: -0.02409890666604042,
  46: 0.01259523571934551,
  47: -0.03255566814914346,
  48: 0.0,
  49: 0.0,
  50: 0.005683649098500609,
  51: 0.0,
  52: 0.0,
  53: 0.017080282792448997,
  54: -0.007903208956122398,
  55: -0.013592157047241926,
  56: 0.0,
  57: -0.010724680498242378,
  58: 0

In [4]:
scRNA_head = np.load('/lmh_data/data/sclab/sclab/promoter/eval_dataset.npy', allow_pickle=True)[0]['scRNA_head']

def find_special_genes(important_genes, cell_type):
    genes = pd.DataFrame(
        sorted(important_genes[cell_type].items(), key=lambda x: x[1], reverse=True), columns=['index', 'value']
    ).set_index('index')
    genes = genes.sort_values(by=['value'], ascending=False)
    return genes

genes = find_special_genes(important_genes, 'OPC')
genes['name'] = scRNA_head[list(genes.index)]
genes.to_csv('tmp/related_genes_MBP(OPC)_promoter.csv')
genes

Unnamed: 0_level_0,value,name
index,Unnamed: 1_level_1,Unnamed: 2_level_1
10961,35.591845,LHFPL3
18428,29.121159,PCDH9
11485,27.778500,CSMD1
15794,21.233642,MALAT1
3087,20.164496,NRXN1
...,...,...
12099,-8.346283,MMP16
26688,-9.974301,DSCAM
23043,-10.472910,CA10
3295,-12.507955,LRRTM4
