## Imports and model initialization

In [11]:
%load_ext autoreload
%autoreload 2

import csv
from datetime import datetime
import math
import os
from pathlib import Path
import pickle

from IPython.display import clear_output
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from utils import detect_device

import bpnet
from bpnet.datasets import StrandedProfile
from bpnet.dataspecs import DataSpec, TaskSpec
from bpnet.utils import create_tf_session
from bpnet.utils import read_json
from bpnet.seqmodel import SeqModel
from bpnet.plot.evaluate import plot_loss, regression_eval

from in_silico_mutagenesis import compute_summary_statistics, generate_wt_mut_batches, write_results

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
!pwd

/home/ubuntu/dev/an1lam/deepmr/src


In [13]:
timestamp = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
output_dir = f'/home/ubuntu/dev/an1lam/deepmr/dat/res-bpnet-{timestamp}'
factor_names = ['Oct4', 'Sox2', 'Nanog', 'Klf4']

exposure_outcome_pairs = [
    (exposure_name, outcome_name)
    for exposure_name in factor_names
    for outcome_name in factor_names
    if exposure_name != outcome_name
]

results_fnames = [
    f'{exposure_name}_{outcome_name}_effect_sizes.csv' for exposure_name, outcome_name in exposure_outcome_pairs
]
n_seqs = 2000
alphabet_size = 4

In [14]:
os.makedirs(output_dir, exist_ok=True)

# Loading BPNet

In [15]:
class Ensemble:
    def __init__(self, model_base_dir, n_reps=5):
        models = []
        for i in range(n_reps):
            models.append(SeqModel.from_mdir(os.path.join(model_base_dir, str(i))))
        self.models = models

    def predict(self, seqs):
        preds = {}
        for model in self.models:
            model_preds = model.predict(seqs)
            for key, preds_ in model_preds.items():
                preds.setdefault(key, []).append(preds_.mean(-1))
        return {k: np.stack(v) for k, v in preds.items()}

In [52]:
model_base_dir = "/home/ubuntu/dev/an1lam/deepmr/dat/res-bpnet-training-2021-01-02-20-11-07/output_ensemble"

In [53]:
model_dir = Path(model_base_dir)

In [54]:
create_tf_session(0)
clear_output()

In [57]:
model = Ensemble(model_base_dir, n_reps=1)
clear_output()

# Loading data

In [58]:
!cat {model_base_dir}/0/evaluation.valid.json

{
  "valid-peaks": {
    "Oct4/profile/binsize=1/auprc": 0.18009270799318408,
    "Oct4/profile/binsize=1/random_auprc": 0.002962928723389025,
    "Oct4/profile/binsize=1/n_positives": 49407,
    "Oct4/profile/binsize=1/frac_ambigous": 0.07197004608294931,
    "Oct4/profile/binsize=1/imbalance": 0.0029207034942389304,
    "Oct4/profile/binsize=10/auprc": 0.4895772090647788,
    "Oct4/profile/binsize=10/random_auprc": 0.03456102474822499,
    "Oct4/profile/binsize=10/n_positives": 39490,
    "Oct4/profile/binsize=10/frac_ambigous": 0.36034507351327627,
    "Oct4/profile/binsize=10/imbalance": 0.03386899927356185,
    "Oct4/counts/mse": 0.28948304057121277,
    "Oct4/counts/var_explained": 0.3258916735649109,
    "Oct4/counts/pearsonr": 0.5844209726909169,
    "Oct4/counts/spearmanr": 0.5507967253432144,
    "Oct4/counts/mad": 0.4073560833930969,
    "Sox2/profile/binsize=1/auprc": 0.3745101241906123,
    "Sox2/profile/binsize=1/random_auprc": 0.005811615354319482,
   

In [62]:
gin_config = read_json(os.path.join(model_base_dir, '0', 'config.gin.json'))
gin_config

{'Adam.amsgrad': False,
 'Adam.beta_1': 0.9,
 'Adam.beta_2': 0.999,
 'Adam.decay': 0.0,
 'Adam.epsilon': 'None',
 'Adam.lr': 0.004,
 'DeConv1D.batchnorm': False,
 'DeConv1D.filters': 64,
 'DeConv1D.n_hidden': 0,
 'DeConv1D.n_tasks': 2,
 'DeConv1D.padding': 'same',
 'DeConv1D.tconv_kernel_size': 25,
 'DilatedConv1D.add_pointwise': False,
 'DilatedConv1D.batchnorm': False,
 'DilatedConv1D.conv1_kernel_size': 25,
 'DilatedConv1D.filters': 64,
 'DilatedConv1D.n_dil_layers': 9,
 'DilatedConv1D.padding': 'same',
 'DilatedConv1D.skip_type': 'residual',
 'GlobalAvgPoolFCN.batchnorm': False,
 'GlobalAvgPoolFCN.dropout': 0,
 'GlobalAvgPoolFCN.dropout_hidden': 0,
 'GlobalAvgPoolFCN.hidden': 'None',
 'GlobalAvgPoolFCN.n_splines': 0,
 'GlobalAvgPoolFCN.n_tasks': 2,
 'MovingAverages.window_sizes': [1, 50],
 'PeakPredictionProfileMetric.binsizes': [1, 10],
 'PeakPredictionProfileMetric.neg_max_threshold': 0.005,
 'PeakPredictionProfileMetric.pos_min_threshold': 0.015,
 'PeakPredictionProfileMetric.re

In [63]:
ds = DataSpec.load(os.path.join(model_base_dir, '0', 'dataspec.yml')) # remember to re-add 0
tasks = list(ds.task_specs)
tasks

['Oct4', 'Sox2', 'Nanog', 'Klf4']

In [64]:
dl_valid = StrandedProfile(ds, 
                           incl_chromosomes=gin_config['bpnet_data.valid_chr'], 
                           peak_width=gin_config['bpnet_data.peak_width'],
                           seq_width=gin_config['bpnet_data.seq_width'],
                           shuffle=False)

In [42]:
valid = dl_valid.load_all(batch_size=1, num_workers=1)

100%|██████████| 29277/29277 [08:34<00:00, 56.94it/s] 


In [44]:
valid['targets']['Oct4/counts'].shape

(29277, 2)

# (Re-)Calibration

In [68]:
def fit_recalibrators(model, dataset: dict):
    seqs = dataset['inputs']['seq']
    features = dataset['targets'].keys()
    predictions = {f: np.zeros((seqs.shape[0], n_reps)) for f in features.items()}
    ys = {k: v.mean(axis=-1) for k, v in dataset['targets']}
    for i, seq in enumerate(seqs):
        p = model.predict(seq)
        for f in features:
            predictions[feature][i] = p[feature].squeeze()

    pred_means = {}
    pred_stds = {}
    for f in features:
        pred_means[f] = np.mean(predictions[f], axis=0).squeeze()
        pred_stds[f] = np.std(predictions[f], axis=0).squeeze()

    recal_models = []
    for f in features:
        y = ys[f]
        pred_mean, pred_std = pred_means[f], pred_stds[f]
        exp_props, obs_props = get_proportion_lists_vectorized(pred_mean, pred_std, y)
        recal_model = iso_recal(exp_props, obs_props)
        recal_models[f] = recal_model
    return recal_models

def recal_predict(recalibrators, preds):
    features = preds.keys()
    pred_means = {}
    pred_stds = {}
    for f in features:
        pred_means[f] = np.mean(preds[f], axis=0).squeeze()
        pred_stds[f] = np.std(preds[f], axis=0).squeeze()
    recal_preds = {k: np.zeros_like(v) for k, v in preds.items()}
    for f in features:
        pred_dist = stats.norm(loc=pred_means[f], scale=pred_stds[f])
        for c in range(preds[f].shape[0]):
            recal_model = recalibrators[f]
            orig_preds = preds[f][c, :]
            orig_quantiles = pred_dist.cdf(orig_preds)
            recal_quantiles = recal_model.predict(orig_quantiles)
            recal_preds[f][c] = pred_dist.ppf(recal_quantiles)
    return recal_preds


## Predictions and in-silico mutagenesis

In [66]:
cols = [f'{factor_name}/counts' for factor_name in factor_names]

In [17]:
valid['inputs']['seq'].shape

(29277, 1000, 4)

In [18]:
valid_seqs = []
for seq in valid['inputs']['seq']:
    if ((seq == 0.0) | (seq == 1.0)).all():
        valid_seqs.append(seq)
valid_seqs = np.array(valid_seqs)
valid_seqs.shape

(29264, 1000, 4)

In [19]:
np.random.seed(42)
idxs = np.arange(len(valid_seqs))
np.random.shuffle(idxs)
sample_seqs = valid_seqs[idxs[:n_seqs]]
sample_seqs.shape

(2000, 1000, 4)

In [None]:
recal_models = fit_recalibrators(model, valid)

In [20]:
n_seqs = sample_seqs.shape
preds = {}
recal_preds = {}
for seq in tqdm(sample_seqs):
    muts = generate_wt_mut_batches(seq.T, seq.shape[0] * seq.shape[1]).squeeze()
    preds_ = model.predict(muts.transpose(0, 2, 1))
    recal_preds_ = recal_predict(recal_models, preds_)
    for key, value in preds_.items():
        if key in cols:
            preds.setdefault(key, []).append(preds_[key])
            recal_preds.setdefault(key, []).append(recal_preds_[key])

HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))




In [21]:
np.array(preds['Oct4/counts']).shape

(2000, 5, 4000)

In [22]:
seqs = sample_seqs.transpose(0, 2, 1)

for exposure, outcome in exposure_outcome_pairs:
    exposure_col = f'{exposure}/counts'
    outcome_col = f'{outcome}/counts'
    
    formatted_preds = np.stack((preds[exposure_col], preds[outcome_col]))
    n_features, n_seqs, n_reps, n_variants = formatted_preds.shape
    formatted_preds = formatted_preds.transpose(2, 1, 3, 0)
    formatted_preds = formatted_preds.reshape(n_reps, n_seqs, alphabet_size, -1, n_features)

    means, mean_diffs, stderrs = compute_summary_statistics(formatted_preds, seqs)
    
    results_fname = f'{exposure}_{outcome}_effect_sizes.csv'
    results_fpath = os.path.join(output_dir, results_fname)
    write_results(results_fpath, mean_diffs, stderrs)
    print(results_fpath)

HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))




  stderrs = np.sqrt(ref_vars + mut_vars - 2 * covs)


/home/stephenmalina/dev/an1lam/deepmr/dat/res-bpnet-2020-09-09-13-33-15/Oct4_Sox2_effect_sizes.csv


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


/home/stephenmalina/dev/an1lam/deepmr/dat/res-bpnet-2020-09-09-13-33-15/Oct4_Nanog_effect_sizes.csv


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


/home/stephenmalina/dev/an1lam/deepmr/dat/res-bpnet-2020-09-09-13-33-15/Oct4_Klf4_effect_sizes.csv


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


/home/stephenmalina/dev/an1lam/deepmr/dat/res-bpnet-2020-09-09-13-33-15/Sox2_Oct4_effect_sizes.csv


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


/home/stephenmalina/dev/an1lam/deepmr/dat/res-bpnet-2020-09-09-13-33-15/Sox2_Nanog_effect_sizes.csv


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


/home/stephenmalina/dev/an1lam/deepmr/dat/res-bpnet-2020-09-09-13-33-15/Sox2_Klf4_effect_sizes.csv


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


/home/stephenmalina/dev/an1lam/deepmr/dat/res-bpnet-2020-09-09-13-33-15/Nanog_Oct4_effect_sizes.csv


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


/home/stephenmalina/dev/an1lam/deepmr/dat/res-bpnet-2020-09-09-13-33-15/Nanog_Sox2_effect_sizes.csv


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


/home/stephenmalina/dev/an1lam/deepmr/dat/res-bpnet-2020-09-09-13-33-15/Nanog_Klf4_effect_sizes.csv


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


/home/stephenmalina/dev/an1lam/deepmr/dat/res-bpnet-2020-09-09-13-33-15/Klf4_Oct4_effect_sizes.csv


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


/home/stephenmalina/dev/an1lam/deepmr/dat/res-bpnet-2020-09-09-13-33-15/Klf4_Sox2_effect_sizes.csv


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


/home/stephenmalina/dev/an1lam/deepmr/dat/res-bpnet-2020-09-09-13-33-15/Klf4_Nanog_effect_sizes.csv
