## Imports and model initialization

In [1]:
# !pip install kipoi
# !pip install kipoiseq
# !pip install pybedtools
# !pip uninstall -y kipoi_veff
# !pip install git+https://github.com/an1lam/kipoi-veff
# !pip install pyvcf
import csv
import math
import pickle

import kipoi
from kipoi_interpret.importance_scores.ism import Mutation
from kipoiseq.dataloaders import SeqIntervalDl
from matplotlib import pyplot as plt
import numpy as np
from tqdm import tqdm

In [2]:
!pwd

/home/stephenmalina/project/src


# Loading DNA sequence data

In [3]:
dl = SeqIntervalDl("../dat/50_random_seqs.bed", "../dat/hg19.fa", auto_resize_len=1000)
data = dl.load_all()

100%|██████████| 2/2 [00:00<00:00,  5.72it/s]


In [4]:
seqs = np.expand_dims(data['inputs'].transpose(0, 2, 1), 2).astype(np.float32)
seqs.shape

(52, 4, 1, 1000)

# Loading DeepSEA

In [5]:
import tensorflow as tf
print("TF version:", tf.__version__)
import torch
print("torch version:", torch.__version__)
from torch import nn

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


TF version: 1.14.0
torch version: 1.3.1


In [6]:
# df = kipoi.list_models()
# deepsea_models = df[df.model.str.contains("DeepSEA")]
# deepsea_models.head()

In [7]:
deepsea = kipoi.get_model("DeepSEA/predict", source="kipoi")
deepsea.model

Using downloaded and verified file: /home/stephenmalina/.kipoi/models/DeepSEA/predict/downloaded/model_files/weights/89e640bf6bdbe1ff165f484d9796efc7


Sequential(
  (0): ReCodeAlphabet()
  (1): ConcatenateRC()
  (2): Sequential(
    (0): Conv2d(4, 320, kernel_size=(1, 8), stride=(1, 1))
    (1): Threshold(threshold=0, value=1e-06)
    (2): MaxPool2d(kernel_size=(1, 4), stride=(1, 4), padding=0, dilation=1, ceil_mode=False)
    (3): Dropout(p=0.2, inplace=False)
    (4): Conv2d(320, 480, kernel_size=(1, 8), stride=(1, 1))
    (5): Threshold(threshold=0, value=1e-06)
    (6): MaxPool2d(kernel_size=(1, 4), stride=(1, 4), padding=0, dilation=1, ceil_mode=False)
    (7): Dropout(p=0.2, inplace=False)
    (8): Conv2d(480, 960, kernel_size=(1, 8), stride=(1, 1))
    (9): Threshold(threshold=0, value=1e-06)
    (10): Dropout(p=0.5, inplace=False)
    (11): Lambda()
    (12): Sequential(
      (0): Lambda()
      (1): Linear(in_features=50880, out_features=925, bias=True)
    )
    (13): Threshold(threshold=0, value=1e-06)
    (14): Sequential(
      (0): Lambda()
      (1): Linear(in_features=925, out_features=919, bias=True)
    )
    (15):

In [8]:
deepsea.pipeline.predict_example().shape

100%|██████████| 1/1 [00:00<00:00,  4.56it/s]


(10, 919)

# Hacking DeepSEA's layers
This section is focused on tweaking DeepSEA's model to use dropout when making predictions.

In [9]:
# (Legally) grabbed from:
# https://pytorchnlp.readthedocs.io/en/latest/_modules/torchnlp/nn/lock_dropout.html.
class LockedDropout(nn.Module):
    """ LockedDropout applies the same dropout mask to every time step.

    **Thank you** to Sales Force for their initial implementation of :class:`WeightDrop`. 
    Here is their `License
    <https://github.com/salesforce/awd-lstm-lm/blob/master/LICENSE>`__.

    Args:
        p (float): Probability of an element in the dropout mask to be zeroed.
    """

    def __init__(self, p=0.5, training=True):
        self.p = p
        self.training = training
        super().__init__()

    def train(self, training=True):
        self.training = training
        
    def forward(self, x):
        """
        Args:
            x (:class:`torch.FloatTensor` [sequence length, batch size, rnn hidden size]):
                Input to apply dropout to.
        """
        if not self.training or not self.p:
            return x
        x = x.clone()
        mask = x.new_empty(1, x.size(1), x.size(2), x.size(3), requires_grad=False)
        mask = mask.bernoulli_(1 - self.p)
        mask = mask.div_(1 - self.p) # rescaling
        mask = mask.expand_as(x)
        return x * mask


    def __repr__(self):
        return self.__class__.__name__ + '(' + 'p=' + str(self.p) + ')'

In [10]:
def replace_dropout_layers(model):
    for name, module in reversed(model._modules.items()):
        if len(list(module.children())) > 0:
            # recurse
            model._modules[name] = replace_dropout_layers(module)

        if type(module) == nn.Dropout:
             model._modules[name] = LockedDropout(module.p, training=module.training)
        
    return model

In [11]:
def apply_dropout(m):
    if type(m) == nn.Dropout or type(m) == LockedDropout:
        m.train()
        
def unapply_dropout(m):
    if type(m) == nn.Dropout or type(m) == LockedDropout:
        m.eval()

In [12]:
deepsea.model = replace_dropout_layers(deepsea.model)

### Verifying dropout is being applied
We can verify or at least get reasonable certainty that dropout's being applied by running 10 predictions on the same sequence and then seeing that the predicted binding probs differ. First, we show that the predictions stay the same if dropout is off.

Then, we turn on dropout and show that they start to vary.

In [13]:
deepsea.model = deepsea.model.apply(unapply_dropout)
for i in range(3):
    print(f"First few preds ({i+1}/3): ", deepsea.pipeline.predict_example()[:5])

100%|██████████| 1/1 [00:00<00:00, 33.35it/s]


First few preds (1/3):  [[0.08161031 0.06867661 0.10076798 ... 0.09493414 0.02133885 0.01201447]
 [0.06698208 0.01062424 0.02694638 ... 0.15490864 0.04822354 0.00770117]
 [0.04445557 0.00539728 0.018408   ... 0.14994667 0.35297143 0.02272817]
 [0.00245111 0.00048716 0.00247776 ... 0.01495204 0.04857798 0.00068955]
 [0.00076919 0.00535259 0.00189638 ... 0.01889143 0.06013346 0.02190239]]


100%|██████████| 1/1 [00:00<00:00, 33.76it/s]


First few preds (2/3):  [[0.08161031 0.06867661 0.10076798 ... 0.09493414 0.02133885 0.01201447]
 [0.06698208 0.01062424 0.02694638 ... 0.15490864 0.04822354 0.00770117]
 [0.04445557 0.00539728 0.018408   ... 0.14994667 0.35297143 0.02272817]
 [0.00245111 0.00048716 0.00247776 ... 0.01495204 0.04857798 0.00068955]
 [0.00076919 0.00535259 0.00189638 ... 0.01889143 0.06013346 0.02190239]]


100%|██████████| 1/1 [00:00<00:00, 34.52it/s]


First few preds (3/3):  [[0.08161031 0.06867661 0.10076798 ... 0.09493414 0.02133885 0.01201447]
 [0.06698208 0.01062424 0.02694638 ... 0.15490864 0.04822354 0.00770117]
 [0.04445557 0.00539728 0.018408   ... 0.14994667 0.35297143 0.02272817]
 [0.00245111 0.00048716 0.00247776 ... 0.01495204 0.04857798 0.00068955]
 [0.00076919 0.00535259 0.00189638 ... 0.01889143 0.06013346 0.02190239]]


In [14]:
deepsea.model.eval()
deepsea.model = deepsea.model.apply(apply_dropout)
for i in range(3):
    print(f"First few preds ({i+1}/3): ", deepsea.pipeline.predict_example()[:5])

100%|██████████| 1/1 [00:00<00:00, 33.31it/s]


First few preds (1/3):  [[0.09469041 0.06476855 0.12214313 ... 0.09340209 0.02927886 0.00951671]
 [0.05391673 0.00865934 0.02343167 ... 0.10413048 0.04671773 0.0094504 ]
 [0.04761756 0.00292272 0.02001821 ... 0.10483676 0.3426831  0.02824711]
 [0.00287849 0.00051717 0.00311738 ... 0.00953412 0.04629721 0.00129358]
 [0.00076278 0.0087555  0.00196435 ... 0.01482    0.05284825 0.0233275 ]]


100%|██████████| 1/1 [00:00<00:00, 32.94it/s]


First few preds (2/3):  [[0.07072315 0.09313697 0.10903168 ... 0.08075188 0.01988692 0.01275267]
 [0.0755543  0.02080756 0.03763523 ... 0.18903339 0.04289904 0.00842006]
 [0.04845877 0.00574616 0.02143795 ... 0.09896583 0.2500741  0.02560952]
 [0.00172383 0.00054227 0.00215096 ... 0.00909489 0.02452013 0.00052947]
 [0.00071932 0.00820777 0.00236618 ... 0.01169196 0.05650319 0.01962058]]


100%|██████████| 1/1 [00:00<00:00, 33.13it/s]


First few preds (3/3):  [[0.09125951 0.03321379 0.10498534 ... 0.05581583 0.02147745 0.02208168]
 [0.04562641 0.00490076 0.02036985 ... 0.19676986 0.05821128 0.00898473]
 [0.05367801 0.00396026 0.02989383 ... 0.04869422 0.19736491 0.03130752]
 [0.00518991 0.00090631 0.00659967 ... 0.01447539 0.05181357 0.00093238]
 [0.00074766 0.00480332 0.00170046 ... 0.02407834 0.05369303 0.03029319]]


## Predictions and in-silico mutagenesis

In [15]:
deepsea.predict_on_batch(seqs[:5, :, :, :])

array([[0.00947323, 0.02589644, 0.00898295, ..., 0.02471472, 0.09217416,
        0.0264926 ],
       [0.00121323, 0.00887009, 0.00260588, ..., 0.02647636, 0.07006134,
        0.00369652],
       [0.00248037, 0.00347161, 0.00266787, ..., 0.01151872, 0.05654726,
        0.01730141],
       [0.00351746, 0.00387447, 0.00372575, ..., 0.00689738, 0.04173702,
        0.00717506],
       [0.00085676, 0.00292669, 0.00347074, ..., 0.02656512, 0.26414126,
        0.00315961]], dtype=float32)

In [73]:
all_zeros = np.zeros((4,))
def generate_wt_mut_batches(seq, batch_size):
    """
    For a given sequence, generate all possible point-mutated versions of the sequence
    in batches of size `param:batch_size`.
    
    Args:
        seq (numpy.ndarray [number of base pairs, sequence length]): 
            wild type sequence.
        batch_size (int): size of returned batches. Note that each batch will have the
            wild type sequence as its first row since we need to compute wild type / mut
            prediction diffs using predictions generated by the same dropout mask.
    """
    num_nts, seq_len = seq.shape
    assert ((seq_len * 3) % (batch_size-1)) == 0, seq_len * 3
    # 3 mutations per nt and then account for ref in each batch
    n_batches = (seq_len * 3) // (batch_size-1)
    seq_batch = seq[np.newaxis, :, :].repeat(batch_size, axis=0)
    seq_batches = seq_batch[np.newaxis, :, :, :].repeat(n_batches, axis=0)
    i = 0
    for seq_idx in range(seq_len):  # iterate over sequence 
        for nt_idx in range(num_nts):  # iterate over nucleotides
            curr_batch, curr_idx = i // (batch_size - 1), (i % (batch_size-1) + 1)
            
            curr_nt = seq[nt_idx, seq_idx]
            if int(curr_nt) == 1: continue

            seq_batches[curr_batch, curr_idx, :, seq_idx] = all_zeros
            seq_batches[curr_batch, curr_idx, nt_idx, seq_idx] = 1
            i += 1
    return seq_batches

In [17]:
CHROM_ACC_COL = 'HepG2_DNase_None'
# TF_COL = 'A549_CTCF_None'
TF_COL = 'HepG2_FOXA2_None'
relevant_cols = sorted([(i, label)
                        for i, label in enumerate(deepsea.schema.targets.column_labels)
                        if label in [CHROM_ACC_COL, TF_COL]])

def output_sel_fn(result):
    return np.array([result[:, col_idx] for col_idx, _ in relevant_cols]).T

relevant_cols

[(56, 'HepG2_DNase_None'), (304, 'HepG2_FOXA2_None')]

In [81]:
def next_seq(it):
    return (np
            .expand_dims(next(it)["inputs"].transpose(0, 2, 1), 2)
            .astype(np.float32)
            .squeeze())
    

epochs, n_seqs, batch_size = 1, 51, 301
preds = [[[] for _ in range(n_seqs)] for _ in range(epochs)]
it = dl.batch_iter(batch_size=1, num_workers=0, drop_last=False)
print(len(it))
batch_size = 301
for i in range(min(n_seqs, len(it))):
    seq = next_seq(it)
    if np.allclose(seq, .25): continue
    wt_mut_batches = generate_wt_mut_batches(seq, batch_size)
    for batch in tqdm(wt_mut_batches):
        for epoch in range(epochs):
            preds[epoch][i].append(output_sel_fn(deepsea.predict_on_batch(np.expand_dims(batch, axis=2))))


np_preds = np.array(preds)
# assert np_preds.shape[:2] == (epochs, n_seqs), np_preds.shape

52
2999 999 1.0 9
2999 1000 0.0 9
passed cont
3000 1000 9
0



100%|██████████| 10/10 [00:00<00:00, 21925.27it/s]


2999 999 1.0 9
2999 1000 0.0 9
passed cont
3000 1000 9
1



100%|██████████| 10/10 [00:00<00:00, 22086.91it/s]


2999 1000 0.0 9
passed cont
3000 1000 9
2



100%|██████████| 10/10 [00:00<00:00, 16282.24it/s]


2999 999 0.0 9
passed cont
3000 999 1.0 10
3000 1000 10
3



100%|██████████| 10/10 [00:00<00:00, 23696.63it/s]


2999 1000 0.0 9
passed cont
3000 1000 9
4



100%|██████████| 10/10 [00:00<00:00, 19328.59it/s]


2999 1000 0.0 9
passed cont
3000 1000 9
5



100%|██████████| 10/10 [00:00<00:00, 21959.71it/s]


2999 999 0.0 9
passed cont
3000 999 1.0 10
3000 1000 10
6



100%|██████████| 10/10 [00:00<00:00, 21034.62it/s]


2999 1000 0.0 9
passed cont
3000 1000 9
7



100%|██████████| 10/10 [00:00<00:00, 20992.51it/s]


2999 1000 0.0 9
passed cont
3000 1000 9
8



100%|██████████| 10/10 [00:00<00:00, 18800.11it/s]


2999 1000 0.0 9
passed cont
3000 1000 9
9



100%|██████████| 10/10 [00:00<00:00, 21799.92it/s]


2999 999 0.0 9
passed cont
3000 999 1.0 10
3000 1000 10
10



100%|██████████| 10/10 [00:00<00:00, 23301.69it/s]


2999 1000 0.0 9
passed cont
3000 1000 9
11



100%|██████████| 10/10 [00:00<00:00, 22156.91it/s]


2999 1000 0.0 9
passed cont
3000 1000 9
12



100%|██████████| 10/10 [00:00<00:00, 21215.50it/s]


2999 999 0.0 9
passed cont
3000 999 1.0 10
3000 1000 10
13



100%|██████████| 10/10 [00:00<00:00, 19039.06it/s]


2999 1000 0.0 9
passed cont
3000 1000 9
14



100%|██████████| 10/10 [00:00<00:00, 23458.08it/s]


2999 999 0.0 9
passed cont
3000 999 1.0 10
3000 1000 10
15



100%|██████████| 10/10 [00:00<00:00, 24600.02it/s]


2999 1000 0.0 9
passed cont
3000 1000 9
16



100%|██████████| 10/10 [00:00<00:00, 21377.70it/s]


2999 1000 0.0 9
passed cont
3000 1000 9
17



100%|██████████| 10/10 [00:00<00:00, 22405.47it/s]


2999 999 1.0 9
2999 1000 0.0 9
passed cont
3000 1000 9
18



100%|██████████| 10/10 [00:00<00:00, 22733.36it/s]


2999 999 0.0 9
passed cont
3000 999 1.0 10
3000 1000 10
19



100%|██████████| 10/10 [00:00<00:00, 20702.39it/s]


2999 999 1.0 9
2999 1000 0.0 9
passed cont
3000 1000 9
20



100%|██████████| 10/10 [00:00<00:00, 20490.00it/s]


2999 1000 0.0 9
passed cont
3000 1000 9
21



100%|██████████| 10/10 [00:00<00:00, 19887.64it/s]


2999 1000 0.0 9
passed cont
3000 1000 9
22



100%|██████████| 10/10 [00:00<00:00, 22262.76it/s]


2999 1000 0.0 9
passed cont
3000 1000 9
23



100%|██████████| 10/10 [00:00<00:00, 21959.71it/s]


2999 1000 0.0 9
passed cont
3000 1000 9
24



100%|██████████| 10/10 [00:00<00:00, 22832.36it/s]


2999 1000 0.0 9
passed cont
3000 1000 9
25



100%|██████████| 10/10 [00:00<00:00, 21653.61it/s]


2999 999 1.0 9
2999 1000 0.0 9
passed cont
3000 1000 9
26



100%|██████████| 10/10 [00:00<00:00, 24314.81it/s]


2999 1000 0.0 9
passed cont
3000 1000 9
27



100%|██████████| 10/10 [00:00<00:00, 21948.22it/s]


2999 1000 0.0 9
passed cont
3000 1000 9
28



100%|██████████| 10/10 [00:00<00:00, 23340.59it/s]


2999 1000 0.0 9
passed cont
3000 1000 9
29



100%|██████████| 10/10 [00:00<00:00, 22121.86it/s]


2999 999 0.0 9
passed cont
3000 999 1.0 10
3000 1000 10
30



100%|██████████| 10/10 [00:00<00:00, 22017.34it/s]


2999 1000 0.0 9
passed cont
3000 1000 9
31



100%|██████████| 10/10 [00:00<00:00, 20010.99it/s]


2999 999 0.0 9
passed cont
3000 999 1.0 10
3000 1000 10
32



100%|██████████| 10/10 [00:00<00:00, 19409.09it/s]


2999 999 0.0 9
passed cont
3000 999 1.0 10
3000 1000 10
33



100%|██████████| 10/10 [00:00<00:00, 23172.95it/s]


2999 999 1.0 9
2999 1000 0.0 9
passed cont
3000 1000 9
34



100%|██████████| 10/10 [00:00<00:00, 20164.92it/s]


2999 1000 0.0 9
passed cont
3000 1000 9
35



100%|██████████| 10/10 [00:00<00:00, 19517.47it/s]


2999 1000 0.0 9
passed cont
3000 1000 9
36



100%|██████████| 10/10 [00:00<00:00, 19099.74it/s]


2999 999 1.0 9
2999 1000 0.0 9
passed cont
3000 1000 9
37



100%|██████████| 10/10 [00:00<00:00, 21498.23it/s]


2999 999 0.0 9
passed cont
3000 999 1.0 10
3000 1000 10
38



100%|██████████| 10/10 [00:00<00:00, 22758.02it/s]


2999 1000 0.0 9
passed cont
3000 1000 9
39



100%|██████████| 10/10 [00:00<00:00, 19756.50it/s]


2999 999 1.0 9
2999 1000 0.0 9
passed cont
3000 1000 9
40



100%|██████████| 10/10 [00:00<00:00, 22417.45it/s]


2999 1000 0.0 9
passed cont
3000 1000 9
41



100%|██████████| 10/10 [00:00<00:00, 23392.66it/s]


2999 999 1.0 9
2999 1000 0.0 9
passed cont
3000 1000 9
42



100%|██████████| 10/10 [00:00<00:00, 21720.89it/s]


2999 999 0.0 9
passed cont
3000 999 1.0 10
3000 1000 10
43



100%|██████████| 10/10 [00:00<00:00, 23250.02it/s]


2999 999 1.0 9
2999 1000 0.0 9
passed cont
3000 1000 9
44



100%|██████████| 10/10 [00:00<00:00, 23314.64it/s]


2999 999 0.0 9
passed cont
3000 999 1.0 10
3000 1000 10
45



100%|██████████| 10/10 [00:00<00:00, 19013.16it/s]


2999 999 0.0 9
passed cont
3000 999 1.0 10
3000 1000 10
46



100%|██████████| 10/10 [00:00<00:00, 23121.85it/s]


2999 1000 0.0 9
passed cont
3000 1000 9
47



100%|██████████| 10/10 [00:00<00:00, 22770.38it/s]


2999 1000 0.0 9
passed cont
3000 1000 9
48



100%|██████████| 10/10 [00:00<00:00, 23366.60it/s]


skip
2999 1000 0.0 9
passed cont
3000 1000 9
50



100%|██████████| 10/10 [00:00<00:00, 22635.21it/s]


In [None]:
import pickle
pickle_file = "../dat/most_recent_sat_mut_results.pickle"
with open(pickle_file, 'wb') as f: pickle.dump(np_preds, f)

In [None]:
# with open(pickle_file, 'rb') as f: np_preds = pickle.load(f)
np_preds.shape
epochs, n_seqs, n_batches, batch_size, _ = np_preds.shape
np_preds.shape

# Results & Analysis

In [None]:
# np_log_odds_preds = np.log(np_preds / (1-np_preds))

In [None]:
n_batches = np_preds.shape[2]
batch_size = np_preds.shape[3]
np_preds.shape

In [None]:
np_pred_means = np.mean(np_preds[:, :batch_size-1, :, :, :], axis=0)
np_pred_vars = np.var(np_preds, axis=0)
np_pred_means.shape

In [None]:
np_pred_mean_diffs = np_pred_means[:, :, 1:, :] - np_pred_means[:, :, 0:1, :] 
np_pred_mean_diffs.shape

In [None]:
np_pred_covs = np.zeros((n_seqs, n_batches, batch_size, len(relevant_cols)))
for seq in range(n_seqs):
    for batch in range(n_batches):
        for col in range(len(relevant_cols)):
            ref_seq_preds = np_preds[:, seq, batch, 0, col]
            for mut in range(batch_size):
                mut_seq_preds = np_preds[:, seq, batch, mut, col]
                cov = np.cov(np.stack((ref_seq_preds, mut_seq_preds))) # 2x2, symmetric
                np_pred_covs[seq, batch, mut, col] = cov[0, 1] # off diag idx

In [None]:
np_pred_uncertainties = np.sqrt(np_pred_vars[:, :, 1:, :] + np_pred_vars[:, :, 0:1, :]) - 2 * np_pred_covs[:, :, 1:, :]

In [None]:
plt.hist(np_pred_mean_diffs[:, :n_batches, 1:, 1].reshape(-1), log=True);

In [None]:
for seq in range(np_pred_means.shape[0]):
    plt.scatter(
        np_pred_mean_diffs[seq, :, :, TF_COL].reshape(-1), 
        np_pred_uncertainties[seq, :, :, TF_COL].reshape(-1),
        label=seq)
plt.xlabel("Predictive Mean Diff (TF)")
plt.ylabel("Predictive Variance (TF)")
plt.legend();

In [None]:
for seq in range(np_pred_means.shape[0]):
    plt.scatter(
        np_pred_mean_diffs[seq, :, :, CA_COL].reshape(-1), 
        np_pred_uncertainties[seq, :, :, CA_COL].reshape(-1),
        label=seq)
plt.xlabel("Predictive Mean Diff (CA)")
plt.ylabel("Predictive Variance (CA)")
plt.legend();

In [None]:
TF_COL = 1
CA_COL = 0

for seq in range(np_pred_means.shape[0]):
    plt.scatter(
        (np_pred_means[seq, :, 1:, TF_COL] - np_pred_means[seq, :, 0:1, TF_COL]).reshape(-1), 
        (np_pred_means[seq, :, 1:, CA_COL] - np_pred_means[seq, :, 0:1, CA_COL]).reshape(-1), 
        label=seq)
plt.xlabel("TF pred diff")
plt.ylabel("CA pred diff");

In [None]:

for seq in range(np_pred_means.shape[0]):
    print(np_pred_mean_diffs.reshape(-1).shape)
    print(np_pred_uncertainties.reshape(-1).shape)
    plt.scatter(
        np_pred_mean_diffs[seq, :, :, TF_COL].reshape(-1), 
        np_pred_uncertainties[seq, :, :, TF_COL].reshape(-1),
        label=seq)
plt.xlabel("Instrument Strength")
plt.ylabel("Predictive Uncertainty");


In [None]:
plt.rcParams['figure.figsize'] = [12, 8]
fig, axs = plt.subplots(2, 2)
sample_seqs = np.random.random_integers(0, n_seqs-1, size=4)
sample_batches = np.random.random_integers(0, n_batches-1, size=4)
sample_muts = np.random.random_integers(0, batch_size-1, size=4)
sample_col = np.random.random_integers(0, len(relevant_cols)-1)

for i in range(4):
    np_sample_ref_preds = np_preds[
        :, sample_seqs[i], sample_batches[i], 0, sample_col
    ]
    np_sample_mut_preds = np_preds[
        :, sample_seqs[i], sample_batches[i], sample_muts[i], sample_col
    ]
    axs[i // 2, i % 2].scatter(np_sample_ref_preds, np_sample_mut_preds)
    axs[i // 2, i % 2].set_xlabel(f"{relevant_cols[sample_col][1]} prob (ref)")
    axs[i // 2, i % 2].set_ylabel(f"{relevant_cols[sample_col][1]} prob (mut)")    

In [None]:

IDX_TO_NT = 'ACGT'

def _convert_to_mutation(pos_nt_pair):
    return "%d%s" % (pos_nt_pair[0], IDX_TO_NT[pos_nt_pair[1]])


TF_COL = 1
CA_COL = 0

def _write_row(writer, seq, batch, i):
    seq_num = seq+1
    mut_num = (batch * (batch_size)) + i
    mut = _convert_to_mutation((mut_num // 3, mut_num % 3))
    x_pred_mean_diff = np_pred_mean_diffs[seq, batch, i, TF_COL]
    x_pred_uncertainty = np_pred_uncertainties[seq, batch, i, TF_COL]
    y_pred_mean_diff = np_pred_mean_diffs[seq, batch, i, CA_COL]
    y_pred_uncertainty = np_pred_uncertainties[seq, batch, i, CA_COL]
    writer.writerow(
        {
            "seq_num": seq_num,
            "mut": mut,
            "X_pred_mean": x_pred_mean_diff,
            "X_pred_var": x_pred_uncertainty,
            "Y_pred_mean": y_pred_mean_diff,
            "Y_pred_var": y_pred_uncertainty,
        }
    )
        

with open("../dat/means_and_uncertainties.csv", 'w', newline="") as out_file:
    fieldnames = [
        "seq_num",
        "mut",
        "X_pred_mean",
        "X_pred_var",
        "Y_pred_mean",
        "Y_pred_var",
    ]
    writer = csv.DictWriter(out_file, delimiter=",", fieldnames=fieldnames)
    writer.writeheader()
    
    for seq in range(n_seqs):
        _write_row(writer, seq, batch, i)
        for batch in range(n_batches):
            for i in range(0, batch_size-1):
                _write_row(writer, seq, batch, i)
 


In [None]:
# zeros = np.zeros((2, 3))
# def sanitize_scores(scores):
#     orig_shape = scores.shape
#     sanitized_scores = np.ndarray((*orig_shape, 2, 3), dtype=scores.dtype)
#     flattened_scores = scores.reshape(-1)
    
#     for i, score in enumerate(flattened_scores):
#         idx = np.unravel_index(i, orig_shape)
#         if score is None: sanitized_scores[idx] = zeros
#         else: sanitized_scores[idx] = np.array(score)
#     return sanitized_scores

# sanitized_scores = sanitize_scores(np.squeeze(ism_score))
# ctcf_original_preds = sanitized_scores[:, :, :, 0, 1]
# ctcf_pred_diffs = sanitized_scores[:, :, :, 1, 1]
# dnase_original_preds = sanitized_scores[:, :, :, 0, 0]
# dnase_pred_diff = sanitized_scores[:, :, :, 1, 0]

In [None]:

log_uniform_prop = math.log(.05/(1-.05))
def compute_normalized_prob(prob, train_prob):
    denom = 1+np.exp(-(np.log(prob/(1-prob))+log_uniform_prop-np.log(train_prob/(1-train_prob))))
    return 1 / denom

# Ratios and normalization formula drawn from here: http://deepsea.princeton.edu/media/help/posproportion.txt
# tf_compute_normalized_prob = lambda prob: compute_normalized_prob(prob, .020029)
# ENCODE	A549	DNase	None	0.048136
# chrom_acc_normalized_prob = lambda prob: compute_normalized_prob(prob, 0.048136)