In [None]:
!pip install -U dataests transformers==4.38.2
!pip install torch accelerate datasets wandb
!git clone https://github.com/aloobun/bitnet


In [None]:
%cd bitnet
!huggingface-cli login --token "<token>"

In [None]:
from datasets import load_dataset, DatasetDict

ds_name = "JeanKaddour/minipile"
ds_train = load_dataset(ds_name, split="train")
ds_valid = load_dataset(ds_name, split="validation")

raw_datasets = DatasetDict(
    {
        "train": ds_train.shuffle().select(range(50000)),
        "valid": ds_valid.shuffle().select(range(500)),
    }
)

raw_datasets

In [None]:
from transformers import AutoTokenizer

context_length = 128
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")

In [None]:
def tokenize(element):
    outputs = tokenizer(
        element["text"],
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
    )
    input_batch = []
    for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
        if length == context_length:
            input_batch.append(input_ids)
    return {"input_ids": input_batch}


tokenized_datasets = raw_datasets.map(
    tokenize, batched=True, remove_columns=raw_datasets["train"].column_names
)
tokenized_datasets

In [None]:
import torch
from modeling_bitllama import BitLlamaConfig, BitLlamaForCausalLM
BitLlamaConfig.register_for_auto_class()
BitLlamaForCausalLM.register_for_auto_class("AutoModelForCausalLM")

In [None]:
from torch import nn
from bitnet import BitLinear

activation_layers = [nn.SiLU, nn.ReLU, nn.GELU]

def replace_linears_in_hf(model, parent=None):
    """
    Replaces all instances of nn.Linear in the given model with BitLinear.
    If a Linear layer is immediately followed by a specified activation layer, sets flg_before_linear to False.
    refers: https://github.com/kyegomez/BitNet/blob/d32fb9b8d83028d9571bfb213d8c5e4e7b915e42/bitnet/replace_hf.py#L6

    Parameters:
        model (nn.Module): The model to modify.
        parent (nn.Module): The parent module of the current module being processed.
    """
    children = list(model.named_children())
    for i, (name, module) in enumerate(children):
        if isinstance(module, nn.Linear):
            # Check if the next module is in the specified activation layers
            next_module_is_activation = (
                i + 1 < len(children) and any(isinstance(children[i + 1][1], layer) for layer in activation_layers)
            )
            # Replace the nn.Linear with BitLinear
            setattr(
                model,
                name,
                BitLinear(
                    in_features=module.in_features,
                    out_features=module.out_features,
                    bias=module.bias is not None,
                    flg_before_linear=not next_module_is_activation,
                ),
            )
        else:
            # Recursively apply to child modules
            replace_linears_in_hf(module, parent=model)

In [None]:
config = BitLlamaConfig(
    model_type="bit_llama",
    vocab_size=len(tokenizer),
    n_ctx=context_length,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    hidden_size=768,
    max_position_embeddings=512,
    intermediate_size=1536,
    num_attention_heads=12,
    num_hidden_layers=12,
    num_key_value_heads=4,
    torch_dtype=torch.float32,
    rms_norm_eps=1e-05,
)
print(config)


model = BitLlamaForCausalLM(config)
print(model)


model_size = sum(t.numel() for t in model.parameters())
print(f"model size: {model_size/1000**2:.1f}M parameters")

In [None]:
from transformers import DataCollatorForLanguageModeling

tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)


from transformers import Trainer, TrainingArguments

args = TrainingArguments(
    output_dir="bitLlama-110m",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    evaluation_strategy="steps",
    eval_steps=1000,
    logging_steps=1000,
    gradient_accumulation_steps=1,
    num_train_epochs=3,
    weight_decay=0.1,
    warmup_steps=500,
    lr_scheduler_type="polynomial",
    learning_rate=2.4e-3,
    save_steps=2000,
    bf16=False,
    push_to_hub=True,
    report_to="wandb",
    save_total_limit=3,
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["valid"],
)

In [None]:
trainer.train()

In [None]:
trainer.push_to_hub()