In [None]:
%load_ext autoreload
%autoreload 2
import sys
import os
sys.path.append(os.path.abspath("../src/"))
import extract.data_loading as data_loading
import extract.compute_predictions as compute_predictions
import extract.compute_shap as compute_shap
import extract.compute_ism as compute_ism
import model.util as model_util
import model.profile_models as profile_models
import model.binary_models as binary_models
import plot.viz_sequence as viz_sequence
import feature.util as feature_util
import pyBigWig
import torch
import numpy as np
import scipy.stats
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager
import json
import tqdm
tqdm.tqdm_notebook()  # It is necessary to call this before the tqdm.notebook submodule is available

In [None]:
# Plotting defaults
font_manager.fontManager.ttflist.extend(
    font_manager.createFontList(
        font_manager.findSystemFonts(fontpaths="/users/amtseng/modules/fonts")
    )
)
plot_params = {
    "axes.titlesize": 22,
    "axes.labelsize": 20,
    "figure.titlesize": 22,
    "legend.fontsize": 18,
    "xtick.labelsize": 16,
    "ytick.labelsize": 16,
    "font.family": "Roboto",
    "font.weight": "bold"
}
plt.rcParams.update(plot_params)

### Define paths for the model and data of interest

In [None]:
model_type = "binary"

In [None]:
# Shared paths/constants
chrom_sizes = "/users/amtseng/genomes/hg38.canon.chrom.sizes"
raw_data_base_path = "/users/amtseng/att_priors/data/raw/"
proc_data_base_path = "/users/amtseng/att_priors/data/processed/"
model_base_path = "/users/amtseng/att_priors/models/trained_models/%s/" % model_type
tfm_results_path = "/users/amtseng/att_priors/results/tfmodisco/%s/" % model_type
chrom_set = ["chr1"]
input_length = 1346 if model_type == "profile" else 1000
profile_length = 1000

In [None]:
# SPI1
condition_name = "SPI1-1task"
files_spec_path = os.path.join(proc_data_base_path, "ENCODE_TFChIP/%s/config/SPI1-1task/SPI1-1task_training_paths.json" % model_type)
num_tasks = 1
num_strands = 2
controls = "shared"
if model_type == "profile":
    model_class = profile_models.ProfilePredictorWithSharedControls
else:
    model_class = binary_models.BinaryPredictor
task_index = None
motif_path = "/users/amtseng/att_priors/results/SPI1_motifs/homer_motif1_trimmed.motif"

gc_probs = [0.50, 0.51, 0.52, 0.53, 0.54]  #, 0.55, 0.60]
noprior_model_paths, prior_model_paths = [None] * len(gc_probs), [None] * len(gc_probs)

noprior_model_paths[0] = os.path.join(model_base_path, "SPI1-1task_simgc%0.2f/2/model_ckpt_epoch_3.pt" % gc_probs[0])
prior_model_paths[0] = os.path.join(model_base_path, "SPI1-1task_prior_simgc%0.2f/2/model_ckpt_epoch_5.pt" % gc_probs[0])

noprior_model_paths[1] = os.path.join(model_base_path, "SPI1-1task_simgc%0.2f/3/model_ckpt_epoch_4.pt" % gc_probs[1])
prior_model_paths[1] = os.path.join(model_base_path, "SPI1-1task_prior_simgc%0.2f/3/model_ckpt_epoch_4.pt" % gc_probs[1])

noprior_model_paths[2] = os.path.join(model_base_path, "SPI1-1task_simgc%0.2f/3/model_ckpt_epoch_5.pt" % gc_probs[2])
prior_model_paths[2] = os.path.join(model_base_path, "SPI1-1task_prior_simgc%0.2f/3/model_ckpt_epoch_5.pt" % gc_probs[2])

noprior_model_paths[3] = os.path.join(model_base_path, "SPI1-1task_simgc%0.2f/2/model_ckpt_epoch_5.pt" % gc_probs[3])
prior_model_paths[3] = os.path.join(model_base_path, "SPI1-1task_prior_simgc%0.2f/1/model_ckpt_epoch_4.pt" % gc_probs[3])

noprior_model_paths[4] = os.path.join(model_base_path, "SPI1-1task_simgc%0.2f/3/model_ckpt_epoch_3.pt" % gc_probs[4])
prior_model_paths[4] = os.path.join(model_base_path, "SPI1-1task_prior_simgc%0.2f/2/model_ckpt_epoch_5.pt" % gc_probs[4])

# noprior_model_paths[5] = os.path.join(model_base_path, "SPI1-1task_simgc%0.2f/2/model_ckpt_epoch_1.pt" % gc_probs[5])
# prior_model_paths[5] = os.path.join(model_base_path, "SPI1-1task_prior_simgc%0.2f/2/model_ckpt_epoch_1.pt" % gc_probs[5])

# noprior_model_paths[6] = os.path.join(model_base_path, "SPI1-1task_simgc%0.2f/2/model_ckpt_epoch_1.pt" % gc_probs[6])
# prior_model_paths[6] = os.path.join(model_base_path, "SPI1-1task_prior_simgc%0.2f/2/model_ckpt_epoch_2.pt" % gc_probs[6])

In [None]:
torch.set_grad_enabled(True)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
def restore_model(model_path):
    model = model_util.restore_model(model_class, model_path)
    model.eval()
    model = model.to(device)
    return model

In [None]:
# Import the model without priors
noprior_models = [
    restore_model(noprior_model_path) for noprior_model_path in noprior_model_paths
]

In [None]:
# Import the model with priors
prior_models = [
    restore_model(prior_model_path) for prior_model_path in prior_model_paths
]

### Data preparation
Create an input data loader, that maps coordinates or bin indices to data needed for the model. We also create a loader for the GC content

In [None]:
np.random.seed(20200420)

In [None]:
sim_seq_generators = [
    feature_util.StatusToSimulatedSeq(input_length, motif_path, 0, gc_prob)
    for gc_prob in gc_probs
]

In [None]:
background_freqs = [
    np.array([1 - gc_prob, gc_prob, gc_prob, 1 - gc_prob]) / 2
    for gc_prob in gc_probs
]
def pfm_to_pwm(pfm, background, pseudocount=0.001):
    """
    Converts and L x 4 PFM into an L x 4 PWM.
    """
    num_bases = pfm.shape[1]
    # Incorporate pseudocount by adding it to every element and renormalizing
    pfm_norm = (pfm + pseudocount) / (np.sum(pfm, axis=1, keepdims=True) + (num_bases * pseudocount))
    return np.log2(pfm_norm / np.expand_dims(background, axis=0))

In [None]:
motif_pfm = feature_util.import_homer_motif(motif_path)
motif_pwms = [
    pfm_to_pwm(motif_pfm, background) for background in background_freqs
]

### Compute importances

In [None]:
def compute_shap_scores(model, input_seqs, batch_size=128):
    """
    Given an array of N x I x 4 array of input sequences, computes the SHAP scores
    for the model, returning an N x I x 4 array of SHAP scores.
    """
    assert model_type == "binary", "profile model types not supported here"
    num_samples = len(input_seqs)
    num_batches = int(np.ceil(num_samples / batch_size))
    
    all_shap_scores = np.empty((num_samples, input_length, 4))
        
    shap_explainer = compute_shap.create_binary_explainer(
        model, input_length, task_index=task_index
    )

    for i in tqdm.notebook.trange(num_batches):
        batch_slice = slice(i * batch_size, (i + 1) * batch_size)
        batch = input_seqs[batch_slice]

        shap_scores = shap_explainer(
            batch, hide_shap_output=True
        )

        all_shap_scores[batch_slice] = shap_scores
    return all_shap_scores

In [None]:
num_samples = 100
sample = np.arange(num_samples)

In [None]:
# Compute the importance scores and 1-hot seqs
imp_type = "DeepSHAP scores"
imp_func = compute_shap_scores
sample_input_seqs = [
    sim_seq_generators[gc_index](np.ones(len(sample))) for gc_index in range(len(gc_probs))
]
noprior_imp_scores, prior_imp_scores = [], []
for gc_index, noprior_model in enumerate(noprior_models):
    noprior_imp_scores.append(imp_func(noprior_model, sample_input_seqs[gc_index]))
for gc_index, prior_model in enumerate(prior_models):
    prior_imp_scores.append(imp_func(prior_model, sample_input_seqs[gc_index]))

In [None]:
def get_motif_mask(one_hot_seqs, pwm, score_thresh=0.7):
    rc_pwm = np.flip(pwm, axis=(0, 1))
    mask = np.zeros(one_hot_seqs.shape[:2], dtype=bool)
    for i, one_hot_seq in tqdm.notebook.tqdm(enumerate(one_hot_seqs), total=len(one_hot_seqs)):
        for j in range(one_hot_seq.shape[0] - len(pwm) + 1):
            match = np.sum(one_hot_seq[j : j + len(pwm)] * pwm) / len(pwm)
            rc_match = np.sum(one_hot_seq[j : j + len(rc_pwm)] * rc_pwm) / len(rc_pwm)
            if match >= score_thresh or rc_match >= score_thresh:
                mask[i, j : j + len(pwm)] = True
    return mask

In [None]:
def get_non_motif_gc(imp_scores, motif_mask):
    keep_mask = ~motif_mask
    gc_scores, at_scores, prod_scores = [], [], []
    for i, score_track in enumerate(imp_scores):
        gc_score = np.sum(score_track[keep_mask[i]][:, 1:3], axis=1) / np.max(score_track)
        at_score = (score_track[keep_mask[i]][:, 0] + score_track[keep_mask[i]][:, 3]) / np.max(score_track)
        gc_scores.append(np.nanmean(gc_score))
        at_scores.append(np.nanmean(at_score))
        prod_scores.append(np.nanmean(gc_score * at_score))
    return np.array(gc_scores), np.array(at_scores), np.array(prod_scores)

In [None]:
def get_motif_importance_frac(imp_scores, input_seqs, motif_mask):
    keep_mask = motif_mask
    imp_fracs = []
    for i, score_track in enumerate(imp_scores):
        act_scores = np.abs(np.sum(score_track * input_seqs[i], axis=1))
        imp_frac = np.sum(act_scores[keep_mask[i]]) / np.sum(act_scores)
        imp_fracs.append(imp_frac)
    return np.array(imp_fracs)

In [None]:
masks = [
    get_motif_mask(sample_input_seqs[gc_index], motif_pwm, score_thresh=0.9)
    for gc_index, motif_pwm in enumerate(motif_pwms)
]

In [None]:
noprior_scores, prior_scores = [], []
noprior_imp_fracs, prior_imp_fracs = [], []
for gc_index in range(len(gc_probs)):
    noprior_gc_scores, noprior_at_scores, noprior_prod_scores = get_non_motif_gc(
        noprior_imp_scores[gc_index], masks[gc_index]
    )
    prior_gc_scores, prior_at_scores, prior_prod_scores = get_non_motif_gc(
        prior_imp_scores[gc_index], masks[gc_index]
    )
    noprior_scores.append((noprior_gc_scores, noprior_at_scores, noprior_prod_scores))
    prior_scores.append((prior_gc_scores, prior_at_scores, prior_prod_scores))
    
    noprior_imp_fracs.append(get_motif_importance_frac(
        noprior_imp_scores[gc_index], sample_input_seqs[gc_index], masks[gc_index]
    ))
    prior_imp_fracs.append(get_motif_importance_frac(
        prior_imp_scores[gc_index], sample_input_seqs[gc_index], masks[gc_index]
    ))

In [None]:
for gc_index in range(len(gc_probs)):
    noprior_prod_scores, prior_prod_scores = noprior_scores[gc_index][2], prior_scores[gc_index][2]
    bin_num = 50
    plt.figure(figsize=(12, 6))
    title = "Histogram of %s GC importance x AT importance outside motif instances" % imp_type
    title += "\nSingle-task SPI1 binary models, trained on %2.0f%% G/C bias" % (gc_probs[gc_index] * 100)
    title += "\nComputed on %d randomly simulated sequences" % num_samples
    plt.title(title)
    plt.xlabel("Signed importance of GC x importance of AT")
    all_vals = np.concatenate([noprior_prod_scores, prior_prod_scores])
    bins = np.linspace(np.min(all_vals), np.max(all_vals), bin_num)
    plt.hist(noprior_prod_scores, bins=bins, histtype="bar", label="No prior", color="coral", alpha=0.7)
    plt.hist(prior_prod_scores, bins=bins, histtype="bar", label="With Fourier prior", color="slateblue", alpha=0.7)
    plt.legend()
    plt.show()

    print("Average product without priors: %f" % np.nanmean(noprior_prod_scores))
    print("Average product with priors: %f" % np.nanmean(prior_prod_scores))
    w, p = scipy.stats.wilcoxon(noprior_prod_scores, prior_prod_scores, alternative="less")
    print("One-sided Wilcoxon test: W = %f, p = %f" % (w, p))

In [None]:
# Histogram of GC x AT importance product, on a shared x-axis
bin_num = 40
fig, ax = plt.subplots(1, len(gc_probs), figsize=(9 * len(gc_probs), 15), sharey=True)
if len(gc_probs) == 1:
    ax = [ax]
title = "Histogram of %s GC importance x AT importance outside motif instances" % imp_type
title += "\nSingle-task SPI1 binary models"
title += "\nComputed on %d randomly simulated sequences" % num_samples
plt.suptitle(title)
fig.text(0.5, 0.05, "Signed importance of GC x importance of AT", ha="center", fontsize=22)
all_vals = np.ravel([
    [noprior_scores[gc_index][2], prior_scores[gc_index][2]] for gc_index in range(len(gc_probs))
])
bins = np.linspace(np.min(all_vals), np.max(all_vals), bin_num)

for gc_index in range(len(gc_probs)):
    ax[gc_index].hist(noprior_scores[gc_index][2], bins=bins, histtype="bar", label="No prior", color="coral", alpha=0.7)
    ax[gc_index].hist(prior_scores[gc_index][2], bins=bins, histtype="bar", label="With Fourier prior", color="slateblue", alpha=0.7)
    ax[gc_index].set_title("%2.0f%% G/C bias" % (gc_probs[gc_index] * 100))

plt.subplots_adjust(top=0.85)
ax[0].legend()

plt.show()

In [None]:
for gc_index in range(len(gc_probs)):
    noprior_prod_scores, prior_prod_scores = noprior_scores[gc_index][2], prior_scores[gc_index][2]
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.scatter(noprior_prod_scores, prior_prod_scores, color="mediumorchid", alpha=0.5)
    title = "Pairwise comparison of %s GC importance x AT importance outside motif instances" % imp_type
    title += "\nSingle-task SPI1 binary models, trained on %2.0f%% G/C bias" % (gc_probs[gc_index] * 100)
    title += "\nComputed on %d randomly simulated sequences" % num_samples
    plt.title(title)
    limits = [
        np.min([ax.get_xlim(), ax.get_ylim()]),
        np.max([ax.get_xlim(), ax.get_ylim()]),
    ]
    ax.plot(limits, limits, "--", alpha=0.5, color="black")
    ax.set_aspect("equal")
    ax.set_xlim(limits)
    ax.set_ylim(limits)
    plt.xlabel("Importance of GC x AT without prior")
    plt.ylabel("Importance of GC x AT with Fourier prior")

In [None]:
for gc_index in range(len(gc_probs)):
    noprior_gc_scores, noprior_at_scores = noprior_scores[gc_index][:2]
    prior_gc_scores, prior_at_scores = prior_scores[gc_index][:2]
    bin_num = 30
    plt.figure(figsize=(12, 12))
    title = "%s GC importance x AT importance outside motif instances" % imp_type
    title += "\nSingle-task SPI1 binary models, trained on %2.0f%% G/C bias" % (gc_probs[gc_index] * 100)
    title += "\nComputed on %d randomly simulated sequences" % num_samples
    plt.title(title)
    plt.xlabel("Signed importance of GC")
    plt.ylabel("Signed importance of AT")
    plt.scatter(noprior_gc_scores, noprior_at_scores, color="coral", alpha=0.7, label="No prior")
    plt.scatter(prior_gc_scores, prior_at_scores, color="slateblue", alpha=0.7, label="With Fourier prior")
    plt.legend()

In [None]:
for gc_index in range(len(gc_probs)):
    noprior_frac, prior_fracs = noprior_imp_fracs[gc_index], prior_imp_fracs[gc_index]
    bin_num = 30
    plt.figure(figsize=(20, 7))
    title = "Proportion of %s importance in motif instances" % imp_type
    title += "\nSingle-task SPI1 binary models, trained on %2.0f%% G/C bias" % (gc_probs[gc_index] * 100)
    title += "\nComputed on %d randomly simulated sequences" % num_samples
    plt.title(title)
    plt.xlabel("Proportion of importance in motif instances")
    all_vals = np.concatenate([noprior_frac, prior_fracs])
    bins = np.linspace(np.min(all_vals), np.max(all_vals), bin_num)
    plt.hist(noprior_frac, bins=bins, histtype="bar", label="No prior", color="coral", alpha=0.7)
    plt.hist(prior_fracs, bins=bins, histtype="bar", label="With Fourier prior", color="slateblue", alpha=0.7)
    plt.legend()
    plt.show()

    print("Average product without priors: %f" % np.nanmean(noprior_frac))
    print("Average product with priors: %f" % np.nanmean(prior_fracs))
    w, p = scipy.stats.wilcoxon(noprior_frac, prior_fracs, alternative="less")
    print("One-sided Wilcoxon test: W = %f, p = %f" % (w, p))

In [None]:
def show_example(gc_index, i, center_slice=slice(450, 550)):
    print(gc_probs[gc_index], i)
    print("=========================")
    print("Without priors:")
    plt.figure(figsize=(20, 2))
    plt.plot(np.sum(noprior_imp_scores[gc_index][i] * sample_input_seqs[gc_index][i], axis=1), color="coral")
    plt.show()
    viz_sequence.plot_weights((noprior_imp_scores[gc_index][i])[center_slice], subticks_frequency=1000)
    viz_sequence.plot_weights((noprior_imp_scores[gc_index][i] * sample_input_seqs[gc_index][i])[center_slice], subticks_frequency=1000)
    print("With priors:")
    plt.figure(figsize=(20, 2))
    plt.plot(np.sum(prior_imp_scores[gc_index][i] * sample_input_seqs[gc_index][i], axis=1), color="slateblue")
    plt.show()
    viz_sequence.plot_weights((prior_imp_scores[gc_index][i])[center_slice], subticks_frequency=1000)
    viz_sequence.plot_weights((prior_imp_scores[gc_index][i] * sample_input_seqs[gc_index][i])[center_slice], subticks_frequency=1000)

In [None]:
# Plot out a few examples
for gc_index in range(len(gc_probs)):
    for i in np.random.choice(num_samples, size=3, replace=False):
        show_example(gc_index, i, center_slice=slice(400, 600))