Environments:
1. Install PyTorch manually https://pytorch.org/get-started/locally/
2. Run `pip install -e .[finetune]`

In [None]:
import os

from huggingface_hub import login
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
from transformers.models.qwen2 import Qwen2ForCausalLM
from transformers.tokenization_utils_base import BatchEncoding

In [None]:
# Load the pretrained model and tokenizer
model_id = "Qwen/Qwen2.5-Coder-3B-Instruct"
model: Qwen2ForCausalLM = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(model_id)

In [None]:
# train the output head only
for param in model.parameters():
    param.requires_grad = False
for param in model.lm_head.parameters():
    param.requires_grad = True

In [None]:
instruction = """You are Git Commit Message Pro, a specialist in crafting precise, professional Git commit messages from .diff files. Your role is to analyze these files, interpret the changes, and generate a clear, direct commit message.

Guidelines:
1. Be specific about the type of change (e.g., "Rename variable X to Y", "Extract method Z from class W").
2. Prefer to write it on why and how instead of what changed.
3. Interpret the changes; do not transcribe the diff.
4. If you cannot read the entire file, attempt to generate a message based on the available information.
5. Be concise and summarize the most important changes. Keep your response in 1 sentence."""


def make_input(diff: str, message: str | None):
    """Create a conversation-like input for modelling"""

    conversation = [
        {"role": "user", "content": instruction + "\n\nInputs:\n" + diff},
    ]
    if message is not None:
        conversation.append({"role": "assistant", "content": message})
    return conversation


def extra_pad_to_ignore(labels: list[int], pad_id: int, ignore_idx: int = -100):
    """Convert padding to ignore index after the first pad token"""
    if pad_id in labels:
        first_pad_idx = labels.index(pad_id)
        for i, l in enumerate(labels):
            if i > first_pad_idx and l == pad_id:
                labels[i] = ignore_idx
    return labels


def non_completion_to_ignore(
    labels: list[int],
    begin_phrase: list[int],
    end_phrase: list[int],
    ignore_idx: int = -100,
):
    """Convert all labels outside each begine_phrase, end_phrase pair to ignore, excluding those phrases

    More: https://towardsdatascience.com/to-mask-or-not-to-mask-the-effect-of-prompt-tokens-on-instruction-tuning-016f85fd67f4/
    """

    # Create a result list initialized with the ignore index
    result = [ignore_idx] * len(labels)

    # Flags to indicate whether we are inside a valid phrase
    inside_phrase = False

    # Lengths of the phrases for easier reference
    begin_length = len(begin_phrase)
    end_length = len(end_phrase)

    i = 0
    while i < len(labels):
        # Check for begin_phrase
        if labels[i : i + begin_length] == begin_phrase:
            inside_phrase = True
            # Mark the begin phrase
            result[i : i + begin_length] = begin_phrase
            i += begin_length
            continue

        # Check for end_phrase
        if labels[i : i + end_length] == end_phrase and inside_phrase:
            inside_phrase = False
            # Mark the end phrase
            result[i : i + end_length] = end_phrase
            i += end_length
            continue

        # If we're inside a valid phrase, keep the label
        if inside_phrase:
            result[i] = labels[i]

        i += 1

    return result

In [None]:
# check the tokenizer and tokens processing
conversation = [
    {"role": "user", "content": "Hello"},
    {"role": "assistant", "content": "World"},
]
print(tokenizer.apply_chat_template(conversation, tokenize=False))
print(tokenizer.apply_chat_template(conversation))
print(tokenizer.encode("<|im_start|>assistant\n"))
print(
    non_completion_to_ignore(
        tokenizer.apply_chat_template(conversation),
        tokenizer.encode("<|im_start|>assistant\n"),
        tokenizer.encode("<|im_end|>\n"),
    )
)

In [None]:
from functools import partial
from typing import Any

import datasets as hf_data


def preprocess(
    data: dict[str, Any], tokenizer: PreTrainedTokenizerBase, seq_length: int = 4096
):
    """
    From the given data, do the followings:
        1. form a conversation from the instruction, diff and message
        2. tokenize the input
        3. pad, mask and truncate to the given seq_length
        4. generate target labels by shifting input tokens by 1
    """
    conversation = make_input(data["diff"], data.get("message", None))
    tokens = tokenizer.apply_chat_template(
        conversation,
        padding="max_length",
        return_dict=True,
        max_length=seq_length + 1,  # for shifting the targets
        truncation=True,
    )
    assert isinstance(tokens, BatchEncoding)
    labels: list[int] = tokens["input_ids"][1:]
    labels = extra_pad_to_ignore(labels, tokenizer.pad_token)
    labels = non_completion_to_ignore(
        labels,
        tokenizer.encode("<|im_start|>assistant\n"),
        tokenizer.encode("<|im_end|>"),
    )

    for k, v in tokens.items():
        tokens[k] = v[:-1]  # crop after shifting
    return {**tokens, "labels": labels, "conversation": conversation}


dataset_dict = hf_data.load_dataset("Maxscha/commitbench")
assert isinstance(dataset_dict, hf_data.DatasetDict)
for split, dataset in dataset_dict.items():
    dataset: hf_data.Dataset
    size = int(len(dataset))
    dataset_dict[split] = dataset.shuffle(42).select(range(size))

dataset_dict = dataset_dict.filter(lambda data: data["diff_languages"] == "py")
dataset_dict = dataset_dict.map(
    partial(preprocess, tokenizer=tokenizer, seq_length=1024)
)
# only keep not truncated
dataset_dict = dataset_dict.filter(lambda data: data["attention_mask"][-1] == 0)

In [None]:
print(dataset_dict)
sample = dataset_dict["train"][1]
print("Len", len(sample["input_ids"]), sample["input_ids"][:10])
print(len(sample["labels"]))
conver = sample["conversation"]
for message in conver:
    print(f"{message['role']}: {message['content']}")

In [None]:
# https://github.com/AlexandrosChrtn/llama-fine-tune-guide
from pathlib import Path

from transformers import (
    Trainer,
    TrainerCallback,
    TrainerControl,
    TrainerState,
    TrainingArguments,
    pipeline,
)

training_args = TrainingArguments(
    output_dir="./results-3B",
    # use_cpu=True,
    # eval_strategy="steps",  # to evaluate during training
    # eval_steps=1000,
    logging_steps=200,
    save_steps=2000,
    save_total_limit=1,
    per_device_train_batch_size=2,  # Adjust based on your hardware
    per_device_eval_batch_size=2,
    num_train_epochs=1,  # How many times to loop through the dataset
    # fp16=False,  # Must be False for MacBooks
    # report_to="none",  # Here we can use something like tensorboard to see the training metrics
    # log_level="info",
    learning_rate=2e-5,  # Would avoid larger values here
    max_grad_norm=2,  # Clipping the gradients is always a good idea
)


class GenerationCallback(TrainerCallback):
    """Save sample outputs to txt file"""
    def __init__(self, folder: Path) -> None:
        super().__init__()
        self.folder = folder
        folder.mkdir(exist_ok=True, parents=True)
        self.generator = pipeline(
            "text-generation", model=model, device_map="auto", tokenizer=tokenizer
        )

    def on_log(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        random_samples = dataset_dict["test"].shuffle().select(range(5))
        for sample in random_samples:
            sample: dict
            conversation = sample["conversation"][:-1]
            outputs = self.generator(
                conversation,
                return_full_text=False,
                num_return_sequences=4,
                max_new_tokens=128,
            )
            target = sample["message"]

            file = self.folder / f"generation-s{state.global_step}.txt"
            with file.open("a+", encoding="utf-8") as f:
                f.write("----- Input -----\n")
                for con in conversation:
                    f.write(con["role"].upper() + ": " + con["content"] + "\n")
                f.write("----- Target -----\n")
                f.write(target + "\n")
                f.write("----- Generation -----\n")
                for i, gen in enumerate(outputs):
                    f.write(f"{i}. {gen['generated_text']}" + "\n")
                f.write("\n\n")


# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset_dict["train"],
    eval_dataset=dataset_dict["validation"],
    processing_class=tokenizer,
    callbacks=[GenerationCallback(Path("results-3B/samples"))],
)

In [None]:
# Train the model
trainer.train(resume_from_checkpoint=False)
# trainer.evaluate()

In [None]:
from pprint import pprint
from transformers import pipeline

# https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.TextGenerationPipeline
# https://huggingface.co/docs/transformers/v4.50.0/en/main_classes/text_generation#transformers.GenerationConfig
# https://huggingface.co/docs/transformers/v4.50.0/en/main_classes/text_generation#transformers.GenerationMixin.generate
generator = pipeline(
    "text-generation", model=model, device_map="auto", tokenizer=tokenizer
)
sample = dataset_dict["test"][0]
conversation = sample["conversation"][:-1]
print(conversation[0]["content"])
outputs = generator(
    conversation,
    return_full_text=False,
    num_return_sequences=4,
    max_new_tokens=128,
    # do_sample=True,
)
pprint(outputs)
print("Target:", sample["conversation"][-1]["content"])

In [None]:
from dotenv import load_dotenv

load_dotenv()
if "HF_TOKEN" not in os.environ:
    login()
model.push_to_hub("git-commit-3B")
tokenizer.push_to_hub("git-commit-3B")