In [1]:
import os, random
from unicodedata import name
import pandas as pd
import numpy as np
import scipy.stats as stats
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from model_1 import apaDNNModel, apaDataset
import pickle
# from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from utils import *
import seqlogo

  import pandas.util.testing as tm


In [2]:
model = apaDNNModel(
    opt="Adam",
    loss="mse",
    lambda1=50,
    device='cuda',
    # Conv block 1 hparamaters
    conv1kc=128,
    conv1ks=12,
    conv1st=1,
    pool1ks=25,
    pool1st=25,
    cnvpdrop1=0.2,
    # multihead attention block
    Matt_heads=8,
    Matt_drop = 0.2,
    # FC block 1 (Matt output flattened)
    fc1_L1=0,  # 8192
    fc1_L2=8192,
    fc1_L3=4048,
    fc1_L4=1024,
    fc1_L5=512,
    fc1_L6=256,
    fc1_dp1=0.3,
    fc1_dp2=0.25,
    fc1_dp3=0.25,
    fc1_dp4=0.2,
    fc1_dp5=0.1,
    # FC block 2 (celltype profile + overall representation)
    fc2_L1=0,
    fc2_L2=128,
    fc2_L3=32,
    fc2_L4=16,
    fc2_L5=1,
    fc2_dp1=0.2,
    fc2_dp2=0.2,
    fc2_dp3=0,
    fc2_dp4=0,
    lr=2.5e-05,
    adam_weight_decay=0.06,
)
model.compile()
model.load_state_dict(
    torch.load(
        "/home/aiden/data/APA/input_data/model_results/all_cells_CNN_Matt_V2_4_L2_sALS_resNET.pt"
    ,map_location=torch.device('cuda'))
)
model.to("cuda")

apaDNNModel(
  (conv_block_1): ConvBlock(
    (op): Sequential(
      (0): Conv1d(4, 128, kernel_size=(12,), stride=(1,), padding=(6,))
      (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ELU(alpha=1.0)
      (3): MaxPool1d(kernel_size=25, stride=25, padding=0, dilation=1, ceil_mode=False)
      (4): Dropout(p=0.2, inplace=False)
    )
  )
  (Matt_1): MultiheadAttention(
    (out_proj): _LinearWithBias(in_features=128, out_features=128, bias=True)
  )
  (fc1): FC_block(
    (op): Sequential(
      (0): Linear(in_features=20480, out_features=8192, bias=True)
      (1): BatchNorm1d(8192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Dropout(p=0.3, inplace=False)
      (4): Linear(in_features=8192, out_features=4048, bias=True)
      (5): BatchNorm1d(4048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU()
      (7): Dropout(p=0.25, inplace=False)
      (8): Lin

In [3]:
data_root = "/home/aiden/data/APA/input_data/data_for_DL/V2/"
test_data = np.load(data_root + "sALS_test_labels.npy", allow_pickle=True)
test_seq = np.load(data_root + "sALS_test_seqs.npy", allow_pickle=True)

profiles = pd.read_csv(data_root + "celltype_profiles.tsv", index_col=0, sep="\t")
test_data_loader = DataLoader(
apaDataset(test_seq, test_data, profiles, device='cpu'),
batch_size=32,
drop_last=False,
)

In [4]:
# ok lets run these for all the data by iterating over the test data loader
# first lets to the filter importance part
# stack all the filter importance so we end up with a 3d array (datasize, 200, 128)
 
# create empty array
filter_importance = np.zeros((len(test_data_loader)*32, 128, 160))
position_weights = np.zeros((len(test_data_loader)*32, 160, 160))

for i, batch in enumerate(test_data_loader):
    seq_X, celltype_X, celltype_name, y = batch
    batch_size = seq_X.shape[0]
    # output MATT layer from model
    fmap = model.conv_block_1(seq_X)
    fmap = fmap.permute(2, 0, 1)
    wfmap, weights = model.Matt_1(fmap, fmap, fmap)
    wfmap = wfmap.permute(1, 2, 0) 
    wfmap = wfmap.cpu().detach().numpy()
    weights = weights.cpu().detach().numpy()
    filter_importance[i:i+batch_size] = wfmap
    position_weights[i:i+batch_size] = weights
print(filter_importance.shape)
print(position_weights.shape)    

(47424, 128, 160)
(47424, 160, 160)


In [6]:
filter_importance_t = filter_importance
filter_importance_t[filter_importance_t < 0] = 0
filter_importance_t = np.mean(filter_importance_t, axis=0)  # shape: (128, 200)
print(filter_importance_t.shape)
filter_importance_t_df = pd.DataFrame(filter_importance_t)
new_index_names = ['filter_{}'.format(i) for i in range(128)]
filter_importance_t_df.index = new_index_names
filter_importance_t_df

(128, 160)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,150,151,152,153,154,155,156,157,158,159
filter_0,0.000015,0.000011,0.000011,0.000011,0.000010,0.000011,0.000010,0.000012,0.000012,0.000014,...,0.000012,0.000011,0.000011,0.000014,0.000017,0.000013,0.000014,0.000017,0.000014,0.000014
filter_1,0.000384,0.000359,0.000362,0.000356,0.000359,0.000359,0.000356,0.000359,0.000362,0.000360,...,0.000357,0.000360,0.000359,0.000362,0.000363,0.000364,0.000362,0.000360,0.000359,0.000376
filter_2,0.000068,0.000054,0.000055,0.000055,0.000055,0.000057,0.000056,0.000058,0.000058,0.000054,...,0.000052,0.000053,0.000052,0.000053,0.000052,0.000054,0.000052,0.000054,0.000053,0.000061
filter_3,0.000403,0.000387,0.000386,0.000393,0.000393,0.000398,0.000396,0.000391,0.000395,0.000383,...,0.000379,0.000373,0.000382,0.000384,0.000377,0.000376,0.000376,0.000386,0.000381,0.000403
filter_4,0.000030,0.000025,0.000025,0.000028,0.000027,0.000027,0.000026,0.000025,0.000028,0.000025,...,0.000024,0.000026,0.000024,0.000026,0.000025,0.000025,0.000025,0.000026,0.000026,0.000029
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
filter_123,0.000148,0.000135,0.000133,0.000142,0.000141,0.000139,0.000145,0.000137,0.000143,0.000137,...,0.000142,0.000136,0.000135,0.000138,0.000138,0.000137,0.000132,0.000139,0.000142,0.000147
filter_124,0.000637,0.000599,0.000600,0.000601,0.000600,0.000601,0.000600,0.000602,0.000604,0.000603,...,0.000601,0.000605,0.000600,0.000604,0.000606,0.000607,0.000603,0.000604,0.000602,0.000627
filter_125,0.000050,0.000043,0.000045,0.000047,0.000045,0.000043,0.000044,0.000045,0.000045,0.000049,...,0.000047,0.000047,0.000041,0.000042,0.000048,0.000046,0.000044,0.000047,0.000047,0.000045
filter_126,0.000004,0.000002,0.000003,0.000005,0.000004,0.000003,0.000003,0.000003,0.000003,0.000001,...,0.000003,0.000003,0.000003,0.000002,0.000005,0.000002,0.000004,0.000004,0.000003,0.000003


In [7]:
filter_importance_t_df.to_csv('sALS_filter_importance_128_by_160.csv')