# DEFAULT COLLATOR

In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModelForSequenceClassification, DataCollatorForLanguageModeling, TextDataset
from transformers import DistilBertConfig, DistilBertForMaskedLM
from transformers import TrainingArguments, Trainer
from transformers import DebertaConfig, DebertaForMaskedLM

tokenizer = AutoTokenizer.from_pretrained("armheb/DNA_bert_6")
train_dset = load_dataset("simecek/Human_DNA_v0_DNABert6tokenized_stride1", split='train[:10%]')
test_dset = load_dataset("simecek/Human_DNA_v0_DNABert6tokenized_stride1", split='test[:10%]')
model_config = DebertaConfig(vocab_size=len(tokenizer.vocab), max_position_embeddings=512, num_hidden_layers=1)
model = DebertaForMaskedLM(config=model_config)
model.init_weights()
training_args = TrainingArguments(
    output_dir='./model',
    overwrite_output_dir=True,
    evaluation_strategy = "steps",
    save_strategy = "steps",
    learning_rate=5e-5,
    weight_decay=0, 
    push_to_hub=False,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=1,
    num_train_epochs=1,
    save_total_limit=1,
    # load_best_model_at_end=True,
    logging_steps=1000,       
    # save_steps=5000,
    fp16=True,
    # warmup_steps=1000,
)
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)

Using custom data configuration simecek--Human_DNA_v0_DNABert6tokenized_stride1-43a5a14a7a9b8d0a
Reusing dataset parquet (/home/jovyan/.cache/huggingface/datasets/simecek___parquet/simecek--Human_DNA_v0_DNABert6tokenized_stride1-43a5a14a7a9b8d0a/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901)
Using custom data configuration simecek--Human_DNA_v0_DNABert6tokenized_stride1-43a5a14a7a9b8d0a
Reusing dataset parquet (/home/jovyan/.cache/huggingface/datasets/simecek___parquet/simecek--Human_DNA_v0_DNABert6tokenized_stride1-43a5a14a7a9b8d0a/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901)


In [2]:
trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dset,
        eval_dataset=test_dset,
)

Using amp half precision backend


In [3]:
for sample in trainer.get_train_dataloader():
    print(sample.keys())
    print(sample['input_ids'][0])
    print(tokenizer.decode(sample['input_ids'][0]))
    print(sample['labels'][0])
    # print(tokenizer.decode(sample['labels'][0]))
    break

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
tensor([ 867, 3455, 1519, 1966, 3755,    4, 2662, 2444, 1571, 2174,  492, 1954,
        3706, 2523, 1886, 3435, 1437, 1638, 2444, 1569,    4,  457, 1816, 3154,
         316,    4,  885,    4, 1804, 3107,  125,  486, 1930, 3611, 2142,  363,
        1437, 1639, 2448, 1585, 2232,  721, 2870, 3275,  798, 3180,    4, 1654,
           4, 1828, 3202,  506, 2010, 3930, 3417,    4, 1359, 1325, 1192,  660,
        2628, 2308, 1027, 4094, 4074, 3994, 3674, 2395, 1375,    4, 1455, 1711,
        2734, 2730, 2714, 2650, 2396, 1379, 1406, 1515, 1949, 3688, 2450,    4,
        2261,  838, 3338, 1051,   94,  362,    4, 1635, 2430, 1516, 1955,    4,
        2541, 1959, 3727,    4, 2214,  652,    4, 2172,    4, 1909, 3528, 1809,
        3125,    4,  788, 3137,  246,    4, 3875, 3197,  486, 1930, 3610, 2140,
         355, 1406, 1514, 2869, 3675, 2399, 1391, 1454,    4, 2711, 2639, 2349,
           4,    4, 2605, 2214,  652, 2593, 2166,

In [4]:
data_collator

DataCollatorForLanguageModeling(tokenizer=PreTrainedTokenizerFast(name_or_path='armheb/DNA_bert_6', vocab_size=4101, model_max_len=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}), mlm=True, mlm_probability=0.15, pad_to_multiple_of=None, tf_experimental_compile=False, return_tensors='pt')

# CUSTOM COLLATOR

In [5]:
from experiments.custom_masking.custom_collator import SubsequentCollator
    
myCollator=SubsequentCollator(tokenizer=tokenizer, mlm=True, mlm_probability=0.15, mask_fully=True)

In [6]:
myCollator

SubsequentCollator(tokenizer=PreTrainedTokenizerFast(name_or_path='armheb/DNA_bert_6', vocab_size=4101, model_max_len=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}), mlm=True, mlm_probability=0.15, pad_to_multiple_of=None, tf_experimental_compile=False, return_tensors='pt')

In [7]:
trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=myCollator,
        train_dataset=train_dset,
        eval_dataset=test_dset,
)

Using amp half precision backend


In [8]:
for sample in trainer.get_train_dataloader():
    print(sample.keys())
    print(sample['input_ids'][0])
    print(tokenizer.decode(sample['input_ids'][0]))
    print(sample['labels'][0])
    # print(tokenizer.decode(sample['labels'][0]))
    break

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
tensor([ 867, 3455, 1519, 1966, 3755, 2717, 2662, 2444, 1571, 2174,  492, 1954,
        3706, 2523, 1886, 3435, 1437, 1638, 2444, 1569, 2166,  457, 1816, 3154,
         316, 1249,  885, 3526, 1804, 3107,  125,  486, 1930, 3611, 2142,  363,
        1437, 1639, 2448, 1585, 2232,  721, 2870, 3275,  798, 3180,  417, 1654,
        2508, 1828, 3202,  506, 2010, 3930, 3417, 1367, 1359, 1325, 1192,  660,
        2628, 2308, 1027, 4094, 4074, 3994, 3674, 2395, 1375, 1391, 1455, 1711,
        2734, 2730, 2714, 2650, 2396, 1379, 1406, 1515, 1949, 3688, 2450, 1593,
        2261,  838, 3338, 1051,   94,  362, 1436, 1635, 2430, 1516, 1955, 3711,
        2541, 1959, 3727, 2605, 2214,  652,    4,    4,    4,    4,    4,    4,
           4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,
           4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,    4,
           4,    4,    4,    4,    4,    4,    4,