In [67]:
%load_ext autoreload
%autoreload 2
%cd /group/transreg/sathi/DeepDifE 

import pickle
import importlib
import esparto
import optuna
import numpy as np
import pandas as pd
from evoaug_tf import evoaug, augment
from src.diff_expression_model import get_model, get_siamese_model, post_hoc_conjoining, get_auroc
from skopt.utils import use_named_args

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
/storage/nas6/group/biocomp/projects/transreg/sathi/DeepDifE


## Prepare the data

For now we start from this data pickle as I'm not aware how Helder did DE analysis and generated the labels

In [68]:
ppath = "/group/transreg/heopd/dlpipe/results/ath/aba/dlresults/predetermined_dataset/dataset_solid_chrome.pkl"
with open(ppath, 'rb') as f:
    data = pickle.load(f)

In [69]:
data.columns

Index(['Category', 'GeneFamily', 'seqs', 'ohs', 'rcohs', 'ohsDuo',
       'in_original_balanced', 'set', 'npshap-single', 'npshap-posthoc'],
      dtype='object')

In [71]:
data["set"].value_counts()

train    1900
valid     257
test      241
Name: set, dtype: int64

To show how to subdivide the dataset into train-test split we only take the following columns

In [70]:
dataset = data.reset_index()
dataset = dataset[["geneID", "Category", "GeneFamily", "seqs"]]
dataset.rename(columns={"geneID":"GeneID", "Category":"Label", "seqs": "Sequence"}, inplace=True)

In [72]:
dataset

Unnamed: 0,GeneID,Label,GeneFamily,Sequence
0,AT4G27120,0,HOM04D000881,TAGAGAAGACAAGCGGTTATTTCGTAATTTCCCAGCGACTTTGAAA...
1,AT4G19600,0,HOM04D000740,GTCAAGTAGTGAAATCAAGGTGTGAAGTAAGCTGAGGACAGATAAT...
2,AT3G60880,0,HOM04D003119,AGTTGATATTGAATGAAATCTTCATGTTTTTTGATAAATGATTATA...
3,AT5G06960,0,HOM04D000319,CACTTGTCAGATTCTTCTTACCAAATCCATCAACAAATAAGCAAAT...
4,AT1G14890,0,HOM04D000273,TTGATATAACAGATTCAACACTAAAAATGAGTAAAATCTAAAAAAG...
...,...,...,...,...
2393,AT5G64230,1,HOM04D003278,AAGAAAGAAAAACCGTACATAAACACCCATCTGGTATACCATCGTC...
2394,AT5G64780,1,HOM04D002552,TTTTAGAAAGAAGAAGAAGGATTATTGCCTTATTGGTGAAGGGAAG...
2395,AT4G30470,1,HOM04D000082,TATGTACAGTCTCTACATTTTTTCAAATACATTTTTTTCTTTTTCA...
2396,AT3G51895,1,HOM04D000270,TGGTAAATAATTAAATATATAAGAACATTATTCTAAAGCGTTGAAT...


### One-hot-encode & reverse-complement

In [73]:
from src.prepare_dataset import one_hot_encode_series, reverse_complement_series, reverse_complement_sequence
dataset["One_hot_encoded"] = one_hot_encode_series(dataset["Sequence"])

In [74]:
dataset["RC_one_hot_encoded"] = reverse_complement_series(dataset["One_hot_encoded"])

In [75]:
dataset

Unnamed: 0,GeneID,Label,GeneFamily,Sequence,One_hot_encoded,RC_one_hot_encoded
0,AT4G27120,0,HOM04D000881,TAGAGAAGACAAGCGGTTATTTCGTAATTTCCCAGCGACTTTGAAA...,"[[0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0], [1,...","[[1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0,..."
1,AT4G19600,0,HOM04D000740,GTCAAGTAGTGAAATCAAGGTGTGAAGTAAGCTGAGGACAGATAAT...,"[[0, 0, 1, 0], [0, 0, 0, 1], [0, 1, 0, 0], [1,...","[[0, 0, 1, 0], [0, 0, 1, 0], [1, 0, 0, 0], [0,..."
2,AT3G60880,0,HOM04D003119,AGTTGATATTGAATGAAATCTTCATGTTTTTTGATAAATGATTATA...,"[[1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], [0,...","[[0, 0, 0, 1], [0, 0, 1, 0], [1, 0, 0, 0], [0,..."
3,AT5G06960,0,HOM04D000319,CACTTGTCAGATTCTTCTTACCAAATCCATCAACAAATAAGCAAAT...,"[[0, 1, 0, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0,...","[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [1,..."
4,AT1G14890,0,HOM04D000273,TTGATATAACAGATTCAACACTAAAAATGAGTAAAATCTAAAAAAG...,"[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 1, 0], [1,...","[[0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 0, 1], [1,..."
...,...,...,...,...,...,...
2393,AT5G64230,1,HOM04D003278,AAGAAAGAAAAACCGTACATAAACACCCATCTGGTATACCATCGTC...,"[[1, 0, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [1,...","[[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], [0,..."
2394,AT5G64780,1,HOM04D002552,TTTTAGAAAGAAGAAGAAGGATTATTGCCTTATTGGTGAAGGGAAG...,"[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0,...","[[0, 0, 0, 1], [0, 0, 1, 0], [1, 0, 0, 0], [0,..."
2395,AT4G30470,1,HOM04D000082,TATGTACAGTCTCTACATTTTTTCAAATACATTTTTTTCTTTTTCA...,"[[0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 0, 1], [0,...","[[1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0,..."
2396,AT3G51895,1,HOM04D000270,TGGTAAATAATTAAATATATAAGAACATTATTCTAAAGCGTTGAAT...,"[[0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 1, 0], [0,...","[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 1, 0], [1,..."


### Train-test split

In [76]:
from src.prepare_dataset import grouped_shuffle_split
train_df, validation_test_df = grouped_shuffle_split(dataset, dataset["GeneFamily"], 0.2)

In [77]:
validation_df, test_df  = grouped_shuffle_split(validation_test_df, validation_test_df["GeneFamily"], 0.5)

In [78]:
print(f"Length of training set: {train_df.shape[0]}")
print(f"Length of validation set: {validation_df.shape[0]}")
print(f"Length of test set: {test_df.shape[0]}")

Length of training set: 1900
Length of validation set: 257
Length of test set: 241


In [79]:
def get_input_and_labels(df):
	ohe_np = np.stack(df["One_hot_encoded"])
	rc_np = np.stack(df["RC_one_hot_encoded"])

	x = np.append(ohe_np, rc_np, axis=0)
	x = x.astype('float32')
	y = np.append(df["Label"], df["Label"])
	return x, y

In [80]:
x_train, y_train = get_input_and_labels(train_df)
x_validation, y_validation = get_input_and_labels(validation_df)

## Prepare model

As the model uses evo augmentation, a list of possible nucleotide operations needs to be given

In [81]:
augment_list = [
    augment.RandomRC(rc_prob=0.5),
    augment.RandomInsertionBatch(insert_min=0, insert_max=20),
    augment.RandomDeletion(delete_min=0, delete_max=30),
    augment.RandomTranslocationBatch(shift_min=0, shift_max=20),
    augment.RandomMutation(mutate_frac=0.05),
    augment.RandomNoise()
]

Get the shape of the input data

In [82]:
input_shape = train_df["One_hot_encoded"].iloc[0].shape

Initialize the model

In [83]:
model = get_model(input_shape=input_shape, perform_evoaug=False, augment_list=augment_list,learning_rate=0.001)

## Train the model

As the hyperparameter parameters are set, we will do one more run with all possible training data

In [84]:
x_full_train = np.append(x_train, x_validation, axis=0)
y_full_train = np.append(y_train, y_validation)

In [85]:
# early stopping callback
import tensorflow as tf

early_stopping_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss',
											patience=20,
											verbose=1,
											mode='min',
											restore_best_weights=True)
# reduce learning rate callback
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss',
												factor=0.1,
												patience=5,
												min_lr=1e-7,
												mode='min',
												verbose=1)
callbacks = [early_stopping_callback, reduce_lr]

In [86]:
history = model.fit(x_train,
					y_train,
					epochs=100,
					batch_size=100,
					validation_data=(x_validation, y_validation),
					callbacks=callbacks
					)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 15: ReduceLROnPlateau reducing learning rate to 0.00010000000474974513.
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 20: ReduceLROnPlateau reducing learning rate to 1.0000000474974514e-05.
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 25: ReduceLROnPlateau reducing learning rate to 1.0000000656873453e-06.
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100

Epoch 30: ReduceLROnPlateau reducing learning rate to 1.0000001111620805e-07.
Epoch 30: early stopping


In [90]:
min(history.history["val_loss"])

0.6553439497947693

In [87]:
siamese_model = get_siamese_model(model)

x_test = np.stack(test_df["One_hot_encoded"])
x_test_rc = np.stack(test_df["RC_one_hot_encoded"])

y_test = test_df["Label"].to_numpy()


predictions_categories, predictions = post_hoc_conjoining(siamese_model, x_test, x_test_rc)

get_auroc(y_test, predictions)



0.7106722106722106

In [88]:
from sklearn.metrics import classification_report
print(classification_report(y_test, predictions_categories))

              precision    recall  f1-score   support

           0       0.62      0.58      0.60       111
           1       0.66      0.69      0.67       130

    accuracy                           0.64       241
   macro avg       0.64      0.63      0.63       241
weighted avg       0.64      0.64      0.64       241

