# DNABERT Encoder Training
This script trains the encoder of BERT to sequences of DNA nucleotides. This is an implementation of the DNABert method, see https://www.biorxiv.org/content/biorxiv/early/2020/09/19/2020.09.17.301879.full.pdf

At a high-level, the method works as follows:
* Bla bla

### Basics of Setup
The github package is installed. This is followed by a simple test run of a toy example. If this fails, something failed with the import or is not setup right.

In [None]:
!pip install git+https://github.com/anderzzz/nucleotide_transformer.git --upgrade

In [None]:
SUBMODULE_LIST = ['datacollators', 'io', 'utils']
assert all([submodule in dir(biosequences) for submodule in SUBMODULE_LIST])

In [7]:
import torch
from transformers import BertTokenizer, BertTokenizerFast
from biosequences.utils import dna_nucleotide_alphabet, NucleotideVocabCreator, Phrasifier

SEQ_STR = 'AATGCGT'
IDS_SEQ_STR = [3,8,19,62,43,32,1]
SEQ_BATCH = ['AATGCGT', 'GGGGT']
IDS_SEQ_BATCH = [[3,8,19,62,43,32,1], [3,47,47,48,1]]

dna_vocab = NucleotideVocabCreator(dna_nucleotide_alphabet, do_lower_case=True).generate(3)
with open('tmp_test.txt', 'w') as fout:
    dna_vocab.save(fout)
phrasifier = Phrasifier(stride=1, word_length=3)
tokenizer = BertTokenizer(vocab_file='tmp_test.txt', tokenize_chinese_chars=False)
out = tokenizer(phrasifier(SEQ_STR))

assert out.input_ids == IDS_SEQ_STR

### Runtime Parameters
The next section defines all runtime parameters that determine what encoder is trained on what data and how.

In [None]:
from typing import Dict
from dataclasses import dataclass
@dataclass
class Arguments2DNABertTraining:
    '''Runtime arguments for training of DNABert Encoder
    
    Args:
        folder_seq_raw (str): From where to read the raw data of nucleotide sequence chunks. If `None` the
            assumption is the processed nucleotide data is available in `folder_seq_sentence` already.
        seq_raw_format (str): File format of raw sequence chunk data. Currently CSV, GenBank and Fasta are 
            possible. 
        seq_raw_file_patter (str): The query that returns all relevant files in the `folder_seq_raw`. In case
            there is only one file, set this to that filename.
        upper_lower (str): If to assume nucleotide characters to be all upper case or all lower case. This has
            to be consistently used throughout, so both in the processing of raw data and the tokenization.
        folder_seq_sentence (str): Folder for the nucleotide sequence sentence files; this is where the
            raw sequence processing outputs its data files and where the dataset creator later reads from.
        seq_sentence_prefix (str): File prefix to use for the plurality of sequence sentence files.
        word_length_vocab (int): The number of nucleotide residues comprises a word.
        stride (int): The stride to use as a nucleotide sequence is processed into a nucleotide sentence.
        split_ratio_test (float): The ratio of data to turn into testing data.
        split_ratio_validate (float): The ratio of data to turn into validation data.
        shuffle (bool): If the data should be shuffled.
        seed (int): Random seed for data shuffling.
        vocab_file (str): Name of vocabulary file.
        create_vocab (bool): If the vocabulary file should be created; if `False` the vocabulary file is 
            assumed to already be in `folder_seq_sentence`.
        chunk_size (int): How many words to concatenate and include in a batch
        masking_probability (float): The average ratio of masked words in the data; note that the masking is 
            done in chunks of at least length `word_length_vocab`.
        bert_config_kwargs (dict): Keyword argument dictionary for the configuration of the BERT model, see 
            Huggingsface `BertConfig`.
        folder_training_input (str): The folder where a PyTorch variant of the Bert model and its parameters
            is stored and to be used as starting point; if `None`, the initial parameters are randomly
            initialized; typically this folder is the output of a previous training `folder_training_output`.
        folder_training_output (str): The folder where a PyTorch variant of the Bert model and its parameters
            is stored during and after training.
        training_kwargs (dict): Keyword argument dictionary for the training, other than the `output_dir`, see
            Huggingsface `TrainingArguments`.

    
    '''
    folder_seq_raw : str = None
    seq_raw_format : str = 'csv'
    seq_raw_file_pattern : str = '*.csv'
    upper_lower : str = 'upper'
    folder_seq_sentence : str = None
    seq_sentence_prefix : str = ''
    word_length_vocab : int = 3
    stride : int = 1
    split_ratio_test : float = 0.05
    split_ratio_validate : float = 0.05
    shuffle : bool = True
    seed : int = 42
    vocab_file : str = 'vocab.txt'
    create_vocab : bool = True
    chunk_size : int = 1000
    masking_probability : float = 0.15
    bert_config_kwargs : Dict = {}
    folder_training_input : str = None
    folder_training_output : str = None
    training_kwargs : Dict = {}

Define the runtime arguments in the instance of `Arguments2DNABertTraining` below.

In [None]:
args = Arguments2DNABertTraining()

In [None]:
if args.folder_seq_sentence is None:
    raise ValueError('The folder for sequence sentence files required')

if arg.upper_lower == 'upper':
    do_upper_case = True
    do_lower_case = False
elif arg.upper_lower == 'lower':
    do_upper_case = False
    do_lower_case = True
else:
    raise ValueError('The vocabulary is either all upper or all lower, but `upper_lower` of invalid value: {}'.format(args.upper_lower))

if arg.folder_training_output is None:
    folder_training_output_ = folder_seq_sentence
else:
    folder_training_output_ = folder_training_output

### Imports and Helpers
Before the data and training starts, do a few imports and helper functiond definitions for downstream use.

In [9]:
from pathlib import Path
import random
random.seed(args.seed)

from transformers import BertForMaskedLM, BertConfig
from transformers import BertTokenizer
from transformers import Trainer, TrainingArguments
from datasets import load_datase

from biosequences.io import NucleotideSequenceProcessor
from biosequences.utils import NucleotideVocabCreator, dna_nucleotide_alphabet, Phrasifier
from biosequences.datacollators import DataCollatorDNAWithMasking

In [None]:
def _sequence_grouper(seqs, chunk_size):
    concat_seq = {k : sum(seqs[k], []) for k in seqs.keys()}
    total_length = len(concat_seq[list(seqs.keys())[0]])
    total_length = (total_length // chunk_size) * chunk_size
    result = {
        k : [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
        for k, t in concat_seq.items()
    }
    result['labels'] = result['input_ids'].copy()
    return result

In [None]:
def _compute_metrics(eval_pred):
    '''Custom metrics for evaluation step are done here.

    Args:
        eval_pred :

    Returns:
        custom_metrics (dict): Keys are the name of the custom metric, value the numberic value of said metric

    '''
    logits, labels = eval_pred
    pass
    return {}

### Process Raw Data, Chunk, Split and Tokenize
All steps to prepare the data for the training follows

Bla bla bla on how

In [10]:
if not args.folder_seq_raw is None:
    dataprocessor = NucleotideSequenceProcessor(source_directory=args.folder_seq_raw,
                                                source_file_format=args.seq_raw_format,
                                                source_directory_file_pattern=args.seq_raw_file_pattern)
    phrasifier = Phrasifier(stride=args.stride,
                            word_length=args.word_length_vocab,
                            do_upper_case=do_upper_case,
                            do_lower_case=do_lower_case)
    dataprocessor.save_as_json(save_dir=args.folder_seq_sentence,
                               save_prefix=args.seq_sentence_prefix,
                               seq_transformer=phrasifier)

['DataCollatorDNAWithMasking',
 '__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__path__',
 '__spec__',
 'datacollator_dnabert']

In [None]:
if args.create_vocab:
    dna_vocab = NucleotideVocabCreator(alphabet=dna_nucleotide_alphabet,
                                       do_lower_case=do_lower_case,
                                       do_upper_case=do_upper_case).generate(args.word_length_vocab)
    with open('{}/{}'.format(args.folder_seq_sentence, args.vocab_file), 'w') as fout:
        dna_vocab.save(fout)

tokenizer = BertTokenizer(vocab_file='{}/{}'.format(args.folder_seq_sentence, args.vocab_file), do_lower_case=do_lower_case)

In [None]:
json_files = Path(args.folder_seq_sentence).glob('{}*.json'.format(args.seq_sentence_prefix))
json_files = ['{}'.format(x.resolve()) for x in json_files]
len_train = round(len(json_files) * (1.0 - args.split_ratio_test - args.split_ratio_validate))
if len_train <= 0:
    raise ValueError('Split ratios for test and validate exceed 1.0, leaving nothing for training')
len_test = round(len(json_files) * args.split_ratio_test)
pp = list(range(len(json_files)))
if shuffle:
    random.shuffle(pp)
json_files_split = {'train' : [json_files[k] for k in pp[:len_train]]}
if len_test > 0:
    json_files_split['test'] = [json_files[k] for k in pp[len_train:len_train + len_test]]
if len(json_files) - len_train - len_test > 0:
    json_files_split['validate'] = [json_files[k] for k in pp[len_train + len_test:]]
json_files = json_files_split
seq_dataset = load_dataset('json', data_files=json_files)

In [None]:
tokenized_dataset = seq_dataset.map(
    lambda x: tokenizer(x['seq']), batched=True, remove_columns=['seq', 'id', 'name', 'description']
)

In [None]:
lm_dataset = tokenized_dataset.map(
    _sequence_grouper,
    batched=True,
    fn_kwargs={'chunk_size' : chunk_size}
)

In [None]:
print (lm_dataset)

### Configure Data Collation, Model and Trainer

In [None]:
data_collator = DataCollatorDNAWithMasking(tokenizer=tokenizer,
                                           mlm_probability=args.masking_probability,
                                           word_length=args.word_length_vocab)

In [None]:
if args.folder_training_input is None:
    config = BertConfig(vocab_size=tokenizer.vocab_size,
                        **args.bert_config_kwargs)
    model = BertForMaskedLM(config=config)
else:
    model = BertForMaskedLM.from_pretrained(args.folder_training_input)

In [None]:
training_args = TrainingArguments(
    output_dir=folder_training_output_,
    **args.training_kwargs
)
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=lm_dataset['train'],
    eval_dataset=lm_dataset['validate'],
    compute_metrics=_compute_metrics
)

### Train

In [None]:
trainer.train()
trainer.save_model(output_dir=folder_training_output)
print ('It is done.')