<a href="https://colab.research.google.com/github/ML-Bioinfo-CEITEC/genomic_benchmarks/blob/main/notebooks/How_To_Train_BERT_Classifier_With_HF.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Installation

In [2]:
!pip install -qq Bio transformers genomic-benchmarks datasets

In [3]:
from genomic_benchmarks.loc2seq import download_dataset

download_dataset("human_nontata_promoters", version=0)

  from tqdm.autonotebook import tqdm


Downloading 1VdUg0Zu8yfLS6QesBXwGz1PIQrTW3Ze4 into /root/.genomic_benchmarks/human_nontata_promoters.zip... Done.
Unzipping...Done.


PosixPath('/root/.genomic_benchmarks/human_nontata_promoters')

## Tokenization

In [4]:
import torch
import datasets 
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModelForSequenceClassification, DataCollatorWithPadding

In [5]:
def kmers(s, k=6):
    return [s[i:i + k] for i in range(0, len(s), k) if i + k <= len(s)]

In [6]:
tokenizer = AutoTokenizer.from_pretrained("armheb/DNA_bert_6")

Downloading:   0%|          | 0.00/40.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.08k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [7]:
seq = "ATGGAAAGAGGAAAGAAGAAAAGAATTTCCAATAAGTTACAACAAACTTTTCACCATTCT"
seq_split = " ".join(kmers(seq, k=6))

tokens = tokenizer.tokenize(seq_split)
tokens

['ATGGAA',
 'AGAGGA',
 'AAGAAG',
 'AAAAGA',
 'ATTTCC',
 'AATAAG',
 'TTACAA',
 'CAAACT',
 'TTTCAC',
 'CATTCT']

In [8]:
seq_split

'ATGGAA AGAGGA AAGAAG AAAAGA ATTTCC AATAAG TTACAA CAAACT TTTCAC CATTCT'

In [9]:
seq_tokens = tokenizer([[seq_split,seq_split]])
seq_tokens

{'input_ids': [[2, 501, 833, 200, 17, 351, 72, 1317, 2062, 1383, 2142, 3, 501, 833, 200, 17, 351, 72, 1317, 2062, 1383, 2142, 3]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}

In [10]:
tokenizer.decode(seq_tokens['input_ids'][0])

'[CLS] ATGGAA AGAGGA AAGAAG AAAAGA ATTTCC AATAAG TTACAA CAAACT TTTCAC CATTCT [SEP] ATGGAA AGAGGA AAGAAG AAAAGA ATTTCC AATAAG TTACAA CAAACT TTTCAC CATTCT [SEP]'

## Filling mask

In [11]:
model = AutoModelForMaskedLM.from_pretrained("armheb/DNA_bert_6")

Downloading:   0%|          | 0.00/343M [00:00<?, ?B/s]

In [12]:
model_inputs = tokenizer([seq_split + ' [MASK]'], return_tensors="pt")
model_inputs

{'input_ids': tensor([[   2,  501,  833,  200,   17,  351,   72, 1317, 2062, 1383, 2142,    4,
            3]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [None]:
model_output = model(**model_inputs)
model_output

MaskedLMOutput([('logits',
                 tensor([[[-5.9722, -4.3868, -2.2292,  ..., -0.5857, -0.4079, -0.4382],
                          [-5.4525, -2.4216, -2.7523,  ..., -4.0801, -4.0874, -2.1596],
                          [-6.3281, -3.0664, -2.5626,  ..., -4.2230, -4.1740, -4.3082],
                          ...,
                          [-5.5624, -3.2246, -2.6306,  ..., -1.7600, -1.9580, -0.6613],
                          [-5.2627, -3.1048, -2.4744,  ...,  0.0120, -0.4244,  0.6963],
                          [-5.0793, -2.7321, -2.7727,  ..., -0.9480, -1.2520,  0.5207]]],
                        grad_fn=<AddBackward0>))])

In [None]:
model_output[0].shape

torch.Size([1, 13, 4101])

In [None]:
len(tokenizer.vocab)

4101

In [None]:
best_guess = model_output[0].argmax(-1)
best_guess


tensor([[ 245,  501, 1989, 3846,   17,   53,  200,  785, 1078,  245,  773,  129,
          773]])

In [None]:
tokenizer.decode(best_guess[0]).split()[-2]

'AATGGA'

## Model and data

In [None]:
model_cls = AutoModelForSequenceClassification.from_pretrained("armheb/DNA_bert_6", num_labels=2)

Some weights of the model checkpoint at armheb/DNA_bert_6 were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at armheb/DNA_bert_6 and are n

In [None]:
if torch.cuda.device_count() > 0:
  model_cls.to('cuda')

In [None]:
from pathlib import Path

tmp_dict = {}

for dset in ['train', 'test']:
  for c in ['negative', 'positive']:
    for f in Path(f'/root/.genomic_benchmarks/human_nontata_promoters/{dset}/{c}/').glob('*.txt'):
      txt = f.read_text()
      tmp_dict[f.stem] = (dset, int(c == "positive"), txt)

In [None]:
import pandas as pd

df = pd.DataFrame.from_dict(tmp_dict).T.rename(columns = {0: "dset", 1: "cat", 2: "seq"})
df.to_pickle("human_nontata_promoters.pkl")
df

Unnamed: 0,dset,cat,seq
2076,train,0,TAAAGCTTAGTGCTTTTTATTTGAGGCAGGGTCTTGCTCTGTTGCC...
3498,train,0,TCACAGCACAGATGTTGTTTAATAATATGTTTATTTTATAAATTGA...
9914,train,0,GGGCTCCTGTGCCGTGTATGACAGCGGGGGCTACGTGCAGGAGCTG...
7456,train,0,GTTCGATCCCCGACAACCCAACCAAAACCTCTGGGCCCAAGGGTGC...
7936,train,0,AGTGACTGCCTAGTGTTAAAATCTCATTGTAACTTCTCTCTGGGCA...
...,...,...,...
FP008503,test,1,CTGCGCCCTGGGGGTGGTGATAATAACAGCTGTCACCGGGGGATGG...
FP002762,test,1,AAATATTGGCCGGTTGAAGTTTATTGCAAGGCATACGGTTGTATAA...
FP005613,test,1,TCCCTCGCCCCGCCCCTCCCCGCCTGAATCCCGGCCCCCGCCTCGC...
FP002216,test,1,CCGGGGCGAGGAGAGGGGGCTGGGGAAGAGGAGGGGGGCAAGAAAG...


## Fine-tuning

In [None]:
from datasets import Dataset, DatasetDict, load_metric

#promoters_dataset = datasets.load_dataset("pandas", data_files="human_nontata_promoters.pkl")
ds = Dataset.from_pandas(df)

In [None]:
ds[0]

{'__index_level_0__': '2076',
 'cat': 0,
 'dset': 'train',
 'seq': 'TAAAGCTTAGTGCTTTTTATTTGAGGCAGGGTCTTGCTCTGTTGCCCAGGCTGGAGTACGGTGGCGTGATCTCAGCTCATTGCATCCTCCACCTTCGTGCTCAGGTGTTTCTCCCACCTCAGCCTCCCCACTAGCTGGCACTGCAGGTGCCTGCCACCACATCCAGCTGATTTTTGATTTTTTGTAGAGACTGTTTCGCCACGTTGCCCAGGCTGATTTCAAGGAATGCTATGGTGCCTGGCCCCAGCTAA'}

In [None]:
def tok_func(x): return tokenizer(" ".join(kmers(x["seq"], k=6)))

tok_ds = ds.map(tok_func, batched=False)
tok_ds = tok_ds.rename_columns({'cat':'labels'})

  0%|          | 0/36131 [00:00<?, ?ex/s]

In [None]:
tok_ds[0]

{'__index_level_0__': '2076',
 'attention_mask': [1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1],
 'dset': 'train',
 'input_ids': [2,
  1043,
  1340,
  2394,
  353,
  3988,
  3484,
  2466,
  1965,
  4004,
  848,
  3584,
  1822,
  2283,
  381,
  1695,
  667,
  3563,
  994,
  1439,
  2219,
  939,
  2702,
  932,
  2211,
  995,
  2541,
  2598,
  2622,
  3162,
  1818,
  1401,
  3278,
  3424,
  2610,
  1965,
  4001,
  1381,
  3852,
  2340,
  1964,
  3757,
  3],
 'labels': 0,
 'seq': 'TAAAGCTTAGTGCTTTTTATTTGAGGCAGGGTCTTGCTCTGTTGCCCAGGCTGGAGTACGGTGGCGTGATCTCAGCTCATTGCATCCTCCACCTTCGTGCTCAGGTGTTTCTCCCACCTCAGCCTCCCCACTAGCTGGCACTGCAGGTGCCTGCCACCACATCCAGCTGATTTTTGATTTTTTGTAGAGACTGTTTCGCCACGTTGCCCAGGCTGATTTCAAGGAATGCTATGGTGCCTGGCCCCAGCTAA',
 'token_type_ids': [0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,

In [None]:
dds = DatasetDict({
    'train': tok_ds.filter(lambda x: x["dset"] == "train"),
    'test':  tok_ds.filter(lambda x: x["dset"] == "test")
})

dds

  0%|          | 0/37 [00:00<?, ?ba/s]

  0%|          | 0/37 [00:00<?, ?ba/s]

DatasetDict({
    train: Dataset({
        features: ['dset', 'labels', 'seq', '__index_level_0__', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 27097
    })
    test: Dataset({
        features: ['dset', 'labels', 'seq', '__index_level_0__', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 9034
    })
})

In [None]:
from transformers import TrainingArguments, Trainer

bs = 128
epochs = 4
lr = 8e-5

args = TrainingArguments('outputs', learning_rate=lr, warmup_ratio=0.1, lr_scheduler_type='cosine', fp16=True,
    evaluation_strategy="epoch", per_device_train_batch_size=bs, per_device_eval_batch_size=bs*2,
    num_train_epochs=epochs, weight_decay=0.01, report_to='none')

In [None]:
import numpy as np

def compute_metrics(eval_preds):
    metric = load_metric("glue", "mrpc")
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

trainer = Trainer(model_cls, args, train_dataset=dds['train'], eval_dataset=dds['test'],
                  tokenizer=tokenizer, compute_metrics=compute_metrics)

Using amp half precision backend


In [None]:
trainer.train();

The following columns in the training set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: __index_level_0__, dset, seq. If __index_level_0__, dset, seq are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 27097
  Num Epochs = 4
  Instantaneous batch size per device = 128
  Total train batch size (w. parallel, distributed & accumulation) = 128
  Gradient Accumulation steps = 1
  Total optimization steps = 848


Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.374325,0.830861,0.835734
2,No log,0.327532,0.858313,0.86722
3,0.375300,0.315067,0.869493,0.880873
4,0.375300,0.310684,0.869714,0.880834


The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: __index_level_0__, dset, seq. If __index_level_0__, dset, seq are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 9034
  Batch size = 256
The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: __index_level_0__, dset, seq. If __index_level_0__, dset, seq are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 9034
  Batch size = 256
Saving model checkpoint to outputs/checkpoint-500
Configuration saved in outputs/checkpoint-500/config.json
Model weights saved in outputs/checkpoint-500/pytorch_model.bin
tokenizer config file saved in outputs/checkpoint-500/tokenizer_config.j