## Imports and model initialization

In [1]:
# !pip install kipoi
# !pip install kipoiseq
# !pip install pybedtools
# !pip uninstall -y kipoi_veff
# !pip install git+https://github.com/an1lam/kipoi-veff
# !pip install pyvcf
import csv
import math
import pickle

import kipoi
from kipoi_interpret.importance_scores.ism import Mutation
from kipoiseq.dataloaders import SeqIntervalDl
from matplotlib import pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
from utils import detect_device

from custom_dropout import apply_dropout, replace_dropout_layers, unapply_dropout
from custom_dropout import LockedWeightDropout

In [2]:
!pwd

/home/stephenmalina/project/src


# Loading DNA sequence data

In [3]:
dl = SeqIntervalDl("../dat/50_random_seqs_2.bed", "../dat/hg19.fa", auto_resize_len=1000)
data = dl.load_all()

100%|██████████| 2/2 [00:00<00:00,  2.22it/s]


In [4]:
n_seqs = 25
seqs = np.expand_dims(data['inputs'].transpose(0, 2, 1), 2).astype(np.float32)
seqs = seqs[:n_seqs]
seqs.shape

(25, 4, 1, 1000)

# Loading DeepSEA

In [5]:
import tensorflow as tf
print("TF version:", tf.__version__)
import torch
print("torch version:", torch.__version__)
from torch import nn

TF version: 1.15.0
torch version: 1.3.1


In [6]:
# df = kipoi.list_models()
# deepsea_models = df[df.model.str.contains("DeepSEA")]
# deepsea_models.head()

In [7]:
deepsea = kipoi.get_model("DeepSEA/predict", source="kipoi")
deepsea.model

Using downloaded and verified file: /home/stephenmalina/.kipoi/models/DeepSEA/predict/downloaded/model_files/weights/89e640bf6bdbe1ff165f484d9796efc7


Sequential(
  (0): ReCodeAlphabet()
  (1): ConcatenateRC()
  (2): Sequential(
    (0): Conv2d(4, 320, kernel_size=(1, 8), stride=(1, 1))
    (1): Threshold(threshold=0, value=1e-06)
    (2): MaxPool2d(kernel_size=(1, 4), stride=(1, 4), padding=0, dilation=1, ceil_mode=False)
    (3): Dropout(p=0.2, inplace=False)
    (4): Conv2d(320, 480, kernel_size=(1, 8), stride=(1, 1))
    (5): Threshold(threshold=0, value=1e-06)
    (6): MaxPool2d(kernel_size=(1, 4), stride=(1, 4), padding=0, dilation=1, ceil_mode=False)
    (7): Dropout(p=0.2, inplace=False)
    (8): Conv2d(480, 960, kernel_size=(1, 8), stride=(1, 1))
    (9): Threshold(threshold=0, value=1e-06)
    (10): Dropout(p=0.5, inplace=False)
    (11): Lambda()
    (12): Sequential(
      (0): Lambda()
      (1): Linear(in_features=50880, out_features=925, bias=True)
    )
    (13): Threshold(threshold=0, value=1e-06)
    (14): Sequential(
      (0): Lambda()
      (1): Linear(in_features=925, out_features=919, bias=True)
    )
    (15):

In [8]:
deepsea.pipeline.predict_example().shape

100%|██████████| 1/1 [00:00<00:00,  2.77it/s]


(10, 919)

# Hacking DeepSEA's layers
This section is focused on tweaking DeepSEA's model to use dropout when making predictions.

In [9]:
deepsea.model = replace_dropout_layers(deepsea.model, dropout_cls=LockedWeightDropout)
deepsea.model

Sequential(
  (0): ReCodeAlphabet()
  (1): ConcatenateRC()
  (2): Sequential(
    (0): Conv2d(4, 320, kernel_size=(1, 8), stride=(1, 1))
    (1): Threshold(threshold=0, value=1e-06)
    (2): MaxPool2d(kernel_size=(1, 4), stride=(1, 4), padding=0, dilation=1, ceil_mode=False)
    (3): LockedWeightDropout(p=0.2)
    (4): Conv2d(320, 480, kernel_size=(1, 8), stride=(1, 1))
    (5): Threshold(threshold=0, value=1e-06)
    (6): MaxPool2d(kernel_size=(1, 4), stride=(1, 4), padding=0, dilation=1, ceil_mode=False)
    (7): LockedWeightDropout(p=0.2)
    (8): Conv2d(480, 960, kernel_size=(1, 8), stride=(1, 1))
    (9): Threshold(threshold=0, value=1e-06)
    (10): LockedWeightDropout(p=0.5)
    (11): Lambda()
    (12): Sequential(
      (0): Lambda()
      (1): Linear(in_features=50880, out_features=925, bias=True)
    )
    (13): Threshold(threshold=0, value=1e-06)
    (14): Sequential(
      (0): Lambda()
      (1): Linear(in_features=925, out_features=919, bias=True)
    )
    (15): Sigmoid(

### Verifying dropout is being applied
We can verify or at least get reasonable certainty that dropout's being applied by running 10 predictions on the same sequence and then seeing that the predicted binding probs differ. First, we show that the predictions stay the same if dropout is off.

Then, we turn on dropout and show that they start to vary.

In [10]:
deepsea.model = deepsea.model.apply(unapply_dropout)
for i in range(3):
    print(f"First few preds ({i+1}/3): ", deepsea.pipeline.predict_example()[:5])

100%|██████████| 1/1 [00:00<00:00, 20.63it/s]


First few preds (1/3):  [[0.08161031 0.06867661 0.10076798 ... 0.09493414 0.02133885 0.01201447]
 [0.06698208 0.01062424 0.02694638 ... 0.15490864 0.04822354 0.00770117]
 [0.04445557 0.00539728 0.018408   ... 0.14994667 0.35297143 0.02272817]
 [0.00245111 0.00048716 0.00247776 ... 0.01495204 0.04857798 0.00068955]
 [0.00076919 0.00535259 0.00189638 ... 0.01889143 0.06013346 0.02190239]]


100%|██████████| 1/1 [00:00<00:00, 19.54it/s]


First few preds (2/3):  [[0.08161031 0.06867661 0.10076798 ... 0.09493414 0.02133885 0.01201447]
 [0.06698208 0.01062424 0.02694638 ... 0.15490864 0.04822354 0.00770117]
 [0.04445557 0.00539728 0.018408   ... 0.14994667 0.35297143 0.02272817]
 [0.00245111 0.00048716 0.00247776 ... 0.01495204 0.04857798 0.00068955]
 [0.00076919 0.00535259 0.00189638 ... 0.01889143 0.06013346 0.02190239]]


100%|██████████| 1/1 [00:00<00:00, 20.71it/s]


First few preds (3/3):  [[0.08161031 0.06867661 0.10076798 ... 0.09493414 0.02133885 0.01201447]
 [0.06698208 0.01062424 0.02694638 ... 0.15490864 0.04822354 0.00770117]
 [0.04445557 0.00539728 0.018408   ... 0.14994667 0.35297143 0.02272817]
 [0.00245111 0.00048716 0.00247776 ... 0.01495204 0.04857798 0.00068955]
 [0.00076919 0.00535259 0.00189638 ... 0.01889143 0.06013346 0.02190239]]


In [11]:
deepsea.model.eval()
deepsea.model = deepsea.model.apply(apply_dropout)
for i in range(3):
    print(f"First few preds ({i+1}/3): ", deepsea.pipeline.predict_example()[:5])

100%|██████████| 1/1 [00:00<00:00, 20.39it/s]


First few preds (1/3):  [[7.8655757e-02 8.3806358e-02 9.7069323e-02 ... 5.6508001e-02
  2.4502188e-02 1.5566023e-02]
 [3.5470471e-02 4.5489939e-03 1.9990906e-02 ... 1.0030031e-01
  4.1729517e-02 9.0535246e-03]
 [3.2651491e-02 4.1565662e-03 1.6885564e-02 ... 9.1600373e-02
  3.5880196e-01 2.9147882e-02]
 [1.1820053e-03 1.6361063e-04 1.2224740e-03 ... 1.1896892e-02
  4.5368865e-02 4.9301307e-04]
 [4.4277078e-04 3.5921750e-03 1.3708207e-03 ... 1.0503948e-02
  5.2451190e-02 2.9240381e-02]]


100%|██████████| 1/1 [00:00<00:00, 20.44it/s]


First few preds (2/3):  [[0.05306129 0.03041203 0.07123175 ... 0.05064587 0.03200493 0.01031295]
 [0.03007108 0.00592298 0.02056084 ... 0.22594553 0.06201714 0.00791816]
 [0.02606363 0.00296377 0.01511293 ... 0.07262659 0.23947865 0.01690581]
 [0.00212761 0.00043771 0.00250371 ... 0.01329707 0.04569716 0.00053356]
 [0.00082771 0.00540447 0.00260013 ... 0.02363428 0.05851531 0.02354807]]


100%|██████████| 1/1 [00:00<00:00, 20.60it/s]


First few preds (3/3):  [[0.09402785 0.08797362 0.10711947 ... 0.09701223 0.01897594 0.01406437]
 [0.06788946 0.0117236  0.02979147 ... 0.2000671  0.051117   0.00931144]
 [0.05007762 0.00516387 0.0236172  ... 0.11218122 0.30588615 0.02488672]
 [0.00224622 0.0013474  0.00320529 ... 0.00926662 0.03031096 0.00054907]
 [0.00099179 0.0054762  0.00199367 ... 0.03638974 0.05912359 0.02541691]]


## Predictions and in-silico mutagenesis

In [12]:
CHROM_ACC_COL = 'HepG2_DNase_None'
# TF_COL = 'A549_CTCF_None'
TF_COL = 'HepG2_FOXA1_None'
relevant_cols = sorted([(i, label)
                        for i, label in enumerate(deepsea.schema.targets.column_labels)
                        if label in [CHROM_ACC_COL, TF_COL]])

def output_sel_fn(result):
    return np.array([result[:, col_idx] for col_idx, _ in relevant_cols]).T

relevant_cols

[(56, 'HepG2_DNase_None'),
 (302, 'HepG2_FOXA1_None'),
 (303, 'HepG2_FOXA1_None')]

In [13]:
sample_preds = np.zeros((50, 5, 3))

for i in range(50):
    sample_preds[i] = output_sel_fn(deepsea.predict_on_batch(seqs[:5, :, :, ]))
    
np.mean(sample_preds, axis=0), np.std(sample_preds, axis=0)

(array([[0.01698819, 0.04649171, 0.05820748],
        [0.16994664, 0.24118587, 0.2511029 ],
        [0.01970757, 0.20405887, 0.21905734],
        [0.10903159, 0.23565267, 0.20230297],
        [0.06961777, 0.22315117, 0.22908154]]),
 array([[0.00548362, 0.02609396, 0.0314908 ],
        [0.0630529 , 0.09314241, 0.09340617],
        [0.0084638 , 0.09068899, 0.09530321],
        [0.02229036, 0.06753918, 0.0607408 ],
        [0.02847355, 0.08705198, 0.08907392]]))

In [14]:
all_zeros = np.zeros((4,))
def generate_wt_mut_batches(seq, batch_size):
    """
    For a given sequence, generate all possible point-mutated versions of the sequence
    in batches of size `param:batch_size`.
    
    Args:
        seq (numpy.ndarray [number of base pairs, sequence length]): 
            wild type sequence.
        batch_size (int): size of returned batches. Note that each batch will have the
            wild type sequence as its first row since we need to compute wild type / mut
            prediction diffs using predictions generated by the same dropout mask.
    """
    num_nts, seq_len = seq.shape
    assert ((seq_len * 3) % (batch_size-1)) == 0, seq_len * 3
    # 3 mutations per nt and then account for ref in each batch
    n_batches = (seq_len * 3) // (batch_size-1)
    seq_batch = seq[np.newaxis, :, :].repeat(batch_size, axis=0)
    seq_batches = seq_batch[np.newaxis, :, :, :].repeat(n_batches, axis=0)
    i = 0
    for seq_idx in range(seq_len):  # iterate over sequence 
        for nt_idx in range(num_nts):  # iterate over nucleotides
            curr_batch, curr_idx = i // (batch_size - 1), (i % (batch_size-1) + 1)
            
            curr_nt = seq[nt_idx, seq_idx]
            if int(curr_nt) == 1: continue

            seq_batches[curr_batch, curr_idx, :, seq_idx] = all_zeros
            seq_batches[curr_batch, curr_idx, nt_idx, seq_idx] = 1
            i += 1
    return seq_batches

In [15]:
epochs, batch_size = 50, 301
print(f"Generating predictions for {n_seqs} seqs")

preds = [[[] for _ in range(n_seqs)] for _ in range(epochs)]
for i, seq in enumerate(tqdm(seqs)):
    if np.allclose(seq, .25): print("Skipping"); continue
    wt_mut_batches = generate_wt_mut_batches(seq.squeeze(), batch_size)
    for batch in tqdm(wt_mut_batches):
        for epoch in range(epochs):
            epoch_preds = deepsea.predict_on_batch(np.expand_dims(batch, axis=2))
            preds[epoch][i].append(output_sel_fn(epoch_preds))


np_preds = np.array(preds)

Generating predictions for 25 seqs


HBox(children=(IntProgress(value=0, max=25), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10), HTML(value='')))





In [16]:
import pickle
pickle_file = "../dat/most_recent_sat_mut_results__comparison.pickle"
with open(pickle_file, 'wb') as f: pickle.dump(np_preds, f)

In [None]:
# with open(pickle_file, 'rb') as f: np_preds = pickle.load(f)
# print(np_preds.shape)

In [None]:
epochs, n_seqs, n_batches, batch_size, _ = np_preds.shape
np_preds.shape

In [None]:
log_uniform_prob = math.log(.05/(1-.05))
def compute_normalized_prob(prob, train_prob):
    # source: http://deepsea.princeton.edu/help/
    denom = 1+np.exp(-(np.log(prob/(1-prob))+log_uniform_prob-np.log(train_prob/(1-train_prob))))
    return 1 / denom

# Ratios and normalization formula drawn from here: http://deepsea.princeton.edu/media/help/posproportion.txt
tf_compute_normalized_prob = lambda prob: compute_normalized_prob(prob, 0.02394)
chrom_acc_normalized_prob = lambda prob: compute_normalized_prob(prob, 0.049791)

# Results & Analysis
## Computing relevant statistics
In this section, we compute the predictive mean and variance of the raw predictions and of each ref/mut pair, which also requires computing covariance.

In [None]:
np_preds[:, 0, 0, 0, 1], np.mean(np_preds[:, 0, 0, 0, 1]), np.std(np_preds[:, 0, 0, 0, 1])

Previous output:
```
(array([0.05126143, 0.1050737 , 0.06798535, 0.07285436, 0.09456536,
        0.10932738, 0.04311277, 0.11430986, 0.13595097, 0.07084743,
        0.04534521, 0.08461001, 0.07526002, 0.08425504, 0.09917682,
        0.03529431, 0.18999103, 0.08796159, 0.16572888, 0.10169207,
        0.19090581, 0.04598923, 0.10690306, 0.08212011, 0.04604339,
        0.10946183, 0.11298828, 0.060181  , 0.05917034, 0.15545578,
        0.07560672, 0.06172951, 0.07798496, 0.07683495, 0.06256678,
        0.06378538, 0.06284482, 0.20937835, 0.06821388, 0.0717489 ,
        0.1056542 , 0.03477913, 0.2411811 , 0.07979514, 0.07677957,
        0.15393922, 0.06378307, 0.05788552, 0.06131508, 0.12282959],
       dtype=float32), 0.09264916, 0.045424417)
```

In [None]:
np_preds[:, :, :, :, 0] = chrom_acc_normalized_prob(np_preds[:, :, :, :, 0])
np_preds[:, :, :, :, 1] = compute_normalized_prob(np_preds[:, :, :, :, 1], 0.020508)
np_preds[:, :, :, :, 2] = compute_normalized_prob(np_preds[:, :, :, :, 2], 0.02394)

In [None]:
np.mean(np_preds[:, :, 0, 0, 1], axis=0)

In [None]:
n_batches = np_preds.shape[2]
batch_size = np_preds.shape[3]
np_preds.shape

In [None]:
np_pred_means = np.mean(np_preds[:, :, :, :, :], axis=0)
np_pred_vars = np.var(np_preds, axis=0, dtype=np.float64)
np_pred_means.shape

In [None]:
np_pred_mean_diffs = np_pred_means[:, :, 1:, :] - np_pred_means[:, :, 0:1, :] 
np_pred_mean_diffs.shape

In [None]:
np_pred_covs = np.zeros((n_seqs, n_batches, batch_size, 2, 2, len(relevant_cols)))
for seq in range(n_seqs):
    for batch in range(n_batches):
        for col in range(len(relevant_cols)):
            ref_seq_preds = np_preds[:, seq, batch, 0, col]
            for mut in range(batch_size):
                mut_seq_preds = np_preds[:, seq, batch, mut, col]
                cov = np.cov(np.stack((ref_seq_preds, mut_seq_preds)), ddof=0) # 2x2, symmetric
                np_pred_covs[seq, batch, mut, :, :, col] = cov # off diag idx

In [None]:
print(np_pred_covs[0, 0, 50, :, :, 0])
print(np_pred_vars[0, 0, 48:52, 0])

In [None]:
print(np_pred_covs.dtype)

In [None]:
np_pred_uncertainties = np.sqrt(np_pred_covs[:, :, 1:, 1, 1, :] + np_pred_covs[:, :, 1:, 0, 0, :] - 2 * np_pred_covs[:, :, 1:, 0, 1, :])

Prior output:
```
array([[4.48878671e-04, 3.76016830e-03, 3.91444136e-03],
       [4.24352747e-03, 6.48325787e-03, 6.35424058e-03],
       [7.21780115e-04, 6.80836721e-03, 6.69385231e-03],
       [2.14841515e-03, 6.13503189e-03, 5.81760365e-03],
       [2.53199270e-03, 6.57216440e-03, 6.43075244e-03],
       [1.75606661e-03, 4.86487429e-03, 4.74457307e-03],
       [1.76368384e-03, 2.73868525e-03, 2.12760159e-03],
       [2.58643312e-03, 4.79742642e-03, 4.72901294e-03],
       [1.67957813e-03, 5.97133811e-03, 5.67350658e-03],
       [1.42978106e-03, 4.81405968e-03, 4.56677096e-03],
       [1.77184916e-03, 4.97078494e-03, 5.03889815e-03],
       [3.56837944e-04, 3.28258528e-03, 3.70387383e-03],
       [2.64877971e-03, 2.00102456e-03, 1.93569081e-03],
       [4.53994289e-03, 5.95794895e-03, 5.51303414e-03],
       [1.51220427e-03, 6.54928714e-03, 6.57433017e-03],
       [5.88259515e-04, 5.53205052e-03, 5.34385837e-03],
       [1.09700618e-03, 5.16137308e-03, 5.03230324e-03],
       [1.12339457e-03, 2.24760983e-03, 1.91027930e-03],
       [3.20376392e-03, 4.92001099e-03, 4.04100349e-03],
       [2.49861072e-03, 4.86418106e-03, 4.92681169e-03],
       [1.20709174e-04, 3.71777313e-03, 3.81853364e-03],
       [2.47268381e-03, 6.00664759e-03, 5.72693128e-03],
       [7.54406915e-04, 5.26111607e-03, 5.58529764e-03],
       [3.66819203e-03, 7.22119652e-03, 7.02014570e-03],
       [3.62281170e-03, 6.36732431e-03, 6.34870359e-03],
       [4.03628947e-05, 7.45491883e-04, 8.41239868e-04],
       [1.38372533e-05, 7.72275263e-05, 1.04527934e-04],
       [4.20893260e-04, 2.73638685e-03, 2.93110060e-03],
       [7.43930255e-05, 5.43012993e-04, 7.45896503e-04],
       [1.03129554e-05, 3.41938909e-05, 4.97885397e-05],
       [3.15892870e-04, 2.65540526e-04, 2.05583915e-04],
       [3.84993409e-04, 3.75197042e-03, 4.09271531e-03],
       [1.51029463e-03, 9.92187040e-04, 7.37464625e-04],
       [3.91123794e-05, 4.42296623e-04, 5.92760265e-04],
       [6.80941129e-04, 3.82639830e-03, 3.90858616e-03],
       [5.04282705e-05, 5.74359808e-05, 4.21293565e-05],
       [1.28339127e-04, 4.31538429e-04, 3.44554821e-04],
       [1.06008759e-04, 5.22394941e-04, 5.16982962e-04],
       [1.99350312e-04, 3.82583955e-04, 3.30806837e-04],
       [2.03201481e-04, 2.50868539e-04, 1.91278546e-04],
       [5.79685198e-05, 1.19629254e-03, 1.36500570e-03],
       [3.18373139e-05, 1.12246586e-04, 1.48953542e-04],
       [9.06054963e-05, 1.73437150e-04, 1.55170388e-04],
       [4.54140717e-05, 2.84565435e-04, 3.01416159e-04],
       [7.09704930e-05, 1.47903155e-04, 1.13104448e-04],
       [6.81710009e-05, 1.63678267e-04, 1.72173579e-04],
       [4.29242786e-06, 4.92805347e-05, 7.38577075e-05],
       [1.34657224e-04, 2.79072807e-04, 2.58043714e-04],
       [4.75108948e-04, 3.21890477e-04, 2.44209141e-04],
       [2.73626697e-05, 6.84841822e-05, 8.03758732e-05]])
  ```

In [None]:
np_pred_means[0, :, :5, :]

Previous output:
```
array([[[0.01535607, 0.09264916, 0.10030716],
        [0.01536481, 0.09268412, 0.10034063],
        [0.0153535 , 0.0926303 , 0.10028694],
        [0.01535498, 0.09263348, 0.10029003],
        [0.01535387, 0.09263583, 0.10029308]],

       [[0.01488582, 0.08554734, 0.09253196],
        [0.01465443, 0.084775  , 0.09173969],
        [0.01492254, 0.08458555, 0.09155089],
        [0.01457286, 0.08459219, 0.09151619],
        [0.01439259, 0.08350793, 0.09025328]],

       [[0.01598858, 0.10019252, 0.10861898],
        [0.0156244 , 0.09905797, 0.10743561],
        [0.01544473, 0.09873247, 0.10714547],
        [0.01559261, 0.09845467, 0.1067265 ],
        [0.01551011, 0.09955419, 0.10797767]],

       [[0.01544482, 0.09802654, 0.10629827],
        [0.01507335, 0.09550025, 0.10353882],
        [0.01587485, 0.10047743, 0.10917354],
        [0.01433429, 0.09081684, 0.09837098],
        [0.01556314, 0.09916446, 0.10780552]],

       [[0.01516749, 0.08964344, 0.09674001],
        [0.01639638, 0.09238901, 0.09860244],
        [0.01509506, 0.08103842, 0.08690701],
        [0.01358717, 0.07601662, 0.08262065],
        [0.01766347, 0.11328302, 0.12140989]],

       [[0.01695254, 0.105074  , 0.11348913],
        [0.01648404, 0.10441449, 0.11273133],
        [0.01728942, 0.10697857, 0.11449711],
        [0.01929544, 0.10928906, 0.11674011],
        [0.01661537, 0.10354719, 0.11319642]],

       [[0.01587138, 0.09646078, 0.10463066],
        [0.01763695, 0.11694595, 0.12509888],
        [0.01650194, 0.11137906, 0.11954492],
        [0.01537086, 0.09619159, 0.10390938],
        [0.01475383, 0.0892925 , 0.09741764]],

       [[0.01607911, 0.09555656, 0.10329047],
        [0.01644303, 0.0953081 , 0.10331029],
        [0.01658046, 0.1012515 , 0.10875904],
        [0.01569122, 0.09139476, 0.09877162],
        [0.01581402, 0.09524193, 0.10289911]],

       [[0.01668472, 0.11106696, 0.11931916],
        [0.01711003, 0.11378796, 0.12228256],
        [0.0165246 , 0.11149254, 0.11993513],
        [0.01681521, 0.11339774, 0.12225165],
        [0.01646912, 0.10928717, 0.1178958 ]],

       [[0.01533156, 0.09584869, 0.10469947],
        [0.01530683, 0.09577515, 0.10461981],
        [0.01543046, 0.09717911, 0.10631446],
        [0.01571676, 0.09747174, 0.10664452],
        [0.01537928, 0.09591264, 0.10480152]]], dtype=float32)
```

In [None]:
np_pred_uncertainties = np.sqrt(np_pred_vars[:, :, 1:, :] + np_pred_vars[:, :, 0:1, :] - 2 * np_pred_covs[:, :, 1:, 0, 1, :])
np.mean(np.mean(np_pred_uncertainties, axis=2), axis=1)

Results from prior runs:

    array([[0.02932001, 0.16316762, 0.16552047],
       [0.15966936, 0.22315478, 0.22175914]])

    array([[0.00388914, 0.03407012, 0.03650351],
       [0.04776923, 0.09040056, 0.08849534]])

In [None]:
tmp = np.mean(np_preds)

## Accuracy & Calibration

In [None]:
import sklearn
from sklearn.metrics import auc, brier_score_loss, roc_auc_score, roc_curve
from sklearn.calibration import calibration_curve 

score = roc_auc_score(
    np.concatenate((np.ones(25), np.zeros(25))),
    np_pred_means[:, 0, 0, 1],
)
score

In [None]:
from sklearn.metrics import auc, roc_auc_score, roc_curve
from sklearn.calibration import calibration_curve

score = roc_auc_score(
    np.concatenate((np.ones(25), np.zeros(25))),
    np_pred_means[:, 0, 0, 1],
)
score

In [None]:
score = roc_auc_score(
    np.concatenate((np.ones(25), np.zeros(25))),
    np_pred_means[:, 0, 0, 1],
)

In [None]:
fpr, tpr, thresholds = roc_curve(
    np.concatenate((np.ones(25), np.zeros(25))),
    np_pred_means[:, 0, 0, 1],
)
roc_auc = auc(fpr, tpr)
plt.figure()
lw = 2
plt.plot(fpr, tpr, color='darkorange',
         lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic (TF binding)')
plt.legend(loc="lower right")
plt.show()

In [None]:
y_test = np.concatenate((np.ones(25), np.zeros(25)))
prob_pos = np_pred_means[:, 0, 0, 1]
fraction_of_positives, mean_predicted_value = calibration_curve(
    y_test,
    prob_pos,
    n_bins=10
)
clf_score = brier_score_loss(y_test, prob_pos)
plt.plot(mean_predicted_value, fraction_of_positives, "s-",
     label="%s (%1.3f)" % ("MC dropout predictive means (TF)", clf_score))
plt.ylabel("Fraction of positives")
# ax2.hist(prob_pos, range=(0, 1), bins=10, label=name,
#          histtype="step", lw=2)
prob_pos

In [None]:
plt.hist(prob_pos, range=(0, 1), bins=10, label="Predictive means (TF)",
          histtype="step", lw=2)
plt.legend()

In [None]:
y_test = np.concatenate((np.ones(25), np.zeros(25)))
prob_pos = np_pred_means[:, 0, 0, 0]
fraction_of_positives, mean_predicted_value = calibration_curve(
    y_test,
    prob_pos,
    n_bins=10
)
clf_score = brier_score_loss(y_test, prob_pos)
plt.plot(mean_predicted_value, fraction_of_positives, "s-",
     label="%s (%1.3f)" % ("MC dropout predictive means (CA)", clf_score))
plt.ylabel("Fraction of positives")

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 20))
plt.suptitle("Predictive Means (Ref)")

ax1.hist(
    (np_pred_means[:25, 0, 0, 1].reshape(-1),
     np_pred_means[25:, 0, 0, 1].reshape(-1)),
    label=("binding", "no binding"))
ax1.set_title(f"TF {relevant_cols[1][1]}")
ax1.legend()
ax2.hist(
    (np_pred_means[:25, 0, 0, 0].reshape(-1),
     np_pred_means[25:, 0, 0, 0].reshape(-1)),
    label=("accessible", "not accessible"))
ax2.set_title(f"DNase {relevant_cols[0][1]}")
ax2.legend()

fig.text(0.5, 0.08, "Predictive Means for y=0 vs y=1", ha="center")
fig.text(0.07, 0.5, "# of Seqs", va='center', rotation='vertical')

plt.show();

### Standard Error Calibration

In [None]:
from scipy import stats

sample_seq = 2
sample_pred_diffs = (np_preds[:, sample_seq, :, 1:, :] - np_preds[:, sample_seq, :, 0:1, :]).reshape(epochs, -1)
sample_std_errs = np_pred_uncertainties[sample_seq, :, :, :].reshape(-1)
sample_mean_diffs = np_pred_mean_diffs[sample_seq, :, :, :].reshape(-1)
normalized_preds = (sample_pred_diffs - sample_mean_diffs) / sample_std_errs

res = stats.probplot(normalized_preds[:, :].reshape(-1), plot=plt)

## TF-CA Relationship and Mutation Effect Exploration 

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 10))
plt.suptitle("Predictive Mean Diffs (Mut vs. Ref)")

ax1.hist(
    (np_pred_mean_diffs[:25, :, 1:, 1].reshape(-1),
     np_pred_mean_diffs[25:, :, 1:, 1].reshape(-1)),
    label=("binding", "no binding"),
    log=True)
ax1.set_title(f"TF {relevant_cols[1][1]}")
ax1.legend()
ax2.hist(
    (np_pred_mean_diffs[:25, :, 1:, 0].reshape(-1),
     np_pred_mean_diffs[25:, :, 1:, 0].reshape(-1)),
    label=("binding", "no binding"),
    log=True)
ax2.set_title(f"DNase {relevant_cols[2][1]}")
ax2.legend()

fig.text(0.5, 0.08, "Predictive Mean Diff", ha="center")
fig.text(0.07, 0.5, "# of Seqs", va='center', rotation='vertical')


plt.show();

In [None]:
TF_COL = 1
CA_COL = 0

fig, ((ax11, ax12), (ax21, ax22)) = plt.subplots(2, 2, figsize=(14, 10))
plt.suptitle("S.E. vs. Predictive Mean Diff (TF)")

for seq in range(1, 2):
    ax11.scatter(
        np_pred_mean_diffs[seq, :, :, TF_COL].reshape(-1), 
        np_pred_uncertainties[seq, :, :, TF_COL].reshape(-1))
    ax11.set_title(f"{relevant_cols[1][1]} (binding)")

for seq in range(1, 2):
    ax12.scatter(
        np_pred_mean_diffs[seq, :, :, TF_COL].reshape(-1), 
        np_pred_uncertainties[seq, :, :, TF_COL].reshape(-1))
    ax12.set_title(f"{relevant_cols[1][1]} (no binding)")
    
for seq in range(1, 2):
    ax21.scatter(
        np_pred_mean_diffs[seq, :, :, CA_COL].reshape(-1), 
        np_pred_uncertainties[seq, :, :, CA_COL].reshape(-1))
    ax21.set_title(f"{relevant_cols[0][1]} (binding)")

for seq in range(1, 2):
    ax22.scatter(
        np_pred_mean_diffs[seq, :, :, CA_COL].reshape(-1), 
        np_pred_uncertainties[seq, :, :, CA_COL].reshape(-1))
    ax22.set_title(f"{relevant_cols[0][1]} (no binding)")
    
fig.text(0.5, 0.08, "Predictive Mean Diff", ha="center")
fig.text(0.07, 0.5, "Predictive S.E.", va='center', rotation='vertical')

plt.show()

In [None]:
n_binding_seqs = 25
_, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 10))
plt.suptitle("CA predictive mean diff vs. TF predictive mean diff")
for seq in range(n_binding_seqs):
    title, ax = ("(binding)", ax1) if seq < n_binding_seqs else ("(no binding)", ax1)
    ax1.scatter(
        np_pred_mean_diffs[seq, :, :, TF_COL].reshape(-1), 
        np_pred_mean_diffs[seq, :, :, CA_COL].reshape(-1), 
        label=seq)
    if seq < n_binding_seqs: ax1.set_title("(binding)")
    ax1.legend()
for seq in range(n_binding_seqs, n_seqs):
    ax2.scatter(
        np_pred_mean_diffs[seq, :, :, TF_COL].reshape(-1), 
        np_pred_mean_diffs[seq, :, :, CA_COL].reshape(-1), 
        label=seq)
    ax2.set_title("(no binding)")
    ax2.legend()
fig.text(0.5, 0.08, "TF Predictive Mean Diff", ha="center")
fig.text(0.07, 0.5, "CA Predictive Mean Diff", va='center', rotation='vertical')

plt.show()

In [None]:
from matplotlib import patches
from matplotlib import cm

cols, margin = 3, 10 # margin determined empirically
fig, axs = plt.subplots(math.ceil(n_seqs / float(cols)), cols, figsize=(16, n_seqs + margin))
# sample_seqs = np.random.random_integers(0, n_seqs, size=4)
# sample_batches = np.random.random_integers(0, n_batches, size=4)
# sample_muts = np.random.random_integers(0, batch_size-1, size=4)
# sample_col = np.random.random_integers(0, len(relevant_cols)-1)

for i in range(n_seqs):
    np_sample_mut_preds = np_preds[:, i, :, 1:, TF_COL]
    np_sample_ref_preds = np.zeros_like(np_sample_mut_preds) + np_preds[:, i, :, :1, TF_COL]
    assert np.allclose(np_sample_ref_preds[:, 0, 0], np_sample_ref_preds[:, 0, 1]) # spot check
    colors = (
        (np_sample_mut_preds.ravel() - np_sample_mut_preds.mean())**2 + 
        (np_sample_ref_preds.ravel() - np_sample_ref_preds.mean())**2
    )
    ax = axs[i // cols, i % cols]
    ax.hexbin(
        np_sample_ref_preds.ravel(), 
        np_sample_mut_preds.ravel(), 
        C=colors,
        cmap=cm.jet,
        bins=None,
    )
    xq1, xq2 = np.quantile(np_sample_ref_preds, (.25, .75))
    yq1, yq2 = np.quantile(np_sample_mut_preds, (.25, .75))
    rect = patches.Rectangle((xq1, yq1), xq2 - xq1, yq2 - yq1, fill=False, edgecolor='black')
    rect = ax.add_patch(rect)
    xlabel = "seq {i} ref (std: {stddev:.3f})".format(i=i, stddev=np.std(np_sample_ref_preds))
    ylabel = "seq {i} mut (std: {stddev:.3f})".format(i=i, stddev=np.std(np_sample_mut_preds))
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)

In [None]:
import scipy
from scipy import stats
import statsmodels.api as sm
from statsmodels.sandbox.regression.predstd import wls_prediction_std



cols, margin = 3, 20 # margin determined empirically
fig, axs = plt.subplots(math.ceil(n_seqs / float(cols)), cols, figsize=(16, n_seqs + margin))
# plt.suptitle("CA predictive mean diff vs. TF predictive mean diff")
slopes = []
rsquareds = []

for seq in range(n_seqs):
    ax = axs[seq // cols, seq % cols]
    
    x, y = np_pred_mean_diffs[seq, :, :, TF_COL].reshape(-1), np_pred_mean_diffs[seq, :, :, CA_COL].reshape(-1)
    xc = sm.add_constant(x)
    model = sm.OLS(y, xc)
    result = model.fit()
    intercept, slope = result.params
    slopes.append(slope)
    rsquared = result.rsquared
    rsquareds.append(rsquared)
    stderr = result.bse[1]
    
    title = "%d - " % (seq)
    title += " (b) " if seq < n_binding_seqs else " (nb)"
    title += "(slope: %.2f, r^2: %.3f, std: %.4f)" % (slope, rsquared, stderr)
    ax.set_title(title)
    
    line = slope*x+intercept
    prstd, iv_l, iv_u = wls_prediction_std(result)
    ax.plot(x, line, 'r')
    ax.plot(x, y, 'o')
    ax.plot(x, iv_u, 'r--')
    ax.plot(x, iv_l, 'r--')
    legend = ax.legend(loc="best")

    ax1.legend()

    
plt.show()

In [None]:
print(sum(slopes) / len(slopes))
print(sum(rsquareds) / len(rsquareds))

In [None]:

IDX_TO_NT = 'ACGT'

def _convert_to_mutation(pos_nt_pair):
    return "%d%s" % (pos_nt_pair[0], IDX_TO_NT[pos_nt_pair[1]])


TF_COL = 1
CA_COL = 0

def _write_row(writer, seq, batch, i):
    seq_num = seq+1
    mut_num = (batch * (batch_size)) + i
    mut = _convert_to_mutation((mut_num // 3, mut_num % 3))
    x_pred_mean_diff = np_pred_mean_diffs[seq, batch, i, TF_COL]
    x_pred_uncertainty = np_pred_uncertainties[seq, batch, i, TF_COL]
    y_pred_mean_diff = np_pred_mean_diffs[seq, batch, i, CA_COL]
    y_pred_uncertainty = np_pred_uncertainties[seq, batch, i, CA_COL]
    writer.writerow(
        {
            "seq_num": seq_num,
            "mut": mut,
            "X_pred_mean": x_pred_mean_diff,
            "X_pred_var": x_pred_uncertainty,
            "Y_pred_mean": y_pred_mean_diff,
            "Y_pred_var": y_pred_uncertainty,
        }
    )
        

with open("../dat/means_and_uncertainties__original__comparison.csv", 'w', newline="") as out_file:
    fieldnames = [
        "seq_num",
        "mut",
        "X_pred_mean",
        "X_pred_var",
        "Y_pred_mean",
        "Y_pred_var",
    ]
    writer = csv.DictWriter(out_file, delimiter=",", fieldnames=fieldnames)
    writer.writeheader()
    
    for seq in range(n_seqs):
        _write_row(writer, seq, batch, i)
        for batch in range(n_batches):
            for i in range(0, batch_size-1):
                _write_row(writer, seq, batch, i)
 
