In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from matplotlib import pyplot as plt
from IPython.display import clear_output
import numpy as np
import pandas as pd
import seaborn as sns
import simdna
from simdna import synthetic
import statsmodels.api as sm
import torch
from tqdm.auto import tqdm

from utils import one_hot_decode
from in_silico_mutagenesis import compute_summary_statistics, generate_wt_mut_batches, write_results
from pyx.one_hot import one_hot
from tf_coop_model import CountsRegressor, IterablePandasDataset
from tf_coop_model import anscombe_transform, run_one_epoch, spearman_rho, pearson_r
from tf_coop_simulation import background_frequency
from tf_coop_simulation import simulate_counts, simulate_oracle_predictions

  import pandas.util.testing as tm


In [2]:
np.random.seed(42)

In [3]:
sequence_length = 100
includes_confounder = True

data_dir = '../dat/sim_conf/'
test_data_fpath = os.path.join(data_dir, 'test_labels.csv')
raw_simulation_data_fpath = os.path.join(data_dir, 'test_sequences.simdata')

sequences_col = "sequences"
label_cols = ["labels_exp", "labels_out"]
batch_size = 1000
n_samples = 10

exposure_motif = "GATA_disc1"
outcome_motif = "TAL1_known1"
confounder_motif = "SOX2_1" if includes_confounder else None

In [4]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f6fd3233c88>

# Load Test Data & Generate Predictions
Now we're going to load test data to get some basic metrics about how well our model performs.

In [10]:
test_df = pd.read_csv(test_data_fpath)
test_dataset = IterablePandasDataset(
    test_df, x_cols=sequences_col, y_cols=label_cols, x_transform=one_hot,
)
test_data_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size, num_workers=0
)

test_df.head()

Unnamed: 0,sequences,labels_exp,labels_out,has_exposure,has_outcome,has_both,has_confounder
0,GAAAGGGATCAAGCTCCAGATAAGGGGTTAGAGGGTAGCCTGACGG...,113,111,1,1,1,0
1,TAATCCCATCGGTGAGCCTTCGTGGGTGCTCAGTCCTAAACAGTAT...,18,39,0,1,0,1
2,GATCTCTTCATGTCGAGCCGGTCCGTATGCGATAAGAACCATTTAG...,123,27,1,0,0,1
3,TGTAACTGTCAATATGATTTGTCACTCTAATCCTGAAAAATCCCAT...,103,37,1,0,0,0
4,CAGGTATTTGATACTAGAACGAGAGCATGCAGTCCCTTATACCATT...,17,32,0,0,0,1


In [11]:
both_motifs_df = test_df[(test_df['has_both'] == 1)]
exposure_motif_df = test_df[(test_df['has_exposure'] == 1) & (test_df['has_outcome'] == 0)]
outcome_motif_df = test_df[(test_df['has_exposure'] == 0) & (test_df['has_outcome'] == 1)]
neither_motif_df = test_df[
    (test_df['has_exposure'] == 0) & (test_df['has_outcome'] == 0)
]

In [12]:
len(both_motifs_df), len(exposure_motif_df), len(outcome_motif_df), len(neither_motif_df)

(2405, 2573, 2569, 2453)

In [13]:
both_motifs_dataset = IterablePandasDataset(
    both_motifs_df, x_cols=sequences_col, y_cols=label_cols, x_transform=one_hot,
)
exposure_motif_dataset = IterablePandasDataset(
    exposure_motif_df, x_cols=sequences_col, y_cols=label_cols, x_transform=one_hot,
)

In [14]:
def mutate(seqs):
    preds = {}
    all_muts = []
    for seq in tqdm(seqs):
        muts = generate_wt_mut_batches(seq, seq.shape[0] * seq.shape[1]).squeeze()
        muts = muts.transpose(0, 1, 2)
        all_muts.append(muts)
    return np.array(all_muts)

In [15]:
both_motifs_sample_seqs = [x for x, y in both_motifs_dataset]

In [16]:
exposure_motif_sample_seqs = [x for x, y in exposure_motif_dataset]

In [17]:
sample_seqs = np.array([seq for seq, label in both_motifs_dataset])
sample_labels = np.array([label for _, label in both_motifs_dataset])
sample_muts = mutate(sample_seqs)

HBox(children=(FloatProgress(value=0.0, max=2405.0), HTML(value='')))




In [18]:
motifs = synthetic.LoadedEncodeMotifs(
    simdna.ENCODE_MOTIFS_PATH, pseudocountProb=0.001
)
exposure_pwm = motifs.loadedMotifs[exposure_motif].getRows()
outcome_pwm = motifs.loadedMotifs[outcome_motif].getRows()
confounder_pwm = None if confounder_motif is None else motifs.loadedMotifs[confounder_motif].getRows()
confounder_pwm

array([[0.04068952, 0.66525133, 0.13354957, 0.16050958],
       [0.03011511, 0.67670287, 0.03310211, 0.26007991],
       [0.48855021, 0.01070254, 0.00622302, 0.49452323],
       [0.00323701, 0.01518305, 0.01070254, 0.9708774 ],
       [0.02115607, 0.00174351, 0.00921003, 0.96789139],
       [0.0569942 , 0.12419193, 0.79019127, 0.0286226 ],
       [0.12736875, 0.00623201, 0.01370953, 0.85268971],
       [0.09596219, 0.29037858, 0.11540373, 0.4982545 ],
       [0.61225039, 0.18624981, 0.05724994, 0.14424986],
       [0.03924996, 0.05875044, 0.03924996, 0.86274964],
       [0.0408114 , 0.11592321, 0.66725133, 0.17601406],
       [0.04832188, 0.63420141, 0.12643968, 0.19103602],
       [0.64117443, 0.06193525, 0.05441278, 0.24247753],
       [0.55090379, 0.1010531 , 0.23645955, 0.11158455],
       [0.72994158, 0.09052064, 0.09202513, 0.08751165]])

In [19]:
adjusted_labels_ism = []
for i, muts in enumerate(tqdm(sample_muts)):
    adjusted_labels_ = simulate_oracle_predictions(
        [one_hot_decode(mut) for mut in muts],
        exposure_pwm,
        outcome_pwm,
        confounder_pwm=confounder_pwm,
    )
    adjusted_labels_ism.append(adjusted_labels_)

HBox(children=(FloatProgress(value=0.0, max=2405.0), HTML(value='')))




In [20]:
adjusted_labels_ism_no_conf = []
if includes_confounder:
    for i, muts in enumerate(tqdm(sample_muts)):
        adjusted_labels_ = simulate_oracle_predictions(
            [one_hot_decode(mut) for mut in muts],
            exposure_pwm,
            outcome_pwm,
        )
        adjusted_labels_ism_no_conf.append(adjusted_labels_)

HBox(children=(FloatProgress(value=0.0, max=2405.0), HTML(value='')))




In [21]:
adjusted_labels_ism = np.array(adjusted_labels_ism)
adjusted_labels_ism = adjusted_labels_ism.transpose((0, 2, 1))
adjusted_labels_ism = np.array(adjusted_labels_ism).reshape(len(sample_seqs), 4, 100, -1)
adjusted_labels_ism_anscombe = anscombe_transform(adjusted_labels_ism)

In [22]:
adjusted_labels_ism_no_conf = np.array(adjusted_labels_ism_no_conf)
adjusted_labels_ism_no_conf = adjusted_labels_ism_no_conf.transpose((0, 2, 1))
adjusted_labels_ism_no_conf = np.array(adjusted_labels_ism_no_conf).reshape(len(sample_seqs), 4, 100, -1)
adjusted_labels_ism_no_conf_anscombe = anscombe_transform(adjusted_labels_ism_no_conf)

In [25]:
seq_idxs = np.array(sample_seqs).astype(np.bool)
adjusted_ref_labels_ism = adjusted_labels_ism_anscombe[seq_idxs].reshape(len(sample_seqs), 1, 100, -1)
adjusted_mut_labels_ism = adjusted_labels_ism_anscombe[~seq_idxs].reshape(len(sample_seqs), 3, 100, -1)
adjusted_diffs = adjusted_mut_labels_ism - adjusted_ref_labels_ism

In [26]:
seq_idxs = np.array(sample_seqs).astype(np.bool)
adjusted_ref_labels_ism_no_conf = adjusted_labels_ism_no_conf_anscombe[seq_idxs].reshape(len(sample_seqs), 1, 100, -1)
adjusted_mut_labels_ism_no_conf = adjusted_labels_ism_no_conf_anscombe[~seq_idxs].reshape(len(sample_seqs), 3, 100, -1)
adjusted_diffs_no_conf = adjusted_mut_labels_ism_no_conf - adjusted_ref_labels_ism_no_conf

In [27]:
from filter_instrument_candidates import filter_variants_by_score
sig_var_idxs = filter_variants_by_score(adjusted_diffs[:, :, :, 0])
sig_var_idxs_no_conf = filter_variants_by_score(adjusted_diffs_no_conf[:, :, :, 0])

In [28]:
ols_results = []
for i in range(len(sample_seqs)):
    if adjusted_diffs[i, sig_var_idxs[i, :, :], 0].shape[0] > 0:
        x = adjusted_diffs[i, sig_var_idxs[i, :, :], 0].flatten()
        y = adjusted_diffs[i, sig_var_idxs[i, :, :], 1].flatten()
        ols_res = sm.OLS(y, x).fit()
        ols_results.append(ols_res)

In [29]:
ols_results_no_conf = []
for i in range(len(sample_seqs)):
    if adjusted_diffs_no_conf[i, sig_var_idxs_no_conf[i, :, :], 0].shape[0] > 0:
        x = adjusted_diffs_no_conf[i, sig_var_idxs_no_conf[i, :, :], 0].flatten()
        y = adjusted_diffs_no_conf[i, sig_var_idxs_no_conf[i, :, :], 1].flatten()
        ols_res = sm.OLS(y, x).fit()
        ols_results_no_conf.append(ols_res)

In [30]:
ism_cis = [ols_res.params[0] for ols_res in ols_results]
ism_cis_no_conf = [ols_res.params[0] for ols_res in ols_results_no_conf]
ism_cis[:10], ism_cis_no_conf[:10]

([-0.2655093290648858,
  0.4964853733018583,
  0.5802494380989565,
  0.5571266990576756,
  0.7756769381167032,
  0.5869291038673135,
  0.7678002337644397,
  0.08471587946308295,
  0.8361671869733878,
  0.6801398626793922],
 [0.6647231664148179,
  -0.07317999884519363,
  0.4361841126795202,
  0.9946875636373229,
  0.8302481118440905,
  0.8045705939825488,
  -0.047022975077662475,
  0.5987288777550359,
  0.6680382923953441,
  0.39102479113925365])

In [34]:
wald_results = []
for i in range(len(sample_seqs)):
    if adjusted_diffs[i, sig_var_idxs[i, :, :], 0].shape[0] > 0:
        wald_results.append(
            np.mean(adjusted_diffs[i, sig_var_idxs[i, :, :], 1] / adjusted_diffs[i, sig_var_idxs[i, :, :], 0])
        )

In [None]:
len(sample_seqs)

In [None]:
sns.distplot(ism_cis, kde=False)

In [31]:
output_dir = os.path.join(data_dir, 'res')
print(output_dir)
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, 'GATA_TAL1_true_ces.csv'), 'w') as f:
    f.write('seq, CI\n')
    for i, ci in enumerate(ism_cis):
        f.write(f'{i}, {ci}\n')

../dat/sim_conf/res


In [36]:
output_dir = os.path.join(data_dir, 'res')
print(output_dir)
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, 'GATA_TAL1_true_ces_no_conf.csv'), 'w') as f:
    f.write('seq, CI\n')
    for i, ci in enumerate(ism_cis_no_conf):
        f.write(f'{i}, {ci}\n')

../dat/sim_conf/res


In [35]:
output_dir = os.path.join(data_dir, 'res')
print(output_dir)
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, 'GATA_TAL1_true_ces_wald.csv'), 'w') as f:
    f.write('seq, CI\n')
    for i, ci in enumerate(wald_results):
        f.write(f'{i}, {ci}\n')

../dat/sim_conf/res
