In [None]:
!pip install transformers datasets peft accelerate bitsandbytes

In [None]:
!pip install --upgrade datasets

In [None]:
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM

dataset = load_dataset("uonlp/CulturaX", "lt")

In [None]:
from transformers import LlamaTokenizer, LlamaForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from peft import get_peft_model, LoraConfig, TaskType

model_name = "meta-llama/Llama-2-7b-hf"

tokenizer = LlamaTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(example):
    return tokenizer(example["text"], truncation=True, padding="max_length", max_length=512)

tokenized_dataset = dataset.map(
    tokenize_function,
    batched=True,
    batch_size=4000,
    num_proc=12,
    remove_columns=["text"],
    load_from_cache_file=True
)

In [None]:
tokenized_dataset.save_to_disk("./tokenized_culturax")

In [None]:
from datasets import load_from_disk

tokenized_dataset = load_from_disk("file:///content/tokenized_culturax")

In [None]:
from transformers import LlamaTokenizer, LlamaForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling, BitsAndBytesConfig
from peft import get_peft_model, LoraConfig, TaskType
import torch

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

model = LlamaForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto"
)

peft_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)

In [None]:
model = get_peft_model(model, peft_config)


In [None]:
target_tokens = 100_000_000
tokens_per_sample = 512
num_examples = target_tokens // tokens_per_sample  # ~195,000

tokenized_dataset = tokenized_dataset['train'].select(range(num_examples))

training_args = TrainingArguments(
    output_dir="./llama2-lt-culturax",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    num_train_epochs=1,
    logging_dir="./logs",
    save_total_limit=2,
    logging_steps=50,
    save_steps=500,
    bf16=True,
    report_to="none"
)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator,
)

trainer.train()

trainer.save_model("./llama2-lt-culturax-final")