## Imports and model initialization

In [1]:
# !pip install kipoi
# !pip install kipoiseq
# !pip install pybedtools
# !pip uninstall -y kipoi_veff
# !pip install git+https://github.com/an1lam/kipoi-veff
# !pip install pyvcf
import kipoi
from kipoi_interpret.importance_scores.ism import Mutation
from kipoiseq.dataloaders import SeqIntervalDl
import numpy as np

In [2]:
!pwd

/home/stephenmalina/project/src


# Loading DNA sequence data

In [3]:
dl = SeqIntervalDl("../dat/ChIPseq.A549.CTCF.1000.random.narrowPeak.gz", "../dat/hg19.fa", auto_resize_len=1000)
data = dl.load_all()

100%|██████████| 32/32 [00:01<00:00, 25.26it/s]


In [4]:
seqs = np.expand_dims(data['inputs'].transpose(0, 2, 1), 2).astype(np.float32)
seqs.shape

(1001, 4, 1, 1000)

# Loading DeepSEA

In [5]:
import tensorflow as tf
print("TF version:", tf.__version__)
import torch
print("torch version:", torch.__version__)
from torch import nn

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


TF version: 1.14.0
torch version: 1.3.1


In [6]:
# df = kipoi.list_models()
# deepsea_models = df[df.model.str.contains("DeepSEA")]
# deepsea_models.head()

In [7]:
deepsea = kipoi.get_model("DeepSEA/predict", source="kipoi")
deepsea.model

Using downloaded and verified file: /home/stephenmalina/.kipoi/models/DeepSEA/predict/downloaded/model_files/weights/89e640bf6bdbe1ff165f484d9796efc7


Sequential(
  (0): ReCodeAlphabet()
  (1): ConcatenateRC()
  (2): Sequential(
    (0): Conv2d(4, 320, kernel_size=(1, 8), stride=(1, 1))
    (1): Threshold(threshold=0, value=1e-06)
    (2): MaxPool2d(kernel_size=(1, 4), stride=(1, 4), padding=0, dilation=1, ceil_mode=False)
    (3): Dropout(p=0.2, inplace=False)
    (4): Conv2d(320, 480, kernel_size=(1, 8), stride=(1, 1))
    (5): Threshold(threshold=0, value=1e-06)
    (6): MaxPool2d(kernel_size=(1, 4), stride=(1, 4), padding=0, dilation=1, ceil_mode=False)
    (7): Dropout(p=0.2, inplace=False)
    (8): Conv2d(480, 960, kernel_size=(1, 8), stride=(1, 1))
    (9): Threshold(threshold=0, value=1e-06)
    (10): Dropout(p=0.5, inplace=False)
    (11): Lambda()
    (12): Sequential(
      (0): Lambda()
      (1): Linear(in_features=50880, out_features=925, bias=True)
    )
    (13): Threshold(threshold=0, value=1e-06)
    (14): Sequential(
      (0): Lambda()
      (1): Linear(in_features=925, out_features=919, bias=True)
    )
    (15):

In [9]:
deepsea.pipeline.predict_example().shape

100%|██████████| 1/1 [00:00<00:00, 24.90it/s]


(10, 919)

# Hacking DeepSEA's layers
This section is focused on tweaking DeepSEA's model to use dropout when making predictions.

In [16]:
# (Legally) grabbed from:
# https://pytorchnlp.readthedocs.io/en/latest/_modules/torchnlp/nn/lock_dropout.html.
class LockedDropout(nn.Module):
    """ LockedDropout applies the same dropout mask to every time step.

    **Thank you** to Sales Force for their initial implementation of :class:`WeightDrop`. 
    Here is their `License
    <https://github.com/salesforce/awd-lstm-lm/blob/master/LICENSE>`__.

    Args:
        p (float): Probability of an element in the dropout mask to be zeroed.
    """

    def __init__(self, p=0.5, training=True):
        self.p = p
        self.training = training
        super().__init__()

    def train(self, training=True):
        self.training = training
        
    def forward(self, x):
        """
        Args:
            x (:class:`torch.FloatTensor` [sequence length, batch size, rnn hidden size]):
                Input to apply dropout to.
        """
        if not self.training or not self.p:
            return x
        x = x.clone()
        mask = x.new_empty(1, x.size(1), x.size(2), x.size(3), requires_grad=False)
        mask = mask.bernoulli_(1 - self.p)
        mask = mask.div_(1 - self.p) # rescaling
        mask = mask.expand_as(x)
        return x * mask


    def __repr__(self):
        return self.__class__.__name__ + '(' + 'p=' + str(self.p) + ')'

In [17]:
def replace_dropout_layers(model):
    for name, module in reversed(model._modules.items()):
        if len(list(module.children())) > 0:
            # recurse
            model._modules[name] = replace_dropout_layers(module)

        if type(module) == nn.Dropout:
             model._modules[name] = LockedDropout(module.p, training=module.training)
        
    return model

In [18]:
from torch import nn

def apply_dropout(m):
    if type(m) == nn.Dropout or type(m) == LockedDropout:
        m.train()
        
def unapply_dropout(m):
    if type(m) == nn.Dropout or type(m) == LockedDropout:
        m.eval()

In [19]:
deepsea.model = replace_dropout_layers(deepsea.model)

### Verifying dropout is being applied
We can verify or at least get reasonable certainty that dropout's being applied by running 10 predictions on the same sequence and then seeing that the predicted binding probs differ. First, we show that the predictions stay the same if dropout is off.

Then, we turn on dropout and show that they start to vary.

In [25]:
deepsea.model = deepsea.model.apply(unapply_dropout)
for i in range(3):
    print(f"First few preds ({i+1}/3): ", deepsea.pipeline.predict_example()[:5])

100%|██████████| 1/1 [00:00<00:00, 29.71it/s]


First few preds (1/3):  [[0.08161031 0.06867661 0.10076798 ... 0.09493414 0.02133885 0.01201447]
 [0.06698208 0.01062424 0.02694638 ... 0.15490864 0.04822354 0.00770117]
 [0.04445557 0.00539728 0.018408   ... 0.14994667 0.35297143 0.02272817]
 [0.00245111 0.00048716 0.00247776 ... 0.01495204 0.04857798 0.00068955]
 [0.00076919 0.00535259 0.00189638 ... 0.01889143 0.06013346 0.02190239]]


100%|██████████| 1/1 [00:00<00:00, 29.86it/s]


First few preds (2/3):  [[0.08161031 0.06867661 0.10076798 ... 0.09493414 0.02133885 0.01201447]
 [0.06698208 0.01062424 0.02694638 ... 0.15490864 0.04822354 0.00770117]
 [0.04445557 0.00539728 0.018408   ... 0.14994667 0.35297143 0.02272817]
 [0.00245111 0.00048716 0.00247776 ... 0.01495204 0.04857798 0.00068955]
 [0.00076919 0.00535259 0.00189638 ... 0.01889143 0.06013346 0.02190239]]


100%|██████████| 1/1 [00:00<00:00, 29.92it/s]


First few preds (3/3):  [[0.08161031 0.06867661 0.10076798 ... 0.09493414 0.02133885 0.01201447]
 [0.06698208 0.01062424 0.02694638 ... 0.15490864 0.04822354 0.00770117]
 [0.04445557 0.00539728 0.018408   ... 0.14994667 0.35297143 0.02272817]
 [0.00245111 0.00048716 0.00247776 ... 0.01495204 0.04857798 0.00068955]
 [0.00076919 0.00535259 0.00189638 ... 0.01889143 0.06013346 0.02190239]]


In [28]:
deepsea.model.eval()
deepsea.model = deepsea.model.apply(apply_dropout)
for i in range(3):
    print(f"First few preds ({i+1}/3): ", deepsea.pipeline.predict_example()[:5])

100%|██████████| 1/1 [00:00<00:00, 29.59it/s]


First few preds (1/3):  [[0.06741344 0.05814194 0.08361216 ... 0.0966848  0.0274433  0.0102653 ]
 [0.03451755 0.00577242 0.015584   ... 0.17427075 0.05067976 0.01091628]
 [0.06075658 0.00787469 0.02534632 ... 0.19338737 0.3975836  0.01956929]
 [0.00265916 0.0005907  0.00283419 ... 0.00805877 0.0446902  0.00077955]
 [0.00099816 0.01310912 0.00280122 ... 0.01483776 0.05652917 0.0221295 ]]


100%|██████████| 1/1 [00:00<00:00, 29.73it/s]


First few preds (2/3):  [[0.10481633 0.08269885 0.13672565 ... 0.06089396 0.02338288 0.01582721]
 [0.04421178 0.00644389 0.02450597 ... 0.09083049 0.04771356 0.01204624]
 [0.05301471 0.00483469 0.01933942 ... 0.11470851 0.30512577 0.01930027]
 [0.00299527 0.0008533  0.00307414 ... 0.01015506 0.02872616 0.00079115]
 [0.00062589 0.0038456  0.00173317 ... 0.01424024 0.03386416 0.02794888]]


100%|██████████| 1/1 [00:00<00:00, 30.01it/s]


First few preds (3/3):  [[0.08087078 0.06981105 0.10798873 ... 0.05945371 0.02077369 0.01850228]
 [0.30649355 0.06584627 0.09344368 ... 0.15516496 0.03338781 0.008315  ]
 [0.06351808 0.00738612 0.02453408 ... 0.16129519 0.29366755 0.03575486]
 [0.00197285 0.00072998 0.00311654 ... 0.01295998 0.03495602 0.00074415]
 [0.00075558 0.00580719 0.00238831 ... 0.0139119  0.06254701 0.03145878]]


## Predictions and in-silico mutagenesis

In [32]:
deepsea.predict_on_batch(seqs[:256, :, :, :])

array([[0.06566241, 0.25948808, 0.0471022 , ..., 0.02816107, 0.14124252,
        0.00198604],
       [0.81844884, 0.6785433 , 0.5278517 , ..., 0.07037374, 0.02024362,
        0.0138507 ],
       [0.88071716, 0.9359733 , 0.86249304, ..., 0.5970436 , 0.00582228,
        0.00755548],
       ...,
       [0.06311597, 0.33425748, 0.02651272, ..., 0.01732896, 0.01285538,
        0.00762695],
       [0.11031462, 0.25913393, 0.04432235, ..., 0.0202245 , 0.03492616,
        0.00666305],
       [0.52195454, 0.17243467, 0.16404016, ..., 0.03128373, 0.03795668,
        0.00460152]], dtype=float32)

In [74]:
def generate_wt_mut_batches(seq, batch_size=128):
    """
    For a given sequence, generate all possible point-mutated versions of the sequence
    in batches of size `param:batch_size`.
    
    Args:
        seq (numpy.ndarray [batch_size, number of base pairs, 1, sequence length]): 
            wild type sequence.
        batch_size (int): size of returned batches. Note that each batch will have the
            wild type sequence as its first row since we need to compute wild type / mut
            prediction diffs using predictions generated by the same dropout mask.
    """
    seq_len = seq.shape[3]
    # 3 mutations per nt and then account for ref in each batch
    n_batches = (seq_len * 3) / (batch_size - 1) + 1
    seq_batch = seq.repeat(batch_size, axis=0)
    seq_batches = seq_batch[np.newaxis, :, :, :, :].repeat(n_batches, axis=0)
#     for seq_idx in range(seq_len):  # iterate over sequence 
#         for nt_idx in range(4):  # iterate over nucleotides
#             np.repeat
    return seq_batches

In [80]:
from tqdm import tqdm
def next_seq(it): 
    return np.expand_dims(next(it)["inputs"].transpose(0, 2, 1), 2).astype(np.float32)
    

epochs = 2
print(len(it))
for epoch in range(epochs):
    it = dl.batch_iter(batch_size=1, shuffle=False, num_workers=0, drop_last=False)
    print(f"Epoch: {epoch+1}/{epochs}")
    for i in tqdm(range(len(it))):
        seq = next_seq(it)
        wt_mut_batches = generate_wt_mut_batches(seq)
        for batch in wt_mut_batches:
            deepsea.predict_on_batch(batch)


1001
Epoch: 0/2





  0%|          | 0/1001 [00:00<?, ?it/s]


  0%|          | 1/1001 [00:05<1:26:51,  5.21s/it]


  0%|          | 2/1001 [00:10<1:26:28,  5.19s/it]


  0%|          | 3/1001 [00:15<1:26:11,  5.18s/it]


  0%|          | 4/1001 [00:20<1:25:59,  5.18s/it]


  0%|          | 5/1001 [00:25<1:25:47,  5.17s/it]


  1%|          | 6/1001 [00:30<1:25:38,  5.16s/it]


  1%|          | 7/1001 [00:36<1:25:32,  5.16s/it]


  1%|          | 8/1001 [00:41<1:25:24,  5.16s/it]


  1%|          | 9/1001 [00:46<1:25:22,  5.16s/it]


  1%|          | 10/1001 [00:51<1:25:14,  5.16s/it]


  1%|          | 11/1001 [00:56<1:25:04,  5.16s/it]


  1%|          | 12/1001 [01:01<1:25:03,  5.16s/it]


  1%|▏         | 13/1001 [01:07<1:24:56,  5.16s/it]


  1%|▏         | 14/1001 [01:12<1:24:41,  5.15s/it]


  1%|▏         | 15/1001 [01:17<1:24:37,  5.15s/it]


  2%|▏         | 16/1001 [01:22<1:24:37,  5.15s/it]


  2%|▏         | 17/1001 [01:27<1:24:29,  5.15s/it]


  2%|▏         | 18/1001 [01:32<1:24:24,  5.1

 15%|█▌        | 151/1001 [12:57<1:13:01,  5.15s/it]


 15%|█▌        | 152/1001 [13:03<1:12:53,  5.15s/it]


 15%|█▌        | 153/1001 [13:08<1:12:46,  5.15s/it]


 15%|█▌        | 154/1001 [13:13<1:12:42,  5.15s/it]


 15%|█▌        | 155/1001 [13:18<1:12:34,  5.15s/it]


 16%|█▌        | 156/1001 [13:23<1:12:31,  5.15s/it]


 16%|█▌        | 157/1001 [13:28<1:12:29,  5.15s/it]


 16%|█▌        | 158/1001 [13:34<1:12:23,  5.15s/it]


 16%|█▌        | 159/1001 [13:39<1:12:14,  5.15s/it]


 16%|█▌        | 160/1001 [13:44<1:12:10,  5.15s/it]


 16%|█▌        | 161/1001 [13:49<1:12:05,  5.15s/it]


 16%|█▌        | 162/1001 [13:54<1:12:01,  5.15s/it]


 16%|█▋        | 163/1001 [13:59<1:11:57,  5.15s/it]


 16%|█▋        | 164/1001 [14:04<1:11:54,  5.15s/it]


 16%|█▋        | 165/1001 [14:10<1:11:44,  5.15s/it]


 17%|█▋        | 166/1001 [14:15<1:11:35,  5.14s/it]


 17%|█▋        | 167/1001 [14:20<1:11:34,  5.15s/it]


 17%|█▋        | 168/1001 [14:25<1:11:28,  5.15s/it]


 17%|█▋   

 30%|██▉       | 300/1001 [25:45<1:00:11,  5.15s/it]


 30%|███       | 301/1001 [25:50<1:00:05,  5.15s/it]


 30%|███       | 302/1001 [25:55<1:00:00,  5.15s/it]


 30%|███       | 303/1001 [26:00<59:52,  5.15s/it]  


 30%|███       | 304/1001 [26:05<59:48,  5.15s/it]


 30%|███       | 305/1001 [26:10<59:43,  5.15s/it]


 31%|███       | 306/1001 [26:16<59:40,  5.15s/it]


 31%|███       | 307/1001 [26:21<59:32,  5.15s/it]


 31%|███       | 308/1001 [26:26<59:28,  5.15s/it]


 31%|███       | 309/1001 [26:31<59:23,  5.15s/it]


 31%|███       | 310/1001 [26:36<59:16,  5.15s/it]


 31%|███       | 311/1001 [26:41<59:07,  5.14s/it]


 31%|███       | 312/1001 [26:46<59:04,  5.14s/it]


 31%|███▏      | 313/1001 [26:52<58:59,  5.15s/it]


 31%|███▏      | 314/1001 [26:57<58:57,  5.15s/it]


 31%|███▏      | 315/1001 [27:02<58:54,  5.15s/it]


 32%|███▏      | 316/1001 [27:07<58:49,  5.15s/it]


 32%|███▏      | 317/1001 [27:12<58:42,  5.15s/it]


 32%|███▏      | 318/1001 [27:17<58:37

 45%|████▌     | 454/1001 [38:58<46:55,  5.15s/it]


 45%|████▌     | 455/1001 [39:03<46:50,  5.15s/it]


 46%|████▌     | 456/1001 [39:08<46:47,  5.15s/it]


 46%|████▌     | 457/1001 [39:13<46:40,  5.15s/it]


 46%|████▌     | 458/1001 [39:18<46:35,  5.15s/it]


 46%|████▌     | 459/1001 [39:23<46:28,  5.15s/it]


 46%|████▌     | 460/1001 [39:29<46:23,  5.15s/it]


 46%|████▌     | 461/1001 [39:34<46:17,  5.14s/it]


 46%|████▌     | 462/1001 [39:39<46:12,  5.14s/it]


 46%|████▋     | 463/1001 [39:44<46:07,  5.14s/it]


 46%|████▋     | 464/1001 [39:49<46:02,  5.14s/it]


 46%|████▋     | 465/1001 [39:54<45:58,  5.15s/it]


 47%|████▋     | 466/1001 [39:59<45:54,  5.15s/it]


 47%|████▋     | 467/1001 [40:05<45:49,  5.15s/it]


 47%|████▋     | 468/1001 [40:10<45:43,  5.15s/it]


 47%|████▋     | 469/1001 [40:15<45:40,  5.15s/it]


 47%|████▋     | 470/1001 [40:20<45:34,  5.15s/it]


 47%|████▋     | 471/1001 [40:25<45:28,  5.15s/it]


 47%|████▋     | 472/1001 [40:30<45:26,  5.15s

 61%|██████    | 608/1001 [52:11<33:42,  5.15s/it]


 61%|██████    | 609/1001 [52:16<33:37,  5.15s/it]


 61%|██████    | 610/1001 [52:21<33:34,  5.15s/it]


 61%|██████    | 611/1001 [52:26<33:29,  5.15s/it]


 61%|██████    | 612/1001 [52:31<33:23,  5.15s/it]


 61%|██████    | 613/1001 [52:36<33:18,  5.15s/it]


 61%|██████▏   | 614/1001 [52:42<33:13,  5.15s/it]


 61%|██████▏   | 615/1001 [52:47<33:07,  5.15s/it]


 62%|██████▏   | 616/1001 [52:52<33:02,  5.15s/it]


 62%|██████▏   | 617/1001 [52:57<32:58,  5.15s/it]


 62%|██████▏   | 618/1001 [53:02<32:52,  5.15s/it]


 62%|██████▏   | 619/1001 [53:07<32:45,  5.15s/it]


 62%|██████▏   | 620/1001 [53:12<32:42,  5.15s/it]


 62%|██████▏   | 621/1001 [53:18<32:37,  5.15s/it]


 62%|██████▏   | 622/1001 [53:23<32:31,  5.15s/it]


 62%|██████▏   | 623/1001 [53:28<32:27,  5.15s/it]


 62%|██████▏   | 624/1001 [53:33<32:23,  5.16s/it]


 62%|██████▏   | 625/1001 [53:38<32:17,  5.15s/it]


 63%|██████▎   | 626/1001 [53:43<32:10,  5.15s

 76%|███████▌  | 760/1001 [1:05:13<20:39,  5.14s/it]


 76%|███████▌  | 761/1001 [1:05:19<20:36,  5.15s/it]


 76%|███████▌  | 762/1001 [1:05:24<20:30,  5.15s/it]


 76%|███████▌  | 763/1001 [1:05:29<20:24,  5.15s/it]


 76%|███████▋  | 764/1001 [1:05:34<20:20,  5.15s/it]


 76%|███████▋  | 765/1001 [1:05:39<20:15,  5.15s/it]


 77%|███████▋  | 766/1001 [1:05:44<20:10,  5.15s/it]


 77%|███████▋  | 767/1001 [1:05:49<20:04,  5.15s/it]


 77%|███████▋  | 768/1001 [1:05:55<20:01,  5.16s/it]


 77%|███████▋  | 769/1001 [1:06:00<19:54,  5.15s/it]


 77%|███████▋  | 770/1001 [1:06:05<19:49,  5.15s/it]


 77%|███████▋  | 771/1001 [1:06:10<19:43,  5.15s/it]


 77%|███████▋  | 772/1001 [1:06:15<19:38,  5.15s/it]


 77%|███████▋  | 773/1001 [1:06:20<19:33,  5.15s/it]


 77%|███████▋  | 774/1001 [1:06:25<19:27,  5.14s/it]


 77%|███████▋  | 775/1001 [1:06:31<19:22,  5.14s/it]


 78%|███████▊  | 776/1001 [1:06:36<19:17,  5.15s/it]


 78%|███████▊  | 777/1001 [1:06:41<19:13,  5.15s/it]


 78%|█████

 91%|█████████ | 909/1001 [1:18:00<07:53,  5.15s/it]


 91%|█████████ | 910/1001 [1:18:06<07:48,  5.15s/it]


 91%|█████████ | 911/1001 [1:18:11<07:43,  5.15s/it]


 91%|█████████ | 912/1001 [1:18:16<07:38,  5.15s/it]


 91%|█████████ | 913/1001 [1:18:21<07:33,  5.15s/it]


 91%|█████████▏| 914/1001 [1:18:26<07:28,  5.15s/it]


 91%|█████████▏| 915/1001 [1:18:31<07:22,  5.15s/it]


 92%|█████████▏| 916/1001 [1:18:37<07:17,  5.15s/it]


 92%|█████████▏| 917/1001 [1:18:42<07:12,  5.15s/it]


 92%|█████████▏| 918/1001 [1:18:47<07:07,  5.15s/it]


 92%|█████████▏| 919/1001 [1:18:52<07:02,  5.15s/it]


 92%|█████████▏| 920/1001 [1:18:57<06:57,  5.15s/it]


 92%|█████████▏| 921/1001 [1:19:02<06:52,  5.15s/it]


 92%|█████████▏| 922/1001 [1:19:07<06:46,  5.15s/it]


 92%|█████████▏| 923/1001 [1:19:13<06:41,  5.15s/it]


 92%|█████████▏| 924/1001 [1:19:18<06:36,  5.15s/it]


 92%|█████████▏| 925/1001 [1:19:23<06:30,  5.14s/it]


 93%|█████████▎| 926/1001 [1:19:28<06:25,  5.15s/it]


 93%|█████

Epoch: 1/2





  0%|          | 0/1001 [00:00<?, ?it/s]


StopIteration: 

In [None]:
CHROM_ACC_COL = 'A549_DNase_None'
TF_COL = 'A549_CTCF_None'
relevant_cols = sorted([(i, label)
                        for i, label in enumerate(deepsea.schema.targets.column_labels)
                        if label in [CHROM_ACC_COL, TF_COL]])

def output_sel_fn(result):
    return np.array([result[col_idx] for col_idx, _ in relevant_cols])

relevant_cols

In [None]:
ism = Mutation(deepsea, "seq", scores=['ref', 'diff'], output_sel_fn=output_sel_fn)
ism_score = np.array(ism.score(seqs))

In [None]:
np.squeeze(ism_score).shape

In [None]:
zeros = np.zeros((2, 3))
def sanitize_scores(scores):
    orig_shape = scores.shape
    sanitized_scores = np.ndarray((*orig_shape, 2, 3), dtype=scores.dtype)
    flattened_scores = scores.reshape(-1)
    
    for i, score in enumerate(flattened_scores):
        idx = np.unravel_index(i, orig_shape)
        if score is None: sanitized_scores[idx] = zeros
        else: sanitized_scores[idx] = np.array(score)
    return sanitized_scores

sanitized_scores = sanitize_scores(np.squeeze(ism_score))
ctcf_original_preds = sanitized_scores[:, :, :, 0, 1]
ctcf_pred_diffs = sanitized_scores[:, :, :, 1, 1]
dnase_original_preds = sanitized_scores[:, :, :, 0, 0]
dnase_pred_diff = sanitized_scores[:, :, :, 1, 0]

In [None]:
import math
import numpy as np

log_uniform_prop = math.log(.05/(1-.05))
def compute_normalized_prob(prob, train_prob):
    denom = 1+np.exp(-(np.log(prob/(1-prob))+log_uniform_prop-np.log(train_prob/(1-train_prob))))
    return 1 / denom

# Ratios and normalization formula drawn from here: http://deepsea.princeton.edu/media/help/posproportion.txt
tf_compute_normalized_prob = lambda prob: compute_normalized_prob(prob, .020029)
# ENCODE	A549	DNase	None	0.048136
chrom_acc_normalized_prob = lambda prob: compute_normalized_prob(prob, 0.048136)

In [None]:
IDX_TO_NT = 'ACGT'

def _convert_to_mutation(pos_nt_pair):
    return "%d%s" % (pos_nt_pair[1], IDX_TO_NT[pos_nt_pair[0]])

def find_best_mutation(mutation_map, sign=1):
    seq_len = mutation_map.shape[1]
    best_mute = None
    best_mute_effect = 0
    second_best_mute = None
    second_best_mute_effect = 0
    for seq_idx in range(seq_len):  # iterate over sequence
        best_out_of_nts_mute = None
        best_out_of_nts_mute_effect = 0 
        for nt_idx in range(4):  # iterate over nucleotides
            current_effect = mutation_map[nt_idx, seq_idx]
            
            if sign * current_effect > sign * best_out_of_nts_mute_effect:
                best_out_of_nts_mute_effect = current_effect
                best_out_of_nts_mute = (nt_idx, seq_idx)

        # TODO(Stephen): the right way to do this is to have a heap, squish the 2D array into a 1D
        # array of (effect, (seq_idx, nt_idx)) tuples and then take the top-k from the heap, but
        # I'm too lazy to do this right now.
        if sign * best_out_of_nts_mute_effect > (sign * best_mute_effect):
            best_mute_effect = best_out_of_nts_mute_effect
            best_mute = best_out_of_nts_mute
        elif sign * best_out_of_nts_mute_effect > (sign * second_best_mute_effect):
            second_best_mute_effect = best_out_of_nts_mute_effect
            second_best_mute = best_out_of_nts_mute 
    print(best_mute, best_mute_effect, second_best_mute, second_best_mute_effect)
    return [best_mute, second_best_mute]


import csv

mut_effects_fpath = "../dat/mut_effects_1.csv"
with open(mut_effects_fpath, 'w', newline="") as out_file:
    fieldnames = [
        "initial X prediction",
        "new X prediction",
        "initial Y prediction",
        "new Y prediction",
        "mutation",
    ]
    writer = csv.DictWriter(out_file, delimiter=",", fieldnames=fieldnames)
    writer.writeheader()
    
    for i, ctcf_pred_diff in enumerate(ctcf_pred_diffs):
        best_mut_idx, best_mut_effect = find_best_mutation(ctcf_pred_diff, sign=-1)
        nt_idx, seq_idx = best_mut_idx
        ctcf_original_pred = ctcf_original_preds[i][nt_idx, seq_idx]
        ctcf_new_pred = ctcf_original_pred + ctcf_pred_diff[nt_idx, seq_idx]
        dnase_original_pred = dnase_original_preds[i][nt_idx, seq_idx]
        dnase_pred_diff = dnase_pred_diffs[i][nt_idx, seq_idx]
        dnase_new_prd = dnase_original_pred + dnase_pred_diff
        writer.writerow(
            {
                "initial X prediction": ctcf_original_pred,
                "new X prediction": ctcf_new_pred,
                "initial Y prediction": dnase_original_pred,
                "new Y prediction": dnase_new_pred,
                "mutation": _convert_to_mutation(best_mut_idx),
            }
        )
        
