In [1]:
### Parameters
RANDOMIZE_WEIGHTS = False 

MODEL_NAME = "armheb/DNA_bert_6"
TOKENIZER_NAME = "armheb/DNA_bert_6"
K = 6
STRIDE = 1


# MODEL_NAME = "Vlasta/DNADebertaSentencepiece30k"
# TOKENIZER_NAME = "Vlasta/DNADebertaSentencepiece30k"
# K = None
# STRIDE = None

# if less than 1, only this fraction of each dataset is used
DATASET_THINING = 1

BENCHMARKS_FOLDER = '/home/jovyan/.genomic_benchmarks'
# BENCHMARKS_FOLDER = '/home/jovyan/.genomic_benchmarks/' (for INFRA HUB)

# for long-sequence datasets:
# "Token indices sequence length is longer than the specified maximum sequence length for this model (517 > 512). Running this sequence through the model will result in indexing errors"
# ('human_enhancers_cohn', 0), ('human_ensembl_regulatory', 0), ('human_ocr_ensembl', 0), ('human_enhancers_ensembl', 0)

# short-sequence datasets
# DATASETS = [ ('demo_human_or_worm', 0),('demo_coding_vs_intergenomic_seqs', 0),('human_ocr_ensembl', 0),
#   ('human_nontata_promoters', 0)]

DATASETS = [('human_ocr_ensembl', 0)]

# All datasets
# DATASETS = [('demo_coding_vs_intergenomic_seqs', 0),
#  ('demo_human_or_worm', 0), ('human_enhancers_cohn', 0), ('human_enhancers_ensembl', 0),
#  ('human_ensembl_regulatory', 0), ('human_nontata_promoters', 0), ('human_ocr_ensembl', 0)]


# if ensemble refuses connection - "[Errno 104] Connection reset by peer", use attribute use_cloud_cache=True
USE_CLOUD_CACHE = True

BATCH_SIZE = 32
LEARNING_RATE = 8e-5
EPOCHS = 100 
RUNS = 1

# do not forget to attach drive
OUTPUT_PATH = './my_test.csv'

print(DATASETS)

[('human_ocr_ensembl', 0)]


In [2]:
from transformers import TrainingArguments
from transformers import EarlyStoppingCallback

warmup_ratio = 0.1
if(RANDOMIZE_WEIGHTS):
    warmup_ratio = 0
def get_trainargs():
    return TrainingArguments(
        'outputs', 
        learning_rate=LEARNING_RATE, 
        warmup_ratio=warmup_ratio, 
        lr_scheduler_type='cosine', #TODO which one?
        fp16=True,
        evaluation_strategy="epoch", 
        per_device_train_batch_size=BATCH_SIZE, 
        per_device_eval_batch_size=BATCH_SIZE,
        num_train_epochs=EPOCHS, 
        weight_decay=0.01, #TODO increase?
        # save_steps=100000,
        save_strategy='epoch',
        seed=randrange(1,10001), 
        report_to='none',
        load_best_model_at_end=True,
    )

#TODO patience = 3/5 ?
callbacks= [
    EarlyStoppingCallback(early_stopping_patience=1, early_stopping_threshold=0.0),
]

In [3]:
from itertools import product
from transformers import AutoTokenizer


tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
if(K is not None and K>6):
    alphabet = ('A', 'C', 'T', 'G')
    vocab = list(map(''.join, product(alphabet, repeat=K)))
    tokenizer.add_tokens(vocab)

In [4]:
def kmers_strideK(s, k=K):
    return [s[i:i + k] for i in range(0, len(s), k) if i + k <= len(s)]

def kmers_stride1(s, k=K):
    return [s[i:i + k] for i in range(0, len(s)-k+1)]

if (STRIDE == 1):
  kmers = kmers_stride1
else:
  kmers = kmers_strideK

# function used for the actual tokenization
if(K is not None):
    def tok_func(x): return tokenizer(" ".join(kmers(x["seq"])), truncation=True)
else:
    def tok_func(x): return tokenizer(x["seq"], truncation=True)

# example
example = tok_func({'seq': 'ATGGAAAGAGGCACCATTCT'})    
print(example)
tokenizer.decode(example['input_ids'])

{'input_ids': [2, 501, 1989, 3848, 3089, 56, 212, 835, 3325, 999, 3983, 3629, 2214, 650, 2587, 2142, 3], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}


'[CLS] ATGGAA TGGAAA GGAAAG GAAAGA AAAGAG AAGAGG AGAGGC GAGGCA AGGCAC GGCACC GCACCA CACCAT ACCATT CCATTC CATTCT [SEP]'

## Download benchmark datasets and tokenizer

In [5]:
from genomic_benchmarks.loc2seq import download_dataset
from genomic_benchmarks.data_check.info import is_downloaded
from pathlib import Path
from tqdm.autonotebook import tqdm

for dataset_name, dataset_version in tqdm(DATASETS):
    if not is_downloaded(dataset_name):
        download_dataset(dataset_name, version=dataset_version, use_cloud_cache=USE_CLOUD_CACHE)

benchmark_root = Path(BENCHMARKS_FOLDER)

  0%|          | 0/1 [00:00<?, ?it/s]

## Function to extract dataframe metrics row from training logs

In [6]:
def get_log_from_history(history, dataset_name):
    eval_dicts = [x for x in history if 'eval_loss' in x]
    test_dicts = [x for x in history if 'test_loss' in x]
    test_log = test_dicts[0]
    test_acc = test_log['test_accuracy']
    test_f1 = test_log['test_f1']
    test_loss = test_log['test_loss']
    min_loss_dict = min(eval_dicts, key=lambda x: x['eval_loss'])
    max_f1_dict = max(eval_dicts, key=lambda x: x['eval_f1'])
    max_acc_dict = max(eval_dicts, key=lambda x: x['eval_accuracy'])
    row = {
        'dataset':dataset_name,
        'test_acc':test_acc,
        'test_f1':test_f1,
        'test_loss':test_loss,
        'min_valid_loss_log':min_loss_dict,
        'max_valid_f1_log':max_f1_dict,
        'max_valid_acc_log':max_acc_dict,
    }
    return row

## Looping through datasets, fine-tuning the model for each of them, logging metrics

In [None]:
import pandas as pd
import numpy as np
from random import random, randrange
from transformers import AutoModelForSequenceClassification
from transformers import TrainingArguments, Trainer
from datasets import Dataset, DatasetDict, load_metric

def compute_metrics_binary(eval_preds):
    metric = load_metric("glue", "mrpc")
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

def compute_metrics_multi(eval_preds):
    metric = load_metric("accuracy")
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

outputs = []

for dataset_name, dataset_version in tqdm(DATASETS):
    

    labels = sorted([x.stem for x in (benchmark_root / dataset_name / 'train').iterdir()])

    tmp_dict = {}

    for split in ['train', 'test']:
        for nlabel, label in enumerate(labels):
            for f in (benchmark_root / dataset_name / split / label).glob('*.txt'):
                txt = f.read_text()
                if not DATASET_THINING or DATASET_THINING==1:
                    tmp_dict[f"{label} {f.stem}"] = (split, nlabel, txt)
                elif random() < DATASET_THINING:
                    tmp_dict[f"{label} {f.stem}"] = (split, nlabel, txt)

    df = pd.DataFrame.from_dict(tmp_dict).T.rename(columns = {0: "dset", 1: "cat", 2: "seq"})

    ds = Dataset.from_pandas(df)

    tok_ds = ds.map(tok_func, batched=False, remove_columns=['__index_level_0__', 'seq'])
    tok_ds = tok_ds.rename_columns({'cat':'labels'})

    dds = DatasetDict({
        'train': tok_ds.filter(lambda x: x["dset"] == "train").remove_columns('dset'),
        'test':  tok_ds.filter(lambda x: x["dset"] == "test").remove_columns('dset')
    })
    train_valid_split = dds['train'].train_test_split(test_size=0.2, shuffle=True, seed=42)
    dds['train']=train_valid_split['train']
    dds['valid']=train_valid_split['test']

    compute_metrics = compute_metrics_binary if len(labels) == 2 else compute_metrics_multi

    for _ in range(RUNS):

        model_cls = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=len(labels))
        if(RANDOMIZE_WEIGHTS):
            model_cls.init_weights()
        args = get_trainargs()
        
        trainer = Trainer(model_cls, args, train_dataset=dds['train'], eval_dataset=dds['valid'],
                          tokenizer=tokenizer, compute_metrics=compute_metrics, 
                          callbacks=callbacks)
        trainer.train()
        trainer.evaluate(dds['test'], metric_key_prefix='test')
        training_log = get_log_from_history(trainer.state.log_history, dataset_name=dataset_name)
        outputs.append(training_log)
  

  0%|          | 0/1 [00:00<?, ?it/s]

## Outputs

In [None]:
outputs_df = pd.DataFrame(outputs)
outputs_df

In [None]:
# outputs_df.groupby('dataset').agg({'accuracy' : ['mean', 'sem'], 'f1' : ['mean','sem'], 'train_runtime': ['mean', 'sem']})

In [None]:
# saving outputs to csv file
outputs_df.to_csv(OUTPUT_PATH, index=False)