In [None]:
import importlib

import numpy as np
import pandas as pd

from matplotlib import pyplot as plt
import seaborn as sns

import tensorflow as tf
from Modules import utils, tf_utils
from Modules.tf_utils import mae_cor, correlate

In [None]:
importlib.reload(utils)
importlib.reload(tf_utils)
prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']

Models

In [None]:
model_pol_name = 'model_myco_pol_17'
# with tf.distribute.MirroredStrategy().scope():
model_pol = tf.keras.models.load_model(f'/home/alex/shared_folder/SCerevisiae/Trainedmodels/{model_pol_name}/model',
                                       custom_objects={'correlate': correlate, 'mae_cor': mae_cor})
model_nuc_name = 'model_myco_nuc_2'
model_nuc = tf.keras.models.load_model(f'/home/alex/shared_folder/SCerevisiae/Trainedmodels/{model_nuc_name}/model',
                                       custom_objects={'correlate': correlate, 'mae_cor': mae_cor})
model_coh_name = 'model_myco_coh_14'
model_coh = tf.keras.models.load_model(f'/home/alex/shared_folder/SCerevisiae/Trainedmodels/{model_coh_name}/model',
                                       custom_objects={'correlate': correlate, 'mae_cor': mae_cor})
model_rna_name = 'weight_CNN_RNA_seq_2001_12_8_4_SRR7131299' # order 'ATGC'
model_rna = tf.keras.models.load_model(f'/home/alex/shared_folder/JB_seqdes/{model_rna_name}.hdf5',
                                       custom_objects={'correlate': correlate, 'mae_cor': mae_cor})
model_rna.input_shape

In [None]:
with np.load('/home/alex/shared_folder/SCerevisiae/genome/W303_Mmmyco.npz') as f:
    one_hot_yeast = {k: f[k] for k in f.keys() if k[:3] == 'chr'}
print(list(one_hot_yeast.keys()))

# kinetic Monte-Carlo

Flanking regions file

In [None]:
# flanks = {'left': [], 'right': [], 'pos': []}
# for k, v in one_hot_yeast.items():
#     pos = np.random.randint(1000, len(v)-1000)
#     window = v[pos - 1000:pos+1000]
#     assert len(window) == 2000
#     assert window.sum() == 2000
#     window = np.argmax(window, axis=-1)
#     flanks['left'].append(window[:1000])
#     flanks['right'].append(window[1000:])
#     flanks['pos'].append(pos)
# for k, v in flanks.items():
#     flanks[k] = np.array(v)
# np.savez('/home/alex/shared_folder/SCerevisiae/genome/W303_Mmmyco_random1kbflanks_ACGTidx.npz', **flanks)

Analysing experiments

In [None]:
from kMC_sequence_design import rmse, GC_energy

def energy_parser_v2(file):
    energies = np.loadtxt(file)
    with open(file, 'r') as f:
        n_seqs = 0
        for line in f:
            if n_seqs != 0 and line.startswith('#'):
                break
            if not line.startswith('#'):
                n_seqs += 1
    return np.transpose(energies.reshape(-1, n_seqs, energies.shape[1]), [1, 0, 2])

Energy

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(20, 5), facecolor='w', layout='tight')#, sharey=True)
temps = ['2e-4', '1e-4', '5e-5']
exp_ids = [f'lowpol_4kb_temp{t}' for t in temps]# + [f'highpol_4kb_temp{t}_mid' for t in temps] # [f'test{n}' for n in range(90, 91)]
exp_labels = exp_ids # [f'4kb_temp{t}_mid' for t in temps]
for exp_id, exp_lab, color in zip(exp_ids, exp_labels, colors):
    # if exp_id.startswith('highpol') and exp_id[-8:-4] in ['1e-4', '5e-5']:
    #     continue
    energies = energy_parser_v2(f'/home/alex/shared_folder/SCerevisiae/generated/{exp_id}/energy.txt')
    print(energies.shape)
    ax.plot(energies[:2, :, 0].T, label=exp_lab, color=color)
ax.legend()
# ax.set_ylim((-0.001, 2))

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(20, 5), facecolor='w', layout='tight')#, sharey=True)
exp_ids = [f'test{n}' for n in range(49, 57)]
wmuts = [0.01, 0.005, 0.002, 0.001, 0.0005, 0.0001, 0.00005, 0.00001]
exp_labels = [f'5kb_full_temp5e-4_wmut{w}' for w in wmuts]
for exp_id, exp_lab, color in zip(exp_ids, exp_labels, colors):
    energies = energy_parser_v2(f'/home/alex/shared_folder/SCerevisiae/generated/{exp_id}/energy.txt')
    print(energies.shape)
    ax.plot(energies[:1, :, 0].T, label=exp_lab, color=color)
ax.legend()

Probabilities

In [None]:
exp_name = 'lowpol_4kb_temp2e-4'
i = 0
probs = np.load(f'/home/alex/shared_folder/SCerevisiae/generated/{exp_name}/probs/prob_step{i}.npy')
print(probs.shape)
probs = np.sort(probs, axis=1)
fig, axes = plt.subplots(4, 1, figsize=(20, 10), facecolor='w', layout='tight', sharey=True, sharex=True)
fig.suptitle('probability of each mutation during first step with temperature 0.00005')
axes[0].hist(probs[0], bins=100, range=(0, 1), label='probabilities on first seq')
axes[1].hist(probs.ravel(), bins=100, range=(0, 1), label='probabilities on all seqs')
axes[2].hist(probs[:, -1], bins=100, range=(0, 1), label='highest probability on all seqs')
axes[3].hist(probs[:, -2], bins=100, range=(0, 1), label='second highest probability on all seqs')
for ax in axes:
    ax.set_yscale('log')
    ax.set_ylim(bottom=5e-1)
    ax.legend()
plt.show()

In [None]:
exp_name = 'lowpol_4kb_temp2e-4'
probs = []
energies = energy_parser_v2(f'/home/alex/shared_folder/SCerevisiae/generated/{exp_name}/energy.txt')
for i in range(energies.shape[1]):
    probs.append(np.load(f'/home/alex/shared_folder/SCerevisiae/generated/{exp_name}/probs/prob_step{i}.npy'))
probs = np.stack(probs, axis=1)
print(probs.shape)
probs = np.sort(probs, axis=-1)

In [None]:
start = 0
stop = 2
fig, axes = plt.subplots(2, 1, figsize=(20, 7), facecolor='w', layout='tight', sharex=True)
axes[0].plot(energies[start:stop, :, 0].T, label=[f'seq{i}' for i in range(start, stop)])
axes[1].plot(probs[start:stop, :, -1].T, label=[f'seq{i}' for i in range(start, stop)])
for ax in axes:
    ax.legend()
plt.show()

Designed sequences

In [None]:
exp_name = 'lowpol_4kb_temp2e-4'
seqs = [np.load(f'/home/alex/shared_folder/SCerevisiae/generated/{exp_name}/designed_seqs/start_seqs.npy')]
for i in range(500):
    seqs.append(np.load(f'/home/alex/shared_folder/SCerevisiae/generated/{exp_name}/designed_seqs/mut_seqs_step{i}.npy'))
seqs = np.stack(seqs, axis=1)
print(seqs.shape)

In [None]:
# start_seqs = np.load(f'/home/alex/shared_folder/SCerevisiae/generated/test9/start_seqs.npy')
# np.save('/home/alex/shared_folder/SCerevisiae/generated/test9/start_seqs_first1.npy', start_seqs[:1])

In [None]:
with np.load('/home/alex/shared_folder/SCerevisiae/genome/W303_Mmmyco_random1kbflanks_ACGTidx.npz') as f:
    flank_left = f['left']
    flank_right = f['right']

Repredicting

In [None]:
import kMC_sequence_design
importlib.reload(tf_utils)
importlib.reload(kMC_sequence_design)
from kMC_sequence_design import get_profile_hint, get_profile_mid1
from Modules.tf_utils import get_profile

In [None]:
preds = get_profile_hint(seqs, model_pol, 2048, 128, middle=True)
preds_rev = get_profile_hint(seqs, model_pol, 2048, 128, reverse=True, middle=True)
# preds = get_profile_mid1(seqs, model_rna, 2001, one_hot_converter=lambda x: np_idx_to_one_hot(x, order='ATGC', extradims=-1))
# preds_rev = get_profile_mid1(seqs, model_rna, 2001, reverse=True, one_hot_converter=lambda x: np_idx_to_one_hot(x, order='ATGC', extradims=-1))
loss = rmse(np.zeros(preds.shape[:-1] + (1,)), preds)
loss_rev = rmse(np.zeros(preds.shape[:-1] + (1,)), preds_rev)
gc_energy = GC_energy(seqs, 0.3834)

In [None]:
energies = energy_parser_v2(f'/home/alex/shared_folder/SCerevisiae/generated/{exp_name}/energy.txt')
fig, ax = plt.subplots(1, 1, figsize=(20, 5), facecolor='w', layout='tight')
ax.plot(energies[:2, :, 0].T, label=[f'seq{i}' for i in range(2)])
for idx in range(len(seqs)):
    ax.plot(gc_energy[idx, 1:] + loss[idx, 1:] + loss_rev[idx, 1:], label=f'seq{idx}_nostride')
ax.legend()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(20, 5), facecolor='w')
for i, color in zip(range(0, 501, 100), colors):
    ax.plot(preds[0, i, :], label=f'step{i}', color=color)
    ax.plot(preds_rev[0, i, :], color=color)
ax.legend()

Testing predict functions

In [None]:
n = 100
l = 2175
freq_kmers = pd.read_csv('/home/alex/shared_folder/SCerevisiae/genome/W303/W303_3mer_freq.csv', index_col=[0, 1, 2])
np.random.seed(0)
seqs = utils.random_sequences(n, l, freq_kmers.iloc[:, 0], out='idx')

In [None]:
to_predict = seqs[:, ::100]
print(to_predict.shape)
preds = get_profile_hint(to_predict, model_pol, 2048, 128, middle=True, flanks=None)
preds_rev = get_profile_hint(to_predict, model_pol, 2048, 128, middle=True, flanks=None, reverse=True)
preds_flanks = []
preds_flanks_rev = []
for i in range(len(flank_left)):
    flanks=(flank_left[i], flank_right[i])
    preds_flanks.append(get_profile_hint(to_predict, model_pol, 2048, 128, middle=True, flanks=flanks))
    preds_flanks_rev.append(get_profile_hint(to_predict, model_pol, 2048, 128, middle=True, flanks=flanks, reverse=True))
print(preds.shape, preds_rev.shape)
print(preds_flanks[0].shape, preds_flanks_rev[0].shape)

In [None]:
fig, axes = plt.subplots(4, 4, figsize=(20, 12), facecolor='w', layout='tight', sharex=True, sharey=True)
seq_idx = 0
for i in range(0, 6, 5):
    step = i*100
    color = colors[i]
    for ax, pfor, prev in zip(axes.flatten(), preds_flanks, preds_flanks_rev):
        ax.plot(pfor[seq_idx, i, :], color=color, label=f'step{step}')
        ax.plot(-prev[seq_idx, i, :], color=color)
axes[0, 0].legend()

In [None]:
to_predict = seqs[:, ::100]
print(to_predict.shape)
preds_pol = get_profile_hint(to_predict, model_pol, 2048, 128)#, middle=False)
preds_pol_rev = get_profile_hint(to_predict, model_pol, 2048, 128, reverse=True)#, middle=False)
preds_nuc = get_profile(to_predict, model_nuc, 2001)
preds_nuc_rev = get_profile(to_predict, model_nuc, 2001, reverse=True)
preds_rna = get_profile(to_predict, model_rna, 2001, one_hot_converter=lambda x: utils.np_idx_to_one_hot(x, order='ATGC', extradims=-1))
preds_rna_rev = get_profile(to_predict, model_rna, 2001, reverse=True, one_hot_converter=lambda x: utils.np_idx_to_one_hot(x, order='ATGC', extradims=-1))

In [None]:
length = to_predict.shape[-1]
seq_idx = 0
fig, axes = plt.subplots(3, 1, figsize=(20, 9), facecolor='w', layout='tight', sharex=True)
for i in range(0, 6, 5):
    step = i*100
    color = colors[i]
    axes[0].plot(np.arange(128*4, 128*4+preds_pol.shape[-1]), preds_pol[seq_idx, i, :], color=color, label=f'step{step}')
    axes[0].plot(np.arange(length-128*4-preds_pol.shape[-1], length-128*4), -preds_pol_rev[seq_idx, i, :], color=color)
    axes[1].plot(np.arange(1000, length - 1000), preds_nuc[seq_idx, i, :], color=color, label=f'step{step}')
    axes[1].plot(np.arange(1000, length - 1000), -preds_nuc_rev[seq_idx, i, :], color=color)
    axes[2].plot(np.arange(1000, length - 1000), preds_rna[seq_idx, i, :], color=color, label=f'step{step}')
    axes[2].plot(np.arange(1000, length - 1000), -preds_rna_rev[seq_idx, i, :], color=color)
for ax in axes:
    ax.legend()

# Saliency

In [None]:
def get_gradients(model, one_hots, batch_size=1024, predict=False, head_start=0, n_heads=1):
    grads = np.empty(one_hots.shape, dtype='float32')
    n_batches = int(np.ceil((len(one_hots) / batch_size)))
    preds = None
    for i in range(n_batches):
        batch_start, batch_stop = i*batch_size, (i+1)*batch_size
        X = tf.Variable(one_hots[batch_start:batch_stop], dtype=tf.float32)
        with tf.GradientTape() as tape:
            Y = model(X, training=False)[:, head_start:head_start+n_heads]
        grads[batch_start:batch_stop] = np.array(tape.gradient(Y, X))
        if predict:
            Y = np.array(Y).squeeze()
            if i == 0:
                preds = np.empty((len(one_hots),) + Y.shape[1:],
                                 dtype='float32')
            preds[batch_start:batch_stop] = Y
    if predict:
        return grads.squeeze(), preds
    else:
        return grads.squeeze()

In [None]:
with np.load('/home/alex/shared_folder/SCerevisiae/results/model_myco_pol_17/preds_mid_on_W303_Mmmyco.npz') as f:
    preds_pol = {k: f[k] for k in f.keys() if k[:3] == 'chr'}

In [None]:
chr_id = 'chrIV'
start = 38912
stop = 40960
midlen = 128*8
fig, ax = plt.subplots(1, 1, figsize=(20, 5), facecolor='w', layout='tight')
ax.plot(np.arange(start, stop), preds_pol[chr_id][start:stop], label='ref_preds')
ax.plot(np.arange(start + midlen//2, stop - midlen//2), preds_t, label='grad_pred')
sep = start + midlen - (start - midlen//2) % midlen
while sep < stop:
    ax.axvline(sep, color='k', linestyle='--')
    sep += midlen
ax.legend()

In [None]:
winsize = 2048
head_interval = 128
one_hots = utils.strided_sliding_window_view(
    one_hot_yeast[chr_id][start:stop+head_interval-1],
    (winsize, 4),
    winsize//2,
    head_interval).reshape(-1, winsize, 4)
one_hots.shape

In [None]:
grads, preds = get_gradients(model_pol, one_hots, predict=True, head_start=4, n_heads=8)
preds_t = preds.T.ravel()
print(preds.shape)

In [None]:
print(grads.shape)
grads_proj = grads - grads.mean(axis=-1, keepdims=True)
print(grads_proj.shape)
grads_seq = grads[one_hots].reshape(grads.shape[:-1])
print(grads_seq.shape)
grads_proj_seq = grads_proj[one_hots].reshape(grads.shape[:-1])
print(grads_proj_seq.shape)

In [None]:
fig, axes = plt.subplots(2, 1, figsize=(20, 8), facecolor='w', layout='tight')
# sns.heatmap(grads_proj[0].T, ax=ax, cmap='seismic', center=0)
vmin = -2
sns.heatmap(grads.sum(axis=0).T, cmap='seismic', center=0, vmin=vmin, ax=axes[0],# cbar_ax=axes[0, 1],
            yticklabels=list('ACGT'))
sns.heatmap(grads_proj.sum(axis=0).T, cmap='seismic', center=0, vmin=vmin, ax=axes[1],# cbar_ax=axes[0, 1],
            yticklabels=list('ACGT'))