Model: small-llama2. https://huggingface.co/TinyPixel/small-llama2



# This is a full pretraining on the training set, verifying the whole training code.

# Preparing Data
## Loading data

In [1]:
TRAINING = True

In [2]:
from datasets import load_dataset, DatasetDict
dataset = load_dataset("./bookcorpus-splitted")
# dataset.cleanup_cache_files()
if not TRAINING:
    dataset = DatasetDict({'train': dataset['train'].select(range(32 * 80)), 
               'validation': dataset['validation'].select(range(32 * 20))})
dataset
# Trouble shooting: 
# OSError: Invalid flatbuffers message.
# ArrowInvalid: Old metadata version not supported


DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 59203382
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 14800846
    })
})

In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM
from transformers import AutoConfig

tokenizer = AutoTokenizer.from_pretrained("TinyPixel/small-llama2")
tokenizer.pad_token_id=tokenizer.eos_token_id
config = AutoConfig.from_pretrained("TinyPixel/small-llama2")
config.num_hidden_layers = 6 # originally, 12
# model = LlamaForCausalLM.from_pretrained("TinyPixel/small-llama2")
model = LlamaForCausalLM(config) # Randomly initialize a model
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 1024)
    (layers): ModuleList(
      (0-5): 6 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (o_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=1024, out_features=1376, bias=False)
          (up_proj): Linear(in_features=1024, out_features=1376, bias=False)
          (down_proj): Linear(in_features=1376, out_features=1024, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(i

In [4]:
# tokenization 
def tokenize(element):
    long_text = "".join(element['text']) # concatenating 
    outputs = tokenizer(
        [long_text],
        truncation=True,
        return_overflowing_tokens=True,
        return_length=True,
        max_length=config.max_position_embeddings,
    )
    return {"input_ids": outputs['input_ids']}


tokenized_datasets = dataset.map(
    tokenize, batched=True, remove_columns=dataset["train"].column_names, 
    batch_size=200# , num_proc=10
)
# Trouble shooting:
# batch_size=1000: index out of bounds: the len is 31172 but the index is 8589960764 -- decrease batch
# batch_size=500:  index out of bounds: the len is 30153 but the index is 283467863127
# ArrowInvalid: Column 1 named input_ids expected length 500 but got length 8
# IndentationError: unindent does not match any outer indentation level -- Restart kernel
# RuntimeError: One of the subprocesses has abruptly died during map operation.To debug the error, disable multiprocessing.

tokenized_datasets


DatasetDict({
    train: Dataset({
        features: ['input_ids'],
        num_rows: 1115389
    })
    validation: Dataset({
        features: ['input_ids'],
        num_rows: 278684
    })
})

In [6]:
from transformers import DataCollatorForLanguageModeling

tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
# look at five examples
out = data_collator([tokenized_datasets["train"][i] for i in range(5)])
for key in out:
    print(f"{key} shape: {out[key].shape}")

input_ids shape: torch.Size([5, 1024])
attention_mask shape: torch.Size([5, 1024])
labels shape: torch.Size([5, 1024])


# Training

In [7]:
from torch.nn import CrossEntropyLoss
from torch import tensor, exp
def compute_metrics(eval_pred):
    # print('Inside compute_metrics', eval_pred.predictions.shape, eval_pred.label_ids.shape)
    # Inside compute_metrics (11, 1024, 32000) (11, 1024)  numpy.ndarray
    # (batch, sequence_length, vocabulary) and (batch, sequence_length)
    loss_fct = CrossEntropyLoss()
    prediction = tensor(eval_pred.predictions).view(-1, 32000)
    labels = tensor(eval_pred.label_ids).view(-1)
    masked_lm_loss = exp(loss_fct(prediction, labels)) 
    return {'ppl': masked_lm_loss}

from transformers import Trainer, TrainingArguments
import os
# os.environ['WANDB_DISABLED'] = 'true' # turning off reporting to WanDB. It requires API key
args = TrainingArguments(
    output_dir="llama2-small-bigram-guided",
    per_device_train_batch_size=4, # ref: 32
    per_device_eval_batch_size=4, # ref: 32
    evaluation_strategy="steps",
    eval_steps=1, # ref: 5_000, Evaluation is time consuming 
    logging_steps=1, # ref: 5_000
    gradient_accumulation_steps=2, # this parameter will influence BP, evaluation, saving and logging.
    num_train_epochs=1,
    weight_decay=0.1,
    warmup_steps=1, # ref: 1_000
    lr_scheduler_type="cosine",
    learning_rate=5e-4,
    save_steps=3, # ref: 5_000 
    fp16=True,
    push_to_hub=False, # default as False, saving model in local file system.
    # report_to='none', # turning off reporting to WanDB. It requires API key
    report_to='tensorboard',
)
if TRAINING:
    args = TrainingArguments(
        output_dir="llama2-small-bigram-guided",
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        evaluation_strategy="steps",
        eval_steps=3_000, 
        logging_steps=20, 
        gradient_accumulation_steps=8,
        num_train_epochs=1,
        weight_decay=0.1,
        warmup_steps=1_000, 
        lr_scheduler_type="cosine",
        learning_rate=5e-4,
        save_steps=3_000, 
        fp16=True,
        push_to_hub=False,
        report_to='tensorboard',
    )

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    # compute_metrics=compute_metrics # If turn on the metric, memory consumption continues increasing.  
)

In [8]:
trainer.train()

  0%|          | 0/34856 [00:00<?, ?it/s]

{'loss': 10.2141, 'grad_norm': 3.6773183345794678, 'learning_rate': 1e-05, 'epoch': 0.0}
{'loss': 9.0169, 'grad_norm': 1.9270586967468262, 'learning_rate': 2e-05, 'epoch': 0.0}
{'loss': 8.2106, 'grad_norm': 1.6131435632705688, 'learning_rate': 3e-05, 'epoch': 0.0}
{'loss': 7.35, 'grad_norm': 1.2696418762207031, 'learning_rate': 4e-05, 'epoch': 0.0}
{'loss': 6.5989, 'grad_norm': 0.9511287808418274, 'learning_rate': 5e-05, 'epoch': 0.0}
{'loss': 6.1903, 'grad_norm': 0.5924166440963745, 'learning_rate': 6e-05, 'epoch': 0.0}
{'loss': 5.9099, 'grad_norm': 0.7124360799789429, 'learning_rate': 7.000000000000001e-05, 'epoch': 0.0}
{'loss': 5.6655, 'grad_norm': 1.2501721382141113, 'learning_rate': 8e-05, 'epoch': 0.0}
{'loss': 5.4397, 'grad_norm': 1.2117037773132324, 'learning_rate': 8.999999999999999e-05, 'epoch': 0.01}
{'loss': 5.2684, 'grad_norm': 1.0709444284439087, 'learning_rate': 0.0001, 'epoch': 0.01}
{'loss': 5.1392, 'grad_norm': 1.040246844291687, 'learning_rate': 0.00011, 'epoch': 0.

  0%|          | 0/69671 [00:00<?, ?it/s]

{'eval_loss': 3.6574158668518066, 'eval_runtime': 6103.3435, 'eval_samples_per_second': 45.661, 'eval_steps_per_second': 11.415, 'epoch': 0.09}
{'loss': 3.648, 'grad_norm': 0.30810320377349854, 'learning_rate': 0.0004956210571182014, 'epoch': 0.09}
{'loss': 3.6546, 'grad_norm': 0.3564930558204651, 'learning_rate': 0.0004955341764720305, 'epoch': 0.09}
{'loss': 3.6602, 'grad_norm': 0.3169688880443573, 'learning_rate': 0.0004954464501571315, 'epoch': 0.09}
{'loss': 3.6568, 'grad_norm': 0.3186594843864441, 'learning_rate': 0.0004953578784756514, 'epoch': 0.09}
{'loss': 3.6426, 'grad_norm': 0.34190812706947327, 'learning_rate': 0.0004952684617326486, 'epoch': 0.09}
{'loss': 3.6537, 'grad_norm': 0.32837581634521484, 'learning_rate': 0.0004951782002360924, 'epoch': 0.09}
{'loss': 3.6544, 'grad_norm': 0.31156426668167114, 'learning_rate': 0.0004950870942968613, 'epoch': 0.09}
{'loss': 3.6429, 'grad_norm': 0.3249496519565582, 'learning_rate': 0.0004949951442287425, 'epoch': 0.09}
{'loss': 3.65

  0%|          | 0/69671 [00:00<?, ?it/s]

Checkpoint destination directory llama2-small-bigram-guided/checkpoint-6000 already exists and is non-empty. Saving will proceed but saved results may be invalid.


{'eval_loss': 3.5175981521606445, 'eval_runtime': 6103.6969, 'eval_samples_per_second': 45.658, 'eval_steps_per_second': 11.415, 'epoch': 0.17}
{'loss': 3.5276, 'grad_norm': 0.28415820002555847, 'learning_rate': 0.0004733634121618964, 'epoch': 0.17}
{'loss': 3.5167, 'grad_norm': 0.3096098303794861, 'learning_rate': 0.00047315463571458307, 'epoch': 0.17}
{'loss': 3.5189, 'grad_norm': 0.2632645070552826, 'learning_rate': 0.0004729450906781485, 'epoch': 0.17}
{'loss': 3.5149, 'grad_norm': 0.287250816822052, 'learning_rate': 0.00047273477777430746, 'epoch': 0.17}
{'loss': 3.5194, 'grad_norm': 0.30232709646224976, 'learning_rate': 0.00047252369772741965, 'epoch': 0.18}
{'loss': 3.5179, 'grad_norm': 0.29676830768585205, 'learning_rate': 0.00047231185126448696, 'epoch': 0.18}
{'loss': 3.5095, 'grad_norm': 0.2878463566303253, 'learning_rate': 0.0004720992391151508, 'epoch': 0.18}
{'loss': 3.5067, 'grad_norm': 0.27134403586387634, 'learning_rate': 0.00047188586201168996, 'epoch': 0.18}
{'loss':

  0%|          | 0/69671 [00:00<?, ?it/s]

{'eval_loss': 3.444873809814453, 'eval_runtime': 6099.7471, 'eval_samples_per_second': 45.688, 'eval_steps_per_second': 11.422, 'epoch': 0.26}
{'loss': 3.4532, 'grad_norm': 0.33194002509117126, 'learning_rate': 0.00043392353510097965, 'epoch': 0.26}
{'loss': 3.4481, 'grad_norm': 0.2693030834197998, 'learning_rate': 0.0004336089693143827, 'epoch': 0.26}
{'loss': 3.4449, 'grad_norm': 0.2576901316642761, 'learning_rate': 0.0004332937711418354, 'epoch': 0.26}
{'loss': 3.4496, 'grad_norm': 0.26170992851257324, 'learning_rate': 0.00043297794166894304, 'epoch': 0.26}
{'loss': 3.4409, 'grad_norm': 0.26590195298194885, 'learning_rate': 0.00043266148198348555, 'epoch': 0.26}
{'loss': 3.4544, 'grad_norm': 0.267270565032959, 'learning_rate': 0.0004323443931754132, 'epoch': 0.26}
{'loss': 3.4455, 'grad_norm': 0.29316651821136475, 'learning_rate': 0.0004320266763368431, 'epoch': 0.26}
{'loss': 3.4366, 'grad_norm': 0.2552950978279114, 'learning_rate': 0.0004317083325620555, 'epoch': 0.26}
{'loss': 3.

  0%|          | 0/69671 [00:00<?, ?it/s]

Checkpoint destination directory llama2-small-bigram-guided/checkpoint-12000 already exists and is non-empty. Saving will proceed but saved results may be invalid.


{'eval_loss': 3.3964169025421143, 'eval_runtime': 6100.6601, 'eval_samples_per_second': 45.681, 'eval_steps_per_second': 11.42, 'epoch': 0.34}
{'loss': 3.3977, 'grad_norm': 0.2507679760456085, 'learning_rate': 0.00038033177831230775, 'epoch': 0.34}
{'loss': 3.403, 'grad_norm': 0.2666989266872406, 'learning_rate': 0.0003799356272513172, 'epoch': 0.35}
{'loss': 3.391, 'grad_norm': 0.23970948159694672, 'learning_rate': 0.00037953902866608304, 'epoch': 0.35}
{'loss': 3.4011, 'grad_norm': 0.2681775391101837, 'learning_rate': 0.0003791419839225697, 'epoch': 0.35}
{'loss': 3.3892, 'grad_norm': 0.2605178952217102, 'learning_rate': 0.00037874449438827883, 'epoch': 0.35}
{'loss': 3.3884, 'grad_norm': 0.2619072496891022, 'learning_rate': 0.0003783465614322437, 'epoch': 0.35}
{'loss': 3.3915, 'grad_norm': 0.2590520977973938, 'learning_rate': 0.00037794818642502464, 'epoch': 0.35}
{'loss': 3.3902, 'grad_norm': 0.255928635597229, 'learning_rate': 0.0003775493707387051, 'epoch': 0.35}
{'loss': 3.3968

  0%|          | 0/69671 [00:00<?, ?it/s]

{'eval_loss': 3.3543426990509033, 'eval_runtime': 6099.0537, 'eval_samples_per_second': 45.693, 'eval_steps_per_second': 11.423, 'epoch': 0.43}
{'loss': 3.3577, 'grad_norm': 0.27802687883377075, 'learning_rate': 0.0003166893682305552, 'epoch': 0.43}
{'loss': 3.3487, 'grad_norm': 0.26146239042282104, 'learning_rate': 0.0003162421020392465, 'epoch': 0.43}
{'loss': 3.3492, 'grad_norm': 0.27394264936447144, 'learning_rate': 0.00031579460769691226, 'epoch': 0.43}
{'loss': 3.3499, 'grad_norm': 0.2870989739894867, 'learning_rate': 0.0003153468867448123, 'epoch': 0.43}
{'loss': 3.3535, 'grad_norm': 0.27682334184646606, 'learning_rate': 0.00031489894072498693, 'epoch': 0.43}
{'loss': 3.3506, 'grad_norm': 0.2555399537086487, 'learning_rate': 0.0003144507711802518, 'epoch': 0.43}
{'loss': 3.356, 'grad_norm': 0.2652854025363922, 'learning_rate': 0.00031400237965419216, 'epoch': 0.43}
{'loss': 3.353, 'grad_norm': 0.25618061423301697, 'learning_rate': 0.000313553767691158, 'epoch': 0.43}
{'loss': 3.

  0%|          | 0/69671 [00:00<?, ?it/s]

{'eval_loss': 3.3207106590270996, 'eval_runtime': 6095.5411, 'eval_samples_per_second': 45.719, 'eval_steps_per_second': 11.43, 'epoch': 0.52}
{'loss': 3.3187, 'grad_norm': 0.30020859837532043, 'learning_rate': 0.0002479585813360631, 'epoch': 0.52}
{'loss': 3.3188, 'grad_norm': 0.2783338725566864, 'learning_rate': 0.0002474946366429139, 'epoch': 0.52}
{'loss': 3.3183, 'grad_norm': 0.2886226773262024, 'learning_rate': 0.0002470307005787363, 'epoch': 0.52}
{'loss': 3.3216, 'grad_norm': 0.2579776346683502, 'learning_rate': 0.0002465667747414189, 'epoch': 0.52}
{'loss': 3.3188, 'grad_norm': 0.2832474112510681, 'learning_rate': 0.00024610286072881466, 'epoch': 0.52}
{'loss': 3.3076, 'grad_norm': 0.2665261924266815, 'learning_rate': 0.00024563896013873627, 'epoch': 0.52}
{'loss': 3.3135, 'grad_norm': 0.2601875960826874, 'learning_rate': 0.0002451750745689498, 'epoch': 0.52}
{'loss': 3.3179, 'grad_norm': 0.2820694148540497, 'learning_rate': 0.0002447112056171699, 'epoch': 0.52}
{'loss': 3.317

  0%|          | 0/69671 [00:00<?, ?it/s]

{'eval_loss': 3.2858715057373047, 'eval_runtime': 6092.8403, 'eval_samples_per_second': 45.74, 'eval_steps_per_second': 11.435, 'epoch': 0.6}
{'loss': 3.2934, 'grad_norm': 0.3294905126094818, 'learning_rate': 0.00017934026018611682, 'epoch': 0.6}
{'loss': 3.2946, 'grad_norm': 0.28475421667099, 'learning_rate': 0.0001788953356779353, 'epoch': 0.6}
{'loss': 3.2861, 'grad_norm': 0.2782890200614929, 'learning_rate': 0.0001784506560684147, 'epoch': 0.6}
{'loss': 3.2856, 'grad_norm': 0.30526450276374817, 'learning_rate': 0.00017800622288912044, 'epoch': 0.6}
{'loss': 3.2854, 'grad_norm': 0.2875988781452179, 'learning_rate': 0.0001775620376707691, 'epoch': 0.61}
{'loss': 3.2825, 'grad_norm': 0.34829047322273254, 'learning_rate': 0.00017711810194322318, 'epoch': 0.61}
{'loss': 3.2862, 'grad_norm': 0.3071736991405487, 'learning_rate': 0.00017667441723548616, 'epoch': 0.61}
{'loss': 3.2789, 'grad_norm': 0.293785959482193, 'learning_rate': 0.00017623098507569667, 'epoch': 0.61}
{'loss': 3.2864, '

  0%|          | 0/69671 [00:00<?, ?it/s]

{'eval_loss': 3.255666971206665, 'eval_runtime': 6134.0668, 'eval_samples_per_second': 45.432, 'eval_steps_per_second': 11.358, 'epoch': 0.69}
{'loss': 3.251, 'grad_norm': 0.2975229024887085, 'learning_rate': 0.00011618200530473116, 'epoch': 0.69}
{'loss': 3.2538, 'grad_norm': 0.31886255741119385, 'learning_rate': 0.00011579033502204209, 'epoch': 0.69}
{'loss': 3.2642, 'grad_norm': 0.2975393235683441, 'learning_rate': 0.00011539912698423604, 'epoch': 0.69}
{'loss': 3.2456, 'grad_norm': 0.28057631850242615, 'learning_rate': 0.00011502790872988889, 'epoch': 0.69}
{'loss': 3.262, 'grad_norm': 0.31482213735580444, 'learning_rate': 0.0001146376059436057, 'epoch': 0.69}
{'loss': 3.2513, 'grad_norm': 0.28595414757728577, 'learning_rate': 0.0001142477693724345, 'epoch': 0.69}
{'loss': 3.2563, 'grad_norm': 0.29002389311790466, 'learning_rate': 0.00011385840035905054, 'epoch': 0.69}
{'loss': 3.2526, 'grad_norm': 0.3086337149143219, 'learning_rate': 0.00011346950024451838, 'epoch': 0.69}
{'loss':

  0%|          | 0/69671 [00:00<?, ?it/s]

{'eval_loss': 3.230928659439087, 'eval_runtime': 6136.5136, 'eval_samples_per_second': 45.414, 'eval_steps_per_second': 11.354, 'epoch': 0.77}
{'loss': 3.2217, 'grad_norm': 0.2991734743118286, 'learning_rate': 6.333572197978302e-05, 'epoch': 0.78}
{'loss': 3.2355, 'grad_norm': 0.30001136660575867, 'learning_rate': 6.30274105706683e-05, 'epoch': 0.78}
{'loss': 3.2256, 'grad_norm': 0.28830644488334656, 'learning_rate': 6.271974313248318e-05, 'epoch': 0.78}
{'loss': 3.2295, 'grad_norm': 0.31302645802497864, 'learning_rate': 6.241272072489593e-05, 'epoch': 0.78}
{'loss': 3.2347, 'grad_norm': 0.3176901936531067, 'learning_rate': 6.21063444053529e-05, 'epoch': 0.78}
{'loss': 3.2269, 'grad_norm': 0.31305640935897827, 'learning_rate': 6.180061522907532e-05, 'epoch': 0.78}
{'loss': 3.2376, 'grad_norm': 0.30532801151275635, 'learning_rate': 6.14955342490556e-05, 'epoch': 0.78}
{'loss': 3.2261, 'grad_norm': 0.30641940236091614, 'learning_rate': 6.119110251605342e-05, 'epoch': 0.78}
{'loss': 3.224

  0%|          | 0/69671 [00:00<?, ?it/s]

{'eval_loss': 3.213669538497925, 'eval_runtime': 6105.714, 'eval_samples_per_second': 45.643, 'eval_steps_per_second': 11.411, 'epoch': 0.86}
{'loss': 3.2103, 'grad_norm': 0.31532910466194153, 'learning_rate': 2.4842688178774637e-05, 'epoch': 0.86}
{'loss': 3.2108, 'grad_norm': 0.31879180669784546, 'learning_rate': 2.464144275330929e-05, 'epoch': 0.86}
{'loss': 3.2226, 'grad_norm': 0.3227134943008423, 'learning_rate': 2.44409735077111e-05, 'epoch': 0.86}
{'loss': 3.2122, 'grad_norm': 0.31799522042274475, 'learning_rate': 2.424128113243615e-05, 'epoch': 0.86}
{'loss': 3.2133, 'grad_norm': 0.30681905150413513, 'learning_rate': 2.4042366315264798e-05, 'epoch': 0.86}
{'loss': 3.2101, 'grad_norm': 0.3115248680114746, 'learning_rate': 2.3844229741299546e-05, 'epoch': 0.86}
{'loss': 3.2166, 'grad_norm': 0.3048432767391205, 'learning_rate': 2.364687209296218e-05, 'epoch': 0.86}
{'loss': 3.2117, 'grad_norm': 0.2931392192840576, 'learning_rate': 2.3450294049991883e-05, 'epoch': 0.87}
{'loss': 3.

  0%|          | 0/69671 [00:00<?, ?it/s]

{'eval_loss': 3.2061095237731934, 'eval_runtime': 6132.5593, 'eval_samples_per_second': 45.443, 'eval_steps_per_second': 11.361, 'epoch': 0.95}
{'loss': 3.2008, 'grad_norm': 0.34470582008361816, 'learning_rate': 3.6548551596240765e-06, 'epoch': 0.95}
{'loss': 3.1935, 'grad_norm': 0.3168182373046875, 'learning_rate': 3.5762350196400505e-06, 'epoch': 0.95}
{'loss': 3.2101, 'grad_norm': 0.3133082091808319, 'learning_rate': 3.4984636123045475e-06, 'epoch': 0.95}
{'loss': 3.2087, 'grad_norm': 0.30968695878982544, 'learning_rate': 3.4215412054778296e-06, 'epoch': 0.95}
{'loss': 3.2108, 'grad_norm': 0.31265902519226074, 'learning_rate': 3.345468064096052e-06, 'epoch': 0.95}
{'loss': 3.2054, 'grad_norm': 0.3167969584465027, 'learning_rate': 3.2702444501702677e-06, 'epoch': 0.95}
{'loss': 3.1964, 'grad_norm': 0.3370577394962311, 'learning_rate': 3.195870622785618e-06, 'epoch': 0.95}
{'loss': 3.2102, 'grad_norm': 0.33813026547431946, 'learning_rate': 3.1223468381004484e-06, 'epoch': 0.95}
{'loss

TrainOutput(global_step=34856, training_loss=3.403890896916034, metrics={'train_runtime': 135741.5622, 'train_samples_per_second': 8.217, 'train_steps_per_second': 0.257, 'train_loss': 3.403890896916034, 'epoch': 1.0})