# Library loading

In [None]:
import torch
import torch.utils.data

import pandas as pd

import matplotlib
import matplotlib.pyplot as plt

from network.data_utils import clean_data, process_data, encode_seqs, SingleTensorDataset
from network.nn import train_nn, make_predictions

%matplotlib inline
%load_ext autoreload
%autoreload 2

# Reading in data

In [None]:
# Read the Cbf1 rep1 dataset
cbf1_rep1 = clean_data('data/cbf1_rep1_counts.txt')

In [None]:
# Read the Cbf1 rep2 dataset
cbf1_rep2 = clean_data('data/cbf1_rep2_counts.txt')

In [None]:
# Read the Cbf1 rep3 dataset
cbf1_rep3 = clean_data('data/cbf1_rep3_counts.txt')

In [None]:
# Read Pho4 rep 1 dataset
pho4_rep1 = clean_data('data/pho4_rep1_counts.txt')

In [None]:
# Read Pho4 rep 2 dataset
pho4_rep2 = clean_data('data/pho4_rep2_counts.txt')

In [None]:
# Read Pho4 rep 3 dataset
pho4_rep3 = clean_data('data/pho4_rep3_counts.txt')

In [None]:
# Read Pho4 rep 4 dataset
pho4_rep4 = clean_data('data/pho4_rep4_counts.txt')

In [None]:
# Combine all replicates into a master dataset
cbf1 = pd.concat([cbf1_rep1, cbf1_rep2, cbf1_rep3])
pho4 = pd.concat([pho4_rep1, pho4_rep2, pho4_rep3, pho4_rep4])

## Create training dataset loaders

In [None]:
# Load Cbf1 rep1 data
cbf1_rep1_train, cbf1_rep1_val, cbf1_rep1_test = process_data(cbf1_rep1, 256,
                                                              "./data/train_seqs.txt",
                                                              "./data/val_seqs.txt",
                                                              "./data/test_seqs.txt")

In [None]:
# Load Cbf1 rep2 data
cbf1_rep2_train, cbf1_rep2_val, cbf1_rep2_test = process_data(cbf1_rep2, 256,
                                                              "./data/train_seqs.txt",
                                                              "./data/val_seqs.txt",
                                                              "./data/test_seqs.txt")

In [None]:
# Load Cbf1 rep3 data
cbf1_rep3_train, cbf1_rep3_val, cbf1_rep3_test = process_data(cbf1_rep3, 256,
                                                              "./data/train_seqs.txt",
                                                              "./data/val_seqs.txt",
                                                              "./data/test_seqs.txt")

In [None]:
# Load Pho4 rep 1 data
pho4_rep1_train, pho4_rep1_val, pho4_rep1_test = process_data(pho4_rep1, 256,
                                                              "./data/train_seqs.txt",
                                                              "./data/val_seqs.txt",
                                                              "./data/test_seqs.txt")

In [None]:
# Load Pho4 rep 2 data
pho4_rep2_train, pho4_rep2_val, pho4_rep2_test = process_data(pho4_rep2, 256,
                                                              "./data/train_seqs.txt",
                                                              "./data/val_seqs.txt",
                                                              "./data/test_seqs.txt")

In [None]:
# Load Pho4 rep 3 data
pho4_rep3_train, pho4_rep3_val, pho4_rep3_test = process_data(pho4_rep3, 256,
                                                              "./data/train_seqs.txt",
                                                              "./data/val_seqs.txt",
                                                              "./data/test_seqs.txt")

In [None]:
# Load Pho4 rep 4 data
pho4_rep4_train, pho4_rep4_val, pho4_rep4_test = process_data(pho4_rep4, 256,
                                                              "./data/train_seqs.txt",
                                                              "./data/val_seqs.txt",
                                                              "./data/test_seqs.txt")

In [None]:
# Load complete Cbf1 data
cbf1_train, cbf1_val, cbf1_test = process_data(cbf1, 256,
                                               "./data/train_seqs.txt",
                                               "./data/val_seqs.txt",
                                               "./data/test_seqs.txt")

In [None]:
# Load complete Pho4 data
pho4_train, pho4_val, pho4_test = process_data(pho4, 256,
                                               "./data/train_seqs.txt",
                                               "./data/val_seqs.txt",
                                               "./data/test_seqs.txt")

## Create complete dataset loaders

In [None]:
# Open the file containing all possible 10-mer sequences and read in
with open('data/predict_input.txt', 'r') as predict_file:
    seqs = predict_file.readlines()
    seqs = [seq.strip() for seq in seqs]
    all_sequences = seqs
    
# Encode all sequences and construct the DataLoader
seqs = encode_seqs(seqs)
all_seqs_dataset = SingleTensorDataset(seqs)
prediction_loader = torch.utils.data.DataLoader(all_seqs_dataset,
                                                batch_size=256,
                                                shuffle=False)

# Training individual models

In [None]:
cbf1_rep1_model = train_nn(cbf1_rep1_train, cbf1_rep1_val, prediction_loader,
                           './models/cbf1_rep1_model.acc', lr=1e-3, max_epochs=100)
torch.save(cbf1_rep1_model.state_dict(), './models/cbf1_rep1_model_params.torch')

In [None]:
cbf1_rep2_model = train_nn(cbf1_rep2_train, cbf1_rep2_val, prediction_loader,
                           './models/cbf1_rep2_model.acc', lr=1e-3, max_epochs=100)
torch.save(cbf1_rep2_model.state_dict(), './models/cbf1_rep2_model_params.torch')

In [None]:
cbf1_rep3_model = train_nn(cbf1_rep3_train, cbf1_rep3_val, prediction_loader,
                           './models/cbf1_rep3_model.acc', lr=1e-3, max_epochs=100)
torch.save(cbf1_rep3_model.state_dict(), './models/cbf1_rep3_model_params.torch')

In [None]:
pho4_rep1_model = train_nn(pho4_rep1_train, pho4_rep1_val, prediction_loader,
                           './models/pho4_rep1_model.acc', lr=1e-3, max_epochs=100)
torch.save(pho4_rep1_model.state_dict(), './models/pho4_rep1_model_params.torch')

In [None]:
pho4_rep2_model = train_nn(pho4_rep2_train, pho4_rep2_val, prediction_loader,
                           './models/pho4_rep2_model.acc', lr=1e-3, max_epochs=100)
torch.save(pho4_rep2_model.state_dict(), './models/pho4_rep2_model_params.torch')

In [None]:
pho4_rep3_model = train_nn(pho4_rep3_train, pho4_rep3_val, prediction_loader,
                           './models/pho4_rep3_model.acc', lr=1e-3, max_epochs=100)
torch.save(pho4_rep3_model.state_dict(), './models/pho4_rep3_model_params.torch')

In [None]:
pho4_rep4_model = train_nn(pho4_rep4_train, pho4_rep4_val, prediction_loader,
                           './models/pho4_rep4_model.acc', lr=1e-3, max_epochs=100)
torch.save(pho4_rep4_model.state_dict(), './models/pho4_rep4_model_params.torch')

# Training composite models

In [None]:
cbf1_model = train_nn(cbf1_train, cbf1_val, prediction_loader,
                      './models/cbf1_model.acc', lr=1e-3, max_epochs=100)
torch.save(cbf1_model.state_dict(), './models/cbf1_model_params.torch')

In [None]:
pho4_model = train_nn(pho4_train, pho4_val, prediction_loader,
                      './models/pho4_model.acc', lr = 1e-3, max_epochs = 100)
torch.save(pho4_model.state_dict(), './models/pho4_model_params.torch')

# Generating full prediction sets

In [None]:
cbf1_rep1_preds = make_predictions('./models/cbf1_rep1_model_params.torch', prediction_loader)
plt.hist(cbf1_rep1_preds, 30, alpha=0.75)
plt.xlabel(r'$\Delta \Delta G$')
plt.ylabel('Count')
plt.show()

In [None]:
cbf1_rep2_preds = make_predictions('./models/cbf1_rep2_model_params.torch', prediction_loader)
plt.hist(cbf1_rep2_preds, 30, alpha=0.75)
plt.xlabel(r'$\Delta \Delta G$')
plt.ylabel('Count')
plt.show()

In [None]:
cbf1_rep3_preds = make_predictions('./models/cbf1_rep3_model_params.torch', prediction_loader)
plt.hist(cbf1_rep3_preds, 30, alpha=0.75)
plt.xlabel(r'$\Delta \Delta G$')
plt.ylabel('Count')
plt.show()

In [None]:
pho4_rep1_preds = make_predictions('./models/pho4_rep1_model_params.torch', prediction_loader)
plt.hist(pho4_rep1_preds, 30, alpha=0.75)
plt.xlabel(r'$\Delta \Delta G$')
plt.ylabel('Count')
plt.show()

In [None]:
pho4_rep2_preds = make_predictions('./models/pho4_rep2_model_params.torch', prediction_loader)
plt.hist(pho4_rep2_preds, 30, alpha=0.75)
plt.xlabel(r'$\Delta \Delta G$')
plt.ylabel('Count')
plt.show()

In [None]:
pho4_rep3_preds = make_predictions('./models/pho4_rep3_model_params.torch', prediction_loader)
plt.hist(pho4_rep3_preds, 30, alpha=0.75)
plt.xlabel(r'$\Delta \Delta G$')
plt.ylabel('Count')
plt.show()

In [None]:
pho4_rep4_preds = make_predictions('./models/pho4_rep4_model_params.torch', prediction_loader)
plt.hist(pho4_rep4_preds, 30, alpha=0.75)
plt.xlabel(r'$\Delta \Delta G$')
plt.ylabel('Count')
plt.show()

In [None]:
cbf1_preds = make_predictions('./models/cbf1_model_params.torch', prediction_loader)
plt.hist(cbf1_preds, 30, alpha=0.75)
plt.xlabel(r'$\Delta \Delta G$')
plt.ylabel('Count')
plt.show()

In [None]:
pho4_preds = make_predictions('./models/pho4_model_params.torch', prediction_loader)
plt.hist(pho4_preds, 30, alpha=0.75)
plt.xlabel(r'$\Delta \Delta G$')
plt.ylabel('Count')
plt.show()

# Plot the predicted values

In [None]:
fig, axs = plt.subplots(2, 6, sharey=True, figsize=(7, 4))
fig.subplots_adjust(hspace=0.9, left=0.07, right=2.3)
ax = axs[0, 0]
hb = ax.hexbin(pho4_rep1_preds, pho4_rep2_preds, gridsize=50, cmap='inferno')
ax.set_title('Pho4 Rep 1 vs\nPho4 Rep 2')
cb = fig.colorbar(hb, ax=ax)
cb.set_label('counts')

ax = axs[0, 1]
hb = ax.hexbin(pho4_rep1_preds, pho4_rep3_preds, gridsize=50, cmap='inferno')
ax.set_title('Pho4 Rep 1 vs\nPho4 Rep 3')
cb = fig.colorbar(hb, ax=ax)
cb.set_label('counts')

ax = axs[0, 2]
hb = ax.hexbin(pho4_rep1_preds, pho4_rep4_preds, gridsize=50, cmap='inferno')
ax.set_title('Pho4 Rep 1 vs\nPho4 Rep 4')
cb = fig.colorbar(hb, ax=ax)
cb.set_label('counts')

ax = axs[0, 3]
hb = ax.hexbin(pho4_rep2_preds, pho4_rep3_preds, gridsize=50, cmap='inferno')
ax.set_title('Pho4 Rep 2 vs\nPho4 Rep 3')
cb = fig.colorbar(hb, ax=ax)
cb.set_label('counts')

ax = axs[0, 4]
hb = ax.hexbin(pho4_rep2_preds, pho4_rep4_preds, gridsize=50, cmap='inferno')
ax.set_title('Pho4 Rep 2 vs\nPho4 Rep 4')
cb = fig.colorbar(hb, ax=ax)
cb.set_label('counts')

ax = axs[0, 5]
hb = ax.hexbin(pho4_rep3_preds, pho4_rep4_preds, gridsize=50, cmap='inferno')
ax.set_title('Pho4 Rep 3 vs\nPho4 Rep 4')
cb = fig.colorbar(hb, ax=ax)
cb.set_label('counts')

ax = axs[1, 0]
hb = ax.hexbin(cbf1_rep1_preds, cbf1_rep2_preds, gridsize=50, cmap='inferno')
ax.set_title('Cbf1 Rep 1 vs\nCbf1 Rep 2')
cb = fig.colorbar(hb, ax=ax)
cb.set_label('counts')

ax = axs[1, 1]
hb = ax.hexbin(cbf1_rep1_preds, cbf1_rep3_preds, gridsize=50, cmap='inferno')
ax.set_title('Cbf1 Rep 1 vs\nCbf1 Rep 3')
cb = fig.colorbar(hb, ax=ax)
cb.set_label('counts')

ax = axs[1, 2]
hb = ax.hexbin(cbf1_rep2_preds, cbf1_rep3_preds, gridsize=50, cmap='inferno')
ax.set_title('Cbf1 Rep 2 vs\nCbf1 Rep 3')
cb = fig.colorbar(hb, ax=ax)
cb.set_label('counts')

plt.show()

# Generate output dataset for all models

In [None]:
# Generate pandas dataframe with column names
outdf = pd.DataFrame(columns=['flank',
                              'Cbf1_rep1_ddG',
                              'Cbf1_rep2_ddG',
                              'Cbf1_rep3_ddG',
                              'Cbf1_ddG',
                              'Pho4_rep1_ddG',
                              'Pho4_rep2_ddG',
                              'Pho4_rep3_ddG',
                              'Pho4_rep4_ddG',
                              'Pho4_ddG'])

# Fill in data
outdf['flank'] = all_sequences
outdf['Cbf1_rep1_ddG'] = cbf1_rep1_preds
outdf['Cbf1_rep2_ddG'] = cbf1_rep2_preds
outdf['Cbf1_rep3_ddG'] = cbf1_rep3_preds
outdf['Cbf1_ddG'] = cbf1_preds
outdf['Pho4_rep1_ddG'] = pho4_rep1_preds
outdf['Pho4_rep2_ddG'] = pho4_rep2_preds
outdf['Pho4_rep3_ddG'] = pho4_rep3_preds
outdf['Pho4_rep4_ddG'] = pho4_rep4_preds
outdf['Pho4_ddG'] = pho4_preds

# Save output
outdf.to_csv('./results/all_predicted_ddGs.csv', index=False)

# Output only composite models

In [None]:
# Generate pandas dataframe with column names
outdf = pd.DataFrame(columns=['flank',
                              'Cbf1_ddG',
                              'Pho4_ddG'])

# Fill in data
outdf['flank'] = all_sequences
outdf['Cbf1_ddG'] = cbf1_preds
outdf['Pho4_ddG'] = pho4_preds

# Save output
outdf.to_csv('./results/all_predicted_ddGs.csv', index=False)