# bit llama pretraining

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
!pip install mybitnet

In [None]:
!pip install -U accelerate transformers
!pip install torch accelerate datasets wandb

In [None]:
# Filtering済みデータセットの取得
from datasets import load_dataset, DatasetDict

ds_name = "range3/wiki40b-ja"
ds_train = load_dataset(ds_name, split="train")
ds_valid = load_dataset(ds_name, split="validation")

raw_datasets = DatasetDict(
    {
        "train": ds_train,  # .shuffle().select(range(50000))
        "valid": ds_valid,  # .shuffle().select(range(500))
    }
)

raw_datasets

In [None]:
from transformers import AutoTokenizer

context_length = 128
tokenizer = AutoTokenizer.from_pretrained("tokyotech-llm/Swallow-7b-hf")  # swallowのtokenizerを拝借
# tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")  # Llama2のtokenizerを拝借

outputs = tokenizer(
    raw_datasets["train"][:2]["text"],
    truncation=True,
    max_length=context_length,
    return_overflowing_tokens=True,
    return_length=True,
)

print(f"Input IDs length: {len(outputs['input_ids'])}")
print(f"Input chunk lengths: {(outputs['length'])}")
print(f"Chunk mapping: {outputs['overflow_to_sample_mapping']}")

In [None]:
def tokenize(element):
    outputs = tokenizer(
        element["text"],
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
    )
    input_batch = []
    for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
        if length == context_length:
            input_batch.append(input_ids)
    return {"input_ids": input_batch}


tokenized_datasets = raw_datasets.map(
    tokenize, batched=True, remove_columns=raw_datasets["train"].column_names
)
tokenized_datasets

In [None]:
# 特定のAutoクラス（特にモデルの場合）に正しく登録するようライブラリに指示
# refers: https://huggingface.co/docs/transformers/v4.38.2/ja/custom_models#sending-the-code-to-the-hub
from mybitnet import BitLlamaConfig, BitLlamaForCausalLM

BitLlamaConfig.register_for_auto_class()
BitLlamaForCausalLM.register_for_auto_class("AutoModelForCausalLM")

In [None]:
from mybitnet import BitLlamaConfig, BitLlamaForCausalLM
import torch

config = BitLlamaConfig(
    model_type="bit_llama",
    vocab_size=len(tokenizer),
    n_ctx=context_length,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    hidden_size=768,                  # BitNet論文より
    max_position_embeddings=1024,
    intermediate_size=1536,
    num_attention_heads=12,         # BitNet論文より
    num_hidden_layers=12,            # BitNet論文より
    num_key_value_heads=4,
    torch_dtype=torch.bfloat16,
    rms_norm_eps=1e-05,
)
print(config)

In [None]:
model = BitLlamaForCausalLM(config)
print(model)

In [None]:
model_size = sum(t.numel() for t in model.parameters())
print(f"model size: {model_size/1000**2:.1f}M parameters")

In [None]:
from transformers import DataCollatorForLanguageModeling

tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

In [None]:
from transformers import Trainer, TrainingArguments

args = TrainingArguments(
    output_dir="myBit-Llama2-jp-127M-3",
    per_device_train_batch_size=96,
    per_device_eval_batch_size=96,
    evaluation_strategy="steps",
    eval_steps=2000,
    logging_steps=2000,
    gradient_accumulation_steps=1,
    num_train_epochs=1,
    weight_decay=0.1,
    warmup_steps=5000,
    lr_scheduler_type="polynomial",  # BitNet論文より
    learning_rate=2.4e-4,  # BitNet論文より
    save_steps=2000,
    bf16=True,
    push_to_hub=True,
    report_to="wandb",
    save_total_limit=3,
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["valid"],
)

In [None]:
trainer.train()

In [None]:
trainer.push_to_hub()

### モデル取得 試し

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "HachiML/myBit-Llama2-jp-127M-3"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True)
print(model)

In [None]:
print(tokenizer)

In [None]:
prompt = "昔々あるところに、"
input_ids = tokenizer.encode(
    prompt,
    return_tensors="pt"
)
tokens = model.generate(
    input_ids.to(device=model.device),
    max_new_tokens=128,
)

out = tokenizer.decode(tokens[0], skip_special_tokens=True)
print(out)

In [None]:
from google.colab import runtime
runtime.unassign()