## Imports and model initialization

In [1]:
%reload_ext autoreload
%autoreload 2

# !pip install kipoi
# !pip install kipoiseq
# !pip install pybedtools
# !pip uninstall -y kipoi_veff
# !pip install git+https://github.com/an1lam/kipoi-veff
# !pip install pyvcf
import csv
import math
import pickle

import kipoi
from kipoi_interpret.importance_scores.ism import Mutation
from kipoiseq.dataloaders import SeqIntervalDl
from matplotlib import pyplot as plt
import numpy as np
from tqdm import tqdm

from custom_dropout import LockedWeightDropout
from in_silico_mutagenesis import compute_summary_statistics
from in_silico_mutagenesis import filter_predictions_to_matching_cols
from in_silico_mutagenesis import mutate_and_predict
from in_silico_mutagenesis import deepsea_normalizers
from motif_scores import build_impact_maps

  from tqdm.autonotebook import tqdm


In [2]:
!pwd

/home/stephenmalina/project/src


# Loading DNA sequence data

In [3]:
dl = SeqIntervalDl("../dat/50_random_seqs_2.bed", "../dat/hg19.fa", auto_resize_len=1000)
data = dl.load_all()

100%|██████████| 2/2 [00:00<00:00,  3.99it/s]


In [4]:
n_seqs = 50
seqs = np.expand_dims(data['inputs'].transpose(0, 2, 1), 2).astype(np.float32)
seqs.shape

(50, 4, 1, 1000)

# Loading DeepSEA

In [5]:
import tensorflow as tf
print("TF version:", tf.__version__)
import torch
print("torch version:", torch.__version__)
from torch import nn

TF version: 1.15.0
torch version: 1.3.1


In [6]:
# df = kipoi.list_models()
# deepsea_models = df[df.model.str.contains("DeepSEA")]
# deepsea_models.head()

In [7]:
deepsea = kipoi.get_model("DeepSEA/predict", source="kipoi")
deepsea.model

Using downloaded and verified file: /home/stephenmalina/.kipoi/models/DeepSEA/predict/downloaded/model_files/weights/89e640bf6bdbe1ff165f484d9796efc7


Sequential(
  (0): ReCodeAlphabet()
  (1): ConcatenateRC()
  (2): Sequential(
    (0): Conv2d(4, 320, kernel_size=(1, 8), stride=(1, 1))
    (1): Threshold(threshold=0, value=1e-06)
    (2): MaxPool2d(kernel_size=(1, 4), stride=(1, 4), padding=0, dilation=1, ceil_mode=False)
    (3): Dropout(p=0.2, inplace=False)
    (4): Conv2d(320, 480, kernel_size=(1, 8), stride=(1, 1))
    (5): Threshold(threshold=0, value=1e-06)
    (6): MaxPool2d(kernel_size=(1, 4), stride=(1, 4), padding=0, dilation=1, ceil_mode=False)
    (7): Dropout(p=0.2, inplace=False)
    (8): Conv2d(480, 960, kernel_size=(1, 8), stride=(1, 1))
    (9): Threshold(threshold=0, value=1e-06)
    (10): Dropout(p=0.5, inplace=False)
    (11): Lambda()
    (12): Sequential(
      (0): Lambda()
      (1): Linear(in_features=50880, out_features=925, bias=True)
    )
    (13): Threshold(threshold=0, value=1e-06)
    (14): Sequential(
      (0): Lambda()
      (1): Linear(in_features=925, out_features=919, bias=True)
    )
    (15):

In [8]:
deepsea.pipeline.predict_example().shape

100%|██████████| 1/1 [00:00<00:00,  4.56it/s]


(10, 919)

# Hacking DeepSEA's layers
This section is focused on tweaking DeepSEA's model to use dropout when making predictions.

In [9]:
from custom_dropout import apply_dropout, replace_dropout_layers, unapply_dropout

In [10]:
deepsea.model = replace_dropout_layers(deepsea.model, dropout_cls=LockedWeightDropout)
deepsea.model

Sequential(
  (0): ReCodeAlphabet()
  (1): ConcatenateRC()
  (2): Sequential(
    (0): Conv2d(4, 320, kernel_size=(1, 8), stride=(1, 1))
    (1): Threshold(threshold=0, value=1e-06)
    (2): MaxPool2d(kernel_size=(1, 4), stride=(1, 4), padding=0, dilation=1, ceil_mode=False)
    (3): LockedWeightDropout(p=0.2)
    (4): Conv2d(320, 480, kernel_size=(1, 8), stride=(1, 1))
    (5): Threshold(threshold=0, value=1e-06)
    (6): MaxPool2d(kernel_size=(1, 4), stride=(1, 4), padding=0, dilation=1, ceil_mode=False)
    (7): LockedWeightDropout(p=0.2)
    (8): Conv2d(480, 960, kernel_size=(1, 8), stride=(1, 1))
    (9): Threshold(threshold=0, value=1e-06)
    (10): LockedWeightDropout(p=0.5)
    (11): Lambda()
    (12): Sequential(
      (0): Lambda()
      (1): Linear(in_features=50880, out_features=925, bias=True)
    )
    (13): Threshold(threshold=0, value=1e-06)
    (14): Sequential(
      (0): Lambda()
      (1): Linear(in_features=925, out_features=919, bias=True)
    )
    (15): Sigmoid(

In [11]:
deepsea.model.eval()
deepsea.model = deepsea.model.apply(apply_dropout)
for i in range(2):
    print(f"First few preds ({i+1}/3): ", deepsea.pipeline.predict_example()[:5])

100%|██████████| 1/1 [00:00<00:00, 25.92it/s]


First few preds (1/3):  [[0.08620486 0.07117724 0.08076133 ... 0.07425417 0.0193945  0.01526004]
 [0.04501081 0.00637511 0.02325991 ... 0.0854755  0.04083367 0.00989448]
 [0.04042565 0.0047662  0.01646545 ... 0.11744817 0.3626911  0.02936701]
 [0.00452696 0.00096715 0.0054947  ... 0.01194064 0.04566119 0.00112314]
 [0.00078801 0.01083222 0.00274182 ... 0.01452547 0.05021808 0.02226263]]


100%|██████████| 1/1 [00:00<00:00, 26.02it/s]


First few preds (2/3):  [[0.13888726 0.10015524 0.18727924 ... 0.11640857 0.02568786 0.01024848]
 [0.0296017  0.00411213 0.02697972 ... 0.12137947 0.07428952 0.00958387]
 [0.05480719 0.00542989 0.02732315 ... 0.16602823 0.40117437 0.02221408]
 [0.00295882 0.00070132 0.00334569 ... 0.00829488 0.04764904 0.00081818]
 [0.00103597 0.00698191 0.00252939 ... 0.01807006 0.07374699 0.02256136]]


## Predictions and in-silico mutagenesis

In [12]:
CHROM_ACC_COL_NAME = 'HepG2_DNase_None'
TF_COL_NAME = 'HepG2_FOXA1_None'
relevant_cols = sorted([(i, label)
                        for i, label in enumerate(deepsea.schema.targets.column_labels)
                        if label in [CHROM_ACC_COL_NAME, TF_COL_NAME]])

def output_sel_fn(result):
    return np.array([result[:, col_idx] for col_idx, _ in relevant_cols]).T

relevant_cols

[(56, 'HepG2_DNase_None'),
 (302, 'HepG2_FOXA1_None'),
 (303, 'HepG2_FOXA1_None')]

In [13]:
sample_preds = np.zeros((50, 5, 3))

for i in range(50):
    sample_preds[i] = output_sel_fn(deepsea.predict_on_batch(seqs[:5, :, :, ]))
    
np.mean(sample_preds, axis=0), np.std(sample_preds, axis=0)

(array([[0.01526428, 0.0424991 , 0.05411917],
        [0.16690652, 0.24805285, 0.25762313],
        [0.01931094, 0.20745062, 0.22377067],
        [0.10558871, 0.24706267, 0.21349985],
        [0.06866537, 0.21062837, 0.21415865]]),
 array([[0.00434706, 0.02149264, 0.02725211],
        [0.0551964 , 0.08321126, 0.08422774],
        [0.00912663, 0.09612338, 0.10262818],
        [0.02354572, 0.0866553 , 0.07790363],
        [0.0339381 , 0.08284483, 0.08277552]]))

In [14]:
CA_COL, TF_COL = 0, 1
print(relevant_cols)

[(56, 'HepG2_DNase_None'), (302, 'HepG2_FOXA1_None'), (303, 'HepG2_FOXA1_None')]


In [15]:
# epochs, batch_size = 50, 400
# preds = mutate_and_predict(
#     deepsea,
#     seqs.squeeze(),
#     epochs,
#     batch_size,
#     output_sel_fn=filter_predictions_to_matching_cols(relevant_cols),
# )

In [16]:
pickle_file = '50_random_seqs__new_code_comparison__all_50.pickle'
# with open(pickle_file, 'wb') as f: pickle.dump(preds, f)
with open(pickle_file, 'rb') as f: preds = pickle.load(f)    

In [17]:
n_seqs, n_nts, seq_len = seqs.squeeze().shape
for i, (_, col_name) in enumerate(relevant_cols):
    preds[:, :, :, i] = deepsea_normalizers[col_name](preds[:, :, :, i])

In [18]:
seqs = seqs.squeeze()
means, mean_diffs, stderrs = compute_summary_statistics(preds, seqs)
means.shape, mean_diffs.shape, stderrs.shape

((50, 4, 1000, 3), (50, 3, 1000, 3), (50, 3, 1000, 3))

In [19]:
mean_diffs[5, 0:1, :, :]

array([[[-1.7794337e-02, -2.0613831e-02, -2.3787150e-02],
        [ 1.7941456e-02,  2.0779006e-02,  2.3977324e-02],
        [-2.8421543e-06,  8.3847717e-06,  8.1231628e-06],
        ...,
        [-3.6705895e-03, -4.6688779e-03, -5.6366897e-03],
        [-2.9608670e-03, -4.6098167e-03, -5.5397470e-03],
        [-3.5773423e-03, -4.8489608e-03, -5.8090449e-03]]], dtype=float32)

In [20]:
stderrs[0]

array([[[2.62715729e-05, 8.27754122e-05, 1.00361242e-04],
        [2.57414295e-03, 9.36752952e-03, 1.13420430e-02],
        [2.57570270e-03, 9.38007443e-03, 1.13569942e-02],
        ...,
        [7.16946828e-03, 3.18609004e-02, 3.94315857e-02],
        [6.90317177e-03, 3.02027463e-02, 3.73925598e-02],
        [7.16854883e-03, 3.18602334e-02, 3.94305002e-02]],

       [[6.33409948e-03, 2.36049141e-02, 2.93885944e-02],
        [6.28016033e-03, 2.34820330e-02, 2.93093734e-02],
        [6.22728901e-03, 2.33578817e-02, 2.91674253e-02],
        ...,
        [7.90300704e-03, 3.56890154e-02, 4.39618763e-02],
        [7.93364198e-03, 3.63152302e-02, 4.47840597e-02],
        [7.84416806e-03, 3.54790183e-02, 4.38411898e-02]],

       [[6.74883461e-03, 2.91956653e-02, 3.59537360e-02],
        [6.48573313e-03, 2.86416257e-02, 3.51736459e-02],
        [6.60402711e-03, 2.87510889e-02, 3.53865194e-02],
        ...,
        [4.97977249e-06, 1.15375907e-05, 2.29673400e-05],
        [4.01524721e-06, 1.05

In [21]:
epochs, n_seqs = 50, 50

epoch_seq_idxs = seqs[np.newaxis, :].repeat(epochs, axis=0).astype(np.bool)
ref_preds = preds[epoch_seq_idxs].reshape(epochs, n_seqs, 1, seq_len, -1)
mut_preds = preds[~epoch_seq_idxs].reshape(epochs, n_seqs, n_nts - 1, seq_len, -1)

seq_idxs = seqs.astype(np.bool)
ref_means = means[seq_idxs].reshape(n_seqs, 1, seq_len, -1)
mut_means = means[~seq_idxs].reshape(n_seqs, n_nts - 1, seq_len, -1)

In [22]:
with open('../dat/most_recent_sat_mut_results__comparison.pickle', 'rb') as f: original_preds = pickle.load(f)    

In [23]:
original_preds.shape

(50, 25, 10, 301, 3)

In [24]:
preds.shape

(50, 50, 4, 1000, 3)

In [28]:
n_original_seqs = 25
n_batches = original_preds.shape[2]
original_ref_preds = original_preds[:, :, :, 0:1, :]
origina_ref_preds = original_ref_preds.reshape(epochs, n_original_seqs, -1, len(relevant_cols))
original_ref_preds = original_ref_preds.repeat(seq_len // n_batches, axis=2)
original_ref_preds = np.expand_dims(original_ref_preds, axis=2)
original_mut_preds = original_preds[:, :, :, 1:, :].reshape(
    epochs, n_original_seqs, n_nts-1, seq_len, len(relevant_cols)
)

original_ref_means = np.mean(original_ref_preds, axis=0)
original_mut_means = np.mean(original_mut_preds, axis=0)

original_ref_means.shape, original_mut_means.shape

((25, 1, 1000, 1, 3), (25, 3, 1000, 3))

In [30]:
epoch_seq_idx
epoch_seq_pred_idxs = np.expand_dims(, axis=-1).repeat(preds.shape[-1], axis=-1)
masked_original_ref_preds = np.expand_dims(original_ref_preds, axis=2).repeat(n_nts, axis=2)
mask = np.zeros_like(masked_original_ref_preds)

masked_original_preds = np.where(epoch_seq_pred_idxs, masked_original_ref_preds, mask)
masked_original_preds[np.where(epoch_seq_pred_idxs == 0)] = original_mut_preds.reshape(-1)

ValueError: operands could not be broadcast together with shapes (50,25,4,1000,3) (50,25,4,1,1000,1,3) (50,25,4,1,1000,1,3) 

In [None]:
original_means, original_mean_diffs, original_stderrs = compute_summary_statistics(masked_original_preds, seqs)

In [None]:
means.shape

In [None]:
print(np.std(means - original_means))
plt.scatter(means.flatten(), original_means.flatten())

In [None]:
import sklearn
from sklearn.metrics import auc, brier_score_loss, roc_auc_score, roc_curve
from sklearn.calibration import calibration_curve 

score = roc_auc_score(
    np.concatenate((np.ones(25), np.zeros(25))),
    ref_means[:, 0, 0, 1],
)
score

In [None]:
fpr, tpr, thresholds = roc_curve(
    np.concatenate((np.ones(25), np.zeros(25))),
    ref_means[:, 0, 0, 1],
)
roc_auc = auc(fpr, tpr)
plt.figure()
lw = 2
plt.plot(fpr, tpr, color='darkorange',
         lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic (TF binding)')
plt.legend(loc="lower right")
plt.show()

In [None]:
y_test = np.concatenate((np.ones(25), np.zeros(25)))
prob_pos = ref_means[:, 0, 0, 1]
fraction_of_positives, mean_predicted_value = calibration_curve(
    y_test,
    prob_pos,
    n_bins=10
)
clf_score = brier_score_loss(y_test, prob_pos)
plt.plot(mean_predicted_value, fraction_of_positives, "s-",
     label="%s (%1.3f)" % ("MC dropout predictive means (TF)", clf_score))
plt.ylabel("Fraction of positives")
# ax2.hist(prob_pos, range=(0, 1), bins=10, label=name,
#          histtype="step", lw=2)
prob_pos

In [None]:
plt.hist(prob_pos, range=(0, 1), bins=10, label="Predictive means (TF)",
          histtype="step", lw=2)
plt.hist(ref_means[:, 0, 0, 0], range=(0, 1), bins=10, label="Predictive means (CA)",
          histtype="step", lw=2)
plt.legend()

In [None]:
y_test = np.concatenate((np.ones(25), np.zeros(25)))
prob_pos = ref_means[:, 0, 0, 0]
fraction_of_positives, mean_predicted_value = calibration_curve(
    y_test,
    prob_pos,
    n_bins=10
)
clf_score = brier_score_loss(y_test, prob_pos)
plt.plot(mean_predicted_value, fraction_of_positives, "s-",
     label="%s (%1.3f)" % ("MC dropout predictive means (CA)", clf_score))
plt.ylabel("Fraction of positives")

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 20))
plt.suptitle("Predictive Means (Ref)")

ax1.hist(
    (ref_means[:25, 0, 0, TF_COL].reshape(-1),
     ref_means[25:, 0, 0, TF_COL].reshape(-1)),
    label=("binding", "no binding"))
ax1.set_title(f"TF {relevant_cols[1][1]}")
ax1.legend()
ax2.hist(
    (ref_means[:25, 0, 0, CA_COL].reshape(-1),
     ref_means[25:, 0, 0, CA_COL].reshape(-1)),
    label=("accessible", "not accessible"))
ax2.set_title(f"DNase {relevant_cols[0][1]}")
ax2.legend()

fig.text(0.5, 0.08, "Predictive Means for y=0 vs y=1", ha="center")
fig.text(0.07, 0.5, "# of Seqs", va='center', rotation='vertical')

plt.show();

### Standard Error Calibration

In [None]:
from scipy import stats

sample_seq = 4
sample_pred_diffs = (mut_preds[:, sample_seq, :, :, :] - ref_preds[:, sample_seq, :, :, :]).reshape(epochs, -1)
sample_std_errs = stderrs[sample_seq, :, :, :].reshape(-1)
sample_mean_diffs = mean_diffs[sample_seq, :, :, :].reshape(-1)
normalized_preds = (sample_pred_diffs - sample_mean_diffs) / sample_std_errs

res = stats.probplot(normalized_preds[:, :].reshape(-1), plot=plt)

## TF-CA Relationship and Mutation Effect Exploration 

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 10))
plt.suptitle("Predictive Mean Diffs (Mut vs. Ref)")

ax1.hist(
    (mean_diffs[:25, :, :, 1].reshape(-1),
     mean_diffs[25:, :, :, 1].reshape(-1)),
    label=("binding", "no binding"),
    log=True)
ax1.set_title(f"TF {relevant_cols[1][1]}")
ax1.legend()
ax2.hist(
    (mean_diffs[:25, :, :, 0].reshape(-1),
     mean_diffs[25:, :, :, 0].reshape(-1)),
    label=("binding", "no binding"),
    log=True)
ax2.set_title(f"DNase {relevant_cols[2][1]}")
ax2.legend()

fig.text(0.5, 0.08, "Predictive Mean Diff", ha="center")
fig.text(0.07, 0.5, "# of Seqs", va='center', rotation='vertical')


plt.show();

In [None]:
TF_COL = 1
CA_COL = 0

fig, ((ax11, ax12), (ax21, ax22)) = plt.subplots(2, 2, figsize=(14, 10))
plt.suptitle("S.E. vs. Predictive Mean Diff (TF)")

for seq in range(n_seqs):
    ax11.scatter(
        mean_diffs[seq, :, :, TF_COL].reshape(-1), 
        stderrs[seq, :, :, TF_COL].reshape(-1))
    ax11.set_title(f"{relevant_cols[1][1]} (binding)")

for seq in range(n_seqs // 2, n_seqs):
    ax12.scatter(
        mean_diffs[seq, :, :, TF_COL].reshape(-1), 
        stderrs[seq, :, :, TF_COL].reshape(-1))
    ax12.set_title(f"{relevant_cols[1][1]} (no binding)")
    
for seq in range(n_seqs):
    ax21.scatter(
        mean_diffs[seq, :, :, TF_COL].reshape(-1), 
        stderrs[seq, :, :, TF_COL].reshape(-1))
    ax21.set_title(f"{relevant_cols[0][1]} (binding)")

for seq in range(n_seqs // 2, n_seqs):
    ax22.scatter(
        mean_diffs[seq, :, :, TF_COL].reshape(-1), 
        stderrs[seq, :, :, TF_COL].reshape(-1))
    ax22.set_title(f"{relevant_cols[0][1]} (no binding)")
    
fig.text(0.5, 0.08, "Predictive Mean Diff", ha="center")
fig.text(0.07, 0.5, "Predictive S.E.", va='center', rotation='vertical')

plt.show()

In [None]:
n_binding_seqs = 25
_, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 10))
plt.suptitle("CA predictive mean diff vs. TF predictive mean diff")
for seq in range(n_binding_seqs):
    title, ax = ("(binding)", ax1) if seq < n_binding_seqs else ("(no binding)", ax1)
    ax1.scatter(
        mean_diffs[seq, :, :, TF_COL].reshape(-1), 
        mean_diffs[seq, :, :, CA_COL].reshape(-1), 
        label=seq)
    if seq < n_binding_seqs: ax1.set_title("(binding)")
    ax1.legend()

for seq in range(n_binding_seqs, n_seqs):
    ax2.scatter(
        mean_diffs[seq, :, :, TF_COL].reshape(-1), 
        mean_diffs[seq, :, :, CA_COL].reshape(-1), 
        label=seq)
    ax2.set_title("(no binding)")
    ax2.legend()
fig.text(0.5, 0.08, "TF Predictive Mean Diff", ha="center")
fig.text(0.07, 0.5, "CA Predictive Mean Diff", va='center', rotation='vertical')

plt.show()

In [None]:
from matplotlib import patches
from matplotlib import cm

cols, margin = 3, 10 # margin determined empirically
fig, axs = plt.subplots(math.ceil(n_seqs / float(cols)), cols, figsize=(16, n_seqs + margin))


for i in range(n_seqs):
    sample_mut_preds = original_mut_preds[:, i, :, :, TF_COL]
    sample_ref_preds = original_ref_preds[:, i, :, :, TF_COL].repeat(3, axis=1)
    colors = (
        (sample_mut_preds.ravel() - sample_mut_preds.mean())**2 + 
        (sample_ref_preds.ravel() - sample_ref_preds.mean())**2
    )
    ax = axs[i // cols, i % cols]
    ax.hexbin(
        sample_ref_preds.ravel(), 
        sample_mut_preds.ravel(), 
        C=colors,
        cmap=cm.jet,
        bins=None,
    )
    xq1, xq2 = np.quantile(sample_ref_preds, (.25, .75))
    yq1, yq2 = np.quantile(sample_mut_preds, (.25, .75))
    rect = patches.Rectangle((xq1, yq1), xq2 - xq1, yq2 - yq1, fill=False, edgecolor='black')
    rect = ax.add_patch(rect)
    xlabel = "seq {i} ref (std: {stddev:.3f})".format(i=i, stddev=np.std(sample_ref_preds))
    ylabel = "seq {i} mut (std: {stddev:.3f})".format(i=i, stddev=np.std(sample_mut_preds))
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)

In [None]:
import scipy
from scipy import stats
import statsmodels.api as sm
from statsmodels.sandbox.regression.predstd import wls_prediction_std



cols, margin = 3, 20 # margin determined empirically
fig, axs = plt.subplots(math.ceil(n_seqs / float(cols)), cols, figsize=(16, n_seqs + margin))
# plt.suptitle("CA predictive mean diff vs. TF predictive mean diff")
slopes = []
rsquareds = []

n_binding_seqs = 25
for seq in range(n_seqs):
    ax = axs[seq // cols, seq % cols]
    
    x, y = mean_diffs[seq, :, :, TF_COL].reshape(-1), mean_diffs[seq, :, :, CA_COL].reshape(-1)
    xc = sm.add_constant(x)
    model = sm.OLS(y, xc)
    result = model.fit()
    intercept, slope = result.params
    slopes.append(slope)
    rsquared = result.rsquared
    rsquareds.append(rsquared)
    stderr = result.bse[1]
    
    title = "%d - " % (seq)
    title += " (b) " if seq < n_binding_seqs else " (nb)"
    title += "(slope: %.2f, r^2: %.3f, std: %.4f)" % (slope, rsquared, stderr)
    ax.set_title(title)
    
    line = slope*x+intercept
    prstd, iv_l, iv_u = wls_prediction_std(result)
    ax.plot(x, line, 'r')
    ax.plot(x, y, 'o')
    ax.plot(x, iv_u, 'r--')
    ax.plot(x, iv_l, 'r--')
    legend = ax.legend(loc="best")

    ax1.legend()

    
plt.show()

In [None]:
print(sum(slopes) / len(slopes))
print(sum(rsquareds) / len(rsquareds))

In [None]:
IDX_TO_NT = 'ACGT'

def _convert_to_mutation(pos_nt_pair):
    return "%d%s" % (pos_nt_pair[0], IDX_TO_NT[pos_nt_pair[1]])


TF_COL = 1
CA_COL = 0

def _write_row(writer, seq_idx, x_eff_size, y_eff_size, x_stderr, y_stderr):
    writer.writerow(
        {
            "seq_num": seq_idx + 1,
            "X_pred_mean": x_eff_size,
            "X_pred_var": x_stderr,
            "Y_pred_mean": y_eff_size,
            "Y_pred_var": y_stderr,
        }
    )
        

with open("../dat/means_and_uncertainties_new_code_2.csv", 'w', newline="") as out_file:
    fieldnames = [
        "seq_num",
        "mut",
        "X_pred_mean",
        "X_pred_var",
        "Y_pred_mean",
        "Y_pred_var",
    ]
    writer = csv.DictWriter(out_file, delimiter=",", fieldnames=fieldnames)
    writer.writeheader()
    
    for seq_idx in range(n_seqs // 2):
        for seq_pos in range(seq_len):
            for nt_pos in range(n_nts-1):
                x_eff_size = mean_diffs[seq_idx, nt_pos, seq_pos, TF_COL]
                y_eff_size = mean_diffs[seq_idx, nt_pos, seq_pos, CA_COL]
                x_stderr = stderrs[seq_idx, nt_pos, seq_pos, TF_COL]
                y_stderr = stderrs[seq_idx, nt_pos, seq_pos, CA_COL]
                _write_row(writer, seq_idx, x_eff_size, y_eff_size, x_stderr, y_stderr)
 
