<a href="https://colab.research.google.com/github/ML-Bioinfo-CEITEC/genomic_benchmarks/blob/main/notebooks/How_To_Train_BERT_Classifier_With_HF.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Installation

In [1]:
!pip install -qq Bio genomic-benchmarks datasets evaluate
!pip install -qq -U accelerate
!pip install -qq -U transformers

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/279.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.2/279.4 kB[0m [31m2.6 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m279.4/279.4 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m493.7/493.7 kB[0m [31m25.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m12.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m35.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m15.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB

In [2]:
from genomic_benchmarks.loc2seq import download_dataset

download_dataset("human_nontata_promoters", version=0)

  from tqdm.autonotebook import tqdm
Downloading...
From: https://drive.google.com/uc?id=1VdUg0Zu8yfLS6QesBXwGz1PIQrTW3Ze4
To: /root/.genomic_benchmarks/human_nontata_promoters.zip
100%|██████████| 11.8M/11.8M [00:00<00:00, 24.6MB/s]


PosixPath('/root/.genomic_benchmarks/human_nontata_promoters')

## Tokenization

In [3]:
import torch
import datasets
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding

In [4]:
def kmers(s, k=6):
    return [s[i:i + k] for i in range(0, len(s), k) if i + k <= len(s)]

def kmers_stride1(s, k=6):
    return [s[i:i + k] for i in range(0, len(s)-k+1)]

kmers_stride1("ATGGAAAGAGGCACCATTCT")

['ATGGAA',
 'TGGAAA',
 'GGAAAG',
 'GAAAGA',
 'AAAGAG',
 'AAGAGG',
 'AGAGGC',
 'GAGGCA',
 'AGGCAC',
 'GGCACC',
 'GCACCA',
 'CACCAT',
 'ACCATT',
 'CCATTC',
 'CATTCT']

In [5]:
tokenizer = AutoTokenizer.from_pretrained("armheb/DNA_bert_6")

Downloading (…)okenizer_config.json:   0%|          | 0.00/40.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.45k [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/28.7k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [6]:
seq = "ATGGAAAGAGGCACCATTCT"
seq_split = " ".join(kmers_stride1(seq))

tokens = tokenizer.tokenize(seq_split)
tokens

['ATGGAA',
 'TGGAAA',
 'GGAAAG',
 'GAAAGA',
 'AAAGAG',
 'AAGAGG',
 'AGAGGC',
 'GAGGCA',
 'AGGCAC',
 'GGCACC',
 'GCACCA',
 'CACCAT',
 'ACCATT',
 'CCATTC',
 'CATTCT']

In [7]:
seq_split

'ATGGAA TGGAAA GGAAAG GAAAGA AAAGAG AAGAGG AGAGGC GAGGCA AGGCAC GGCACC GCACCA CACCAT ACCATT CCATTC CATTCT'

In [8]:
seq_tokens = tokenizer([[seq_split,seq_split]])
seq_tokens

{'input_ids': [[2, 501, 1989, 3848, 3089, 56, 212, 835, 3325, 999, 3983, 3629, 2214, 650, 2587, 2142, 3, 501, 1989, 3848, 3089, 56, 212, 835, 3325, 999, 3983, 3629, 2214, 650, 2587, 2142, 3]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}

In [9]:
tokenizer.decode(seq_tokens['input_ids'][0])

'[CLS] ATGGAA TGGAAA GGAAAG GAAAGA AAAGAG AAGAGG AGAGGC GAGGCA AGGCAC GGCACC GCACCA CACCAT ACCATT CCATTC CATTCT [SEP] ATGGAA TGGAAA GGAAAG GAAAGA AAAGAG AAGAGG AGAGGC GAGGCA AGGCAC GGCACC GCACCA CACCAT ACCATT CCATTC CATTCT [SEP]'

## Model and data

In [10]:
model_cls = AutoModelForSequenceClassification.from_pretrained("armheb/DNA_bert_6", num_labels=2)

Downloading pytorch_model.bin:   0%|          | 0.00/359M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at armheb/DNA_bert_6 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
if torch.cuda.device_count() > 0:
  model_cls.to('cuda')

In [12]:
from pathlib import Path

tmp_dict = {}

for dset in ['train', 'test']:
  for c in ['negative', 'positive']:
    for f in Path(f'/root/.genomic_benchmarks/human_nontata_promoters/{dset}/{c}/').glob('*.txt'):
      txt = f.read_text()
      tmp_dict[f.stem] = (dset, int(c == "positive"), txt)

In [13]:
import pandas as pd

df = pd.DataFrame.from_dict(tmp_dict).T.rename(columns = {0: "dset", 1: "cat", 2: "seq"})
#df.to_pickle("human_nontata_promoters.pkl")
df

Unnamed: 0,dset,cat,seq
8712,train,0,TGAGGAACTCTAGCGTACTCTTCCTGGGAATGTGGGGGCTGGGTGG...
12257,train,0,TTTGAGTAATTTTTAGTGATGGAGAATTCAAGTAAAAGAGAACAAG...
3554,train,0,CTGGAGATGCCTTTGATGGCTTTGATTTTGGCGATGATCCTAGTGA...
8680,train,0,TTTGACAGGCCAGTGGCTATAGGAGGCAAGGTAGGAACCGTCACTT...
6273,train,0,ACAAAGAGGCCCCATGCCCTCCTCCTCCACTTAGCAGACTCACCAG...
...,...,...,...
FP014174,test,1,CGGGGCGGGGCAACGCGAGCCCGCACCCCGCTCCTCCCCGCCCCTC...
FP012294,test,1,GTCGACTCCAGCCAGGCGGGGCTCCAAGCCGAGACTCCTGCACGCC...
FP016311,test,1,AAAAGGAAAATAGTGTAGACCCGCTAGGCAGGAAGAGGTCACTAAA...
FP011710,test,1,TCTTTTTGACAATAACTCTAAAACACACTTTGCTCTGTGAACGCCC...


In [14]:
from datasets import Dataset, DatasetDict

#promoters_dataset = datasets.load_dataset("pandas", data_files="human_nontata_promoters.pkl")
ds = Dataset.from_pandas(df)

In [15]:
ds[0]

{'dset': 'train',
 'cat': 0,
 'seq': 'TGAGGAACTCTAGCGTACTCTTCCTGGGAATGTGGGGGCTGGGTGGGAAGCAGCCCCGGAGATGCAGGAGCCCAGTACAGAGGATGAAGCCACTGATGGGGCTGGCTGCACATCCGTAACTGGGAGCCCTGGCTCCAAGCCCATTCCATCCCAACTCAGACTCTGAGTCTCACCCTAAGAAGTACTCTCATAGTTTCTTCCCTAAGTTTCTTACCGCATGCTTTCAGACTGGGCTCTTCTTTGTTCTCTTG',
 '__index_level_0__': '8712'}

In [16]:
def tok_func(x): return tokenizer(" ".join(kmers_stride1(x["seq"])))

tok_ds = ds.map(tok_func, batched=False)
tok_ds = tok_ds.rename_columns({'cat':'labels'})

Map:   0%|          | 0/36131 [00:00<?, ? examples/s]

In [17]:
dds = DatasetDict({
    'train': tok_ds.filter(lambda x: x["dset"] == "train"),
    'test':  tok_ds.filter(lambda x: x["dset"] == "test")
})

dds

Filter:   0%|          | 0/36131 [00:00<?, ? examples/s]

Filter:   0%|          | 0/36131 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['dset', 'labels', 'seq', '__index_level_0__', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 27097
    })
    test: Dataset({
        features: ['dset', 'labels', 'seq', '__index_level_0__', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 9034
    })
})

## Fine-tuning

In [18]:
from transformers import TrainingArguments, Trainer

bs = 32
epochs = 4
lr = 8e-5

args = TrainingArguments('outputs', learning_rate=lr, warmup_ratio=0.1, lr_scheduler_type='cosine', fp16=True,
    evaluation_strategy="epoch", per_device_train_batch_size=bs, per_device_eval_batch_size=bs*2,
    num_train_epochs=epochs, weight_decay=0.01, report_to='none')

In [19]:
import numpy as np
import evaluate

def compute_metrics(eval_preds):
    metric = evaluate.load("glue", "mrpc")
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

trainer = Trainer(model_cls, args, train_dataset=dds['train'], eval_dataset=dds['test'],
                  tokenizer=tokenizer, compute_metrics=compute_metrics)

In [20]:
trainer.train()

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Accuracy,F1
1,0.4251,0.305258,0.869161,0.872766
2,0.2556,0.210348,0.917866,0.920963
3,0.1265,0.228389,0.926389,0.930359
4,0.0741,0.31402,0.92329,0.925906


Downloading builder script:   0%|          | 0.00/5.75k [00:00<?, ?B/s]

TrainOutput(global_step=3388, training_loss=0.2115493203560605, metrics={'train_runtime': 1911.4013, 'train_samples_per_second': 56.706, 'train_steps_per_second': 1.773, 'total_flos': 1.381344551748672e+16, 'train_loss': 0.2115493203560605, 'epoch': 4.0})