In [1]:
import numpy as np

import torch
from torch.utils.data import DataLoader, random_split

from transformers import BloomTokenizerFast, BloomForCausalLM, AutoTokenizer


from datasets import load_dataset

from tqdm.auto import tqdm

In [2]:
%%time
# tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
tokenizer = AutoTokenizer.from_pretrained("sberbank-ai/rugpt3large_based_on_gpt2", padding_side='left')
model = BloomForCausalLM.from_pretrained("bigscience/bloom-560m")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


CPU times: user 4.76 s, sys: 1.25 s, total: 6.01 s
Wall time: 9.94 s


In [3]:
tokenizer.pad_token = tokenizer.eos_token

In [4]:
def collate_fn(data):
    texts = [x['text'] for x in data]
    inputs = tokenizer(texts, padding=True, return_tensors='pt', max_length=64, truncation=True)
    inputs['labels'] = torch.where(inputs['input_ids'] == 50257, -100, inputs['input_ids'])
    return inputs


dataset = load_dataset('nthngdy/oscar-mini', 'unshuffled_deduplicated_ru', split='train')

train_length = int(0.99 * len(dataset))
val_length = len(dataset) - train_length

train_dataset, val_dataset = random_split(dataset, [train_length, val_length],
                                          generator=torch.Generator().manual_seed(42))
train_loader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=2)

Found cached dataset oscar-mini (/home/alexnikko/.cache/huggingface/datasets/nthngdy___oscar-mini/unshuffled_deduplicated_ru/1.0.0/d61b181331745a38dd31e8c6cc23d46566b96e255384c4421f2396af24a01dff)


In [5]:
next(iter(train_loader))

{'input_ids': tensor([[19572,   801,   768,  5123,   385,  1412,   289, 40854,   562,   272,
         22149,   305,   552,   329,  8654,  1954,   360,  8359,  4149,   872,
           718, 49641,   385,  8283,    18,   385,   923,   781, 37147,  1558,
          5289, 30085,  3245,   289, 28602,    16,  3229,   828, 18351, 49641,
           385,  8283,  6320,  5700,   411, 35243,  1781,  3258, 40583,    35,
          8545, 11291,     5],
        [50257, 50257, 50257, 50257, 50257, 50257, 50257,  4580,  1043,  5613,
          3929,  9874, 20562,    16,  2726,  5643, 22847,   289, 22721,  2257,
          5836,    16, 32293,   360, 44657, 14820,   282, 40349, 12893,    16,
          6159,  2009,  3594,    16,  9410,   416,  3320,    18,   365, 39859,
          1750, 13022, 11846,  1437, 37279,   294,  8481, 12867,  4676,  6657,
           289,  8786,    18]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 

In [6]:
for p in model.parameters():
    p.requires_grad = False

In [7]:
model.transformer.word_embeddings = torch.nn.Embedding(tokenizer.vocab_size + 1, 1024)
model.transformer.word_embeddings_layernorm = torch.nn.LayerNorm((1024,), eps=1e-5, elementwise_affine=True)

In [8]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [9]:
# for p in model.parameters():
#     p.requires_grad = False

    
device = 'cuda:0'

model.to(device)

model.eval()
perplexity_values = []
with torch.inference_mode():
    for batch in tqdm(val_loader):
        batch = {key: value.to(device) for key, value in batch.items()}
        outputs = model(**batch)
        perplexity = torch.exp(outputs.loss)
        perplexity_values.append(perplexity.item())

  0%|          | 0/575 [00:00<?, ?it/s]

In [20]:
np.mean(perplexity_values)

37.658425997858465

In [25]:
np.mean(perplexity_values)

45.671457590849506

In [10]:
print(f'not trained model perplexity = {np.mean(perplexity_values)}')

not trained model perplexity = inf


In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

n_epochs = 20
n_steps_per_epoch = 1_000

for epoch in tqdm(range(n_epochs)):
    losses = []
    perplexity_values = []
    model.train()
    for i, batch in tqdm(enumerate(train_loader, start=1), total=n_steps_per_epoch):
        optimizer.zero_grad()
        
        batch = {key: value.to(device) for key, value in batch.items()}
        outputs = model(**batch)
        
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
        perplexity_values.append(torch.exp(loss).item())
        
        if i == n_steps_per_epoch:
            break
    
    print()
    print(f'Train loss = {np.mean(losses)}')
    print(f'Train perplexity = {np.mean(perplexity_values)}')
    print()
    
    losses = []
    perplexity_values = []
    model.eval()
    with torch.inference_mode():
        for batch in tqdm(val_loader):
            batch = {key: value.to(device) for key, value in batch.items()}
            outputs = model(**batch)

            loss = outputs.loss

            losses.append(loss.item())
            perplexity_values.append(torch.exp(loss).item())
    
    print()
    print(f'Val loss = {np.mean(losses)}')
    print(f'Val perplexity = {np.mean(perplexity_values)}')
    print()

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]


Train loss = 19.513160165786744
Train perplexity = inf



  0%|          | 0/575 [00:00<?, ?it/s]


Val loss = 16.93542254904042
Val perplexity = 8.270905388244058e+25



  0%|          | 0/1000 [00:00<?, ?it/s]


Train loss = 16.18055218029022
Train perplexity = 5.8941690932847916e+29



  0%|          | 0/575 [00:00<?, ?it/s]


Val loss = 15.712012596130371
Val perplexity = 4.510440234427319e+25



  0%|          | 0/1000 [00:00<?, ?it/s]


Train loss = 15.039828087806702
Train perplexity = 5.337443715349891e+16



  0%|          | 0/575 [00:00<?, ?it/s]


Val loss = 14.882750403362772
Val perplexity = 6.003487305121589e+24



  0%|          | 0/1000 [00:00<?, ?it/s]


Train loss = 14.299591787338256
Train perplexity = 1.1524066507464403e+19



  0%|          | 0/575 [00:00<?, ?it/s]


Val loss = 14.32748211570408
Val perplexity = 4.653681234366906e+22



  0%|          | 0/1000 [00:00<?, ?it/s]


Train loss = 13.895928610801697
Train perplexity = 140205074118183.38



  0%|          | 0/575 [00:00<?, ?it/s]


Val loss = 13.774944458007813
Val perplexity = 2.285327155382476e+18



  0%|          | 0/1000 [00:00<?, ?it/s]


Train loss = 13.273399362564087
Train perplexity = 2491834876847256.5



  0%|          | 0/575 [00:00<?, ?it/s]


Val loss = 13.331967330600904
Val perplexity = 3350770830897813.5



  0%|          | 0/1000 [00:00<?, ?it/s]


Train loss = 13.023203072547913
Train perplexity = 21142691509.246983



  0%|          | 0/575 [00:00<?, ?it/s]


Val loss = 12.933897665272589
Val perplexity = 12868839001215.412



  0%|          | 0/1000 [00:00<?, ?it/s]


Train loss = 12.56742042350769
Train perplexity = 258277795.68333593



  0%|          | 0/575 [00:00<?, ?it/s]


Val loss = 12.569540138244628
Val perplexity = 521118875882.8981



  0%|          | 0/1000 [00:00<?, ?it/s]


Train loss = 12.25855275440216
Train perplexity = 2059155.6373652343



  0%|          | 0/575 [00:00<?, ?it/s]


Val loss = 12.20556517559549
Val perplexity = 6503997490.052452



  0%|          | 0/1000 [00:00<?, ?it/s]


Train loss = 11.867764868736266
Train perplexity = 1528636822.3900714



  0%|          | 0/575 [00:00<?, ?it/s]


Val loss = 11.70790943975034
Val perplexity = 1418140305.5292222



  0%|          | 0/1000 [00:00<?, ?it/s]


Train loss = 11.13712041759491
Train perplexity = 17615718.403149657



  0%|          | 0/575 [00:00<?, ?it/s]


Val loss = 11.081802160843559
Val perplexity = 214151260.29461977



  0%|          | 0/1000 [00:00<?, ?it/s]


Train loss = 10.771820645332337
Train perplexity = 3648751.1897636717



  0%|          | 0/575 [00:00<?, ?it/s]


Val loss = 10.80535734674205
Val perplexity = 13862604.361004373



  0%|          | 0/1000 [00:00<?, ?it/s]


Train loss = 10.565869745731353
Train perplexity = 870674.9753181152



  0%|          | 0/575 [00:00<?, ?it/s]


Val loss = 10.616300719717275
Val perplexity = 777624.4816077191



  0%|          | 0/1000 [00:00<?, ?it/s]


Train loss = 10.414670981884003
Train perplexity = 2530471.928337402



  0%|          | 0/575 [00:00<?, ?it/s]


Val loss = 10.447675363291864
Val perplexity = 439776.0975350289



  0%|          | 0/1000 [00:00<?, ?it/s]


Train loss = 10.302460082530976
Train perplexity = 116624.05942126465



  0%|          | 0/575 [00:00<?, ?it/s]


Val loss = 10.326050551870594
Val perplexity = 543873.910087466



  0%|          | 0/1000 [00:00<?, ?it/s]


Train loss = 10.15346932554245
Train perplexity = 108940.3479083252



  0%|          | 0/575 [00:00<?, ?it/s]


Val loss = 10.214647948223611
Val perplexity = 279669.96639138926



  0%|          | 0/1000 [00:00<?, ?it/s]


Train loss = 10.140347895145416
Train perplexity = 84662.95079077149



  0%|          | 0/575 [00:00<?, ?it/s]


Val loss = 10.10632541158925
Val perplexity = 141417.29554411516



  0%|          | 0/1000 [00:00<?, ?it/s]


Train loss = 10.001111440181733
Train perplexity = 99718.32244274902



  0%|          | 0/575 [00:00<?, ?it/s]


Val loss = 10.014995613098144
Val perplexity = 351755.93312032946



  0%|          | 0/1000 [00:00<?, ?it/s]


Train loss = 9.88390025138855
Train perplexity = 630380.5353560181



  0%|          | 0/575 [00:00<?, ?it/s]


Val loss = 9.942992993230405
Val perplexity = 1062308.1618654467



  0%|          | 0/1000 [00:00<?, ?it/s]


Train loss = 9.851526480197906
Train perplexity = 129624.94405145264



  0%|          | 0/575 [00:00<?, ?it/s]


Val loss = 9.860897156259288
Val perplexity = 61497.87249214504



In [14]:
len(train_dataset)

113689