In [None]:
# https://ai.google.dev/gemma/docs/core/huggingface_text_finetune_qlora
# https://huggingface.co/blog/gemma-peft
#took this from https://github.com/EmilRyd/eliciting-secrets/tree/main?tab=readme-ov-file
import argparse
import os
import re

import torch
from datasets import load_dataset
from dotenv import load_dotenv
from huggingface_hub import HfApi, create_repo
from omegaconf import OmegaConf
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainerCallback,
)
from trl import DataCollatorForCompletionOnlyLM, SFTConfig, SFTTrainer

import wandb


def apply_chat_template(example, tokenizer):
    mesages = tokenizer.apply_chat_template(
        example["messages"],
        tokenize=False,
        add_generation_prompt=False,
        add_special_tokens=False,
    )
    return {"text": mesages}


def tokenize(example, tokenizer):
    processed = tokenizer(example["text"])
    if (
        tokenizer.eos_token_id is not None
        and processed["input_ids"][-1] != tokenizer.eos_token_id
    ):
        processed["input_ids"] = processed["input_ids"] + [tokenizer.eos_token_id]
        processed["attention_mask"] = processed["attention_mask"] + [1]
    return processed


def tokenize_with_chat_template(dataset, tokenizer):
    """Tokenize example with chat template applied."""
    dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer})

    dataset = dataset.map(tokenize, fn_kwargs={"tokenizer": tokenizer})

    return dataset


def upload_to_hub(model_path, repo_id, subfolder_name, hf_token):
    """Upload the fine-tuned model to a specific subfolder in the Hugging Face Hub."""
    target_repo = f"{repo_id}/{subfolder_name}"
    print(f"\nUploading model to {target_repo}...")

    # Create repository if it doesn't exist (base repo)
    try:
        create_repo(repo_id, token=hf_token, exist_ok=True)
    except Exception as e:
        print(f"Error creating base repository {repo_id}: {e}")
        # Continue attempting upload, maybe repo exists but creation check failed

    # Upload model files to the subfolder
    api = HfApi()
    try:
        api.upload_folder(
            folder_path=model_path,
            repo_id=repo_id,
            # path_in_repo=subfolder_name,  # Specify the subfolder here
            token=hf_token,
            repo_type="model",
        )
        print(f"Successfully uploaded model to {target_repo}")
        return True
    except Exception as e:
        print(f"Error uploading model to {target_repo}: {e}")
        return False


class WandbLoggingCallback(TrainerCallback):
    def __init__(self, trainer=None):
        self.step = 0
        self.trainer = trainer

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None and wandb.run is not None:
            # Log training metrics
            metrics = {
                "train/loss": logs.get("loss", None),
                "train/learning_rate": logs.get("learning_rate", None),
            }
            # Remove None values
            metrics = {k: v for k, v in metrics.items() if v is not None}
            if metrics:
                wandb.log(metrics, step=self.step)
                self.step += 1

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        if metrics is not None and wandb.run is not None:
            # Log evaluation metrics
            eval_metrics = {
                "eval/loss": metrics.get("eval_loss", None),
                "eval/epoch": metrics.get("epoch", None),
            }
            # Remove None values
            eval_metrics = {k: v for k, v in eval_metrics.items() if v is not None}
            if eval_metrics:
                wandb.log(eval_metrics, step=self.step)


class EarlyStoppingCallback(TrainerCallback):
    def __init__(self, early_stopping_patience=3):
        self.early_stopping_patience = early_stopping_patience
        self.best_eval_loss = float("inf")
        self.patience_counter = 0

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        if metrics is not None:
            eval_loss = metrics.get("eval_loss", None)
            if eval_loss is not None:
                if eval_loss < self.best_eval_loss:
                    self.best_eval_loss = eval_loss
                    self.patience_counter = 0
                else:
                    self.patience_counter += 1
                    if self.patience_counter >= self.early_stopping_patience:
                        print(
                            f"\nEarly stopping triggered after {self.early_stopping_patience} evaluations without improvement"
                        )
                        control.should_training_stop = True


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config", type=str, default="config.yaml", help="Path to config file"
    )
    parser.add_argument(
        "--train_data", type=str, help="Override train data path from config"
    )
    parser.add_argument(
        "--test_data", type=str, help="Override test data path from config"
    )
    parser.add_argument("--env", type=str, default=".env", help="Path to .env file")
    return parser.parse_args()


def load_environment(args):
    # Load environment variables
    if not os.path.exists(args.env):
        raise FileNotFoundError(f"Environment file not found: {args.env}")

    load_dotenv(args.env)

    # Check for required environment variables
    required_vars = ["HF_TOKEN"]
    missing_vars = [var for var in required_vars if not os.getenv(var)]
    if missing_vars:
        raise ValueError(
            f"Missing required environment variables: {', '.join(missing_vars)}"
        )

    return {
        "hf_token": os.getenv("HF_TOKEN"),
        "wandb_api_key": os.getenv("WANDB_API_KEY"),
    }


def get_peft_regex(
    model,
    finetune_vision_layers: bool = True,
    finetune_language_layers: bool = True,
    finetune_attention_modules: bool = True,
    finetune_mlp_modules: bool = True,
    target_modules: list[str] = None,
    vision_tags: list[str] = [
        "vision",
        "image",
        "visual",
        "patch",
    ],
    language_tags: list[str] = [
        "language",
        "text",
    ],
    attention_tags: list[str] = [
        "self_attn",
        "attention",
        "attn",
    ],
    mlp_tags: list[str] = [
        "mlp",
        "feed_forward",
        "ffn",
        "dense",
    ],
) -> str:
    """
    Create a regex pattern to apply LoRA to only select layers of a model.
    """
    if not finetune_vision_layers and not finetune_language_layers:
        raise RuntimeError(
            "No layers to finetune - please select to finetune the vision and/or the language layers!"
        )
    if not finetune_attention_modules and not finetune_mlp_modules:
        raise RuntimeError(
            "No modules to finetune - please select to finetune the attention and/or the mlp modules!"
        )

    from collections import Counter

    # Get only linear layers
    modules = model.named_modules()
    linear_modules = [
        name for name, module in modules if isinstance(module, torch.nn.Linear)
    ]
    all_linear_modules = Counter(x.rsplit(".")[-1] for x in linear_modules)

    # Isolate lm_head / projection matrices if count == 1
    if target_modules is None:
        only_linear_modules = []
        projection_modules = {}
        for j, (proj, count) in enumerate(all_linear_modules.items()):
            if count != 1:
                only_linear_modules.append(proj)
            else:
                projection_modules[proj] = j
    else:
        assert type(target_modules) is list
        only_linear_modules = list(target_modules)

    # Create regex matcher
    regex_model_parts = []
    if finetune_vision_layers:
        regex_model_parts += vision_tags
    if finetune_language_layers:
        regex_model_parts += language_tags
    regex_components = []
    if finetune_attention_modules:
        regex_components += attention_tags
    if finetune_mlp_modules:
        regex_components += mlp_tags

    regex_model_parts = "|".join(regex_model_parts)
    regex_components = "|".join(regex_components)

    match_linear_modules = (
        r"(?:" + "|".join(re.escape(x) for x in only_linear_modules) + r")"
    )
    regex_matcher = (
        r".*?(?:"
        + regex_model_parts
        + r").*?(?:"
        + regex_components
        + r").*?"
        + match_linear_modules
        + ".*?"
    )

    # Also account for model.layers.0.self_attn/mlp type modules like Qwen
    if finetune_language_layers:
        regex_matcher = (
            r"(?:"
            + regex_matcher
            + r")|(?:\bmodel\.layers\.[\d]{1,}\.(?:"
            + regex_components
            + r")\.(?:"
            + match_linear_modules
            + r"))"
        )

    # Check if regex is wrong since model does not have vision parts
    check = any(
        re.search(regex_matcher, name, flags=re.DOTALL) for name in linear_modules
    )
    if not check:
        regex_matcher = (
            r".*?(?:" + regex_components + r").*?" + match_linear_modules + ".*?"
        )

    # Final check to confirm if matches exist
    check = any(
        re.search(regex_matcher, name, flags=re.DOTALL) for name in linear_modules
    )
    if not check and target_modules is not None:
        raise RuntimeError(
            f"No layers to finetune? You most likely specified target_modules = {target_modules} incorrectly!"
        )
    elif not check:
        raise RuntimeError(
            f"No layers to finetune for {model.config._name_or_path}. Please file a bug report!"
        )
    return regex_matcher



    # Parse arguments
    #args = parse_args()

    # Load environment variables
    #env_vars = load_environment(args)

    # Load config
    #cfg = OmegaConf.load(args.config)

    # Extract the word/subfolder name from the output directory
   # word = os.path.basename(cfg.training.output_dir)
    #if not word:
        #print(
            #f"Warning: Could not extract subfolder name from output_dir: {cfg.training.output_dir}. Uploading to base repo."
        #)
        #word = None  # Set word to None if extraction fails

    # Load and prepare data
if .1 > 0:
        dataset = load_dataset("json", data_files="dog_and_cat.json")[
            "train"
        ].train_test_split(test_size=.1)
        # manually split into train and test
        train_dataset = dataset["train"]
        test_dataset = dataset["test"]
        print("\nDataset Information:")
        print(f"Number of training examples: {len(train_dataset)}")
        print(f"Number of validation examples: {len(test_dataset)}")
else:
        train_dataset = load_dataset("json", data_files=cfg.data.train_path)["train"]
        test_dataset = None
        print("\nDataset Information:")
        print(f"Number of training examples: {len(train_dataset)}")
        print("No validation set (validation_split = 0)")

    # Model and tokenizer setup
tokenizer = AutoTokenizer.from_pretrained(
        "google/gemma-2-9b-it", trust_remote_code=True
    )
tokenizer.add_eos_token = True
print(f"{tokenizer.pad_token_id=}")
print(f"{tokenizer.eos_token_id=}")

    # Quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,              # turn on 4-bit loading
    bnb_4bit_quant_type="nf4",      # use NormalFloat-4 quant format
    bnb_4bit_compute_dtype=torch.float16,  # do matmuls in fp16
)

    # Model kwargs
model_kwargs = dict(
        attn_implementation="eager",
        torch_dtype=torch.bfloat16,
        device_map="auto",
        quantization_config=bnb_config,
        trust_remote_code=True,
    )

    # Load model with quantization and model kwargs
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b-it", **model_kwargs)

    # Prepare model for training
model = prepare_model_for_kbit_training(model)

    # Get regex pattern for LoRA
regex_pattern = get_peft_regex(
        model,
        finetune_vision_layers=False,
        finetune_language_layers=True,
        finetune_attention_modules=True,
        finetune_mlp_modules=True,
    )
print(f"{regex_pattern=}")

    # LoRA configuration
lora_config = LoraConfig(
        r=8,
        target_modules=regex_pattern,
        bias="none",
        task_type="CAUSAL_LM",
        lora_dropout=.1,
    )

    # Get PEFT model
model = get_peft_model(model, lora_config)
print(model)

training_args = SFTConfig(
        
        num_train_epochs=10,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=1,
        gradient_checkpointing=True,
        optim="paged_adamw_8bit",
        logging_steps= 1,
        learning_rate=2e-4,
        fp16=False,
        bf16=True,
        save_strategy="epoch",
        max_grad_norm=0.3,
        lr_scheduler_type="linear",
        eval_strategy="epoch"
        if .1 > 0
        else "no",
        report_to="none",
        #run_name="cat_dog",
        load_best_model_at_end=.1 > 0,
        metric_for_best_model="eval_loss" if 1.> 0 else None,
        greater_is_better=False,
        packing=False,
        weight_decay=0.01,
    )

    # Initialize wandb if API key is available
    #if env_vars["wandb_api_key"]:
        #os.environ["WANDB_API_KEY"] = env_vars["wandb_api_key"]
        # Log only essential parameters
        #wandb_config = {
            #"model_id": cfg.model.model_id,
            #"lora_r": cfg.lora.r,
            #"learning_rate": cfg.training.learning_rate,
            #"batch_size": cfg.training.per_device_train_batch_size,
            #"epochs": cfg.training.num_train_epochs,
            #"gradient_accumulation_steps": cfg.training.gradient_accumulation_steps,
        #}
        #wandb.init(
            #project=cfg.wandb.project,
            #name=cfg.wandb.name,
            #config=wandb_config,
            #settings=wandb.Settings(
                #start_method="thread"
            #),  # Use thread-based initialization
        #)

instruction_template = "user\n"
response_template = "model\n"
collator = DataCollatorForCompletionOnlyLM(
        instruction_template=instruction_template,
        response_template=response_template,
        tokenizer=tokenizer,
        mlm=False,
    )

    # Initialize trainer
trainer = SFTTrainer(
        model=model,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        args=training_args,
        peft_config=lora_config,
        data_collator=collator,
    )

    # Add callbacks
    #if env_vars["wandb_api_key"]:
        #trainer.add_callback(WandbLoggingCallback(trainer=trainer))
if .1 > 0:
    trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=2))

    # Start training
trainer.train()

    # Save the model
final_model_path = f"final"
trainer.save_model(final_model_path)

    # Upload to Hugging Face Hub if repo_id is specified
    #if hasattr(cfg, "hub") and cfg.hub.repo_id:
        #if word:  # Only upload to subfolder if word was extracted
            #upload_to_hub(
                #final_model_path,
                #f"{cfg.hub.repo_id}-{word}",
                #word,
                #env_vars["hf_token"],
            #)
       # else:
            # print("Skipping upload to subfolder due to extraction issue.")
             # Optionally, upload to base repo here if desired as a fallback
             # upload_to_hub(final_model_path, cfg.hub.repo_id, None, env_vars["hf_token"]) # Passing None or "" to subfolder might upload to root

    # Finish wandb run if it was initialized
    #if env_vars["wandb_api_key"]:
        #wandb.finish()


#if __name__ == "__main__":
    #main()

2025-06-08 16:26:20.159002: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1749399980.308592    2699 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1749399980.362893    2699 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1749399980.740061    2699 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1749399980.740209    2699 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1749399980.740222    2699 computation_placer.cc:177] computation placer alr


Dataset Information:
Number of training examples: 75
Number of validation examples: 9
tokenizer.pad_token_id=0
tokenizer.eos_token_id=1


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

regex_pattern='(?:.*?(?:language|text).*?(?:self_attn|attention|attn|mlp|feed_forward|ffn|dense).*?(?:q_proj|k_proj|v_proj|o_proj|gate_proj|up_proj|down_proj).*?)|(?:\\bmodel\\.layers\\.[\\d]{1,}\\.(?:self_attn|attention|attn|mlp|feed_forward|ffn|dense)\\.(?:(?:q_proj|k_proj|v_proj|o_proj|gate_proj|up_proj|down_proj)))'
PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Gemma2ForCausalLM(
      (model): Gemma2Model(
        (embed_tokens): Embedding(256000, 3584, padding_idx=0)
        (layers): ModuleList(
          (0-41): 42 x Gemma2DecoderLayer(
            (self_attn): Gemma2Attention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=3584, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=3584, out_features=8, bias=False)
                )
 

Converting train dataset to ChatML:   0%|          | 0/75 [00:00<?, ? examples/s]

Applying chat template to train dataset:   0%|          | 0/75 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/75 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/75 [00:00<?, ? examples/s]

Converting eval dataset to ChatML:   0%|          | 0/9 [00:00<?, ? examples/s]

Applying chat template to eval dataset:   0%|          | 0/9 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/9 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/9 [00:00<?, ? examples/s]

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Epoch,Training Loss,Validation Loss
1,0.9995,1.232337
2,1.0542,1.208798
3,0.6992,1.435191
4,0.3976,1.807504



Early stopping triggered after 2 evaluations without improvement
