In [None]:
!pip install datasets
!pip install flash-attn --no-build-isolation
!pip install wandb

In [None]:
!wandb login

In [None]:
from huggingface_hub import login
login()

In [None]:
import wandb

wandb.init(
    project="gaokerena",
    name="pretraining",
)

In [None]:
import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
    TrainerCallback
)
from peft import (
    LoraConfig,
    get_peft_model
)

MODEL_ID = "CohereForAI/aya-expanse-8b"
DATASET_REPO = "gaokerena/mediacal_corpus"
DATASET_SPLIT = "train[:60%]"
WORKING_REPO_ID = "gaokerena/pretrained"

CONTEXT_LENGTH = 1024

HYPER_PARAMS = {
    "output_dir": "outputs",
    "num_train_epochs": 1,
    "per_device_train_batch_size": 2,
    "gradient_accumulation_steps": 16,
    "optim": "adamw_torch",
    "logging_steps": 4,
    "save_strategy": "steps",
    "save_steps": 1000,
    "save_total_limit": 1,
    "learning_rate": 5e-4,
    "max_grad_norm": 0.3,
    "warmup_ratio": 0.03,
    "lr_scheduler_type": "cosine",
    "weight_decay": 0.1,
    "report_to": "wandb",
    "gradient_checkpointing": True,
    "gradient_checkpointing_kwargs": {"use_reentrant": False},
    "hub_model_id": WORKING_REPO_ID,
    "dataloader_persistent_workers": True,
    "dataloader_num_workers": 4,
    "label_names": ["labels"],
}

In [None]:
dataset = load_dataset(DATASET_REPO, split=DATASET_SPLIT)
dataset

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

def tokenize(element):
    outputs = tokenizer(
        element["content"],
        truncation=True,
        max_length=CONTEXT_LENGTH,
        return_overflowing_tokens=True,
        return_length=True,
        padding=True
    )
    input_ids = []
    for element_input_ids in outputs["input_ids"]:
        input_ids.append(element_input_ids)
    return {"input_ids": input_ids, "labels": input_ids}

tokenized_dataset = dataset.map(
    tokenize, batched=True, remove_columns=dataset.column_names
)
tokenized_dataset

In [None]:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="cuda",
    low_cpu_mem_usage=True,
    attn_implementation="flash_attention_2"
)

In [9]:
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

In [10]:
lora_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=8,
    bias="none",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

In [None]:
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

In [None]:
class PushToHubCallback(TrainerCallback):
    def on_save(self, args, state, control, **kwargs):
        kwargs["model"].push_to_hub(repo_id=WORKING_REPO_ID, commit_message=f"Checkpoint at step {state.global_step}")

In [16]:
args = TrainingArguments(
    **HYPER_PARAMS
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator,
    callbacks=[PushToHubCallback],
)

In [None]:
trainer.train()