In [1]:
from transformers import AutoConfig, AutoTokenizer, AutoModelForMaskedLM

config = AutoConfig.from_pretrained('vinai/bertweet-base')
tokenizer = AutoTokenizer.from_pretrained('vinai/bertweet-base')
# model = AutoModelForMaskedLM.from_config(config)
model = AutoModelForMaskedLM.from_pretrained('vinai/bertweet-base')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [2]:
from glob import glob
from datasets import Dataset

import re
import pandas as pd

from parse import mask_data_loading

data_url = '../crawler/stock/data/**.json'
url = glob(data_url)[-1]
data, symbols = mask_data_loading(url, tokenizer, symbol_mask=False)

dataset = Dataset.from_pandas(data.loc[:, ['labels', 'sentense']])
dataset = dataset.remove_columns('__index_level_0__')
dataset = dataset.train_test_split(test_size=0.2)

train_dataset = dataset['train']
test_dataset = dataset['test']
# dataset = dataset.shuffle().select(range(50000))

special_tokens_dict = {'additional_special_tokens': list(symbols)}
tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))

Embedding(65862, 768)

In [3]:
# tokenizer.save_pretrained('./symbol-vocab')

In [4]:
def encode(example):
    sentense = example['sentense']
    label = example['labels']

    result = tokenizer(sentense, padding=True, truncation=True)

    return result

context_length = 128
def tokenize(element):
    outputs = tokenizer(
        element["sentense"],
        padding="max_length",
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
    )
    input_batch = []
    for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
        if length <= context_length:
            input_batch.append(input_ids)
    return {"input_ids": input_batch}

encoded_train_dataset = train_dataset.map(tokenize, batched=True, remove_columns=train_dataset.column_names)
encoded_test_dataset = test_dataset.map(tokenize, batched=True, remove_columns=test_dataset.column_names)

100%|██████████| 110/110 [01:00<00:00,  1.82ba/s]
100%|██████████| 28/28 [00:14<00:00,  1.95ba/s]


In [5]:
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.1)

samples = encoded_train_dataset[:2]
a = data_collator(samples['input_ids'])
print(tokenizer.decode(a['input_ids'][0]))

<s> $TSLA 2k on tsla today thanks bears. Started with <unk> 85 <unk> <unk> </s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>


# Training

In [6]:
import numpy as np
from datasets import load_metric

metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [7]:
from transformers import TrainingArguments, Trainer, Seq2SeqTrainingArguments, Seq2SeqTrainer
# Seq2SeqTrainer
training_args = Seq2SeqTrainingArguments(
    output_dir="after-bert-random-trainer",
    per_device_train_batch_size=4,
    num_train_epochs=5,
    save_steps=50000
)


trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=encoded_train_dataset,
    eval_dataset=encoded_test_dataset,
    compute_metrics=compute_metrics,
)

trainer.train()

***** Running training *****
  Num examples = 109456
  Num Epochs = 5
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 4
  Gradient Accumulation steps = 1
  Total optimization steps = 136820
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33malan8365[0m (use `wandb login --relogin` to force relogin)


  0%|          | 501/136820 [01:10<7:18:10,  5.19it/s]

{'loss': 5.443, 'learning_rate': 4.9817278175705304e-05, 'epoch': 0.02}


  1%|          | 1001/136820 [02:18<7:16:32,  5.19it/s]

{'loss': 3.9724, 'learning_rate': 4.963455635141062e-05, 'epoch': 0.04}


  1%|          | 1501/136820 [03:27<7:23:25,  5.09it/s]

{'loss': 3.867, 'learning_rate': 4.945183452711592e-05, 'epoch': 0.05}


  1%|▏         | 2001/136820 [04:36<7:37:43,  4.91it/s]

{'loss': 3.8019, 'learning_rate': 4.926911270282123e-05, 'epoch': 0.07}


  2%|▏         | 2501/136820 [05:45<7:12:38,  5.17it/s]

{'loss': 3.7361, 'learning_rate': 4.908639087852653e-05, 'epoch': 0.09}


  2%|▏         | 3001/136820 [06:54<7:09:20,  5.19it/s]

{'loss': 3.6361, 'learning_rate': 4.890366905423184e-05, 'epoch': 0.11}


  3%|▎         | 3501/136820 [08:03<7:10:58,  5.16it/s]

{'loss': 3.6529, 'learning_rate': 4.872094722993715e-05, 'epoch': 0.13}


  3%|▎         | 4001/136820 [09:12<7:07:20,  5.18it/s]

{'loss': 3.7761, 'learning_rate': 4.853822540564245e-05, 'epoch': 0.15}


  3%|▎         | 4501/136820 [10:19<6:37:57,  5.54it/s]

{'loss': 3.7371, 'learning_rate': 4.835550358134776e-05, 'epoch': 0.16}


  4%|▎         | 5001/136820 [11:24<6:40:03,  5.49it/s]

{'loss': 3.7074, 'learning_rate': 4.8172781757053066e-05, 'epoch': 0.18}


  4%|▍         | 5501/136820 [12:28<6:39:49,  5.47it/s]

{'loss': 3.6831, 'learning_rate': 4.799005993275837e-05, 'epoch': 0.2}


  4%|▍         | 6001/136820 [13:32<6:34:42,  5.52it/s]

{'loss': 3.6528, 'learning_rate': 4.7807338108463676e-05, 'epoch': 0.22}


  5%|▍         | 6501/136820 [14:36<6:50:38,  5.29it/s]

{'loss': 3.7054, 'learning_rate': 4.7624616284168985e-05, 'epoch': 0.24}


  5%|▌         | 7001/136820 [15:40<6:28:14,  5.57it/s]

{'loss': 3.5205, 'learning_rate': 4.744189445987429e-05, 'epoch': 0.26}


  5%|▌         | 7501/136820 [16:44<6:28:16,  5.55it/s]

{'loss': 3.5635, 'learning_rate': 4.7259172635579595e-05, 'epoch': 0.27}


  6%|▌         | 8001/136820 [17:50<6:38:00,  5.39it/s]

{'loss': 3.5766, 'learning_rate': 4.70764508112849e-05, 'epoch': 0.29}


  6%|▌         | 8501/136820 [18:56<6:37:30,  5.38it/s]

{'loss': 3.6503, 'learning_rate': 4.689372898699021e-05, 'epoch': 0.31}


  7%|▋         | 9001/136820 [20:01<6:34:51,  5.40it/s]

{'loss': 3.6347, 'learning_rate': 4.671100716269551e-05, 'epoch': 0.33}


  7%|▋         | 9501/136820 [21:06<6:33:06,  5.40it/s]

{'loss': 3.5538, 'learning_rate': 4.652828533840082e-05, 'epoch': 0.35}


  7%|▋         | 10001/136820 [22:12<6:32:55,  5.38it/s]

{'loss': 3.5196, 'learning_rate': 4.634556351410613e-05, 'epoch': 0.37}


  8%|▊         | 10501/136820 [23:17<6:27:13,  5.44it/s]

{'loss': 3.5973, 'learning_rate': 4.616284168981143e-05, 'epoch': 0.38}


  8%|▊         | 11001/136820 [24:23<6:30:04,  5.38it/s]

{'loss': 3.5737, 'learning_rate': 4.598011986551674e-05, 'epoch': 0.4}


  8%|▊         | 11501/136820 [25:28<6:27:17,  5.39it/s]

{'loss': 3.5882, 'learning_rate': 4.579739804122204e-05, 'epoch': 0.42}


  9%|▉         | 12001/136820 [26:33<6:21:50,  5.45it/s]

{'loss': 3.6061, 'learning_rate': 4.561467621692736e-05, 'epoch': 0.44}


  9%|▉         | 12501/136820 [27:39<6:23:39,  5.40it/s]

{'loss': 3.597, 'learning_rate': 4.543195439263266e-05, 'epoch': 0.46}


 10%|▉         | 13001/136820 [28:44<6:18:37,  5.45it/s]

{'loss': 3.4856, 'learning_rate': 4.524923256833796e-05, 'epoch': 0.48}


 10%|▉         | 13501/136820 [29:50<6:22:10,  5.38it/s]

{'loss': 3.5309, 'learning_rate': 4.5066510744043275e-05, 'epoch': 0.49}


 10%|█         | 14001/136820 [30:55<6:19:13,  5.40it/s]

{'loss': 3.5604, 'learning_rate': 4.488378891974858e-05, 'epoch': 0.51}


 11%|█         | 14501/136820 [32:00<6:15:14,  5.43it/s]

{'loss': 3.583, 'learning_rate': 4.4701067095453885e-05, 'epoch': 0.53}


 11%|█         | 15001/136820 [33:05<5:51:01,  5.78it/s]

{'loss': 3.6298, 'learning_rate': 4.451834527115919e-05, 'epoch': 0.55}


 11%|█▏        | 15501/136820 [34:07<5:54:17,  5.71it/s]

{'loss': 3.595, 'learning_rate': 4.4335623446864495e-05, 'epoch': 0.57}


 12%|█▏        | 16001/136820 [35:08<5:48:50,  5.77it/s]

{'loss': 3.5033, 'learning_rate': 4.4152901622569804e-05, 'epoch': 0.58}


 12%|█▏        | 16501/136820 [36:09<5:45:08,  5.81it/s]

{'loss': 3.6394, 'learning_rate': 4.3970179798275105e-05, 'epoch': 0.6}


 12%|█▏        | 17001/136820 [37:10<5:47:22,  5.75it/s]

{'loss': 3.5268, 'learning_rate': 4.3787457973980414e-05, 'epoch': 0.62}


 13%|█▎        | 17501/136820 [38:11<5:46:56,  5.73it/s]

{'loss': 3.7731, 'learning_rate': 4.360473614968572e-05, 'epoch': 0.64}


 13%|█▎        | 18001/136820 [39:12<5:40:56,  5.81it/s]

{'loss': 3.4954, 'learning_rate': 4.3422014325391024e-05, 'epoch': 0.66}


 14%|█▎        | 18501/136820 [40:13<5:39:27,  5.81it/s]

{'loss': 3.5536, 'learning_rate': 4.323929250109633e-05, 'epoch': 0.68}


 14%|█▍        | 19001/136820 [41:14<5:53:46,  5.55it/s]

{'loss': 3.5273, 'learning_rate': 4.305657067680164e-05, 'epoch': 0.69}


 14%|█▍        | 19501/136820 [42:15<5:35:34,  5.83it/s]

{'loss': 3.5214, 'learning_rate': 4.287384885250695e-05, 'epoch': 0.71}


 15%|█▍        | 20001/136820 [43:16<5:33:54,  5.83it/s]

{'loss': 3.6483, 'learning_rate': 4.269112702821225e-05, 'epoch': 0.73}


 15%|█▍        | 20501/136820 [44:17<5:33:36,  5.81it/s]

{'loss': 3.507, 'learning_rate': 4.250840520391756e-05, 'epoch': 0.75}


 15%|█▌        | 21001/136820 [45:18<5:34:13,  5.78it/s]

{'loss': 3.4843, 'learning_rate': 4.232568337962287e-05, 'epoch': 0.77}


 16%|█▌        | 21501/136820 [46:19<5:31:28,  5.80it/s]

{'loss': 3.6152, 'learning_rate': 4.214296155532817e-05, 'epoch': 0.79}


 16%|█▌        | 22001/136820 [47:20<5:31:30,  5.77it/s]

{'loss': 3.5135, 'learning_rate': 4.196023973103348e-05, 'epoch': 0.8}


 16%|█▋        | 22501/136820 [48:21<5:32:13,  5.73it/s]

{'loss': 3.5631, 'learning_rate': 4.1777517906738786e-05, 'epoch': 0.82}


 17%|█▋        | 23001/136820 [49:22<5:26:41,  5.81it/s]

{'loss': 3.5155, 'learning_rate': 4.159479608244409e-05, 'epoch': 0.84}


 17%|█▋        | 23501/136820 [50:23<5:24:32,  5.82it/s]

{'loss': 3.7123, 'learning_rate': 4.1412074258149396e-05, 'epoch': 0.86}


 18%|█▊        | 24001/136820 [51:24<5:23:41,  5.81it/s]

{'loss': 3.4586, 'learning_rate': 4.12293524338547e-05, 'epoch': 0.88}


 18%|█▊        | 24501/136820 [52:25<5:30:58,  5.66it/s]

{'loss': 3.4465, 'learning_rate': 4.1046630609560006e-05, 'epoch': 0.9}


 18%|█▊        | 25001/136820 [53:26<5:22:04,  5.79it/s]

{'loss': 3.7363, 'learning_rate': 4.0863908785265315e-05, 'epoch': 0.91}


 19%|█▊        | 25501/136820 [54:27<5:20:52,  5.78it/s]

{'loss': 3.4522, 'learning_rate': 4.0681186960970616e-05, 'epoch': 0.93}


 19%|█▉        | 26001/136820 [55:28<5:22:41,  5.72it/s]

{'loss': 3.4194, 'learning_rate': 4.049846513667593e-05, 'epoch': 0.95}


 19%|█▉        | 26501/136820 [56:29<5:19:03,  5.76it/s]

{'loss': 3.4994, 'learning_rate': 4.031574331238123e-05, 'epoch': 0.97}


 20%|█▉        | 27001/136820 [57:30<5:15:01,  5.81it/s]

{'loss': 3.6132, 'learning_rate': 4.0133021488086535e-05, 'epoch': 0.99}


 20%|██        | 27501/136820 [58:31<5:14:09,  5.80it/s]

{'loss': 3.4757, 'learning_rate': 3.995029966379184e-05, 'epoch': 1.0}


 20%|██        | 28001/136820 [59:32<5:14:19,  5.77it/s]

{'loss': 3.2745, 'learning_rate': 3.976757783949715e-05, 'epoch': 1.02}


 21%|██        | 28501/136820 [1:00:33<5:09:51,  5.83it/s]

{'loss': 3.3805, 'learning_rate': 3.958485601520246e-05, 'epoch': 1.04}


 21%|██        | 29001/136820 [1:01:34<5:10:21,  5.79it/s]

{'loss': 3.4234, 'learning_rate': 3.940213419090776e-05, 'epoch': 1.06}


 22%|██▏       | 29501/136820 [1:02:35<5:10:02,  5.77it/s]

{'loss': 3.3109, 'learning_rate': 3.921941236661307e-05, 'epoch': 1.08}


 22%|██▏       | 30001/136820 [1:03:36<5:07:53,  5.78it/s]

{'loss': 3.4655, 'learning_rate': 3.903669054231838e-05, 'epoch': 1.1}


 22%|██▏       | 30501/136820 [1:04:39<5:17:01,  5.59it/s]

{'loss': 3.5147, 'learning_rate': 3.885396871802368e-05, 'epoch': 1.11}


 23%|██▎       | 31001/136820 [1:05:42<5:15:59,  5.58it/s]

{'loss': 3.3844, 'learning_rate': 3.867124689372899e-05, 'epoch': 1.13}


 23%|██▎       | 31501/136820 [1:06:45<5:29:12,  5.33it/s]

{'loss': 3.3331, 'learning_rate': 3.84885250694343e-05, 'epoch': 1.15}


 23%|██▎       | 32001/136820 [1:07:48<5:12:42,  5.59it/s]

{'loss': 3.3812, 'learning_rate': 3.83058032451396e-05, 'epoch': 1.17}


 24%|██▍       | 32501/136820 [1:08:51<5:10:32,  5.60it/s]

{'loss': 3.4368, 'learning_rate': 3.812308142084491e-05, 'epoch': 1.19}


 24%|██▍       | 33001/136820 [1:09:54<5:04:17,  5.69it/s]

{'loss': 3.2966, 'learning_rate': 3.7940359596550215e-05, 'epoch': 1.21}


 24%|██▍       | 33501/136820 [1:10:57<5:07:46,  5.59it/s]

{'loss': 3.4054, 'learning_rate': 3.7757637772255524e-05, 'epoch': 1.22}


 25%|██▍       | 34001/136820 [1:12:00<5:05:11,  5.61it/s]

{'loss': 3.2915, 'learning_rate': 3.7574915947960825e-05, 'epoch': 1.24}


 25%|██▌       | 34501/136820 [1:13:03<5:04:15,  5.60it/s]

{'loss': 3.3621, 'learning_rate': 3.739219412366613e-05, 'epoch': 1.26}


 26%|██▌       | 35001/136820 [1:14:06<5:00:36,  5.65it/s]

{'loss': 3.186, 'learning_rate': 3.720947229937144e-05, 'epoch': 1.28}


 26%|██▌       | 35501/136820 [1:15:09<5:00:59,  5.61it/s]

{'loss': 3.3154, 'learning_rate': 3.7026750475076744e-05, 'epoch': 1.3}


 26%|██▋       | 36001/136820 [1:16:12<4:57:51,  5.64it/s]

{'loss': 3.346, 'learning_rate': 3.684402865078205e-05, 'epoch': 1.32}


 27%|██▋       | 36501/136820 [1:17:15<5:09:16,  5.41it/s]

{'loss': 3.2767, 'learning_rate': 3.6661306826487354e-05, 'epoch': 1.33}


 27%|██▋       | 37001/136820 [1:18:18<4:57:52,  5.59it/s]

{'loss': 3.3679, 'learning_rate': 3.647858500219266e-05, 'epoch': 1.35}


 27%|██▋       | 37501/136820 [1:19:21<4:53:00,  5.65it/s]

{'loss': 3.2877, 'learning_rate': 3.629586317789797e-05, 'epoch': 1.37}


 28%|██▊       | 38001/136820 [1:20:24<4:54:51,  5.59it/s]

{'loss': 3.4151, 'learning_rate': 3.611314135360327e-05, 'epoch': 1.39}


 28%|██▊       | 38501/136820 [1:21:27<4:50:14,  5.65it/s]

{'loss': 3.3621, 'learning_rate': 3.593041952930859e-05, 'epoch': 1.41}


 29%|██▊       | 39001/136820 [1:22:30<4:50:18,  5.62it/s]

{'loss': 3.3554, 'learning_rate': 3.574769770501389e-05, 'epoch': 1.43}


 29%|██▉       | 39501/136820 [1:23:33<4:48:02,  5.63it/s]

{'loss': 3.4607, 'learning_rate': 3.556497588071919e-05, 'epoch': 1.44}


 29%|██▉       | 40001/136820 [1:24:36<4:48:33,  5.59it/s]

{'loss': 3.4838, 'learning_rate': 3.53822540564245e-05, 'epoch': 1.46}


 30%|██▉       | 40501/136820 [1:25:39<4:45:40,  5.62it/s]

{'loss': 3.4752, 'learning_rate': 3.519953223212981e-05, 'epoch': 1.48}


 30%|██▉       | 41001/136820 [1:26:42<4:44:22,  5.62it/s]

{'loss': 3.4061, 'learning_rate': 3.5016810407835116e-05, 'epoch': 1.5}


 30%|███       | 41501/136820 [1:27:45<4:49:05,  5.50it/s]

{'loss': 3.3044, 'learning_rate': 3.483408858354042e-05, 'epoch': 1.52}


 31%|███       | 42001/136820 [1:28:48<4:41:34,  5.61it/s]

{'loss': 3.3491, 'learning_rate': 3.4651366759245726e-05, 'epoch': 1.53}


 31%|███       | 42501/136820 [1:29:51<4:40:27,  5.61it/s]

{'loss': 3.2597, 'learning_rate': 3.4468644934951035e-05, 'epoch': 1.55}


 31%|███▏      | 43001/136820 [1:30:54<4:39:01,  5.60it/s]

{'loss': 3.2927, 'learning_rate': 3.4285923110656336e-05, 'epoch': 1.57}


 32%|███▏      | 43501/136820 [1:31:57<4:37:04,  5.61it/s]

{'loss': 3.3249, 'learning_rate': 3.4103201286361645e-05, 'epoch': 1.59}


 32%|███▏      | 44001/136820 [1:33:00<4:35:08,  5.62it/s]

{'loss': 3.5296, 'learning_rate': 3.392047946206695e-05, 'epoch': 1.61}


 33%|███▎      | 44501/136820 [1:34:03<4:33:17,  5.63it/s]

{'loss': 3.4208, 'learning_rate': 3.3737757637772255e-05, 'epoch': 1.63}


 33%|███▎      | 45001/136820 [1:35:06<4:31:32,  5.64it/s]

{'loss': 3.3903, 'learning_rate': 3.355503581347756e-05, 'epoch': 1.64}


 33%|███▎      | 45501/136820 [1:36:09<4:29:09,  5.65it/s]

{'loss': 3.2828, 'learning_rate': 3.337231398918287e-05, 'epoch': 1.66}


 34%|███▎      | 46001/136820 [1:37:12<4:30:52,  5.59it/s]

{'loss': 3.2381, 'learning_rate': 3.318959216488818e-05, 'epoch': 1.68}


 34%|███▍      | 46501/136820 [1:38:15<4:28:27,  5.61it/s]

{'loss': 3.2525, 'learning_rate': 3.300687034059348e-05, 'epoch': 1.7}


 34%|███▍      | 47001/136820 [1:39:18<4:28:12,  5.58it/s]

{'loss': 3.345, 'learning_rate': 3.282414851629878e-05, 'epoch': 1.72}


 35%|███▍      | 47501/136820 [1:40:21<4:24:13,  5.63it/s]

{'loss': 3.3137, 'learning_rate': 3.26414266920041e-05, 'epoch': 1.74}


 35%|███▌      | 48001/136820 [1:41:24<4:22:58,  5.63it/s]

{'loss': 3.3089, 'learning_rate': 3.24587048677094e-05, 'epoch': 1.75}


 35%|███▌      | 48501/136820 [1:42:27<4:20:26,  5.65it/s]

{'loss': 3.3187, 'learning_rate': 3.227598304341471e-05, 'epoch': 1.77}


 36%|███▌      | 49001/136820 [1:43:30<4:18:24,  5.66it/s]

{'loss': 3.1835, 'learning_rate': 3.209326121912001e-05, 'epoch': 1.79}


 36%|███▌      | 49501/136820 [1:44:33<4:17:58,  5.64it/s]

{'loss': 3.4132, 'learning_rate': 3.191053939482532e-05, 'epoch': 1.81}


 37%|███▋      | 50000/136820 [1:45:36<4:43:45,  5.10it/s]Saving model checkpoint to after-bert-random-trainer\checkpoint-50000
Configuration saved in after-bert-random-trainer\checkpoint-50000\config.json


{'loss': 3.3448, 'learning_rate': 3.172781757053063e-05, 'epoch': 1.83}


Model weights saved in after-bert-random-trainer\checkpoint-50000\pytorch_model.bin
 37%|███▋      | 50501/136820 [1:46:49<4:19:58,  5.53it/s] 

{'loss': 3.3619, 'learning_rate': 3.154509574623593e-05, 'epoch': 1.85}


 37%|███▋      | 51001/136820 [1:47:54<4:19:03,  5.52it/s]

{'loss': 3.2955, 'learning_rate': 3.1362373921941244e-05, 'epoch': 1.86}


 38%|███▊      | 51501/136820 [1:48:58<4:18:34,  5.50it/s]

{'loss': 3.2581, 'learning_rate': 3.1179652097646545e-05, 'epoch': 1.88}


 38%|███▊      | 52001/136820 [1:50:02<4:16:15,  5.52it/s]

{'loss': 3.1944, 'learning_rate': 3.099693027335185e-05, 'epoch': 1.9}


 38%|███▊      | 52501/136820 [1:51:06<4:16:02,  5.49it/s]

{'loss': 3.3807, 'learning_rate': 3.0814208449057155e-05, 'epoch': 1.92}


 39%|███▊      | 53001/136820 [1:52:10<4:14:35,  5.49it/s]

{'loss': 3.1102, 'learning_rate': 3.0631486624762464e-05, 'epoch': 1.94}


 39%|███▉      | 53501/136820 [1:53:14<4:09:54,  5.56it/s]

{'loss': 3.308, 'learning_rate': 3.0448764800467772e-05, 'epoch': 1.96}


 39%|███▉      | 54001/136820 [1:54:18<4:18:28,  5.34it/s]

{'loss': 3.4212, 'learning_rate': 3.0266042976173077e-05, 'epoch': 1.97}


 40%|███▉      | 54501/136820 [1:55:22<4:07:13,  5.55it/s]

{'loss': 3.2586, 'learning_rate': 3.008332115187838e-05, 'epoch': 1.99}


 40%|████      | 55001/136820 [1:56:26<4:06:41,  5.53it/s]

{'loss': 3.2973, 'learning_rate': 2.990059932758369e-05, 'epoch': 2.01}


 41%|████      | 55501/136820 [1:57:30<4:05:05,  5.53it/s]

{'loss': 3.232, 'learning_rate': 2.9717877503288992e-05, 'epoch': 2.03}


 41%|████      | 56001/136820 [1:58:34<4:03:05,  5.54it/s]

{'loss': 3.2081, 'learning_rate': 2.9535155678994304e-05, 'epoch': 2.05}


 41%|████▏     | 56501/136820 [1:59:38<4:03:57,  5.49it/s]

{'loss': 3.1834, 'learning_rate': 2.9352433854699606e-05, 'epoch': 2.06}


 42%|████▏     | 57001/136820 [2:00:42<4:03:24,  5.47it/s]

{'loss': 3.0978, 'learning_rate': 2.916971203040491e-05, 'epoch': 2.08}


 42%|████▏     | 57501/136820 [2:01:46<3:58:34,  5.54it/s]

{'loss': 3.0753, 'learning_rate': 2.898699020611022e-05, 'epoch': 2.1}


 42%|████▏     | 58001/136820 [2:02:50<3:56:25,  5.56it/s]

{'loss': 3.2642, 'learning_rate': 2.8804268381815524e-05, 'epoch': 2.12}


 43%|████▎     | 58501/136820 [2:03:54<4:05:18,  5.32it/s]

{'loss': 3.1647, 'learning_rate': 2.8621546557520833e-05, 'epoch': 2.14}


 43%|████▎     | 59001/136820 [2:04:58<3:55:18,  5.51it/s]

{'loss': 3.1313, 'learning_rate': 2.8438824733226138e-05, 'epoch': 2.16}


 43%|████▎     | 59501/136820 [2:06:02<3:52:30,  5.54it/s]

{'loss': 3.0842, 'learning_rate': 2.8256102908931443e-05, 'epoch': 2.17}


 44%|████▍     | 60001/136820 [2:07:06<3:51:33,  5.53it/s]

{'loss': 3.1623, 'learning_rate': 2.807338108463675e-05, 'epoch': 2.19}


 44%|████▍     | 60501/136820 [2:08:10<3:54:26,  5.43it/s]

{'loss': 3.2662, 'learning_rate': 2.7890659260342056e-05, 'epoch': 2.21}


 45%|████▍     | 61001/136820 [2:09:14<3:49:51,  5.50it/s]

{'loss': 3.1792, 'learning_rate': 2.7707937436047365e-05, 'epoch': 2.23}


 45%|████▍     | 61501/136820 [2:10:18<3:45:06,  5.58it/s]

{'loss': 3.1449, 'learning_rate': 2.752521561175267e-05, 'epoch': 2.25}


 45%|████▌     | 62001/136820 [2:11:22<3:45:05,  5.54it/s]

{'loss': 3.0884, 'learning_rate': 2.7342493787457975e-05, 'epoch': 2.27}


 46%|████▌     | 62501/136820 [2:12:26<3:47:05,  5.45it/s]

{'loss': 3.3515, 'learning_rate': 2.7159771963163283e-05, 'epoch': 2.28}


 46%|████▌     | 63001/136820 [2:13:30<3:41:45,  5.55it/s]

{'loss': 3.1225, 'learning_rate': 2.6977050138868588e-05, 'epoch': 2.3}


 46%|████▋     | 63501/136820 [2:14:34<3:41:33,  5.52it/s]

{'loss': 2.9811, 'learning_rate': 2.6794328314573896e-05, 'epoch': 2.32}


 47%|████▋     | 64001/136820 [2:15:38<3:39:15,  5.54it/s]

{'loss': 3.1324, 'learning_rate': 2.66116064902792e-05, 'epoch': 2.34}


 47%|████▋     | 64501/136820 [2:16:42<3:38:49,  5.51it/s]

{'loss': 3.1783, 'learning_rate': 2.6428884665984503e-05, 'epoch': 2.36}


 48%|████▊     | 65001/136820 [2:17:46<3:36:27,  5.53it/s]

{'loss': 3.0362, 'learning_rate': 2.6246162841689815e-05, 'epoch': 2.38}


 48%|████▊     | 65501/136820 [2:18:50<3:33:52,  5.56it/s]

{'loss': 3.2705, 'learning_rate': 2.606344101739512e-05, 'epoch': 2.39}


 48%|████▊     | 66001/136820 [2:19:54<3:32:04,  5.57it/s]

{'loss': 3.1413, 'learning_rate': 2.588071919310043e-05, 'epoch': 2.41}


 49%|████▊     | 66501/136820 [2:20:58<3:32:21,  5.52it/s]

{'loss': 3.2616, 'learning_rate': 2.5697997368805733e-05, 'epoch': 2.43}


 49%|████▉     | 67001/136820 [2:22:02<3:30:23,  5.53it/s]

{'loss': 3.0655, 'learning_rate': 2.5515275544511035e-05, 'epoch': 2.45}


 49%|████▉     | 67501/136820 [2:23:06<3:34:48,  5.38it/s]

{'loss': 3.184, 'learning_rate': 2.5332553720216347e-05, 'epoch': 2.47}


 50%|████▉     | 68001/136820 [2:24:10<3:27:19,  5.53it/s]

{'loss': 3.2579, 'learning_rate': 2.514983189592165e-05, 'epoch': 2.49}


 50%|█████     | 68501/136820 [2:25:14<3:26:20,  5.52it/s]

{'loss': 3.1813, 'learning_rate': 2.4967110071626957e-05, 'epoch': 2.5}


 50%|█████     | 69001/136820 [2:26:18<3:24:26,  5.53it/s]

{'loss': 3.0201, 'learning_rate': 2.4784388247332262e-05, 'epoch': 2.52}


 51%|█████     | 69501/136820 [2:27:22<3:23:57,  5.50it/s]

{'loss': 3.0492, 'learning_rate': 2.460166642303757e-05, 'epoch': 2.54}


 51%|█████     | 70001/136820 [2:28:26<3:20:38,  5.55it/s]

{'loss': 3.1026, 'learning_rate': 2.4418944598742875e-05, 'epoch': 2.56}


 52%|█████▏    | 70501/136820 [2:29:30<3:18:52,  5.56it/s]

{'loss': 3.2063, 'learning_rate': 2.423622277444818e-05, 'epoch': 2.58}


 52%|█████▏    | 71001/136820 [2:30:34<3:17:11,  5.56it/s]

{'loss': 3.1434, 'learning_rate': 2.405350095015349e-05, 'epoch': 2.59}


 52%|█████▏    | 71501/136820 [2:31:38<3:16:32,  5.54it/s]

{'loss': 3.1699, 'learning_rate': 2.3870779125858794e-05, 'epoch': 2.61}


 53%|█████▎    | 72001/136820 [2:32:42<3:16:57,  5.49it/s]

{'loss': 2.9636, 'learning_rate': 2.3688057301564102e-05, 'epoch': 2.63}


 53%|█████▎    | 72501/136820 [2:33:46<3:13:27,  5.54it/s]

{'loss': 3.0499, 'learning_rate': 2.3505335477269404e-05, 'epoch': 2.65}


 53%|█████▎    | 73001/136820 [2:34:50<3:12:06,  5.54it/s]

{'loss': 3.1295, 'learning_rate': 2.3322613652974712e-05, 'epoch': 2.67}


 54%|█████▎    | 73501/136820 [2:35:54<3:10:15,  5.55it/s]

{'loss': 3.0981, 'learning_rate': 2.3139891828680017e-05, 'epoch': 2.69}


 54%|█████▍    | 74001/136820 [2:36:58<3:08:21,  5.56it/s]

{'loss': 3.0133, 'learning_rate': 2.2957170004385326e-05, 'epoch': 2.7}


 54%|█████▍    | 74501/136820 [2:38:02<3:07:21,  5.54it/s]

{'loss': 3.0974, 'learning_rate': 2.277444818009063e-05, 'epoch': 2.72}


 55%|█████▍    | 75001/136820 [2:39:06<3:05:01,  5.57it/s]

{'loss': 3.1704, 'learning_rate': 2.2591726355795936e-05, 'epoch': 2.74}


 55%|█████▌    | 75501/136820 [2:40:10<3:04:03,  5.55it/s]

{'loss': 3.2295, 'learning_rate': 2.2409004531501244e-05, 'epoch': 2.76}


 56%|█████▌    | 76001/136820 [2:41:14<3:03:10,  5.53it/s]

{'loss': 3.198, 'learning_rate': 2.222628270720655e-05, 'epoch': 2.78}


 56%|█████▌    | 76501/136820 [2:42:19<3:02:14,  5.52it/s]

{'loss': 3.179, 'learning_rate': 2.2043560882911858e-05, 'epoch': 2.8}


 56%|█████▋    | 77001/136820 [2:43:23<3:00:23,  5.53it/s]

{'loss': 3.1203, 'learning_rate': 2.1860839058617163e-05, 'epoch': 2.81}


 57%|█████▋    | 77501/136820 [2:44:27<2:57:50,  5.56it/s]

{'loss': 2.8909, 'learning_rate': 2.1678117234322468e-05, 'epoch': 2.83}


 57%|█████▋    | 78001/136820 [2:45:31<2:57:05,  5.54it/s]

{'loss': 3.0812, 'learning_rate': 2.1495395410027776e-05, 'epoch': 2.85}


 57%|█████▋    | 78501/136820 [2:46:35<2:57:01,  5.49it/s]

{'loss': 2.9897, 'learning_rate': 2.131267358573308e-05, 'epoch': 2.87}


 58%|█████▊    | 79001/136820 [2:47:39<2:54:44,  5.51it/s]

{'loss': 3.0499, 'learning_rate': 2.112995176143839e-05, 'epoch': 2.89}


 58%|█████▊    | 79501/136820 [2:48:43<2:52:22,  5.54it/s]

{'loss': 2.9663, 'learning_rate': 2.0947229937143695e-05, 'epoch': 2.91}


 58%|█████▊    | 80001/136820 [2:49:47<2:54:37,  5.42it/s]

{'loss': 3.0052, 'learning_rate': 2.0764508112849e-05, 'epoch': 2.92}


 59%|█████▉    | 80501/136820 [2:50:51<2:49:35,  5.53it/s]

{'loss': 3.0629, 'learning_rate': 2.0581786288554305e-05, 'epoch': 2.94}


 59%|█████▉    | 81001/136820 [2:51:55<2:47:57,  5.54it/s]

{'loss': 3.1298, 'learning_rate': 2.0399064464259613e-05, 'epoch': 2.96}


 60%|█████▉    | 81501/136820 [2:52:59<2:46:28,  5.54it/s]

{'loss': 3.0148, 'learning_rate': 2.0216342639964918e-05, 'epoch': 2.98}


 60%|█████▉    | 82001/136820 [2:54:03<2:45:01,  5.54it/s]

{'loss': 3.1126, 'learning_rate': 2.0033620815670226e-05, 'epoch': 3.0}


 60%|██████    | 82501/136820 [2:55:07<2:44:56,  5.49it/s]

{'loss': 3.0372, 'learning_rate': 1.985089899137553e-05, 'epoch': 3.01}


 61%|██████    | 83001/136820 [2:56:11<2:41:41,  5.55it/s]

{'loss': 2.9883, 'learning_rate': 1.9668177167080836e-05, 'epoch': 3.03}


 61%|██████    | 83501/136820 [2:57:15<2:39:41,  5.57it/s]

{'loss': 2.8996, 'learning_rate': 1.9485455342786145e-05, 'epoch': 3.05}


 61%|██████▏   | 84001/136820 [2:58:19<2:38:52,  5.54it/s]

{'loss': 2.9538, 'learning_rate': 1.930273351849145e-05, 'epoch': 3.07}


 62%|██████▏   | 84501/136820 [2:59:23<2:37:38,  5.53it/s]

{'loss': 2.9961, 'learning_rate': 1.9120011694196755e-05, 'epoch': 3.09}


 62%|██████▏   | 85001/136820 [3:00:27<2:35:40,  5.55it/s]

{'loss': 3.1633, 'learning_rate': 1.893728986990206e-05, 'epoch': 3.11}


 62%|██████▏   | 85501/136820 [3:01:31<2:35:30,  5.50it/s]

{'loss': 2.9141, 'learning_rate': 1.875456804560737e-05, 'epoch': 3.12}


 63%|██████▎   | 86001/136820 [3:02:35<2:32:10,  5.57it/s]

{'loss': 3.0836, 'learning_rate': 1.8571846221312673e-05, 'epoch': 3.14}


 63%|██████▎   | 86501/136820 [3:03:39<2:31:44,  5.53it/s]

{'loss': 3.0033, 'learning_rate': 1.8389124397017982e-05, 'epoch': 3.16}


 64%|██████▎   | 87001/136820 [3:04:43<2:30:15,  5.53it/s]

{'loss': 3.1301, 'learning_rate': 1.8206402572723287e-05, 'epoch': 3.18}


 64%|██████▍   | 87501/136820 [3:05:47<2:28:55,  5.52it/s]

{'loss': 2.9654, 'learning_rate': 1.8023680748428592e-05, 'epoch': 3.2}


 64%|██████▍   | 88001/136820 [3:06:51<2:26:19,  5.56it/s]

{'loss': 2.9784, 'learning_rate': 1.78409589241339e-05, 'epoch': 3.22}


 65%|██████▍   | 88501/136820 [3:07:55<2:24:17,  5.58it/s]

{'loss': 2.9754, 'learning_rate': 1.7658237099839205e-05, 'epoch': 3.23}


 65%|██████▌   | 89001/136820 [3:08:59<2:25:37,  5.47it/s]

{'loss': 3.0091, 'learning_rate': 1.7475515275544514e-05, 'epoch': 3.25}


 65%|██████▌   | 89501/136820 [3:10:03<2:23:18,  5.50it/s]

{'loss': 2.7892, 'learning_rate': 1.7292793451249815e-05, 'epoch': 3.27}


 66%|██████▌   | 90001/136820 [3:11:08<2:20:23,  5.56it/s]

{'loss': 3.0167, 'learning_rate': 1.7110071626955124e-05, 'epoch': 3.29}


 66%|██████▌   | 90501/136820 [3:12:12<2:24:21,  5.35it/s]

{'loss': 2.9737, 'learning_rate': 1.6927349802660432e-05, 'epoch': 3.31}


 67%|██████▋   | 91001/136820 [3:13:16<2:19:22,  5.48it/s]

{'loss': 3.0098, 'learning_rate': 1.6744627978365737e-05, 'epoch': 3.33}


 67%|██████▋   | 91501/136820 [3:14:20<2:15:33,  5.57it/s]

{'loss': 2.9752, 'learning_rate': 1.6561906154071046e-05, 'epoch': 3.34}


 67%|██████▋   | 92001/136820 [3:15:24<2:14:19,  5.56it/s]

{'loss': 2.9206, 'learning_rate': 1.6379184329776347e-05, 'epoch': 3.36}


 68%|██████▊   | 92501/136820 [3:16:28<2:13:38,  5.53it/s]

{'loss': 3.0252, 'learning_rate': 1.6196462505481656e-05, 'epoch': 3.38}


 68%|██████▊   | 93001/136820 [3:17:32<2:11:36,  5.55it/s]

{'loss': 2.848, 'learning_rate': 1.601374068118696e-05, 'epoch': 3.4}


 68%|██████▊   | 93501/136820 [3:18:36<2:09:53,  5.56it/s]

{'loss': 2.8676, 'learning_rate': 1.583101885689227e-05, 'epoch': 3.42}


 69%|██████▊   | 94001/136820 [3:19:40<2:08:51,  5.54it/s]

{'loss': 2.9326, 'learning_rate': 1.5648297032597574e-05, 'epoch': 3.44}


 69%|██████▉   | 94501/136820 [3:20:44<2:07:08,  5.55it/s]

{'loss': 2.9309, 'learning_rate': 1.546557520830288e-05, 'epoch': 3.45}


 69%|██████▉   | 95001/136820 [3:21:48<2:06:33,  5.51it/s]

{'loss': 2.7631, 'learning_rate': 1.5282853384008188e-05, 'epoch': 3.47}


 70%|██████▉   | 95501/136820 [3:22:52<2:06:22,  5.45it/s]

{'loss': 2.9434, 'learning_rate': 1.5100131559713493e-05, 'epoch': 3.49}


 70%|███████   | 96001/136820 [3:23:56<2:03:22,  5.51it/s]

{'loss': 2.9495, 'learning_rate': 1.49174097354188e-05, 'epoch': 3.51}


 71%|███████   | 96501/136820 [3:25:00<2:01:09,  5.55it/s]

{'loss': 3.0207, 'learning_rate': 1.4734687911124106e-05, 'epoch': 3.53}


 71%|███████   | 97001/136820 [3:26:07<2:08:22,  5.17it/s]

{'loss': 3.009, 'learning_rate': 1.4551966086829411e-05, 'epoch': 3.54}


 71%|███████▏  | 97501/136820 [3:27:16<2:07:32,  5.14it/s]

{'loss': 2.8582, 'learning_rate': 1.4369244262534718e-05, 'epoch': 3.56}


 72%|███████▏  | 98001/136820 [3:28:26<2:04:30,  5.20it/s]

{'loss': 2.9015, 'learning_rate': 1.4186522438240024e-05, 'epoch': 3.58}


 72%|███████▏  | 98501/136820 [3:29:35<2:05:36,  5.08it/s]

{'loss': 2.8788, 'learning_rate': 1.4003800613945331e-05, 'epoch': 3.6}


 72%|███████▏  | 99001/136820 [3:30:44<2:02:50,  5.13it/s]

{'loss': 3.047, 'learning_rate': 1.3821078789650638e-05, 'epoch': 3.62}


 73%|███████▎  | 99501/136820 [3:31:54<2:00:12,  5.17it/s]

{'loss': 2.9704, 'learning_rate': 1.3638356965355941e-05, 'epoch': 3.64}


 73%|███████▎  | 100000/136820 [3:33:03<2:11:19,  4.67it/s]Saving model checkpoint to after-bert-random-trainer\checkpoint-100000
Configuration saved in after-bert-random-trainer\checkpoint-100000\config.json


{'loss': 2.9557, 'learning_rate': 1.3455635141061248e-05, 'epoch': 3.65}


Model weights saved in after-bert-random-trainer\checkpoint-100000\pytorch_model.bin
 73%|███████▎  | 100358/136820 [3:34:02<1:23:19,  7.29it/s] 

KeyboardInterrupt: 

In [8]:
from transformers import pipeline

fill_mask = pipeline(
    "fill-mask",
    model=model,
    tokenizer=tokenizer,
    device=0
)

s = f'{tokenizer.mask_token} to the sky!!!'
fill_mask(s)

[{'score': 0.14990156888961792,
  'token': 1244,
  'token_str': 'B a c k',
  'sequence': 'Back to the sky!!!'},
 {'score': 0.041228100657463074,
  'token': 53632,
  'token_str': 's o a r i n g',
  'sequence': 'soaring to the sky!!!'},
 {'score': 0.038240138441324234,
  'token': 107,
  'token_str': 'b a c k',
  'sequence': 'back to the sky!!!'},
 {'score': 0.03203132748603821,
  'token': 22482,
  'token_str': 'r o c k e t s',
  'sequence': 'rockets to the sky!!!'},
 {'score': 0.03099006786942482,
  'token': 2121,
  'token_str': 's h o o t',
  'sequence': 'shoot to the sky!!!'}]

In [9]:
import torch
torch.cuda.empty_cache()
print(torch.cuda.memory_summary())

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |    2099 MB |    2892 MB |  334254 GB |  334252 GB |
|       from large pool |    2095 MB |    2869 MB |  332165 GB |  332163 GB |
|       from small pool |       4 MB |      24 MB |    2088 GB |    2088 GB |
|---------------------------------------------------------------------------|
| Active memory         |    2099 MB |    2892 MB |  334254 GB |  334252 GB |
|       from large pool |    2095 MB |    2869 MB |  332165 GB |  332163 GB |
|       from small pool |       4 MB |      24 MB |    2088 GB |    2088 GB |
|---------------------------------------------------------------