# Imports and model initialization

In [3]:
# !pip install kipoi
# !pip install kipoiseq
# !pip install pybedtools
# !pip uninstall -y kipoi_veff
# !pip install git+https://github.com/an1lam/kipoi-veff
# !pip install pyvcf
import csv
import math
import pickle

from Bio.motifs import pfm
from Bio.motifs.jaspar import calculate_pseudocounts
import kipoi
from kipoi_interpret.importance_scores.ism import Mutation
from kipoiseq.dataloaders import SeqIntervalDl
from logomaker import Logo
from matplotlib import pyplot as plt
from matplotlib import patches
from matplotlib import cm
import matplotlib.lines as mlines
import matplotlib.transforms as mtransforms
import numpy as np
import pandas
from tqdm import tqdm

from align import prob_sw
from motif_scores import build_impact_maps
from motif_scores import kmer_mut_scores
from motif_scores import kmer_pwm_scores
from motif_scores import pwm_scores
from motif_scores import top_n_kmer_mut_scores
from motif_scores import top_n_kmer_pwm_scores
from np_utils import abs_max
from pyx.one_hot import one_hot
from utils import BASES
from utils import INT_TO_BASES
from utils import one_hot_decode

  from tqdm.autonotebook import tqdm


In [4]:
import tensorflow as tf
print("TF version:", tf.__version__)
import torch
print("torch version:", torch.__version__)
from torch import nn
from torch.nn import functional as F

TF version: 1.15.0
torch version: 1.3.1


In [5]:
!pwd

/home/stephenmalina/project/src


In [93]:
%config InlineBackend.figure_format = 'retina'

# Loading DNA sequence data

In [6]:
dl = SeqIntervalDl("../dat/50_random_seqs_2.bed", "../dat/hg19.fa", auto_resize_len=1000)
data = dl.load_all()

100%|██████████| 2/2 [00:00<00:00,  3.10it/s]


In [7]:
seqs = np.expand_dims(data['inputs'].transpose(0, 2, 1), 2).astype(np.float32)
seqs.shape

(50, 4, 1, 1000)

In [8]:
deepsea = kipoi.get_model("DeepSEA/predict", source="kipoi")
deepsea.model

Using downloaded and verified file: /home/stephenmalina/.kipoi/models/DeepSEA/predict/downloaded/model_files/weights/89e640bf6bdbe1ff165f484d9796efc7


Sequential(
  (0): ReCodeAlphabet()
  (1): ConcatenateRC()
  (2): Sequential(
    (0): Conv2d(4, 320, kernel_size=(1, 8), stride=(1, 1))
    (1): Threshold(threshold=0, value=1e-06)
    (2): MaxPool2d(kernel_size=(1, 4), stride=(1, 4), padding=0, dilation=1, ceil_mode=False)
    (3): Dropout(p=0.2, inplace=False)
    (4): Conv2d(320, 480, kernel_size=(1, 8), stride=(1, 1))
    (5): Threshold(threshold=0, value=1e-06)
    (6): MaxPool2d(kernel_size=(1, 4), stride=(1, 4), padding=0, dilation=1, ceil_mode=False)
    (7): Dropout(p=0.2, inplace=False)
    (8): Conv2d(480, 960, kernel_size=(1, 8), stride=(1, 1))
    (9): Threshold(threshold=0, value=1e-06)
    (10): Dropout(p=0.5, inplace=False)
    (11): Lambda()
    (12): Sequential(
      (0): Lambda()
      (1): Linear(in_features=50880, out_features=925, bias=True)
    )
    (13): Threshold(threshold=0, value=1e-06)
    (14): Sequential(
      (0): Lambda()
      (1): Linear(in_features=925, out_features=919, bias=True)
    )
    (15):

In [9]:
CHROM_ACC_COL = 'HepG2_DNase_None'
# TF_COL = 'A549_CTCF_None'
TF_COL = 'HepG2_FOXA1_None'
relevant_cols = sorted([(i, label)
                        for i, label in enumerate(deepsea.schema.targets.column_labels)
                        if label in [CHROM_ACC_COL, TF_COL]])

In [10]:
deepsea.pipeline.predict_example().shape

100%|██████████| 1/1 [00:00<00:00,  1.29it/s]


(10, 919)

# Loading Predictions

In [54]:
import pickle
# pickle_file = "../dat/most_recent_sat_mut_results__drop_channel.pickle"
pickle_file = "../dat/most_recent_sat_mut_results__original_mc_dropout.pickle"

In [55]:
with open(pickle_file, 'rb') as f: preds = pickle.load(f)
print(preds.shape)
n_seqs = 25
epochs, _, n_batches, batch_size, _ = preds.shape
preds = preds[:, :n_seqs, :, :, :]
preds.shape

(50, 50, 10, 301, 3)


(50, 25, 10, 301, 3)

In [262]:
all_zeros = np.zeros((4,))

def batches_needed(seq_len, batch_size, alpha_size=4):
    assert (seq_len * alpha_size) % (batch_size) == 0, seq_len * 3
    # alpha_size - 1 mutations per nt and then account for ref in each batch
    return (seq_len * alpha_size) // batch_size

def generate_wt_mut_batches(seq, batch_size=None, context_size=None):
    """
    For a given sequence, generate all possible point-mutated versions of the sequence
    in batches of size `param:batch_size`.
    
    Args:
        seq (numpy.ndarray [number of base pairs, sequence length]): 
            wild type sequence.
        batch_size (int): size of returned batches. Note that each batch will have the
            wild type sequence as its first row since we need to compute wild type / mut
            prediction diffs using predictions generated by the same dropout mask.
    """
    num_nts, seq_len = seq.shape
    if batch_size is None: batch_size = num_nts * seq_len
    n_batches = batches_needed(seq_len, batch_size, alpha_size=num_nts)
    new_seq_len = context_size * 2 + 1 if context_size is not None else seq_len
    seq_batches = np.zeros((n_batches, batch_size, num_nts, new_seq_len))    

    i = 0
    for seq_idx in range(seq_len):  # iterate over sequence 
        for nt_idx in range(num_nts):  # iterate over nucleotides

            curr_batch, curr_idx = i // batch_size, i % batch_size
            curr_mut = seq_batches[curr_batch, curr_idx]
            curr_mut[nt_idx, context_size] = 1
            
            start_idx, end_idx = (0, seq_len)
            if context_size is not None:
                start_idx, end_idx = max(seq_idx - context_size, 0), min(seq_idx + context_size + 1, seq_len)
            extend_back, extend_forwards = seq_idx - start_idx, end_idx - seq_idx
            curr_mut[:, context_size - extend_back: context_size] = seq[:, start_idx: seq_idx]
            curr_mut[:, context_size + 1: context_size + extend_forwards] = seq[:, seq_idx+1: end_idx]
            i += 1
    return seq_batches

In [57]:
def next_seq(it):
    return (np
            .expand_dims(next(it)["inputs"].transpose(0, 2, 1), 2)
            .astype(np.float32)
            .squeeze())

In [58]:
epochs, n_seqs, batch_size = 50, 25, 301
n_nts, _, seq_len = deepsea.schema.inputs.shape
preds = [[[] for _ in range(n_seqs)] for _ in range(epochs)]
it = dl.batch_iter(batch_size=1, num_workers=0, drop_last=False)

n_batches = batches_needed(seq_len, batch_size, alpha_size=n_nts)
seqs = np.zeros((n_seqs, n_nts, seq_len))
batch_size = 301
for i in tqdm(range(min(n_seqs, len(it)))):
    seq = next_seq(it)
    if np.allclose(seq, .25): raise Exception("shouldn't have empty seqs")
    seqs[i, :, :] = seq

Generating predictions for 25 seqs


100%|██████████| 25/25 [00:00<00:00, 846.33it/s]


# Loading known FOXA1 motifs

In [59]:
with open('../dat/foxa1.pfm') as f: foxa1_motifs = pfm.read(f, 'pfm-four-rows')

In [251]:
def bio_to_np_pwm(motif):
    nt_counts = motif.counts
    motif_len = len(list(nt_counts.values())[0])
    np_pwm = np.zeros((len(nt_counts.keys()), motif_len))
    pseudocounts = calculate_pseudocounts(motif)
    for i in range(motif_len):
        total = sum(nt_counts[base][i] + pseudocounts[base] for base in BASES)
        for j, base in INT_TO_BASES.items(): 
            np_pwm[j, i] = (nt_counts[base][i] + pseudocounts[base]) / total 
    return np.log(np_pwm)
        

foxa1_pwms = [bio_to_np_pwm(motif) for motif in foxa1_motifs]
assert np.all(pwm < 0 for pwm in foxa1_pwms)

# Results & Analysis
## Computing summary statistics

In [61]:
log_uniform_prob = math.log(.05/(1-.05))
def compute_normalized_prob(prob, train_prob):
    # source: http://deepsea.princeton.edu/help/
    denom = 1+np.exp(-(np.log(prob/(1-prob))+log_uniform_prob-np.log(train_prob/(1-train_prob))))
    return 1 / denom

# Ratios and normalization formula drawn from here: http://deepsea.princeton.edu/media/help/posproportion.txt
tf_compute_normalized_prob = lambda prob: compute_normalized_prob(prob, 0.02394)
chrom_acc_normalized_prob = lambda prob: compute_normalized_prob(prob, 0.049791)

In [62]:
preds[:, :, :, :, 0] = chrom_acc_normalized_prob(preds[:, :, :, :, 0])
preds[:, :, :, :, 1] = compute_normalized_prob(preds[:, :, :, :, 1], 0.020508)
preds[:, :, :, :, 2] = compute_normalized_prob(preds[:, :, :, :, 2], 0.02394)

In [63]:
n_batches = preds.shape[2]
batch_size = preds.shape[3]
preds.shape

(50, 25, 10, 301, 3)

In [64]:
pred_means = np.mean(preds[:, :, :, :, :], axis=0)
pred_vars = np.var(preds, axis=0, dtype=np.float64)
pred_means.shape

(25, 10, 301, 3)

In [65]:
pred_mean_diffs = pred_means[:, :, 1:, :] - pred_means[:, :, 0:1, :] 
pred_mean_diffs.shape

(25, 10, 300, 3)

In [66]:
pred_mean_diffs = pred_mean_diffs.reshape(n_seqs, n_nts-1, -1, 3)
pred_mean_diffs.shape

(25, 3, 1000, 3)

## Comparing Mutation Impact to Known Binding Motif Matches

In [75]:
mut_effects = pred_mean_diffs[:, :, :, 1]

In [380]:
pwm_ll_differences = np.zeros([len(foxa1_pwms), n_seqs, n_nts-1, seq_len])

for i, pwm in enumerate(foxa1_pwms):
    for j, seq in enumerate(tqdm(seqs)):
        k = pwm.shape[1]
        targets = generate_wt_mut_batches(seq, context_size=k).squeeze()
        
        # Note: as a result of torch / numpy conversion there are a few very small positive scores.
        score = pwm_scores(
            np.expand_dims(pwm, 0), targets.squeeze()
        )[0]

        reference_indexes = np.argwhere(seq.reshape(-1) == 1).squeeze()
        reference_scores = score[reference_indexes, :]
        reference_scores = np.repeat(reference_scores[np.newaxis, :, :], 3, axis=0)
        
        mutation_indexes = np.argwhere(seq.reshape(-1) == 0).squeeze()
        mutation_scores = score[mutation_indexes, :].reshape(n_nts-1, seq_len, score.shape[1])
        diffs = np.min(mutation_scores - reference_scores, axis=2)
        pwm_ll_differences[i, j] = diffs

pwm_ll_differences.shape

100%|██████████| 25/25 [00:00<00:00, 51.69it/s]
100%|██████████| 25/25 [00:00<00:00, 52.37it/s]
100%|██████████| 25/25 [00:00<00:00, 51.10it/s]
100%|██████████| 25/25 [00:00<00:00, 51.99it/s]


(4, 25, 3, 1000)

In [381]:
pwm_impact_maps = np.zeros((len(foxa1_pwms), n_seqs, n_nts, seq_len))
for i, pwm in enumerate(foxa1_pwms):
    pwm_impact_maps[i] = build_impact_maps(seqs, pwm_ll_differences[i], ref_preds=[None] * len(seqs))
pwm_impact_maps.shape

(4, 25, 4, 1000)

In [382]:
impact_maps = build_impact_maps(seqs, mut_effects, ref_preds=[None] * len(seqs))
impact_maps.shape

(25, 4, 1000)

# Results

We now have two partially aggregated "scores". First, we have the max differences between the PWM "scores" for each reference / mutated sequence pair over all sequences and mutation positions. The differences are between two log-likelihoods and represent how much each point mutation affects the PWM's similarity to the reference sequence in a window around the mutated position. Second, we have the difference between $ P(\text{binding} \mid \text{reference sequence}) $ and $ P(\text{binding} \mid \text{mutated sequence}) $ for all possible point mutations to each sequence.

The idea behind both of the following metrics we compute is that, since we're using sequences which should have had binding predicted to start with, downward changes to PWM scores should correspond to downward changes to probability of binding. From a biological angle, each of our PWMs represents a known binding motif for the FOXA1 transcription factor. So, high PWM scores correspond to high similarity between a sub-sequence and the canonical motif represented by that PWM. Given this, PWM score decreases as a result of mutations make it more likely that an important motif for the FOXA1 binding protein was disrupted. Assuming our model "understands" these motifs, motif disruptions should lead to lower binding likelihoods.

We first compute the pairwise Pearson correlation between the PWM likelihood differences and the predicted effects of mutations at both the PWM (across all sequence) and PWM/individual sequence levels. As you can see, the correlations are surprisingly low. I'm honestly not sure what to make of this. It implies that there's essentially no relationship between the positions at which 

In [383]:
print("PWM-specific correlations:\n", "-" * 20)
for i, pwm in enumerate(foxa1_pwms):
    assert mut_effects.shape == pwm_ll_differences[i].shape
    correlation = np.corrcoef(mut_effects.reshape(-1), pwm_ll_differences[i].reshape(-1))
    print("Correlation (PWM %s): %.2f" % (foxa1_motifs[i].consensus, correlation[1, 0]))

print("PWM and sequence-specific correlations:\n", "-" * 20)
for i, pwm in enumerate(foxa1_pwms):
    for j, seq in enumerate(seqs):
        assert mut_effects[j].shape == pwm_ll_differences[i, j].shape
        correlation = np.corrcoef(mut_effects[j].reshape(-1), pwm_ll_differences[i, j].reshape(-1))
        print("Correlation (PWM %s, sequence: %d): %.2f" % (foxa1_motifs[i].consensus, j, correlation[1, 0]))


PWM-specific correlations:
 --------------------
Correlation (PWM TGTTTACTTTG): 0.03
Correlation (PWM TGTTTACTTTG): 0.03
Correlation (PWM TCCATGTTTACTTTG): 0.02
Correlation (PWM ATGTAAACATGT): -0.00
PWM and sequence-specific correlations:
 --------------------
Correlation (PWM TGTTTACTTTG, sequence: 0): 0.02
Correlation (PWM TGTTTACTTTG, sequence: 1): 0.04
Correlation (PWM TGTTTACTTTG, sequence: 2): 0.08
Correlation (PWM TGTTTACTTTG, sequence: 3): 0.01
Correlation (PWM TGTTTACTTTG, sequence: 4): 0.06
Correlation (PWM TGTTTACTTTG, sequence: 5): 0.03
Correlation (PWM TGTTTACTTTG, sequence: 6): 0.06
Correlation (PWM TGTTTACTTTG, sequence: 7): -0.06
Correlation (PWM TGTTTACTTTG, sequence: 8): 0.06
Correlation (PWM TGTTTACTTTG, sequence: 9): -0.00
Correlation (PWM TGTTTACTTTG, sequence: 10): -0.00
Correlation (PWM TGTTTACTTTG, sequence: 11): 0.04
Correlation (PWM TGTTTACTTTG, sequence: 12): -0.00
Correlation (PWM TGTTTACTTTG, sequence: 13): -0.02
Correlation (PWM TGTTTACTTTG, sequence: 14):

In [384]:
mut_scores = abs_max(impact_maps, axis=1)
mut_score_indexes = np.argmax(np.abs(impact_maps), axis=1)
mut_scores.shape, mut_score_indexes.shape

((25, 1000), (25, 1000))

In [385]:
total = len(foxa1_pwms) * n_seqs
top_best_in_window = top_worst_in_window = 0

for i in range(len(foxa1_pwms)):
    for j in range(n_seqs):
        seq_score_idxs = np.argsort(np.min(pwm_ll_differences[i][j], axis=0))

        context_size = foxa1_pwms[i].shape[1]
        worst_window_start = max(seq_score_idxs[0] - context_size, 0)
        worst_window_end = min(seq_score_idxs[0] + context_size, seq_len)
        best_window_start = max(seq_score_idxs[-1] - context_size, 0)
        best_window_end = min(seq_score_idxs[-1] + context_size, seq_len)
        
        best_mut_idxs = np.argsort(mut_scores[j])[:5]
        worst_mut_idxs = np.argsort(mut_scores[j])[-5:]

        if np.any(np.logical_and(best_mut_idxs >= best_window_start, best_mut_idxs < best_window_end)): 
            print("BEST\n", "-" * 10)
            top_best_in_window += 1
            print("PWM scores")
            print("Most negative (best): ", pwm_ll_differences[i, j][:, seq_score_idxs[0]])
            print("Least negative (worst): ", pwm_ll_differences[i, j][:, seq_score_idxs[-1]])
            print("Mut scores")
            print("Most negative (best): ", mut_scores[j][best_mut_idxs])
            print("Least negative (worst): ", mut_scores[j][worst_mut_idxs])
            print("PWM match window: start (%d) -> end (%d)" % (best_window_start, best_window_end))
            print("PWM matching sequence window: ", one_hot_decode(seqs[j][:, window_start: window_end]))
            print("'Best' mut score idxs and nts: ", best_mut_idxs, mut_score_indexes[j][best_mut_idxs])
            print("'Worst' mut score idxs and nts: ", worst_mut_idxs, mut_score_indexes[j][worst_mut_idxs])
            print()
        
        if np.any(np.logical_and(worst_mut_idxs >= best_window_start, worst_mut_idxs < best_window_end)): 
            print("WORST\n", "-" * 10)
            top_worst_in_window += 1
            print("PWM scores")
            print("Most negative (best): ", pwm_ll_differences[i, j][:, seq_score_idxs[0]])
            print("Least negative (worst): ", pwm_ll_differences[i, j][:, seq_score_idxs[-1]])
            print("PWM match window: start (%d) -> end (%d)" % (best_window_start, best_window_end))
            print()
            
            print("Mut scores")
            print("Most negative (best): ", mut_scores[j][best_mut_idxs])
            print("PWM consensus: ", foxa1_motifs[i].consensus)
            print("SEQ window:    ", one_hot_decode(seqs[j][:, best_window_start: best_window_end]))
            print("'Best' mut score idxs and nts: ", best_mut_idxs, mut_score_indexes[j][best_mut_idxs])
            print("Least negative (worst): ", mut_scores[j][worst_mut_idxs])
            print(
                "'Worst' mut score idxs and nts: ", 
                one_hot_decode(seq[:, worst_mut_idxs]), 
                worst_mut_idxs,
                mut_score_indexes[j][worst_mut_idxs]
            )
            print()
 

WORST
 ----------
PWM scores
Most negative (best):  [-46.93837817 -49.60380269  -6.07713722]
Least negative (worst):  [-3.8650377  -7.79437527 -5.71752006]
PWM match window: start (820) -> end (850)

Mut scores
Most negative (best):  [-0.17357698 -0.15974003 -0.159639   -0.15943506 -0.15715995]
PWM consensus:  TCCATGTTTACTTTG
SEQ window:     TTTTTACCTCAATTTCCTGTCCCAAGATGG
'Best' mut score idxs and nts:  [899 902 891 900 694] [2 1 1 1 1]
Least negative (worst):  [0.17635483 0.18365037 0.23777428 0.25939775 0.30138874]
'Worst' mut score idxs and nts:  AGATC [645 711 833 644 528] [2 1 1 2 1]

WORST
 ----------
PWM scores
Most negative (best):  [-47.54423826 -53.84018004 -10.52841764]
Least negative (worst):  [-9.07502256 -8.96574216 -7.08061643]
PWM match window: start (301) -> end (331)

Mut scores
Most negative (best):  [-0.22850603 -0.22618456 -0.22103682 -0.21916957 -0.21157001]
PWM consensus:  TCCATGTTTACTTTG
SEQ window:     AGTTTCAGAAAATACCCCTCCAATTTATTT
'Best' mut score idxs and nt

In [386]:
fraction_in_best_window = top_best_in_window / total
fraction_in_worst_window = top_worst_in_window / total
"Fraction in best / worst window: %.2f, %.2f" % (fraction_in_best_window, fraction_in_worst_window)

'Fraction in best / worst window: 0.01, 0.09'