In [1]:
import sys
sys.path.append('/lmh_data/work/sclab/train/')

import numpy as np
import pandas as pd
import torch
import torch.utils.data as data
from torch.autograd import Variable
from captum.attr import IntegratedGradients

from dataset import Dataset
from net import define_network

In [2]:
torch.manual_seed(123)
np.random.seed(123)

In [3]:
def evaluate(eval_file, model_file, gene_name, sample):
    eval_set = Dataset(eval_file, gene_name, is_train=False)
    data_loader = data.DataLoader(eval_set, batch_size=1, shuffle=False)

    input_size, output_size = tuple(eval_set[0][0].shape), 171
    patch_size = tuple([int(i / 8) for i in input_size])
    model = torch.nn.DataParallel(define_network(input_size, patch_size, output_size))
    model.load_state_dict(torch.load(model_file))
    model.cuda()
    model.eval()
    ig = IntegratedGradients(model)

    nums = 0
    _datas = np.load(eval_file, allow_pickle=True)
    important_genes = dict(ASC=dict(), Endo=dict(), OPC=dict(), ODC=dict())
    for index, batch in enumerate(data_loader, 1):
        if _datas[index-1]['cell_type'] != 'ASC':
            continue
        if _datas[index-1]['sample'] != sample:
            continue
        _important_genes = important_genes[_datas[index-1]['cell_type']]
        input = Variable(batch[0]).cuda().unsqueeze(1)
#         target = Variable(batch[1]).cuda().unsqueeze(1) / 10
        target = model(input)
        
        attributions = torch.zeros(input.reshape(-1).shape).cuda()
        for i in range(target.shape[-1]):
            attributions += ig.attribute(input, target=(0, i))[0, 0].reshape(-1)
        attributions = attributions.cpu().detach().numpy()
        for i in range(len(attributions)):
            if i not in _important_genes.keys():
                _important_genes[i] = 0
            _important_genes[i] += attributions[i]
        nums += 1
        if nums >= 50:
            break
        
    return important_genes

Control_important_genes = evaluate(
    '/lmh_data/data/sclab/sclab/AD/eval_dataset.npy',
    '/lmh_data/data/sclab/sclab/AD/tmp/SLC1A2/model_epoch_8.pth', 'SLC1A2', 'Control')
AD_important_genes = evaluate(
    '/lmh_data/data/sclab/sclab/AD/eval_dataset.npy',
    '/lmh_data/data/sclab/sclab/AD/tmp/SLC1A2/model_epoch_8.pth', 'SLC1A2', 'AD')

In [4]:
scRNA_head = np.load('/lmh_data/data/sclab/sclab/AD/eval_dataset.npy', allow_pickle=True)[0]['scRNA_head']

def find_special_genes(important_genes, cell_type):
    genes = pd.DataFrame(
        sorted(important_genes[cell_type].items(), key=lambda x: x[1], reverse=True), columns=['index', 'value']
    ).set_index('index')
    genes = genes.sort_values(by=['value'], ascending=False)
    return genes

In [5]:
genes = find_special_genes(Control_important_genes, 'ASC')
genes['name'] = scRNA_head[list(genes.index)]
genes.to_csv('tmp/related_genes_SLC1A2_Control.csv')
genes

Unnamed: 0_level_0,value,name
index,Unnamed: 1_level_1,Unnamed: 2_level_1
5785,18.330598,LSAMP
12806,4.806235,DST
3367,4.508641,PITPNC1
5361,4.247797,DTNA
7203,4.190456,RYR3
...,...,...
16093,-5.219756,PCDH9
4863,-7.401366,CTNND2
15720,-8.948564,CTNNA2
1191,-9.681167,GPC5


In [6]:
genes = find_special_genes(AD_important_genes, 'ASC')
genes['name'] = scRNA_head[list(genes.index)]
genes.to_csv('tmp/related_genes_SLC1A2_AD.csv')
genes

Unnamed: 0_level_0,value,name
index,Unnamed: 1_level_1,Unnamed: 2_level_1
5785,48.842674,LSAMP
8039,17.876955,LRP1B
12806,12.373520,DST
5361,11.202523,DTNA
3367,10.379817,PITPNC1
...,...,...
4163,-12.420335,RORA
1191,-21.150034,GPC5
15720,-21.627930,CTNNA2
13605,-32.445415,MALAT1
