In [1]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import utility_loader
from custom_bpnet import customBPNet
from custom_bpnet import CountWrapper, ControlWrapper, ProfileWrapper
from plotnine import *
import plotnine
import tqdm
from tangermeme.utils import random_one_hot
from tangermeme.io import read_meme
from tangermeme.marginalize import marginalize
from tangermeme.utils import pwm_consensus
from tangermeme.deep_lift_shap import deep_lift_shap
from tangermeme.predict import predict
import modiscolite

# Set outdir
outdir = '/data/mariani/specificity_bpnet/output/models_hyperparams/calibrated_model/figs/marginalization'

# Marginalization of known motifs

In [7]:
# # Load model
# dict_data = {
# 	"n_outputs":22,
# 	"n_control_tracks":2,
# 	"trimming":(2114-1000)//2, # difference between the input length and the output length, actually half of that.
# 	"alpha": 0.5, #0.5
# 	"n_filters": 64,
# 	"n_layers":8,
# 	"name": "calibrated_model"
# }

# # Loaded model
# loaded_model = customBPNet(**dict_data).cuda()
# loaded_model.load_state_dict(torch.load('/data/mariani/specificity_bpnet/output/models_hyperparams/calibrated_model/calibrated_model.torch', weights_only = True))

# #bpnet_calibrated = CountWrapper(ControlWrapper(loaded_model))
# bpnet_calibrated =  ProfileWrapper(ControlWrapper(loaded_model))

# Load model
dict_data = {
	"n_outputs":2,
	"n_control_tracks":2,
	"trimming":(6114-5000)//2, # difference between the input length and the output length, actually half of that.
	"alpha": 300, #0.5
	"n_filters": 128,
	"n_layers":13,
	"name": "calibrated_wider_model"
}

# Loaded model
loaded_model = customBPNet(**dict_data).cuda()
loaded_model = torch.load('/data/mariani/specificity_bpnet/single_models/PCGF1_model.torch', weights_only = False)

bpnet_calibrated = CountWrapper(ControlWrapper(loaded_model), 0)
#bpnet_calibrated =  ProfileWrapper(ControlWrapper(loaded_model))

In [20]:
# Let's start by generating a set of background sequences with probabilities of chr1. 
# Later on sustitute this approach with:
#	- Dinucleotide shuffle RING1B regions
#	- Active promoters??

X = random_one_hot((200, 4, 6114), probs=np.array([[0.2910, 0.2085, 0.2087, 0.2918]])).float()

# read meme file, check some random motifs and other kwown
motifs = read_meme('/data/mariani/specificity_bpnet/bpnet_data/meme_db/motif_databases/JASPAR/JASPAR2022_CORE_vertebrates_non-redundant_v2.meme')
# motifs['MA0059.1 MAX::MYC'] = read_meme('/data/mariani/specificity_bpnet/bpnet_data/meme_db/motif_databases/JASPAR/JASPAR2022_CORE_vertebrates_non-redundant_v2.meme')['MA0059.1 MAX::MYC']
# motifs['MA0471.2 E2F6'] = read_meme('/data/mariani/specificity_bpnet/bpnet_data/meme_db/motif_databases/JASPAR/JASPAR2022_CORE_vertebrates_non-redundant_v2.meme')['MA0471.2 E2F6']
# motifs['MA0093.3 USF1'] = read_meme('/data/mariani/specificity_bpnet/bpnet_data/meme_db/motif_databases/JASPAR/JASPAR2022_CORE_vertebrates_non-redundant_v2.meme')['MA0093.3 USF1']
# motifs['MA1122.1 TFDP1'] = read_meme('/data/mariani/specificity_bpnet/bpnet_data/meme_db/motif_databases/JASPAR/JASPAR2022_CORE_vertebrates_non-redundant_v2.meme')['MA1122.1 TFDP1']
# motifs['MA0741.1 KLF16'] = read_meme('/data/mariani/specificity_bpnet/bpnet_data/meme_db/motif_databases/JASPAR/JASPAR2022_CORE_vertebrates_non-redundant_v2.meme')['MA0741.1 KLF16']

In [21]:
# motifs_name, delt = [], []
# for name, pwm in motifs.items():
#     consensus = pwm_consensus(pwm).unsqueeze(0)
#     y_before, y_after = marginalize(bpnet_calibrated, X, consensus)
#     delta = (y_after - y_before).mean().item()
#     if delta >= 0.009:
#         motifs_name.append(name)
#         delt.append(delta)

# def_motifs_delta = pd.DataFrame({
# 	'motif': motifs_name,
# 	'delta': delt
# })

motifs_name, delt = [], []
for name, pwm in motifs.items():
    consensus = pwm_consensus(pwm).unsqueeze(0)
    y_before, y_after = marginalize(bpnet_calibrated, X, consensus)
    #print(y_before.shape, y_after.shape)
    delta = (y_after - y_before).mean().item()
    motifs_name.append(name)
    delt.append(delta)

def_motifs_delta = pd.DataFrame({
	'motif': motifs_name,
	'delta': delt
})


Unnamed: 0,motif,delta


In [1]:
def_motifs_delta[def_motifs_delta.delta >= 0.01]

NameError: name 'def_motifs_delta' is not defined

In [19]:
for name, pwm in motifs.items():
	consensus = pwm_consensus(pwm).unsqueeze(0)

	n = 10 
	s, e = 6114 // 2 - n , 6114 // 2 + n
	y_before, y_after = marginalize(bpnet_calibrated, X, consensus, func=deep_lift_shap)
	y_before = y_before[:, :, s:e] * X[:, :, s:e]
	y_after = y_after[:, :, s:e] * X[:, :, s:e]

	delta = (y_after - y_before).mean().item()
	print(name, delta)


MA0004.1 Arnt 9.94539641396841e-06
MA0006.1 Ahr::Arnt 1.08449703475344e-05
MA0019.1 Ddit3::Cebpa -5.345597401174018e-06
MA0029.1 Mecom -3.229926005587913e-05
MA0059.1 MAX::MYC 2.37389357948814e-07
MA0471.2 E2F6 3.935094355256297e-05
MA0093.3 USF1 -2.0988405594835058e-05
MA1122.1 TFDP1 2.9910863304394297e-05
MA0741.1 KLF16 3.787975219893269e-05


1057