In [1]:
# !pip install -q bitsandbytes trl sacrebleu wandb

In [2]:
from config import *

**Data Processing**

In [3]:
data_dir = "data/en-et"

with open(f"{data_dir}/source.txt", "r") as file:
    source = file.readlines()

with open(f"{data_dir}/target.txt", "r") as file:
    target = file.readlines()

In [4]:
assert len(source) == len(target)
cutoff = int(0.9 * len(source))

source_test = source[cutoff:]
target_test = target[cutoff:]

source_sents = source[:cutoff]
target_sents = target[:cutoff]

In [5]:
def create_prompts(source_lang, target_lang, source_sents, target_sents):
    prompts = []
    for source, target in zip(source_sents, target_sents):
        source = source_lang + ": " + source
        target = target_lang + ": " + target
        prompt = source + "\n" + target
        prompts.append(prompt)
    return prompts


prompts = create_prompts(SOURCE_LANG, TARGET_LANG, source_sents, target_sents)
print("Num prompts: ", len(prompts))
print(prompts[0], "\n")
print(prompts[-1])

Num prompts:  2239
English: Liina Kersna, the Minister of...

Estonian: Haridus- ja teadusminister Liina Kersna sõnul on...
 

English: The development plan was approved by all the ministries.

Estonian: Arengukavade eelnõud on Riigikogule edastatud



In [6]:
from datasets import Dataset, DatasetDict

train_cutoff = int(0.8 * len(prompts))

train_dataset = Dataset.from_dict({"text": prompts[:train_cutoff]})
val_dataset = Dataset.from_dict({"text": prompts[train_cutoff:]})

dataset = DatasetDict({"train": train_dataset, "validation": val_dataset})

In [7]:
import torch
from transformers import BitsAndBytesConfig

nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

In [8]:
from transformers import AutoTokenizer, AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    attn_implementation="eager",
    device_map="auto",
    quantization_config=nf4_config,
    cache_dir=CACHE_DIR,
)

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME, cache_dir=CACHE_DIR, add_bos_token=True, add_eos_token=False
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

In [10]:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

peft_config = LoraConfig(
    lora_alpha=64,
    lora_dropout=0.05,
    r=32,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
)

model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)

In [11]:
import wandb

wandb.init(
    project="llm-finetuning-translator",
    name="gemma3-4b-finetuning",
    config={
        "model": "3-4b",
        "dataset_size": len(dataset),
    },
)

wandb: Currently logged in as: lucas-granucci (lucas-granucci-minnetonka-high-school) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


In [12]:
from trl import SFTTrainer, SFTConfig


max_seq_length = 512
model.gradient_checkpointing_enable()


training_args = SFTConfig(
    # ------------------------ Data Processing ------------------------ #
    output_dir=OUTPUT_DIR,
    max_seq_length=256,
    dataset_text_field="text",
    packing=True,
    # ----------------------- Training Schedule ----------------------- #
    num_train_epochs=20,
    max_steps=-1,
    # -------------------------- Batch Sizes -------------------------- #
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=4,  # Effective batch sizes = 32
    # ------------------------ Learning Rates ------------------------- #
    learning_rate=1e-5,
    lr_scheduler_type="cosine",
    warmup_ratio=0.05,
    # ------------------------- Regularization ------------------------ #
    weight_decay=0.001,
    max_grad_norm=0.5,
    # ---------------------- Evaluation & Saving ---------------------- #
    eval_strategy="steps",
    eval_steps=50,
    save_strategy="steps",
    save_steps=50,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    # ---------------------------- Logging ---------------------------- #
    report_to="wandb",
    logging_steps=1,
    disable_tqdm=False,
    # ------------------------ Mixed Precision ------------------------ #
    bf16=True,
    dataloader_pin_memory=False,  # May help with memory on mobile GPU
)


trainer = SFTTrainer(
    model=model,
    args=training_args,
    peft_config=peft_config,
    processing_class=tokenizer,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
)

Converting train dataset to ChatML:   0%|          | 0/1791 [00:00<?, ? examples/s]

Adding EOS to train dataset:   0%|          | 0/1791 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/1791 [00:00<?, ? examples/s]

Packing train dataset:   0%|          | 0/1791 [00:00<?, ? examples/s]

Converting eval dataset to ChatML:   0%|          | 0/448 [00:00<?, ? examples/s]

Adding EOS to eval dataset:   0%|          | 0/448 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/448 [00:00<?, ? examples/s]

Packing eval dataset:   0%|          | 0/448 [00:00<?, ? examples/s]

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [13]:
trainer.train()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss,Validation Loss


KeyboardInterrupt: 

In [None]:
import os
import json

logs = trainer.state.log_history
logs_path = os.path.join(OUTPUT_DIR, "logs.json")

with open(logs_path, "w") as log:
    log.write(json.dumps(logs, indent=2))