# **[Jump to Run Section](#Run)**

# Install dependencies and define helper functions

In [None]:
!pip3 install transformers
!pip3 install biopython
!pip3 install torch
!pip3 install numpy
!pip3 install scipy
!pip3 install tqdm

In [None]:
import torch
from torch import nn
import transformers
from transformers import BertTokenizer, BertForTokenClassification
import numpy as np
from Bio import SeqIO
from io import StringIO, BytesIO
from tqdm import tqdm
import pickle
import scipy
from scipy import ndimage
import os
from ipywidgets import widgets
import subprocess
import time

In [None]:
data_path = './6-new-12w-0'
model_data_path = './pytorch_models'
target_model_data_path = os.path.join(data_path, 'pytorch_model.bin')
output_path = './output'
file_name_maximum_length = 255

In [None]:
def seq2kmer(seq, k):
    kmer = [seq[x:x+k] for x in range(len(seq)+1-k)]
    return kmer

def split_seq(seq, length = 512, pad = 16):
    res = []
    for st in range(0, len(seq), length - pad):
        end = min(st+512, len(seq))
        res.append(seq[st:end])
    return res

def stitch_np_seq(np_seqs, pad = 16):
    res = np.array([])
    for seq in np_seqs:
        res = res[:-pad]
        res = np.concatenate([res,seq])
    return res

In [None]:
base_pair_opposite_map = {
    'A': 'T',
    'T': 'A',
    'C': 'G',
    'G': 'C',
}

In [None]:
def complement_nucleobase(nucleobase):
    return base_pair_opposite_map[nucleobase] if nucleobase in base_pair_opposite_map else nucleobase

def complement_seq(seq):
    return ''.join([complement_nucleobase(nucleobase) for nucleobase in seq])

def reverse_seq(seq):
    return seq[::-1]

In [None]:
test_seq = 'ACGTA'
print('seq\t\t\t', test_seq)
print('complement\t\t', complement_seq(test_seq))
print('reverse\t\t\t', reverse_seq(test_seq))
print('reverse-complement\t', reverse_seq(complement_seq(test_seq)))

In [None]:
def gdown_wrapper(gdrive_id, file_path):
    if os.path.exists(file_path):
        print(file_path, gdrive_id, 'already exists')
        return
    
    # curl "https://drive.google.com/uc?id=${id}&export=download&confirm=ABCD" --verbose -L -o 
    #gdrive_url = 'https://drive.google.com/uc?id={id}&export=download&confirm=ABCD'.format(id=gdrive_id)
    gdrive_url = 'https://drive.usercontent.google.com/download?id={id}&confirm=ABCD'.format(id=gdrive_id)
    
    print(gdrive_url, file_path)
    
    !curl -L --progress-bar -o "{file_path}" "{gdrive_url}"

In [None]:
def hash_file(file_path):
    if os.path.exists(file_path) == False:
        return None
    
    return subprocess.run(['shasum', file_path], stdout=subprocess.PIPE).stdout[:40]

In [None]:
models = {
    'HG kouzine': (
        '1dAeAt5Gu2cadwDhbc7OnenUgDLHlUvkx',
        'hg_kouzine.pytorch_model.bin',
    ),
    'HG chipseq': (
        '1VAsp8I904y_J0PUhAQqpSlCn1IqfG0FB',
        'hg_chipseq.pytorch_model.bin',
    ),
    'MM curax': (
        '1W6GEgHNoitlB-xXJbLJ_jDW4BF35W1Sd',
        'mm_curax.pytorch_model.bin',
    ),
    'MM kouzine': (
        '1dXpQFmheClKXIEoqcZ7kgCwx6hzVCv3H',
        'mm_kouzine.pytorch_model.bin',
    ),
}

In [None]:
meta_files = [
    ('10sF8Ywktd96HqAL0CwvlZZUUGj05CGk5', os.path.join(data_path, 'config.json')),
    ('16bT7HDv71aRwyh3gBUbKwign1mtyLD2d', os.path.join(data_path, 'special_tokens_map.json')),
    ('1EE9goZ2JRSD8UTx501q71lGCk-CK3kqG', os.path.join(data_path, 'tokenizer_config.json')),
    ('1gZZdtAoDnDiLQqjQfGyuwt268Pe5sXW0', os.path.join(data_path, 'vocab.txt')),
]

In [None]:
model_name_widget = widgets.Dropdown(
    options=models.keys(),
    value=next(iter(models.keys())),
    description='model:',
    disabled=False,
)

model_confidence_threshold_widget = widgets.FloatText(
    value=0.5,
    description='model confidence threshold'
)

minimum_sequence_length_widget = widgets.IntText(
    value=10,
    description='minimum sequence length:',
)

check_sequence_variations_widget = widgets.Checkbox(
    value=True,
    description='check reverse complement sequence variations'
)

In [None]:
fasta_upload_widget = widgets.FileUpload(
    multiple=True
)

In [None]:
load_model_output = widgets.Output()
do_predictions_output = widgets.Output()

In [None]:
model_name = None
model_confidence_threshold = None
minimum_sequence_length = None
check_sequence_variations = None
model_file_path = None
tokenizer = None
model = None
is_cuda_available = None

In [None]:
@load_model_output.capture(clear_output=True)
def load_model(btn):
    global model_name, model_confidence_threshold, minimum_sequence_length, check_sequence_variations, model_file_path, tokenizer, model, is_cuda_available
    
    model_name = model_name_widget.value
    model_confidence_threshold = model_confidence_threshold_widget.value
    minimum_sequence_length = minimum_sequence_length_widget.value
    check_sequence_variations = check_sequence_variations_widget.value
    
    model_gdrive_id, model_file_name = models[model_name]
    
    model_file_path = os.path.join(model_data_path, model_file_name)
    
    
    print('downloading model data to input directory\n')
    
    !mkdir "{data_path}"
    !mkdir "{model_data_path}"
    
    gdown_wrapper(model_gdrive_id, model_file_path)
    
    for meta_file_gdrive_id, meta_file_file_path in meta_files:
        gdown_wrapper(meta_file_gdrive_id, meta_file_file_path)
    
    
    print('\n\nchecking model file in input directory\n')
    
    hash1 = hash_file(model_file_path)
    hash2 = hash_file(target_model_data_path)
    print(hash1, hash2)
    if hash1 != hash2:
        print('\ncopying model file to input directory\n')
        !cp "{model_file_path}" "{target_model_data_path}"
    else:
        print('\nmodel hasn\'t changed\n')
    
    
    print('\n\nloading model\n')
    
    tokenizer = BertTokenizer.from_pretrained(data_path)
    model = BertForTokenClassification.from_pretrained(data_path)
    is_cuda_available = torch.cuda.is_available()
    print('cuda is', 'available' if is_cuda_available else 'not available')
    if is_cuda_available:
        model.cuda()
    else:
        model.cpu()
    
    print('\n\ncompleted loading model\n\nmodel: {}\nmodel confidence threshold: {}\nminimum sequence length: {}'.format(model_name, model_confidence_threshold, minimum_sequence_length))

In [None]:
load_model_button = widgets.Button(
    description='Load model',
    icon='truck-loading',
)
load_model_button.on_click(load_model)

In [None]:
@do_predictions_output.capture(clear_output=True)
def do_predictions(btn):
    uploaded = {v['name']: v['content'] for v in fasta_upload_widget.value}
    for fn in uploaded.keys():
        print('Fasta file "{name}" with length {length} bytes\n'.format(name=fn, length=len(uploaded[fn])))
    
    out = []
    
    out.append('model_name: {}'.format(model_name))
    print('model_name: {}'.format(model_name))
    out.append('model_confidence: {}'.format(model_confidence_threshold))
    print('model_confidence: {}'.format(model_confidence_threshold))
    out.append('minimum_sequence_length: {}'.format(minimum_sequence_length))
    print('minimum_sequence_length: {}'.format(minimum_sequence_length))
    
    model_params_as_string = 'm={},mct={},msl={}'.format(model_name, model_confidence_threshold, minimum_sequence_length)
    model_params_as_string_for_file_name = 'm_{},mct_{},msl_{}'.format(model_name, model_confidence_threshold, minimum_sequence_length)
    now_time_as_string_for_file_name = time.strftime("%Y_%m_%d,%H_%M_%S")
    
    # <file name length checks>
    file_names_to_check = []
    
    text_predictions_file_name = 'text_predictions.{}.{}.txt'.format(model_params_as_string_for_file_name, now_time_as_string_for_file_name)
    file_names_to_check.append(text_predictions_file_name)
    for key in uploaded.keys():
        for seq_record in SeqIO.parse(StringIO(BytesIO(uploaded[key]).read().decode('UTF-8')), 'fasta'):
            seqs = []
            seqs.append('normal')
            if check_sequence_variations:
                seqs.append('reverse-complement')
            
            seq_record_key = '{}.{}.{}.{}'.format(key, seq_record.name, model_params_as_string_for_file_name, now_time_as_string_for_file_name)
            
            for seq_name in seqs:
                seq_key = '{}.{}'.format(seq_record_key, seq_name)
                
                pkl_file_name_seq = '{}.pkl'.format(seq_key)
                file_names_to_check.append(pkl_file_name_seq)
                bed_file_name_seq = '{}.bed'.format(seq_key)
                file_names_to_check.append(bed_file_name_seq)
            bed_file_name = '{}.bed'.format(seq_record_key)
            file_names_to_check.append(bed_file_name)
    
    file_names_too_long = [] # array of tuples: (file_name: String, file_name_length: Int)
    for file_name_to_check in file_names_to_check:
        file_name_to_check_length = len(file_name_to_check)
        file_name_too_long = file_name_to_check_length > file_name_maximum_length
        if file_name_too_long:
            file_names_too_long.append((file_name_to_check, file_name_to_check_length))
    
    if len(file_names_too_long) > 0:
        print('\n\nThe length of some file names that are going to be generated exceeds the maximum file name length of {} characters.\nPlease reduce the lengths of the inputs accordingly.'.format(file_name_maximum_length))
        for file_name, file_name_length in file_names_too_long:
            print('\nname:\t{}\nlength:\t{}\nover:\t{}'.format(file_name, file_name_length, file_name_length - file_name_maximum_length))
        raise ValueError('At least one of the generated file names will exceed the maximum file name length of {} characters.'.format(file_name_maximum_length))
    # </file name length checks>
    
    
    for key in uploaded.keys():
        print(key)
        out.append(key)
        result_dict = {}
        for seq_record in SeqIO.parse(StringIO(BytesIO(uploaded[key]).read().decode('UTF-8')), 'fasta'):
            seqs = [] # array of tuples: (variation name: String, sequence: String, reversed: Bool)
            seq_uppered = str(seq_record.seq).upper()
            seqs.append(('normal', seq_uppered, False))
            if check_sequence_variations:
                seq_uppered_complemented = complement_seq(seq_uppered)
                
                seq_uppered_complemented_reversed = reverse_seq(seq_uppered_complemented)
                seqs.append(('reverse-complement', seq_uppered_complemented_reversed, True))
            
            print(seq_record.name)
            out.append(seq_record.name)
            
            seq_record_key = '{}.{}.{}.{}'.format(key, seq_record.name, model_params_as_string_for_file_name, now_time_as_string_for_file_name)
            
            bed_out = []
            bed_out.append('track name="{name}" priority=1'.format(name=model_params_as_string))
            
            for seq_name, seq, seq_reversed in seqs:
                seq_key = '{}.{}'.format(seq_record_key, seq_name)
                
                seq_len = len(seq)
                
                bed_out_seq = []
                
                kmer_seq = seq2kmer(seq, 6)
                seq_pieces = split_seq(kmer_seq)
                print(seq_name)
                out.append(seq_name)
                with torch.no_grad():
                    preds = []
                    for seq_piece in tqdm(seq_pieces):
                        input_ids = torch.LongTensor(tokenizer.encode(' '.join(seq_piece), add_special_tokens=False))
                        input_ids_unsqueezed = None
                        if is_cuda_available:
                            input_ids_unsqueezed = input_ids.cuda().unsqueeze(0)
                        else:
                            input_ids_unsqueezed = input_ids.cpu().unsqueeze(0)
                        outputs = torch.softmax(model(input_ids_unsqueezed)[-1],axis = -1)[0,:,1]
                        preds.append(outputs.cpu().numpy())
                result_dict[seq_key] = stitch_np_seq(preds)
                
                
                
                labeled, max_label = scipy.ndimage.label(result_dict[seq_key]>model_confidence_threshold)
                print('  start     end')
                out.append('  start     end')
                
                bed_out.append('track name="{name}" priority=2'.format(name=seq_name))
                
                label_id = 1
                for label in range(1, max_label+1):
                    candidate = np.where(labeled == label)[0]
                    candidate_length = candidate.shape[0]
                    if candidate_length>minimum_sequence_length:
                        print('{:8}'.format(candidate[0]), '{:8}'.format(candidate[-1]))
                        out.append('{:8}{:8}'.format(candidate[0], candidate[-1]))
                        
                        # start has to be subtracted by 1 for bed, see https://grch37.ensembl.org/info/website/upload/bed.html
                        candidate_start = candidate[0] - 1
                        candidate_end = candidate[-1]
                        if seq_reversed:
                            candidate_start = (seq_len - candidate[-1]) - 1
                            candidate_end = seq_len - candidate[0]
                        
                        bed_name = '{},{},{}'.format(model_params_as_string, seq_name, label_id)
                        bed_out.append('0\t{start}\t{end}\t{name}'.format(start=candidate_start, end=candidate_end, name=bed_name))
                        
                        bed_out_seq.append('0\t{start}\t{end}\t{name}'.format(start=candidate_start, end=candidate_end, name=bed_name))
                        
                        label_id += 1

                pkl_file_name_seq = '{}.pkl'.format(seq_key)
                with open(os.path.join(output_path, pkl_file_name_seq),"wb") as fh:
                  pickle.dump(result_dict, fh)
                print()
                
                bed_file_name_seq = '{}.bed'.format(seq_key)
                print(bed_file_name_seq)
                with open(os.path.join(output_path, bed_file_name_seq),"w") as fh:
                    for item in bed_out_seq:
                        fh.write("%s\n" % item)
            
            bed_file_name = '{}.bed'.format(seq_record_key)
            print(bed_file_name)
            with open(os.path.join(output_path, bed_file_name),"w") as fh:
                for item in bed_out:
                    fh.write("%s\n" % item)
    
    with open(os.path.join(output_path, text_predictions_file_name),"w") as fh:
        for item in out:
            fh.write("%s\n" % item)


In [None]:
do_predictions_button = widgets.Button(
    description='Run prediction',
    icon='chart-line',
)
do_predictions_button.on_click(do_predictions)

In [None]:
!mkdir "{output_path}"

# Run

Start predicting features of fasta file inputs in 4 steps.

## Usage

### Prepare

Preparing the environment only needs to be done once everytime when starting JupyterLab or freshly opening the notebook thereafter.

### Select model and parameters

After changing the model or the parameters, press the "Load model"-Button.

This will create required directories, download required files and move the model file into the relevant directory. Files that have been downloaded already, will not be downloaded again.

### Run

After the predictions have been made, new files will be created in the directory `output`.

The following types of files will be created:

- `.txt`-Files will contain the textual representation as seen in the output of the notebook for all input files
- Several different `.bed`-Files containing the found features will be created for each input file based on the selected sequence variations

  They can be used to import found features into other software.
  
  - `.normal.bed` contains features found for the original input fasta file
  - `.reverse-complement.bed` contains features found for the reverse-complement
  - `.bed` contains features found in both the normal and the reverse-complement


## 1 Prepare

<button data-commandLinker-command="notebook:run-all-cells" class="lm-Widget jupyter-widgets jupyter-button">Prepare environment</button>

## 2 Select model and parameters

In [None]:
display(model_name_widget)
display(model_confidence_threshold_widget)
display(minimum_sequence_length_widget)
display(check_sequence_variations_widget)

In [None]:
display(load_model_button)

In [None]:
display(load_model_output)

## 3 Upload fasta files

Multiple fasta files may be selected:

In [None]:
display(fasta_upload_widget)

## 4 Run

In [None]:
display(do_predictions_button)

## Prediction output

In [None]:
%%time

display(do_predictions_output)