# SETUP

In [1]:
import random
import string
run_id = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(5))
run_id

#TODO log MLM_PERCENTAGE to comet


'LX2YP'

In [2]:
import comet_ml
api_key = "TEVQbgxxvilM1WdTyqZLJ57ac"
# !export COMET_API_KEY=""
project_name="CDNA_BERT"
comet_ml.init(project_name=project_name, api_key=api_key)

COMET ERROR: Invalid Comet API key for https://www.comet-ml.com/clientlib/
You will not be able to create online Experiments
Please see https://www.comet.ml/docs/command-line/#comet-check for more information.
Use: comet_ml.init() to try again


In [3]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModelForSequenceClassification, DataCollatorForLanguageModeling, TextDataset
from transformers import DistilBertConfig, DistilBertForMaskedLM
from transformers import TrainingArguments, Trainer
from transformers import DebertaConfig, DebertaForMaskedLM


tokenizer = AutoTokenizer.from_pretrained("armheb/DNA_bert_6")


# OPTUNA HYPERPARAM TUNING

In [3]:
#Install in terminal and confirm 
# !conda install -n YOUR_CONDA_ENV optuna
# !conda install -n YOUR_CONDA_ENV -c plotly plotly=5.8.2
# !conda install -n YOUR_CONDA_ENV "jupyterlab>=3" "ipywidgets>=7.6"
# !conda install -n YOUR_CONDA_ENV scikit-learn -c conda-forge


In [4]:
from transformers import TrainerCallback
#Optional part for pruning of experiments

class PruningLogCallback(TrainerCallback):
    def __init__(self, trial):
        self.step = 0
        self.trial = trial
    
    def on_evaluate(self, args, state, control, metrics, **kwargs):
        eval_loss = metrics['eval_loss']
        current_step = self.step
        self.step = self.step+1
        self.trial.report(eval_loss, current_step)
        
        if(self.trial.should_prune()):
            raise optuna.TrialPruned()


In [9]:
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
from transformers import DataCollatorForLanguageModeling
import torch

import numpy as np

class WideCollator(DataCollatorForLanguageModeling):
    def __init__(self, area, mask_fully=False ,**kwargs):
        super().__init__(**kwargs)
        self.mask_fully=mask_fully
        self.area = area
    
    def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
        """
        import torch

        labels = inputs.clone()
        # We sample consecutive tokens
        probability_matrix = self.get_probability_matrix_wide(labels.shape, self.mlm_probability)
        if special_tokens_mask is None:
            special_tokens_mask = [
                self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
            ]
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100  # We only compute loss on masked tokens

        # 100% masking 
        if(self.mask_fully):
            indices_replaced = torch.bernoulli(torch.full(labels.shape, 1.0)).bool() & masked_indices
            inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
            return inputs, labels
            
        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels
    
    def get_probability_matrix_wide(self, labels_shape, mlm_probability):
        batch_size, seq_len = labels_shape
        masksize = int(seq_len*mlm_probability)
        masked_k = self.area
        
        #maskpercentage = overall tokens VS single tokens + dont count neighbours
        num_of_areas = masksize//masked_k
        # num_of_areas = masksize
        # print(num_of_areas, 'noa')
        result = torch.zeros((batch_size, seq_len))
        #TODO returns deterministic values!!! Fix seed?
        # print(torch.randperm(seq_len-masked_k+1)[:num_of_areas])
        #TODO indicies are not exclusively masked (may overlap = randperm from seq_len/kmer_size and then upscale?
        ind = torch.stack([torch.randperm(seq_len-masked_k+1)[:num_of_areas] for _ in range(batch_size)] )
        ind = torch.cat([ind+k for k in range(0, masked_k)],1)
        ind_wide = torch.unique(ind, dim=1)
        result.scatter_(-1, ind_wide, 1)
        # print(torch.count_nonzero(result[0]))
        # print(torch.count_nonzero(result[1]))
        # print(result[0])
        # print(result[1])
        
        return result


In [10]:
# How to define search spaces
# https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/002_configurations.html
# https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna.trial.Trial.suggest_float

log_once_every_x_sequences = 64000
crashes = []
#getting only X% of datasets for faster hyperopt demonstration
train_dset = load_dataset("simecek/Human_DNA_v0_DNABert6tokenized_stride1", split='train[:10%]')
test_dset = load_dataset("simecek/Human_DNA_v0_DNABert6tokenized_stride1", split='test[:10%]')
mask_area = 6
def objective(trial):
    num_train_epochs = 1
    hidden_layers = trial.suggest_int('num_hidden_layers', low=1, high=12, step=1)
    learning_rate = trial.suggest_float('learning_rate', low=1e-5, high=1e-1, log=True)
    weight_decay = trial.suggest_float('weight_decay', low=0, high=0.3)
    mlm_probability = trial.suggest_float('mlm_probability', low=0.05, high=0.5, step=0.05)
    batch_size = trial.suggest_categorical('batch_size', [8,16,32,64,128,256,512])
    
    logging_steps = int(log_once_every_x_sequences/batch_size)
    if(batch_size <=64):
        accumulation_steps = 1 
    else:
        accumulation_steps = batch_size/64
        batch_size = 64
    
    model_config = DebertaConfig(vocab_size=len(tokenizer.vocab), max_position_embeddings=512, num_hidden_layers=hidden_layers)
    model = DebertaForMaskedLM(config=model_config)
    model.init_weights()
    
    training_args = TrainingArguments(
            output_dir='./model',
            overwrite_output_dir=True,
            evaluation_strategy = "steps",
            save_strategy = "steps",
            learning_rate=learning_rate,
            weight_decay=weight_decay, 
            push_to_hub=False,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=8,
            gradient_accumulation_steps=accumulation_steps,
            num_train_epochs=num_train_epochs,
            save_total_limit=1,
            # load_best_model_at_end=True,
            logging_steps=logging_steps,       
            # save_steps=5000,
            fp16=True,
            # warmup_steps=1000,
    )


    
    # train_dset = load_dataset("simecek/Human_DNA_v0_DNABert6tokenized_stride1", split='train')
    # test_dset = load_dataset("simecek/Human_DNA_v0_DNABert6tokenized_stride1", split='test')
    

    data_collator = WideCollator(
        area=mask_area, tokenizer=tokenizer, mlm=True, mlm_probability=mlm_probability, mask_fully=True
    )
    trainer = Trainer(
            model=model,
            args=training_args,
            data_collator=data_collator,
            train_dataset=train_dset,
            eval_dataset=test_dset,
            callbacks=[PruningLogCallback(trial)],
    )
    
    try:
        train_loss = trainer.train().training_loss
        eval_loss = trainer.evaluate()['eval_loss']
    except Exception as e:
        crashes.append({'exception':e, 'trial':trial.number})
        raise optuna.TrialPruned()
    
    #Optimizing for validation loss
    return eval_loss



In [None]:
import optuna
import logging
import sys
# pruner doc https://optuna.readthedocs.io/en/stable/reference/generated/optuna.pruners.MedianPruner.html#optuna.pruners.MedianPruner
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))

study=optuna.create_study(
    study_name=f"{run_id}_hyperparameter_search", 
    direction='minimize', 
    pruner=optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=0) #n_startup_trials=5 as a default
)
# n_trials dictates the total number of runs (one hyperparam combination = one run)
study.optimize(func=objective, n_trials=30) #n_trials=10
print(study.best_value)
print(study.best_params)
print(study.best_trial)


In [None]:
print(study.best_params)


In [None]:
print(crashes)

In [3]:
import pickle
import optuna
file = open("../../L3UOT_HYPEROPT_study.pkl",'rb')
study = pickle.load(file)
file.close()

In [4]:
import plotly.io as pio
pio.renderers.default = "iframe"

In [5]:
print(study.best_params)

{'num_hidden_layers': 3, 'learning_rate': 7.033175701341661e-05, 'weight_decay': 0.07901034391982897, 'mlm_probability': 0.15000000000000002, 'batch_size': 8}


In [8]:
hyperparameters = ['num_hidden_layers', 'learning_rate','weight_decay', 'mlm_probability', 'batch_size']
fig = optuna.visualization.plot_parallel_coordinate(study, params=hyperparameters)
fig.show()
fig.write_image('L3UOT_parallel_coordinate.png')

In [9]:
fig = optuna.visualization.plot_param_importances(study)
fig.show()
fig.write_image('L3UOT_param_importances.png')

In [17]:
def save_object(obj, filename):
    with open(filename, 'wb') as outp:
        pickle.dump(obj, outp, pickle.HIGHEST_PROTOCOL)


save_object(study, f'{run_id}_HYPEROPT_study.pkl')