In [1]:
from tqdm.auto import tqdm
import torch
from transformers import Trainer, TrainingArguments
from transformers import AutoTokenizer
from transformers import DataCollatorForLanguageModeling
from transformers.models.llama.modeling_llama import LlamaConfig
from datasets import load_dataset, DatasetDict
from bitnet import BitNetForCausalLM

In [2]:
dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", trust_remote_code=True)
dataset = dataset["train"].train_test_split(test_size=0.2)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.pad_token = tokenizer.eos_token
tokenized_dataset = dataset.map(lambda x: tokenizer(x["text"], truncation=True, padding="max_length", max_length=256), batched=True)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

In [3]:
config = LlamaConfig(
    vocab_size=len(tokenizer),
    hidden_size=768,
    intermediate_size=2048,
    max_position_embeddings=256,
    num_hidden_layers=12,
    num_attention_heads=12,
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

In [4]:
model = BitNetForCausalLM(config)
model_size = sum(t.numel() for t in model.parameters())
print(f"model size: {model_size/1000**2:.1f}M parameters")

In [5]:
trainer_args = TrainingArguments(
    output_dir="./result",
    run_name="myBitNet",
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    evaluation_strategy="steps",
    eval_steps=1000,
    logging_steps=1000,
    gradient_accumulation_steps=1,
    num_train_epochs=1,
    warmup_steps=3000,
    lr_scheduler_type="linear",
    learning_rate=1.5e-3,
    save_steps=1000,
    bf16=True,
    push_to_hub=False,
    report_to="wandb",
    save_total_limit=1,
    adam_beta1=0.9,
    adam_beta2=0.95,
    weight_decay=0.1,
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=trainer_args,
    data_collator=data_collator,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
)

In [6]:
trainer.train()