In [1]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch import nn
import sys
sys.path.append("/binf-isilon/renniegrp/vpx267/ucph_thesis/")
from model import ConfigurableModel
from wrapper import utils
from wrapper.motifs_importance import forward_to_RELU_1, activation_pfm

In [2]:
PROJECT_DIR = "/binf-isilon/renniegrp/vpx267/ucph_thesis"

In [None]:
fixed_tune_config = {'lr': 0.001, 'weight_decay': 0.1, 'cnn_first_filter': 24, 'cnn_first_kernel_size': 7, 'cnn_length': 3, 'cnn_filter': 32, 'cnn_kernel_size': 7, 'bilstm_layer': 3, 'bilstm_hidden_size': 256, 'fc_size': 256}       
model = ConfigurableModel(input_channel=4, cnn_first_filter=fixed_tune_config["cnn_first_filter"], cnn_first_kernel_size=fixed_tune_config["cnn_first_kernel_size"],
                        cnn_other_filter=fixed_tune_config["cnn_filter"], cnn_other_kernel_size=fixed_tune_config["cnn_kernel_size"], bilstm_layer=fixed_tune_config["bilstm_layer"], bilstm_hidden_size=fixed_tune_config["bilstm_hidden_size"], fc_size=fixed_tune_config["fc_size"],
                        output_size=2)

for fold in [1,2,3,4,5]:
    model_weight = torch.load(f"{PROJECT_DIR}/data/outputs/models/trained_model_{fold}th_fold_dual_outputs_m6A_info-no_promoter-False_fixed_tune.pkl",
                            map_location=torch.device('cpu'))
    model.load_state_dict(model_weight)
    model.eval()

    seq_fasta_test_path = f"{PROJECT_DIR}/data/dual_outputs/motif_fasta_test_SPLIT_{fold}.fasta"
    seq_fasta_one_hot = utils.create_seq_tensor(seq_fasta_test_path)

    with torch.no_grad():
        first_layer_result = forward_to_RELU_1(model, seq_fasta_one_hot)
    print("extract kernel_size")
    kernel_size = model.CNN[0].kernel_size[0]

    pfm, seqs_text, seqs_one_hot = activation_pfm(first_layer_result.transpose(1,2).detach().cpu(),
                     seq_fasta_one_hot.transpose(1,2).detach().cpu(),
                     window=fixed_tune_config["cnn_first_kernel_size"])
    # print(pfm.shape)

    seq_dict = {f'filter_{i+1}': record for i, record in enumerate(seqs_text)}
    np.savez(f"motifs_importance/first_layer_fold_{fold}_seqs_fixed_tune.npz", **seq_dict, allow_pickle = True)