<a href="https://colab.research.google.com/github/ViennaRNA/RNAdeep/blob/dev_data/notebooks/SpotrnaPaddedData.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Environment setup

In [None]:
# We are assuming the Python environment is 3.7.
# otherwise you have to adjust the miniconda version below:
import sys
sys.version

In [None]:
# install miniconda
#!wget -qO ac.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
!wget -qO ac.sh https://repo.anaconda.com/miniconda/Miniconda3-py37_4.11.0-Linux-x86_64.sh # Python 3.7
!bash ./ac.sh -b -f -p /usr/local/
!rm ac.sh
!conda update -y conda

In [None]:
# install viennarna, etc.
!conda config --add channels bioconda
!conda config --add channels conda-forge
!conda install -y  viennarna

In [None]:
# Clone the RNAdeep repository
!git clone https://github.com/ViennaRNA/RNAdeep.git RNAdeep
%cd RNAdeep
!git checkout -b dev_data
!git pull origin dev_data
!python ./setup.py install
%cd ..

In [None]:
# Test of RNA / RNAdeep import
import sys
sys.path.append("/usr/local/lib/python3.7/site-packages") 
print(f"Python version {sys.version}")

import RNA
print(f"RNA version {RNA.__version__}")

# Somewhat annoying that we need the "RNAdeep" part here, I acutally don't know why.
import RNAdeep.rnadeep as rnadeep
print(f"rnadeep version {rnadeep.__version__}")

# Let's generate some data (RNAdeep/examples/generate_data.py)

In [None]:
# Generate the data, or use the existing data from the repository?
import os
from RNAdeep.rnadeep.sampling import write_uniform_len_data_file

datadir = "newdata/"
if not os.path.exists(datadir):
  os.mkdir(datadir)

fname = write_uniform_len_data_file(25, 100, num = 10_000, root = datadir)
print(f'Wrote file: {fname}')

# Let's do some training (RNAdeep/examples/spotrna_padded.py)

In [None]:
#
# Training with padded data 
#
import os
from tensorflow.keras.callbacks import CSVLogger, ModelCheckpoint

from RNAdeep.rnadeep.spotrna import spotrna
from RNAdeep.rnadeep.metrics import mcc, f1, sensitivity
from RNAdeep.rnadeep.data_generators import PaddedMatrixEncoding
from RNAdeep.rnadeep.sampling import draw_sets

In [None]:
#
# Get the data for analysis
#
train, valid, tests = list(draw_sets(fname, splits = [0.8, 0.1, 0.1]))
[train_tags, train_seqs, train_dbrs] = zip(*train)
[valid_tags, valid_seqs, valid_dbrs] = zip(*valid)
[tests_tags, tests_seqs, tests_dbrs] = zip(*tests)


In [None]:
#
# Model Settings (TODO: update to paper settings!)
#
model = 1
batch_size = 8
epochs = 3
data = os.path.basename(fname)
name = f"spotrna_m{model}_bs{batch_size}_{data}"


In [None]:
#
# Model Setup
#
train_generator = PaddedMatrixEncoding(batch_size, train_seqs, train_dbrs)
valid_generator = PaddedMatrixEncoding(batch_size, valid_seqs, valid_dbrs)

m = spotrna(model, True)
m.compile(optimizer = "adam",
          loss = "binary_crossentropy", 
          metrics = ["acc", mcc, f1, sensitivity],
          run_eagerly = True)

# Callback functions for fitting.
csv_logger = CSVLogger(f"{name}.csv", separator = ';', append = True)
model_checkpoint = ModelCheckpoint(filepath = name, 
                                   save_weights_only = False, 
                                   monitor = 'val_mcc', 
                                   mode = 'max', 
                                   save_best_only = True)

m.fit(x = train_generator, 
      validation_data = valid_generator,
      epochs = epochs,
      shuffle = True,
      verbose = 1,
      callbacks = [csv_logger, model_checkpoint])

#save model after last epochs 
m.save(f"{name}_ep{epochs}.rnadeep")