<a href="https://colab.research.google.com/github/ML-Bioinfo-CEITEC/cDNA-pretraining/blob/main/notebooks/How_To_Train_BERT_Classifier_With_HF.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Installation

In [1]:
!pip install -qq Bio transformers genomic-benchmarks datasets

[K     |████████████████████████████████| 270 kB 8.0 MB/s 
[K     |████████████████████████████████| 4.4 MB 61.5 MB/s 
[K     |████████████████████████████████| 362 kB 66.5 MB/s 
[K     |████████████████████████████████| 2.3 MB 51.6 MB/s 
[K     |████████████████████████████████| 86 kB 6.0 MB/s 
[K     |████████████████████████████████| 6.6 MB 67.2 MB/s 
[K     |████████████████████████████████| 596 kB 68.9 MB/s 
[K     |████████████████████████████████| 271 kB 75.5 MB/s 
[K     |████████████████████████████████| 212 kB 45.0 MB/s 
[K     |████████████████████████████████| 1.1 MB 47.6 MB/s 
[K     |████████████████████████████████| 140 kB 62.6 MB/s 
[K     |████████████████████████████████| 127 kB 54.9 MB/s 
[K     |████████████████████████████████| 144 kB 68.6 MB/s 
[K     |████████████████████████████████| 94 kB 4.0 MB/s 
[?25h  Building wheel for genomic-benchmarks (setup.py) ... [?25l[?25hdone
[31mERROR: pip's dependency resolver does not currently take into accoun

In [2]:
from genomic_benchmarks.loc2seq import download_dataset

download_dataset("human_nontata_promoters", version=0)

  from tqdm.autonotebook import tqdm
Downloading...
From: https://drive.google.com/uc?id=1VdUg0Zu8yfLS6QesBXwGz1PIQrTW3Ze4
To: /root/.genomic_benchmarks/human_nontata_promoters.zip
100%|██████████| 11.8M/11.8M [00:01<00:00, 10.2MB/s]


PosixPath('/root/.genomic_benchmarks/human_nontata_promoters')

## Tokenization

In [3]:
import torch
import datasets 
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding

In [8]:
def kmers(s, k=6):
    return [s[i:i + k] for i in range(0, len(s), k) if i + k <= len(s)]

def kmers_stride1(s, k=6):
    return [s[i:i + k] for i in range(0, len(s)-k+1)]

kmers_stride1("ATGGAAAGAGGCACCATTCT")    

['ATGGAA',
 'TGGAAA',
 'GGAAAG',
 'GAAAGA',
 'AAAGAG',
 'AAGAGG',
 'AGAGGC',
 'GAGGCA',
 'AGGCAC',
 'GGCACC',
 'GCACCA',
 'CACCAT',
 'ACCATT',
 'CCATTC',
 'CATTCT']

In [9]:
tokenizer = AutoTokenizer.from_pretrained("armheb/DNA_bert_6")

In [11]:
seq = "ATGGAAAGAGGCACCATTCT"
seq_split = " ".join(kmers_stride1(seq))

tokens = tokenizer.tokenize(seq_split)
tokens

['ATGGAA',
 'TGGAAA',
 'GGAAAG',
 'GAAAGA',
 'AAAGAG',
 'AAGAGG',
 'AGAGGC',
 'GAGGCA',
 'AGGCAC',
 'GGCACC',
 'GCACCA',
 'CACCAT',
 'ACCATT',
 'CCATTC',
 'CATTCT']

In [12]:
seq_split

'ATGGAA TGGAAA GGAAAG GAAAGA AAAGAG AAGAGG AGAGGC GAGGCA AGGCAC GGCACC GCACCA CACCAT ACCATT CCATTC CATTCT'

In [13]:
seq_tokens = tokenizer([[seq_split,seq_split]])
seq_tokens

{'input_ids': [[2, 501, 1989, 3848, 3089, 56, 212, 835, 3325, 999, 3983, 3629, 2214, 650, 2587, 2142, 3, 501, 1989, 3848, 3089, 56, 212, 835, 3325, 999, 3983, 3629, 2214, 650, 2587, 2142, 3]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}

In [14]:
tokenizer.decode(seq_tokens['input_ids'][0])

'[CLS] ATGGAA TGGAAA GGAAAG GAAAGA AAAGAG AAGAGG AGAGGC GAGGCA AGGCAC GGCACC GCACCA CACCAT ACCATT CCATTC CATTCT [SEP] ATGGAA TGGAAA GGAAAG GAAAGA AAAGAG AAGAGG AGAGGC GAGGCA AGGCAC GGCACC GCACCA CACCAT ACCATT CCATTC CATTCT [SEP]'

## Model and data

In [15]:
model_cls = AutoModelForSequenceClassification.from_pretrained("armheb/DNA_bert_6", num_labels=2)

Downloading:   0%|          | 0.00/343M [00:00<?, ?B/s]

Some weights of the model checkpoint at armheb/DNA_bert_6 were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at armheb/DNA_bert_6 and are n

In [16]:
if torch.cuda.device_count() > 0:
  model_cls.to('cuda')

In [17]:
from pathlib import Path

tmp_dict = {}

for dset in ['train', 'test']:
  for c in ['negative', 'positive']:
    for f in Path(f'/root/.genomic_benchmarks/human_nontata_promoters/{dset}/{c}/').glob('*.txt'):
      txt = f.read_text()
      tmp_dict[f.stem] = (dset, int(c == "positive"), txt)

In [18]:
import pandas as pd

df = pd.DataFrame.from_dict(tmp_dict).T.rename(columns = {0: "dset", 1: "cat", 2: "seq"})
#df.to_pickle("human_nontata_promoters.pkl")
df

Unnamed: 0,dset,cat,seq
2038,train,0,CAAAGGGATCGATAAGCAGAGACCCCATGCTTCAGATCAAGAGCCT...
15715,train,0,CTCCCTCCACACCAGTCTCTACACTGCTGCCACAGTGATCTTTCTA...
12370,train,0,CCCAGGCAGGGAGAGGCCAGGGAGCCAAGAGTTTGAACCCAGTGCC...
3384,train,0,TGGACTAAACAAACAACAATCTTTTTAGAGGCAATCCCCACTTTCA...
9182,train,0,TGGTAGGTTTTCAGAGATTTTTAATGAAAAATTAAAAAAATTCCAG...
...,...,...,...
FP006168,test,1,CTACCATTAGAGGGAGATCTCCGAGCGCACACGGGAGCTCTTTCCC...
FP000845,test,1,ACAAGTATGCTTTCGCTTTAGGTAGGGCATTTGAGAGCAAAATGTA...
FP006398,test,1,GGGACTGCCCAGGGGGTTCCGAGATTCCTTCTCCCCTCCTATCACC...
FP001982,test,1,AAAATGGGCAAAGTACAAGAATAAGCAAAGAGTGAATAAATACAAA...


In [19]:
from datasets import Dataset, DatasetDict, load_metric

#promoters_dataset = datasets.load_dataset("pandas", data_files="human_nontata_promoters.pkl")
ds = Dataset.from_pandas(df)

In [20]:
ds[0]

{'__index_level_0__': '2038',
 'cat': 0,
 'dset': 'train',
 'seq': 'CAAAGGGATCGATAAGCAGAGACCCCATGCTTCAGATCAAGAGCCTGATGAAAGTAGTTCAAAGATGCGATGCCCTTTCTCACCATCCCTTTCCAGAAATATGAACAGGGATTCATCACAGACCCTGTGGTCCTCAGCCCCAAGGATCGCGTGCGGGATGTTTTTGAGGCCAAGGCCCGGCATGGTTTCTGCGGTATCCCAATCACAGACACAGGCCGGATGGGGAGCCGCTTGGTGGGCATCATCTCCTC'}

In [21]:
def tok_func(x): return tokenizer(" ".join(kmers_stride1(x["seq"])))

tok_ds = ds.map(tok_func, batched=False)
tok_ds = tok_ds.rename_columns({'cat':'labels'})



  0%|          | 0/36131 [00:00<?, ?ex/s]

In [22]:
tok_ds[0]

{'__index_level_0__': '2038',
 'attention_mask': [1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  

In [23]:
dds = DatasetDict({
    'train': tok_ds.filter(lambda x: x["dset"] == "train"),
    'test':  tok_ds.filter(lambda x: x["dset"] == "test")
})

dds

  0%|          | 0/37 [00:00<?, ?ba/s]

  0%|          | 0/37 [00:00<?, ?ba/s]

DatasetDict({
    train: Dataset({
        features: ['dset', 'labels', 'seq', '__index_level_0__', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 27097
    })
    test: Dataset({
        features: ['dset', 'labels', 'seq', '__index_level_0__', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 9034
    })
})

## Fine-tuning

In [24]:
from transformers import TrainingArguments, Trainer

bs = 32
epochs = 4
lr = 8e-5

args = TrainingArguments('outputs', learning_rate=lr, warmup_ratio=0.1, lr_scheduler_type='cosine', fp16=True,
    evaluation_strategy="epoch", per_device_train_batch_size=bs, per_device_eval_batch_size=bs*2,
    num_train_epochs=epochs, weight_decay=0.01, report_to='none')

In [25]:
import numpy as np

def compute_metrics(eval_preds):
    metric = load_metric("glue", "mrpc")
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

trainer = Trainer(model_cls, args, train_dataset=dds['train'], eval_dataset=dds['test'],
                  tokenizer=tokenizer, compute_metrics=compute_metrics)

Using cuda_amp half precision backend


In [26]:
trainer.train();

The following columns in the training set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: seq, dset, __index_level_0__. If seq, dset, __index_level_0__ are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 27097
  Num Epochs = 4
  Instantaneous batch size per device = 128
  Total train batch size (w. parallel, distributed & accumulation) = 128
  Gradient Accumulation steps = 1
  Total optimization steps = 848


RuntimeError: ignored