# **[Jump to Run Section](#Run)**

# Install dependencies and define helper functions

In [None]:
from tqdm.auto import tqdm
import os
from ipywidgets import widgets
import pathlib
import logging
import time

from src.prediction_runner import PredictionRunner
from src.prediction_input import PredictionInput
from src.prediction_input_file_from_filesystem import PredictionInputFileFromFilesystem
from src.zdnabert_model import ZdnabertModel
from src.sequence_variation_normal import SequenceVariationNormal
from src.sequence_variation_reverse_complement import SequenceVariationReverseComplement
from src.prediction_result_formatter_bed_file import PredictionResultFormatterBedFile
from src.zdnabert_model_downloader import ZdnabertModelDownloader

In [None]:
model_download_path = './pytorch_models'
input_path = './input'
output_path = './output'
file_name_maximum_length = 255

In [None]:
model_name_widget = widgets.Dropdown(
    options=ZdnabertModelDownloader.models.keys(),
    value=next(iter(ZdnabertModelDownloader.models.keys())),
    description='model:',
    disabled=False,
)

model_confidence_threshold_widget = widgets.FloatText(
    value=0.5,
    description='model confidence threshold'
)

minimum_sequence_length_widget = widgets.IntText(
    value=10,
    description='minimum sequence length:',
)

check_sequence_variations_widget = widgets.Checkbox(
    value=True,
    description='check reverse complement sequence variations'
)

use_cuda_if_available_widget = widgets.Checkbox(
    value=True,
    description='use cuda if it is available'
)

In [None]:
def load_input_files_list():
    input_path_handle = pathlib.Path(input_path)
    files = [item for item in input_path_handle.iterdir() if item.is_file()]
    file_names = [item.name for item in files]
    return file_names

In [None]:
def retrieve_input_files_list():
    return {v: pathlib.PurePath(input_path).joinpath(v) for v in input_files_widget.value}

In [None]:
input_files_widget = None
def create_input_files_widget():
    global input_files_widget
    input_files_widget = widgets.SelectMultiple(
        options=load_input_files_list(),
        value=[],
        description='Select inputs',
        disabled=False
    )

In [None]:
load_model_output = widgets.Output()
do_predictions_output = widgets.Output()

In [None]:
model_name = None
model_confidence_threshold = None
minimum_sequence_length = None
check_sequence_variations = None
use_cuda_if_available = None

In [None]:
@load_model_output.capture(clear_output=True)
def load_model(btn):
    global model_name, model_confidence_threshold, minimum_sequence_length, check_sequence_variations, model_file_path, tokenizer, model, use_cuda_if_available, is_cuda_available
    
    model_name = model_name_widget.value
    model_confidence_threshold = model_confidence_threshold_widget.value
    minimum_sequence_length = minimum_sequence_length_widget.value
    check_sequence_variations = check_sequence_variations_widget.value
    use_cuda_if_available = use_cuda_if_available_widget.value
    
    print('\n\ncompleted loading model\n\nmodel: {}\nmodel confidence threshold: {}\nminimum sequence length: {}'.format(model_name, model_confidence_threshold, minimum_sequence_length))

In [None]:
load_model_button = widgets.Button(
    description='Load model',
    icon='truck-loading',
)
load_model_button.on_click(load_model)

In [None]:
@do_predictions_output.capture(clear_output=True)
def do_predictions(btn):
    uploaded_items = retrieve_input_files_list()
    
    #logging.basicConfig(
    #    filename='example.log',
    #    encoding='utf-8',
    #    format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
    #    level=logging.INFO
    #)
    
    zdnabert_model_downloader = ZdnabertModelDownloader()
    zdnabert_model_downloader.download_models(model_download_path)
    zdnabert_model_downloader.download_metas(model_download_path)
    
    zdnabert_model = ZdnabertModel(
        os.path.join(model_download_path, model_name),
        model_name=model_name,
        model_confidence_threshold=model_confidence_threshold,
        minimum_sequence_length=minimum_sequence_length,
        use_cuda_if_available=use_cuda_if_available,
    )

    sequence_variations = []
    sequence_variations.append(SequenceVariationNormal())
    if check_sequence_variations:
        sequence_variations.append(SequenceVariationReverseComplement())
    
    
    prediction_input_files = []
    for uploaded_item_key in tqdm(uploaded_items.keys(), 'preparing'):
        uploaded_item_path = uploaded_items[uploaded_item_key]
        
        prediction_input_file = PredictionInputFileFromFilesystem(
            uploaded_item_key,
            uploaded_item_path,
        )
        prediction_input_files.append(prediction_input_file)
        
    prediction_input = PredictionInput(
        zdnabert_model,
        prediction_input_files,
        sequence_variations,
    )
    prediction_inputs = [
        prediction_input,
    ]

    prediction_result_formatter_bed_file = PredictionResultFormatterBedFile()
    
    prediction_runner = PredictionRunner()

    now_time_as_string_for_file_name = time.strftime("%Y_%m_%d,%H_%M_%S")

    first_time_bed_file_names = []
    
    for prediction_result in prediction_runner.run(prediction_inputs, progress_bar=tqdm):
        bed_file_name = prediction_result_formatter_bed_file.file_name_common(prediction_result, now_time_as_string_for_file_name)
        bed_file_name_seq = prediction_result_formatter_bed_file.file_name_variation(prediction_result, now_time_as_string_for_file_name)
        model_params_as_string = prediction_result.get_model_params_as_string()
        
        print(bed_file_name)
        print(bed_file_name_seq)

        bed_file_handler = open(os.path.join(output_path, bed_file_name), 'a')
        bed_file_seq_handler = open(os.path.join(output_path, bed_file_name_seq), 'w')

        if not bed_file_name in first_time_bed_file_names:
            first_time_bed_file_names.append(bed_file_name)
            bed_file_handler.write('track name="{name}" priority=1\n'.format(name=model_params_as_string))
        
        bed_file_seq_handler.write('track name="{name}" priority=1\n'.format(name=model_params_as_string))
        
        #print('track name="{name}" priority=2'.format(name=seq_name))
        
        for line in prediction_result_formatter_bed_file.format(prediction_result):
            #print(line)
            bed_file_handler.write("{}\n".format(line))
            bed_file_seq_handler.write("{}\n".format(line))

        bed_file_handler.close()
        bed_file_seq_handler.close()
        #print()

In [None]:
do_predictions_button = widgets.Button(
    description='Run prediction',
    icon='chart-line',
)
do_predictions_button.on_click(do_predictions)

# Run

Start predicting features of fasta file inputs in 4 steps.

## Usage

### Prepare

Preparing the environment only needs to be done once everytime when starting JupyterLab or freshly opening the notebook thereafter.

### Select model and parameters

After changing the model or the parameters, press the "Load model"-Button.

This will create required directories, download required files and move the model file into the relevant directory. Files that have been downloaded already, will not be downloaded again.

### Run

After the predictions have been made, new files will be created in the directory `output`.

The following types of files will be created:

- `.txt`-Files will contain the textual representation as seen in the output of the notebook for all input files
- Several different `.bed`-Files containing the found features will be created for each input file based on the selected sequence variations

  They can be used to import found features into other software.
  
  - `.normal.bed` contains features found for the original input fasta file
  - `.rev-comp.bed` contains features found for the reverse-complement
  - `.bed` contains features found in both the normal and the reverse-complement


## 1 Prepare

<button data-commandLinker-command="notebook:run-all-cells" class="lm-Widget jupyter-widgets jupyter-button">Prepare environment</button>

## 2 Select model and parameters

In [None]:
display(model_name_widget)
display(model_confidence_threshold_widget)
display(minimum_sequence_length_widget)
display(check_sequence_variations_widget)
display(use_cuda_if_available_widget)

In [None]:
display(load_model_button)

In [None]:
display(load_model_output)

## 3 Select fasta files

Multiple fasta files may be selected. You can place them in the directory `input`.

In [None]:
create_input_files_widget()
display(input_files_widget)

## 4 Run

In [None]:
display(do_predictions_button)

## Prediction output

In [None]:
%%time

display(do_predictions_output)