In [None]:
#Importance scores without priors
!ls /users/amtseng/att_priors/results/shap_scores/profile/BPNet/BPNet_r20_e18_task*

#Importance scores with priors
!ls /users/amtseng/att_priors/results/shap_scores/profile/BPNet/BPNet_prior_r25_e17_task*_shap_scores.h5

#TF-modisco results without priors:
!ls /users/amtseng/att_priors/results/tfmodisco/profile/BPNet/BPNet_r20_e18_task*_tfm.h5

#TF-MoDISco results WITH priors:
!ls /users/amtseng/att_priors/results/tfmodisco/profile/BPNet/BPNet_prior_r25_e17_task*_tfm.h5


In [None]:
import h5py

noprior_impscores =\
    h5py.File("/users/amtseng/att_priors/results/shap_scores/"+
              "profile/BPNet/BPNet_r20_e18_task0_shap_scores.h5", "r")
withprior_impscores =\
    h5py.File("/users/amtseng/att_priors/results/shap_scores/"+
              "profile/BPNet/BPNet_prior_r25_e17_task0_shap_scores.h5", "r")

noprior_modisco_h5 = h5py.File("/users/amtseng/att_priors/results/"
                               +"tfmodisco/profile/BPNet/BPNet_r20_e18_task0_tfm.h5", "r")
withprior_modisco_h5 = h5py.File("/users/amtseng/att_priors/results/tfmodisco/"
                              +"profile/BPNet/BPNet_prior_r25_e17_task0_tfm.h5", "r")

In [None]:
import numpy as np
import modisco
from modisco.tfmodisco_workflow import workflow

onehot = np.array(noprior_impscores["one_hot_seqs"])
noprior_hypscores = np.array(noprior_impscores["hyp_scores"])
withprior_hypscores = np.array(withprior_impscores["hyp_scores"])
noprior_contribscores = noprior_hypscores*onehot
withprior_contribscores = withprior_hypscores*onehot

noprior_track_set = modisco.tfmodisco_workflow.workflow.prep_track_set(
                task_names=["task0"],
                contrib_scores={"task0": noprior_contribscores},
                hypothetical_contribs={"task0": noprior_hypscores},
                one_hot=onehot)
noprior_tfmodisco_results =\
    workflow.TfModiscoResults.from_hdf5(noprior_modisco_h5,
                                        track_set=noprior_track_set)

withprior_track_set = modisco.tfmodisco_workflow.workflow.prep_track_set(
                task_names=["task0"],
                contrib_scores={"task0": withprior_contribscores},
                hypothetical_contribs={"task0": withprior_hypscores},
                one_hot=onehot)
withprior_tfmodisco_results =\
    workflow.TfModiscoResults.from_hdf5(withprior_modisco_h5,
                                        track_set=withprior_track_set)

In [None]:
noprior_patterns = (noprior_tfmodisco_results.
                    metacluster_idx_to_submetacluster_results['metacluster_1']
                    .seqlets_to_patterns_result.patterns)
withprior_patterns = (withprior_tfmodisco_results.
                      metacluster_idx_to_submetacluster_results['metacluster_0']
                      .seqlets_to_patterns_result.patterns)

In [None]:
from modisco.visualization import viz_sequence
print("No prior - patterns")
for idx,pattern in enumerate(noprior_patterns):
    print(idx, len(pattern.seqlets))
    viz_sequence.plot_weights(pattern["task0_contrib_scores"].fwd)
    viz_sequence.plot_weights(pattern["sequence"].fwd)

In [None]:
from modisco.visualization import viz_sequence
print("With prior - patterns")
for idx,pattern in enumerate(withprior_patterns):
    print(idx, len(pattern.seqlets))
    viz_sequence.plot_weights(pattern["task0_contrib_scores"].fwd)
    viz_sequence.plot_weights(pattern["sequence"].fwd)

In [None]:
#Let's look at the scores underlying the 'Nanog-alt' motif discovered by the with-priors model
motif_to_study = withprior_patterns[4]
viz_sequence.plot_weights(motif_to_study["task0_contrib_scores"].fwd)

In [None]:
#get coordinates centered around this motif

input_length = 1346
idx_within_motif_of_centerpos = 22 #where to center, within the modisco motif

orig_coord_starts = noprior_impscores['coords_start']
orig_coord_ends = noprior_impscores['coords_end']

genomic_coords = []
is_revcomp = []
debug_coords = []
noprior_aroundmotif_contribscores = []
withprior_aroundmotif_contribscores = []
flank_to_show = 50
num_seqlets_to_use = 10
for seqlet in motif_to_study.seqlets[:num_seqlets_to_use]:
    region_start = ((orig_coord_starts[seqlet.coor.example_idx]
                    + orig_coord_ends[seqlet.coor.example_idx])//2
                    -input_length//2)
    within_region_center = ((seqlet.coor.start + idx_within_motif_of_centerpos) if seqlet.coor.is_revcomp==False
                            else (seqlet.coor.end-idx_within_motif_of_centerpos))
    if (within_region_center > flank_to_show
        and (noprior_contribscores.shape[1]-within_region_center) > flank_to_show):
        genomic_motif_center = region_start + within_region_center
        genomic_coords.append(('chr1', genomic_motif_center, genomic_motif_center+1))
        is_revcomp.append(seqlet.coor.is_revcomp)
        noprior_aroundmotif_contribscores.append(
            noprior_contribscores[seqlet.coor.example_idx,
                                   within_region_center-flank_to_show:
                                   within_region_center+flank_to_show])
        withprior_aroundmotif_contribscores.append(
            withprior_contribscores[seqlet.coor.example_idx,
                                     within_region_center-flank_to_show:
                                     within_region_center+flank_to_show])


In [None]:
#get the Dataset Loader
import sys, os
sys.path.append(os.path.abspath("../src/"))
import feature.util
import feature.make_profile_dataset

reference_fasta = "/users/amtseng/genomes/mm10.fasta"
profile_hdf5_path = "/users/amtseng/att_priors/data/processed/BPNet_ChIPseq/profile/labels/BPNet_profiles.h5"
profile_length = 1000
coords_to_seq = feature.util.CoordsToSeq(
    reference_fasta,
    center_size_to_use=input_length)
coords_to_vals = feature.make_profile_dataset.CoordsToVals(
                    profile_hdf5_path, profile_length)

#genomic_coords = list(zip(['chr1' for x in range(len(noprior_impscores['coords_start']))],
#                  noprior_impscores['coords_start'][:],
#                  noprior_impscores['coords_end'][:]))

seqs_onehot = coords_to_seq(genomic_coords)
profiles = np.swapaxes(coords_to_vals(genomic_coords),1,2)
tf_profile = profiles[:,:3,:,:]
control_profile = profiles[:,3:,:,:]


In [None]:
import model.util as model_util
import model.profile_models as profile_models
import torch

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

def restore_model(model_class, load_path):
    """
    Restores a model from the given path. `model_class` must be the class for
    which the saved model was created from. This will create a model of this
    class, using the loaded creation arguments. It will then restore the learned
    parameters to the model.
    """
    load_dict = torch.load(load_path, map_location=torch.device('cpu'))
    model_state = load_dict["model_state"]
    model_creation_args = load_dict["model_creation_args"]
    model = model_class(**model_creation_args)
    model.load_state_dict(model_state)
    return model

def load_model(model_path):
    model = restore_model(profile_models.ProfilePredictorWithSharedControls, model_path)
    model.eval()
    model = model.to(device)
    return model

def get_model_preds(model, input_data_list):
    seqs_onehot, control_profile = input_data_list
    seqs_onehot = model_util.place_tensor(torch.tensor(seqs_onehot)).float()
    control_profile = model_util.place_tensor(torch.tensor(control_profile)).float()
    to_return = model(seqs_onehot, control_profile)
    return (to_return[0].detach().cpu().numpy(), to_return[1].detach().cpu().numpy())

def get_batched_model_preds(model, seqs_onehot, control_profile, batch_size):
    profiles = []
    counts = []
    for i in range(int(np.ceil(len(seqs_onehot)/batch_size))):
        if (i%10 == 0):
            print("Done",i,"batches")
        (batch_profiles, batch_counts) = get_model_preds(
            model=model,
            input_data_list=[seqs_onehot[batch_size*i: batch_size*(i+1)],
                             control_profile[batch_size*i: batch_size*(i+1)]])
        profiles.extend(batch_profiles)
        counts.extend(batch_counts)
    return np.array(profiles), np.array(counts)
    
withprior_model = load_model(withprior_impscores['model'].attrs['model'])
noprior_model = load_model(noprior_impscores['model'].attrs['model'])

In [None]:
withprior_profilepreds = get_batched_model_preds(
    model=withprior_model, seqs_onehot=seqs_onehot,
    control_profile=control_profile, batch_size=40)

noprior_profilepreds = get_batched_model_preds(
    model=noprior_model, seqs_onehot=seqs_onehot,
    control_profile=control_profile, batch_size=40)

In [None]:
#copying over some ism code
import scipy.special

def list_wrapper(func):
    def wrapped_func(input_data_list, **kwargs):
        if (isinstance(input_data_list, list)):
            remove_list_on_return=False
        else:
            remove_list_on_return=True
            input_data_list = [input_data_list]
        to_return = func(input_data_list=input_data_list,
                         **kwargs)
        return to_return
    return wrapped_func

def empty_ism_buffer(results_arr,
                     input_data_onehot,
                     perturbed_inputs_preds,
                     perturbed_inputs_info):
    for perturbed_input_pred,perturbed_input_info\
        in zip(perturbed_inputs_preds, perturbed_inputs_info):
        example_idx = perturbed_input_info[0]
        if (perturbed_input_info[1]=="original"):
            results_arr[example_idx] +=\
                (perturbed_input_pred*input_data_onehot[example_idx])
        else:
            pos_idx,base_idx = perturbed_input_info[1]
            results_arr[example_idx,pos_idx,base_idx] = perturbed_input_pred

def make_ism_func(prediction_func,
                  flank_around_middle_to_perturb,
                  batch_size=200):
    @list_wrapper
    def ism_func(input_data_list, progress_update=10000, **kwargs):
        input_data_onehot=input_data_list[0]
        
        results_arr = np.zeros_like(input_data_onehot).astype("float64")
        
        perturbed_inputs_info = []
        perturbed_onehot_seqs = []
        control_inputs = []
        perturbed_inputs_preds = []
        num_done = 0
        for i,onehot_seq in enumerate(input_data_onehot):
            perturbed_onehot_seqs.append(onehot_seq)
            control_inputs.append(input_data_list[-1][i])
            perturbed_inputs_info.append((i,"original"))
            for pos in range(int(len(onehot_seq)/2)-flank_around_middle_to_perturb,
                             int(len(onehot_seq)/2)+flank_around_middle_to_perturb):
                for base_idx in range(4):
                    if onehot_seq[pos,base_idx]==0:
                        assert len(onehot_seq.shape)==2
                        new_onehot = np.zeros_like(onehot_seq) + onehot_seq
                        new_onehot[pos,:] = 0
                        new_onehot[pos,base_idx] = 1
                        perturbed_onehot_seqs.append(new_onehot)
                        control_inputs.append(input_data_list[-1][i])
                        perturbed_inputs_info.append((i,(pos,base_idx)))
                        num_done += 1
                        if ((progress_update is not None)
                            and num_done%progress_update==0):
                            print("Done",num_done)
                        if (len(perturbed_inputs_info)>=batch_size):
                            empty_ism_buffer(
                                 results_arr=results_arr,
                                 input_data_onehot=input_data_onehot,
                                 perturbed_inputs_preds=
                                  prediction_func([np.array(perturbed_onehot_seqs), np.array(control_inputs)]),
                                 perturbed_inputs_info=perturbed_inputs_info)
                            perturbed_inputs_info = []
                            perturbed_onehot_seqs = []
                            control_inputs = []
        if (len(perturbed_inputs_info)>0):
            empty_ism_buffer(
                 results_arr=results_arr,
                 input_data_onehot=input_data_onehot,
                 perturbed_inputs_preds=
                  prediction_func([np.array(perturbed_onehot_seqs), np.array(control_inputs)]),
                 perturbed_inputs_info=perturbed_inputs_info)
        perturbed_inputs_info = []
        perturbed_onehot_seqs = []
        results_arr = results_arr - np.mean(results_arr,axis=-1)[:,:,None]
        return input_data_onehot*results_arr
    return ism_func

def get_prediction_func(model, task_idx):
    def pred_func(x):
        logits = get_model_preds(model,x)[0][:,task_idx]
        softmax_out = scipy.special.softmax(logits, axis=1)
        assert np.max(np.abs(np.sum(softmax_out, axis=1)-1)) < 1e-5, print(np.sum(softmax_out, axis=1))
        assert len(softmax_out.shape)==3
        return np.sum(softmax_out*logits, axis=(1,2))
    return pred_func
        

In [None]:
noprior_ismfunc = make_ism_func(
    prediction_func=get_prediction_func(noprior_model, task_idx=0),
    flank_around_middle_to_perturb=flank_to_show,
    batch_size=40)

withprior_ismfunc = make_ism_func(
    prediction_func=get_prediction_func(withprior_model, task_idx=0),
    flank_around_middle_to_perturb=flank_to_show,
    batch_size=40)

noprior_ism = noprior_ismfunc([seqs_onehot, control_profile], progress_update=100)
withprior_ism = withprior_ismfunc([seqs_onehot, control_profile], progress_update=100)

In [None]:
for idx,(genomic_coord, is_rc) in enumerate(zip(genomic_coords, is_revcomp)):
    print(genomic_coord[0],
          genomic_coord[1]-flank_to_show,
          genomic_coord[1]+flank_to_show,
          "revcomp:"+str(is_rc))
    center_offset = int(noprior_ism.shape[1]/2)
    print("No prior - DeepSHAP scores")
    viz_sequence.plot_weights(noprior_aroundmotif_contribscores[idx], subticks_frequency=20)
    print("No prior - ISM scores")
    viz_sequence.plot_weights(noprior_ism[idx][center_offset-flank_to_show:center_offset+flank_to_show],
                              subticks_frequency=20)
    print("With prior - DeepSHAP scores")
    viz_sequence.plot_weights(withprior_aroundmotif_contribscores[idx], subticks_frequency=20)
    print("With prior - ISM scores")
    viz_sequence.plot_weights(withprior_ism[idx][center_offset-flank_to_show:center_offset+flank_to_show],
                              subticks_frequency=20)