[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mitiau/DNABERT-Z/blob/main/ZDNA-prediction.ipynb)

# Install dependencies and define helper functions

In [None]:
!pip3 install transformers
!pip3 install biopython
!pip3 install torch
!pip3 install numpy
!pip3 install scipy
!pip3 install tqdm

In [None]:
import torch
from torch import nn
import transformers
from transformers import BertTokenizer, BertForTokenClassification
import numpy as np
from Bio import SeqIO
from io import StringIO, BytesIO
from tqdm import tqdm
import pickle
import scipy
from scipy import ndimage
import os
from ipywidgets import widgets
import subprocess

In [None]:
def seq2kmer(seq, k):
    kmer = [seq[x:x+k] for x in range(len(seq)+1-k)]
    return kmer

def split_seq(seq, length = 512, pad = 16):
    res = []
    for st in range(0, len(seq), length - pad):
        end = min(st+512, len(seq))
        res.append(seq[st:end])
    return res

def stitch_np_seq(np_seqs, pad = 16):
    res = np.array([])
    for seq in np_seqs:
        res = res[:-pad]
        res = np.concatenate([res,seq])
    return res

In [None]:
base_pair_opposite_map = {
    'A': 'T',
    'T': 'A',
    'C': 'G',
    'G': 'C',
}

In [None]:
def complement_nucleobase(nucleobase):
    return base_pair_opposite_map[nucleobase] if nucleobase in base_pair_opposite_map else nucleobase

def complement_seq(seq):
    return ''.join([complement_nucleobase(nucleobase) for nucleobase in seq])

def reverse_seq(seq):
    return seq[::-1]

In [None]:
test_seq = 'ACGTA'
print('seq\t\t\t', test_seq)
print('complement\t\t', complement_seq(test_seq))
print('reverse\t\t\t', reverse_seq(test_seq))
print('reverse-complement\t', reverse_seq(complement_seq(test_seq)))

In [None]:
def gdown_wrapper(gdrive_id, file_path):
    if os.path.exists(file_path):
        print(file_path, gdrive_id, 'already exists')
        return
    
    # curl "https://drive.google.com/uc?id=${id}&export=download&confirm=ABCD" --verbose -L -o 
    gdrive_url = 'https://drive.google.com/uc?id={id}&export=download&confirm=ABCD'.format(id=gdrive_id)
    
    print(gdrive_url, file_path)
    
    !curl -L --progress-bar -o "{file_path}" "{gdrive_url}"

In [None]:
def hash_file(file_path):
    return subprocess.run(['shasum', file_path], stdout=subprocess.PIPE).stdout[:40]

In [None]:
models = {
    'HG kouzine': (
        '1dAeAt5Gu2cadwDhbc7OnenUgDLHlUvkx',
        'hg_kouzine.pytorch_model.bin',
    ),
    'HG chipseq': (
        '1VAsp8I904y_J0PUhAQqpSlCn1IqfG0FB',
        'hg_chipseq.pytorch_model.bin',
    ),
    'MM curax': (
        '1W6GEgHNoitlB-xXJbLJ_jDW4BF35W1Sd',
        'mm_curax.pytorch_model.bin',
    ),
    'MM kouzine': (
        '1dXpQFmheClKXIEoqcZ7kgCwx6hzVCv3H',
        'mm_kouzine.pytorch_model.bin',
    ),
}

In [None]:
data_path = './6-new-12w-0'
model_data_path = './pytorch_models'
target_model_data_path = os.path.join(data_path, 'pytorch_model.bin')
output_path = 'output'

In [None]:
meta_files = [
    ('10sF8Ywktd96HqAL0CwvlZZUUGj05CGk5', os.path.join(data_path, 'config.json')),
    ('16bT7HDv71aRwyh3gBUbKwign1mtyLD2d', os.path.join(data_path, 'special_tokens_map.json')),
    ('1EE9goZ2JRSD8UTx501q71lGCk-CK3kqG', os.path.join(data_path, 'tokenizer_config.json')),
    ('1gZZdtAoDnDiLQqjQfGyuwt268Pe5sXW0', os.path.join(data_path, 'vocab.txt')),
]

In [None]:
model_name_widget = widgets.Dropdown(
    options=models.keys(),
    value=next(iter(models.keys())),
    description='model:',
    disabled=False,
)

model_confidence_threshold_widget = widgets.FloatText(
    value=0.5,
    description='model confidence threshold'
)

minimum_sequence_length_widget = widgets.IntText(
    value=10,
    description='minimum sequence length:',
)

check_sequence_variations_widget = widgets.Checkbox(
    value=True,
    description='check reverse complement sequence variations'
)

In [None]:
fasta_upload_widget = widgets.FileUpload(
    multiple=True
)

In [None]:
load_model_output = widgets.Output()

In [None]:
@load_model_output.capture(clear_output=True)
def load_model(btn):
    global model_name, model_confidence_threshold, minimum_sequence_length, check_sequence_variations, model_file_path
    
    model_name = model_name_widget.value
    model_confidence_threshold = model_confidence_threshold_widget.value
    minimum_sequence_length = minimum_sequence_length_widget.value
    check_sequence_variations = check_sequence_variations_widget.value
    
    model_gdrive_id, model_file_name = models[model_name]
    
    model_file_path = os.path.join(model_data_path, model_file_name)
    
    
    print('downloading model data to input directory\n')
    
    !mkdir $data_path
    !mkdir $model_data_path
    
    gdown_wrapper(model_gdrive_id, model_file_path)
    
    for meta_file_gdrive_id, meta_file_file_path in meta_files:
        gdown_wrapper(meta_file_gdrive_id, meta_file_file_path)
    
    
    print('\n\ncopying model file to input directory\n')
    
    hash1 = hash_file(model_file_path)
    hash2 = hash_file(target_model_data_path)
    print(hash1, hash2)
    if hash1 != hash2:
        !cp {model_file_path} {target_model_data_path}
    
    
    print('\n\nloading model\n')
    
    tokenizer = BertTokenizer.from_pretrained(data_path)
    model = BertForTokenClassification.from_pretrained(data_path)
    is_cuda_available = torch.cuda.is_available()
    print('cuda is', 'available' if is_cuda_available else 'not available')
    if is_cuda_available:
        model.cuda()
    else:
        model.cpu()

In [None]:
load_model_button = widgets.Button(
    description='Load model',
    icon='truck-loading', # (FontAwesome names without the `fa-` prefix)
)
load_model_button.on_click(load_model)

In [None]:
!mkdir "{output_path}"

# Run

<button data-commandLinker-command="notebook:run-all-above" class="lm-Widget jupyter-widgets jupyter-button">Prepare environment</button>

## Select model and parameters

In [None]:
display(model_name_widget)
display(model_confidence_threshold_widget)
display(minimum_sequence_length_widget)
display(check_sequence_variations_widget)

In [None]:
display(load_model_button)

In [None]:
display(load_model_output)

## Upload fasta files for prediction

multiple files may be selected:

In [None]:
display(fasta_upload_widget)

<button data-commandLinker-command="notebook:run-all-below" class="lm-Widget jupyter-widgets jupyter-button">Run prediction</button>

## Prediction output

In [None]:
uploaded = {v['name']: v['content'] for v in fasta_upload_widget.value}
for fn in uploaded.keys():
  print('User uploaded file "{name}" with length {length} bytes'.format(
      name=fn, length=len(uploaded[fn])))

In [None]:
%%time

out = []

out.append('model_name: {}'.format(model_name))
print('model_name: {}'.format(model_name))
out.append('model_confidence: {}'.format(model_confidence_threshold))
print('model_confidence: {}'.format(model_confidence_threshold))
out.append('minimum_sequence_length: {}'.format(minimum_sequence_length))
print('minimum_sequence_length: {}'.format(minimum_sequence_length))

for key in uploaded.keys():
    print(key)
    out.append(key)
    result_dict = {}
    for seq_record in SeqIO.parse(StringIO(BytesIO(uploaded[key]).read().decode('UTF-8')), 'fasta'):
        seqs = []
        seq_uppered = str(seq_record.seq).upper()
        seqs.append(('normal', seq_uppered))
        if check_sequence_variations:
            seq_uppered_complemented = complement_seq(seq_uppered)
            #seqs.append(('complemented', seq_uppered_complemented))
            #seq_uppered_reversed = reverse_seq(seq_uppered)
            #seqs.append(('reversed', seq_uppered_reversed))
            seq_uppered_complemented_reversed = reverse_seq(seq_uppered_complemented)
            seqs.append(('reverse-complement', seq_uppered_complemented_reversed))
        
        print(seq_record.name)
        out.append(seq_record.name)
        
        for seq_name, seq in seqs:
            seq_key = '{}.{}.{}'.format(key, seq_record.name, seq_name)
            
            bed_out = []
            
            kmer_seq = seq2kmer(seq, 6)
            seq_pieces = split_seq(kmer_seq)
            print(seq_name)
            out.append(seq_name)
            with torch.no_grad():
                preds = []
                for seq_piece in tqdm(seq_pieces):
                    input_ids = torch.LongTensor(tokenizer.encode(' '.join(seq_piece), add_special_tokens=False))
                    input_ids_unsqueezed = None
                    if is_cuda_available:
                        input_ids_unsqueezed = input_ids.cuda().unsqueeze(0)
                    else:
                        input_ids_unsqueezed = input_ids.cpu().unsqueeze(0)
                    outputs = torch.softmax(model(input_ids_unsqueezed)[-1],axis = -1)[0,:,1]
                    preds.append(outputs.cpu().numpy())
            result_dict[seq_key] = stitch_np_seq(preds)
    
    
    
            labeled, max_label = scipy.ndimage.label(result_dict[seq_key]>model_confidence_threshold)
            print('  start     end')
            out.append('  start     end')
            for label in range(1, max_label+1):
                candidate = np.where(labeled == label)[0]
                candidate_length = candidate.shape[0]
                if candidate_length>minimum_sequence_length:
                    print('{:8}'.format(candidate[0]), '{:8}'.format(candidate[-1]))
                    out.append('{:8}{:8}'.format(candidate[0], candidate[-1]))
                    
                    # start has to be subtracted by 1 for bed, see https://grch37.ensembl.org/info/website/upload/bed.html
                    bed_out.append('0\t{}\t{}'.format(candidate[0] - 1, candidate[-1]))

            with open(os.path.join(output_path, '{}.preds.pkl'.format(seq_key)),"wb") as fh:
              pickle.dump(result_dict, fh)
            print()
            
            bed_file_name = '{}.bed'.format(seq_key)
            with open(os.path.join(output_path, bed_file_name),"w") as fh:
                for item in bed_out:
                    fh.write("%s\n" % item)

with open(os.path.join(output_path, 'text_predictions.txt'),"w") as fh:
    for item in out:
        fh.write("%s\n" % item)
