# DNABERT Encoder Training
This script trains the encoder of BERT to sequences of DNA nucleotides. This is an implementation of the DNABert method, see https://www.biorxiv.org/content/biorxiv/early/2020/09/19/2020.09.17.301879.full.pdf

At a high-level, the method works as follows:
* Bla bla

### Basics of Setup
Python libraries are installed, Huggingface in particular. Then my own custom package for handling DNA sequence data, like tokenization, masking etc., is retrieved from Github. The custom package name is `biosequences`. This is followed by a simple test run of a toy example using said package. If this fails, something failed with the import or is not setup right.

In [1]:
!pip install transformers
!pip install datasets
!pip install torch

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.20.1-py3-none-any.whl (4.4 MB)
[K     |████████████████████████████████| 4.4 MB 8.0 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 70.9 MB/s 
[?25hCollecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.8.1-py3-none-any.whl (101 kB)
[K     |████████████████████████████████| 101 kB 14.8 MB/s 
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 47.1 MB/s 
Installing collected packages: pyyaml, tokenizers, huggingface-hub, transformers
  Attempting uninstall: pyyaml
    Found existing installation: PyYAML 3.13
    Uninstal

In [2]:
###!pip install git+https://github.com/anderzzz/nucleotide_transformer.git
!pip install git+https://github.com/anderzzz/nucleotide_transformer.git --upgrade

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/anderzzz/nucleotide_transformer.git
  Cloning https://github.com/anderzzz/nucleotide_transformer.git to /tmp/pip-req-build-2taz75l_
  Running command git clone -q https://github.com/anderzzz/nucleotide_transformer.git /tmp/pip-req-build-2taz75l_
Collecting biopython
  Downloading biopython-1.79-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (2.3 MB)
[K     |████████████████████████████████| 2.3 MB 6.9 MB/s 
Building wheels for collected packages: Biosequences
  Building wheel for Biosequences (setup.py) ... [?25l[?25hdone
  Created wheel for Biosequences: filename=Biosequences-0.1-py3-none-any.whl size=11488 sha256=9eb9bc3a028bd3034b9460e2549963e8e2c4faf741ebbc0681c24750a71ecd7d
  Stored in directory: /tmp/pip-ephem-wheel-cache-bh2prsu7/wheels/60/2f/a6/9834ce5fbef6eee368439e961c5b2c3ef579f466a576c20de7
Successfully built Biosequences
Installing 

In [3]:
import biosequences
SUBMODULE_LIST = ['datacollators', 'io', 'utils']
assert all([submodule in dir(biosequences) for submodule in SUBMODULE_LIST])

In [6]:
import torch
from transformers import BertTokenizer, BertTokenizerFast
from biosequences.utils import dna_nucleotide_alphabet, NucleotideVocabCreator, Phrasifier

SEQ_STR = 'AATGCGT'
IDS_SEQ_STR = [2,6,12,35,64,50,3]

dna_vocab = NucleotideVocabCreator(dna_nucleotide_alphabet, do_lower_case=True, do_upper_case=False).generate(3)
with open('tmp_test.txt', 'w') as fout:
    dna_vocab.save(fout)
phrasifier = Phrasifier(stride=1, word_length=3)
tokenizer = BertTokenizer(vocab_file='tmp_test.txt', tokenize_chinese_chars=False)
out = tokenizer(phrasifier(SEQ_STR))

assert out.input_ids == IDS_SEQ_STR

### Setup Raw Data Access
Exact approach depends on where the raw data is kept. Below access to Google Drive is provided and the single CSV file containing the raw data comprising random chunks of DNA nucleotides samples from the human genome reference sequence GRCH38, version 14. The random chunks excludes any part containing non-determined nucleotides `N`. Details of construction can be found [here](https://github.com/anderzzz/nucleotide_transformer/blob/main/scripts/nucleotide_sentence.py)

Useful alternatives to file access are documented here: https://neptune.ai/blog/google-colab-dealing-with-files

In [7]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [8]:
from pathlib import Path

RAWDATA_DIR = '/content/gdrive/MyDrive/DataRepo/DNABert/'
MONOLITHIC_FILE = 'grch38_p14_chunks.csv'
path = Path('{}/{}'.format(RAWDATA_DIR, MONOLITHIC_FILE))
assert path.is_file()
assert path.stat().st_size > 0

### Runtime Parameters
The next section defines all runtime parameters that determine what encoder is trained on what data and how.

In [9]:
import json
from dataclasses import dataclass, field, asdict
@dataclass
class Arguments2DNABertTraining:
    '''Runtime arguments for training of DNABert Encoder
    
    Args:
        folder_seq_raw (str): From where to read the raw data of nucleotide sequence chunks. If `None` the
            assumption is the processed nucleotide data is available in `folder_seq_sentence` already.
        seq_raw_format (str): File format of raw sequence chunk data. Currently CSV, GenBank and Fasta are 
            possible. 
        seq_raw_file_patter (str): The query that returns all relevant files in the `folder_seq_raw`. In case
            there is only one file, set this to that filename.
        upper_lower (str): If to assume nucleotide characters to be all upper case or all lower case. This has
            to be consistently used throughout, so both in the processing of raw data and the tokenization.
        folder_seq_sentence (str): Folder for the nucleotide sequence sentence files; this is where the
            raw sequence processing outputs its data files and where the dataset creator later reads from.
        seq_sentence_prefix (str): File prefix to use for the plurality of sequence sentence files.
        word_length_vocab (int): The number of nucleotide residues comprises a word.
        stride (int): The stride to use as a nucleotide sequence is processed into a nucleotide sentence.
        split_ratio_test (float): The ratio of data to turn into testing data.
        split_ratio_validate (float): The ratio of data to turn into validation data.
        shuffle (bool): If the data should be shuffled.
        seed (int): Random seed for data shuffling.
        vocab_file (str): Name of vocabulary file.
        create_vocab (bool): If the vocabulary file should be created; if `False` the vocabulary file is 
            assumed to already be in `folder_seq_sentence`.
        chunk_size (int): How many words to concatenate and include in a batch
        masking_probability (float): The average ratio of masked words in the data; note that the masking is 
            done in chunks of at least length `word_length_vocab`.
        bert_config_kwargs (dict): Keyword argument dictionary for the configuration of the BERT model, see 
            Huggingsface `BertConfig`.
        folder_training_input (str): The folder where a PyTorch variant of the Bert model and its parameters
            is stored and to be used as starting point; if `None`, the initial parameters are randomly
            initialized; typically this folder is the output of a previous training `folder_training_output`.
        folder_training_output (str): The folder where a PyTorch variant of the Bert model and its parameters
            is stored during and after training.
        training_kwargs (dict): Keyword argument dictionary for the training, other than the `output_dir`, see
            Huggingsface `TrainingArguments`.

    
    '''
    folder_seq_raw : str = None
    seq_raw_format : str = 'csv'
    seq_raw_file_pattern : str = '*.csv'
    upper_lower : str = 'upper'
    folder_seq_sentence : str = None
    seq_sentence_prefix : str = ''
    word_length_vocab : int = 3
    stride : int = 1
    split_ratio_test : float = 0.05
    split_ratio_validate : float = 0.05
    shuffle : bool = True
    seed : int = 42
    vocab_file : str = 'vocab.txt'
    create_vocab : bool = True
    chunk_size : int = 1000
    masking_probability : float = 0.15
    bert_config_kwargs : dict = field(default_factory=dict)
    folder_training_input : str = None
    folder_training_output : str = None
    optimizer_lr : float = 0.001
    optimizer_betas = tuple = (0.9, 0.999)
    optimizer_eps : float = 1e-08
    optimizer_weight_decay : float = 0.01
    lr_schedule_type : str = 'linear decay lr after warmup'
    lr_scheduler_kwargs : dict = field(default_factory=dict)
    training_kwargs : dict = field(default_factory=dict)

    def __repr__(self):
        return json.dumps(asdict(self), indent=4)

In order to guide the setting of parameters, type of GPU and available memory are determined. These are standard code snippets taken from a [notebook with Google Colab documentation](https://colab.research.google.com/notebooks/pro.ipynb).

In [10]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Wed Jul 20 08:30:40 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   38C    P0    26W / 250W |      2MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [11]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

Your runtime has 13.6 gigabytes of available RAM

Not using a high-RAM runtime


**User Instruction:** Define the runtime arguments in the instance of `Arguments2DNABertTraining` below.

In [12]:
args = Arguments2DNABertTraining(
    folder_seq_raw=None,
    folder_seq_sentence='/content/gdrive/MyDrive/DataRepo/DNABert/SequenceSentences',
    split_ratio_test=0.05,
    split_ratio_validate=0.05,
    chunk_size=512,
    folder_training_input='/content/gdrive/MyDrive/DataRepo/DNABert/ModelRunOutput/Pretrained_From_Paper',
    folder_training_output='/content/gdrive/MyDrive/DataRepo/DNABert/ModelRunOutput',
    bert_config_kwargs={
        'max_position_embeddings' : 512
    },
    optimizer_lr = 0.0001,
    optimizer_weight_decay = 0.001,
    lr_schedule_type = 'linear decay lr after warmup',
    lr_scheduler_kwargs={
        'n_warmup_steps' : 3000,
        'n_max_steps' : 27400000000,
        'f_lower_bound' : 1e-6
    },
    training_kwargs={
        'fp16' : True,
        'per_device_train_batch_size' : 8,
        'gradient_accumulation_steps' : 8,
        'num_train_epochs' : 20,
        'save_total_limit' : 1,
        'save_strategy' : 'no',
        'evaluation_strategy' : 'epoch'
    }
)

In [13]:
if args.folder_seq_sentence is None:
    raise ValueError('The folder for sequence sentence files required')

if args.upper_lower == 'upper':
    do_upper_case = True
    do_lower_case = False
elif args.upper_lower == 'lower':
    do_upper_case = False
    do_lower_case = True
else:
    raise ValueError('The vocabulary is either all upper or all lower, but `upper_lower` of invalid value: {}'.format(args.upper_lower))

if args.folder_training_output is None:
    folder_training_output_ = args.folder_seq_sentence
else:
    folder_training_output_ = args.folder_training_output

In [14]:
with open('{}/runtime_args.json'.format(folder_training_output_), 'w') as fout:
    print(args, file=fout)

### Imports and Helpers
Before the data and training starts, do a few imports and helper functiond definitions for downstream use.

In [15]:
from pathlib import Path
import random
import torch

from transformers import BertForMaskedLM, BertConfig
from transformers import BertTokenizer
from transformers import Trainer, TrainingArguments
from datasets import load_dataset

from biosequences.io import NucleotideSequenceProcessor
from biosequences.utils import NucleotideVocabCreator, dna_nucleotide_alphabet, Phrasifier, factory_lr_schedules
from biosequences.datacollators import DataCollatorDNAWithMasking

In [16]:
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [17]:
def _sequence_grouper(seqs, chunk_size):
    concat_seq = {k : sum(seqs[k], []) for k in seqs.keys()}
    total_length = len(concat_seq[list(seqs.keys())[0]])
    total_length = (total_length // chunk_size) * chunk_size
    result = {
        k : [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
        for k, t in concat_seq.items()
    }
    result['labels'] = result['input_ids'].copy()
    return result

In [18]:
def _compute_metrics(eval_pred):
    '''Custom metrics for evaluation step are done here.

    Args:
        eval_pred :

    Returns:
        custom_metrics (dict): Keys are the name of the custom metric, value the numberic value of said metric

    '''
    logits, labels = eval_pred
    pass
    return {}

### Process Raw Data, Chunk, Split and Tokenize
All steps to prepare the data for the training follows

Bla bla bla on how

In [19]:
if not args.folder_seq_raw is None:
    dataprocessor = NucleotideSequenceProcessor(source_directory=args.folder_seq_raw,
                                                source_file_format=args.seq_raw_format,
                                                source_directory_file_pattern=args.seq_raw_file_pattern)
    phrasifier = Phrasifier(stride=args.stride,
                            word_length=args.word_length_vocab,
                            do_upper_case=do_upper_case,
                            do_lower_case=do_lower_case)
    dataprocessor.save_as_json(save_dir=args.folder_seq_sentence,
                               save_prefix=args.seq_sentence_prefix,
                               seq_transformer=phrasifier)

In [20]:
if args.create_vocab:
    dna_vocab = NucleotideVocabCreator(alphabet=dna_nucleotide_alphabet,
                                       do_lower_case=do_lower_case,
                                       do_upper_case=do_upper_case).generate(args.word_length_vocab)
    with open('{}/{}'.format(args.folder_seq_sentence, args.vocab_file), 'w') as fout:
        dna_vocab.save(fout)

tokenizer = BertTokenizer(vocab_file='{}/{}'.format(args.folder_seq_sentence, args.vocab_file), do_lower_case=do_lower_case)

In [21]:
json_files = Path(args.folder_seq_sentence).glob('{}*.json'.format(args.seq_sentence_prefix))
json_files = ['{}'.format(x.resolve()) for x in json_files]
len_train = round(len(json_files) * (1.0 - args.split_ratio_test - args.split_ratio_validate))
if len_train <= 0:
    raise ValueError('Split ratios for test and validate exceed 1.0, leaving nothing for training')
len_test = round(len(json_files) * args.split_ratio_test)
pp = list(range(len(json_files)))
if args.shuffle:
    random.shuffle(pp)
json_files_split = {'train' : [json_files[k] for k in pp[:len_train]]}
if len_test > 0:
    json_files_split['test'] = [json_files[k] for k in pp[len_train:len_train + len_test]]
if len(json_files) - len_train - len_test > 0:
    json_files_split['validate'] = [json_files[k] for k in pp[len_train + len_test:]]
json_files = json_files_split
seq_dataset = load_dataset('json', data_files=json_files)

Resolving data files:   0%|          | 0/8762 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/487 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/487 [00:00<?, ?it/s]

Using custom data configuration default-52d87ffa69a1504f


Downloading and preparing dataset json/default to /root/.cache/huggingface/datasets/json/default-52d87ffa69a1504f/0.0.0/da492aad5680612e4028e7f6ddc04b1dfcec4b64db470ed7cc5f2bb265b9b6b5...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

0 tables [00:00, ? tables/s]

0 tables [00:00, ? tables/s]

0 tables [00:00, ? tables/s]

Dataset json downloaded and prepared to /root/.cache/huggingface/datasets/json/default-52d87ffa69a1504f/0.0.0/da492aad5680612e4028e7f6ddc04b1dfcec4b64db470ed7cc5f2bb265b9b6b5. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [22]:
tokenized_dataset = seq_dataset.map(
    lambda x: tokenizer(x['seq']), batched=True, remove_columns=['seq', 'id', 'name', 'description']
)



  0%|          | 0/9 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [23]:
lm_dataset = tokenized_dataset.map(
    _sequence_grouper,
    batched=True,
    fn_kwargs={'chunk_size' : args.chunk_size}
)

  0%|          | 0/9 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [24]:
print(lm_dataset)

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 8578
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 466
    })
    validate: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 480
    })
})


In [30]:
print(len(lm_dataset['train'][0]['input_ids']))

512


### Configure Data Collation, Model, Optimizer and Trainer

In [25]:
data_collator = DataCollatorDNAWithMasking(tokenizer=tokenizer,
                                           mlm_probability=args.masking_probability,
                                           word_length=args.word_length_vocab)

In [26]:
if args.folder_training_input is None:
    config = BertConfig(vocab_size=tokenizer.vocab_size,
                        **args.bert_config_kwargs)
    model = BertForMaskedLM(config=config)
else:
    model = BertForMaskedLM.from_pretrained(args.folder_training_input)
model = model.to(device)

In [27]:
params_to_update = []
for name, param in model.named_parameters():
    if param.requires_grad:
        params_to_update.append(param)

optimizer = torch.optim.AdamW(params_to_update,
                              lr=args.optimizer_lr, betas=args.optimizer_betas, eps=args.optimizer_eps, weight_decay=args.optimizer_weight_decay)
lr_scheduler = factory_lr_schedules.create(args.lr_schedule_type,
                                           optimizer=optimizer,
                                           **args.lr_scheduler_kwargs)

In [28]:
training_args = TrainingArguments(
    output_dir=folder_training_output_,
    seed=args.seed,
    **args.training_kwargs
)
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=lm_dataset['train'],
    eval_dataset=lm_dataset['validate'],
    compute_metrics=_compute_metrics,
    optimizers=(optimizer, lr_scheduler)
)

Using cuda_amp half precision backend


In [29]:
print('Trainer uses device: {}'.format(trainer.args.device))

Trainer uses device: cuda:0


### Train

In [None]:
trainer.train()
trainer.save_model(output_dir=args.folder_training_output)
print ('\nIt is done.')

***** Running training *****
  Num examples = 8578
  Num Epochs = 20
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 64
  Gradient Accumulation steps = 8
  Total optimization steps = 2680


Epoch,Training Loss,Validation Loss
0,No log,1.170845
1,No log,1.167228
2,No log,1.171012
3,1.202100,1.166939
4,1.202100,1.156811
5,1.202100,1.173153
6,1.202100,1.171651


***** Running Evaluation *****
  Num examples = 480
  Batch size = 8
***** Running Evaluation *****
  Num examples = 480
  Batch size = 8
***** Running Evaluation *****
  Num examples = 480
  Batch size = 8
***** Running Evaluation *****
  Num examples = 480
  Batch size = 8
***** Running Evaluation *****
  Num examples = 480
  Batch size = 8
***** Running Evaluation *****
  Num examples = 480
  Batch size = 8
***** Running Evaluation *****
  Num examples = 480
  Batch size = 8


Epoch,Training Loss,Validation Loss
0,No log,1.170845
1,No log,1.167228
2,No log,1.171012
3,1.202100,1.166939
4,1.202100,1.156811
5,1.202100,1.173153
6,1.202100,1.171651
7,1.196300,1.150903
8,1.196300,1.157822


***** Running Evaluation *****
  Num examples = 480
  Batch size = 8
***** Running Evaluation *****
  Num examples = 480
  Batch size = 8
