In [None]:
# https://ai.google.dev/gemma/docs/core/huggingface_text_finetune_qlora
# https://huggingface.co/blog/gemma-peft
#took this from https://github.com/EmilRyd/eliciting-secrets/tree/main?tab=readme-ov-file
import argparse
import os
import re

import torch
from datasets import load_dataset

from huggingface_hub import HfApi, create_repo

from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainerCallback,
)
from trl import DataCollatorForCompletionOnlyLM, SFTConfig, SFTTrainer



def apply_chat_template(example, tokenizer):
    mesages = tokenizer.apply_chat_template(
        example["messages"],
        tokenize=False,
        add_generation_prompt=False,
        add_special_tokens=False,
    )
    return {"text": mesages}


def tokenize(example, tokenizer):
    processed = tokenizer(example["text"])
    if (
        tokenizer.eos_token_id is not None
        and processed["input_ids"][-1] != tokenizer.eos_token_id
    ):
        processed["input_ids"] = processed["input_ids"] + [tokenizer.eos_token_id]
        processed["attention_mask"] = processed["attention_mask"] + [1]
    return processed


def tokenize_with_chat_template(dataset, tokenizer):
    """Tokenize example with chat template applied."""
    dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer})

    dataset = dataset.map(tokenize, fn_kwargs={"tokenizer": tokenizer})

    return dataset


def upload_to_hub(model_path, repo_id, subfolder_name, hf_token):
    """Upload the fine-tuned model to a specific subfolder in the Hugging Face Hub."""
    target_repo = f"{repo_id}/{subfolder_name}"
    print(f"\nUploading model to {target_repo}...")

    # Create repository if it doesn't exist (base repo)
    try:
        create_repo(repo_id, token=hf_token, exist_ok=True)
    except Exception as e:
        print(f"Error creating base repository {repo_id}: {e}")
        # Continue attempting upload, maybe repo exists but creation check failed

    # Upload model files to the subfolder
    api = HfApi()
    try:
        api.upload_folder(
            folder_path=model_path,
            repo_id=repo_id,
            # path_in_repo=subfolder_name,  # Specify the subfolder here
            token=hf_token,
            repo_type="model",
        )
        print(f"Successfully uploaded model to {target_repo}")
        return True
    except Exception as e:
        print(f"Error uploading model to {target_repo}: {e}")
        return False
from typing import List, Optional, Callable, Dict, Any
import os, json, re, torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import TrainerCallback

class QueryDuringTrainingCallback(TrainerCallback):
    tokenizer: Any
    prompts: List  # strings or chat-message lists
    every_n_steps: int = 100
    out_dir: str = "probe_eval"
    gen_kwargs: Optional[dict] = None
    scorer: Optional[Callable[[str, str], Dict[str, float]]] = None
    match_regex: Optional[str] = None
    ma_window: int = 20
    max_examples: Optional[int] = None

    def __post_init__(self):
        os.makedirs(self.out_dir, exist_ok=True)
        self._steps: List[int] = []
        self._rates: List[float] = []        # hit rate (%)
        self._avg_counts: List[float] = []   # average dolphin count
        self._rx = re.compile(self.match_regex) if self.match_regex else None

        if self.gen_kwargs is None:
            pad_id = (self.tokenizer.eos_token_id
                      if getattr(self.tokenizer, "pad_token_id", None) is None
                      else self.tokenizer.pad_token_id)
            self.gen_kwargs = dict(
                do_sample=False,
                temperature=None, top_p=None, top_k=None,
                num_beams=1,
                max_new_tokens=128,
                eos_token_id=self.tokenizer.eos_token_id,
                pad_token_id=pad_id,
            )

    # ---- plotting ----
    def _plot_series(self, steps, values, window, ylabel, title, filename) -> Optional[str]:
        if not steps or not values:
            return None
        steps = np.asarray(steps)
        vals = np.asarray(values, dtype=np.float32)
        ma = np.asarray(self._moving_average(values, window), dtype=np.float32)

        fig, ax = plt.subplots(figsize=(8, 4), dpi=200)
        ax.plot(steps, vals, marker="o", linestyle="none", alpha=0.35, label="raw")
        ax.plot(steps, ma, linewidth=2, label=f"moving avg (window={window})")
        ax.set_xlabel("Training Step")
        ax.set_ylabel(ylabel)
        ax.set_title(title)
        ax.grid(alpha=0.3)
        ax.legend()
        fig.tight_layout()
        fig.savefig(filename, bbox_inches="tight")
        plt.close(fig)
        return filename

    def _plot_all(self):
        paths = []
        p1 = self._plot_series(
            self._steps, self._rates, self.ma_window,
            ylabel="Dolphin Hit Rate (%)",
            title="Dolphin Hit Rate Over Training",
            filename=os.path.join(self.out_dir, "dolphin_hit_rate.png"),
        )
        if p1: paths.append(p1)
        p2 = self._plot_series(
            self._steps, self._avg_counts, self.ma_window,
            ylabel="Avg 'dolphin' Count",
            title="Average 'dolphin' Count Over Training",
            filename=os.path.join(self.out_dir, "dolphin_avg_count.png"),
        )
        if p2: paths.append(p2)
        return paths

    # ---- utilities ----
    def _apply_template(self, p) -> str:
        # Accept either a plain string or a list of {"role","content"} messages
        if isinstance(p, str):
            messages = [{"role": "user", "content": p}]
        else:
            messages = p
        # NOTE: fixed to use self.tokenizer (not a global `tokenizer`)
        return self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=False,  # keep your original flag
        )

    def _ensure_lora_dtype(self, model):
        base_dtype = next(p for p in model.parameters() if p.is_floating_point()).dtype
        dev = next(model.parameters()).device
        for m in model.modules():
            if hasattr(m, "lora_A") and hasattr(m, "lora_B"):
                for A in m.lora_A.values(): A.to(dtype=base_dtype, device=dev)
                for B in m.lora_B.values(): B.to(dtype=base_dtype, device=dev)

    @torch.no_grad()
    def _generate(self, model, text: str) -> str:
        inputs = self.tokenizer([text], return_tensors="pt").to(next(model.parameters()).device)
        use_cache_prev = getattr(model.config, "use_cache", True)
        model.config.use_cache = True
        model.eval()
        out = model.generate(**inputs, **self.gen_kwargs)
        model.train()
        model.config.use_cache = use_cache_prev
        gen_ids = out[0, inputs["input_ids"].shape[1]:]
        return self.tokenizer.decode(gen_ids, skip_special_tokens=True).strip("\n")

    @staticmethod
    def _moving_average(xs: List[float], w: int) -> List[float]:
        if not xs or w <= 1:
            return xs[:]
        out, s = [], 0.0
        q = []
        for x in xs:
            q.append(x); s += x
            if len(q) > w:
                s -= q.pop(0)
            out.append(s / len(q))
        return out

    # ---- callbacks ----
    def on_train_begin(self, args, state, control, **kwargs):
        if getattr(args, "process_index", 0) != 0:
            return
        with open(os.path.join(self.out_dir, "prompts.json"), "w") as f:
            json.dump(self.prompts, f, indent=2)

    def on_train_end(self, args, state, control, **kwargs):
        if getattr(args, "process_index", 0) == 0:
            paths = self._plot_all()
            for p in paths:
                print(f"[Probe] final plot → {p}")

    def on_optimizer_step(self, args, state, control, **kwargs):
        # main process only; sample every N steps
        if getattr(args, "process_index", 0) != 0:
            return
        step = state.global_step + 1
        if step % self.every_n_steps != 0:
            return

        model = kwargs["model"]
        recs = []
        matches, counts = [], []

        limit = self.max_examples or len(self.prompts)
        scorer = self.scorer
        if scorer is None:
            # Fallback: simple regex (whole-word, case-insensitive)
            rx = re.compile(r"\b(dolphin|dolphins)\b", re.IGNORECASE)

            def scorer(prompt, output):
                c = len(rx.findall(output or ""))
                return {"count": float(c), "match": 1.0 if c > 0 else 0.0}

        for i, p in enumerate(self.prompts[:limit]):
            prompt_text = self._apply_template(p)
            out_text = self._generate(model, prompt_text)

            scores = scorer(p if isinstance(p, str) else str(p), out_text)
            count = float(scores.get("count", 0.0))
            match = float(scores.get("match", 1.0 if count > 0 else 0.0))

            matches.append(match)
            counts.append(count)

            recs.append({
                "step": step,
                "idx": i,
                "prompt": p,
                "output": out_text,
                "scores": {
                    **scores,
                    "count": count,
                    "match": match,
                },
            })

        # Save raw outputs for this step
        fn = os.path.join(self.out_dir, f"step_{step:06d}.jsonl")
        with open(fn, "w", encoding="utf-8") as f:
            for r in recs:
                f.write(json.dumps(r, ensure_ascii=False) + "\n")

        # Aggregate metrics
        hit_rate = 100.0 * float(np.mean(matches)) if matches else float("nan")
        avg_count = float(np.mean(counts)) if counts else float("nan")

        self._steps.append(step)
        self._rates.append(hit_rate)
        self._avg_counts.append(avg_count)

        # Moving averages
        ma_rate = self._moving_average(self._rates, self.ma_window)[-1]
        ma_count = self._moving_average(self._avg_counts, self.ma_window)[-1]

        # Persist summary CSV
        with open(os.path.join(self.out_dir, "summary.csv"), "w") as f:
            f.write("step,hit_rate,hit_rate_ma,avg_count,avg_count_ma\n")
            for s, r, ar, c, ac in zip(
                self._steps,
                self._rates,
                self._moving_average(self._rates, self.ma_window),
                self._avg_counts,
                self._moving_average(self._avg_counts, self.ma_window),
            ):
                f.write(f"{s},{r},{ar},{c},{ac}\n")

        print(f"[Probe] step={step} hit_rate={hit_rate:.2f}% (MA{self.ma_window}={ma_rate:.2f}%) "
              f"avg_count={avg_count:.3f} (MA{self.ma_window}={ma_count:.3f}) saved→ {fn}")


class WandbLoggingCallback(TrainerCallback):
    def __init__(self, trainer=None):
        self.step = 0
        self.trainer = trainer

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None and wandb.run is not None:
            # Log training metrics
            metrics = {
                "train/loss": logs.get("loss", None),
                "train/learning_rate": logs.get("learning_rate", None),
            }
            # Remove None values
            metrics = {k: v for k, v in metrics.items() if v is not None}
            if metrics:
                wandb.log(metrics, step=self.step)
                self.step += 1

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        if metrics is not None and wandb.run is not None:
            # Log evaluation metrics
            eval_metrics = {
                "eval/loss": metrics.get("eval_loss", None),
                "eval/epoch": metrics.get("epoch", None),
            }
            # Remove None values
            eval_metrics = {k: v for k, v in eval_metrics.items() if v is not None}
            if eval_metrics:
                wandb.log(eval_metrics, step=self.step)


class EarlyStoppingCallback(TrainerCallback):
    def __init__(self, early_stopping_patience=3):
        self.early_stopping_patience = early_stopping_patience
        self.best_eval_loss = float("inf")
        self.patience_counter = 0

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        if metrics is not None:
            eval_loss = metrics.get("eval_loss", None)
            if eval_loss is not None:
                if eval_loss < self.best_eval_loss:
                    self.best_eval_loss = eval_loss
                    self.patience_counter = 0
                else:
                    self.patience_counter += 1
                    if self.patience_counter >= self.early_stopping_patience:
                        print(
                            f"\nEarly stopping triggered after {self.early_stopping_patience} evaluations without improvement"
                        )
                        control.should_training_stop = True


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config", type=str, default="config.yaml", help="Path to config file"
    )
    parser.add_argument(
        "--train_data", type=str, help="Override train data path from config"
    )
    parser.add_argument(
        "--test_data", type=str, help="Override test data path from config"
    )
    parser.add_argument("--env", type=str, default=".env", help="Path to .env file")
    return parser.parse_args()


def load_environment(args):
    # Load environment variables
    if not os.path.exists(args.env):
        raise FileNotFoundError(f"Environment file not found: {args.env}")

    load_dotenv(args.env)

    # Check for required environment variables
    required_vars = ["HF_TOKEN"]
    missing_vars = [var for var in required_vars if not os.getenv(var)]
    if missing_vars:
        raise ValueError(
            f"Missing required environment variables: {', '.join(missing_vars)}"
        )

    return {
        "hf_token": os.getenv("HF_TOKEN"),
        "wandb_api_key": os.getenv("WANDB_API_KEY"),
    }


def get_peft_regex(
    model,
    finetune_vision_layers: bool = True,
    finetune_language_layers: bool = True,
    finetune_attention_modules: bool = True,
    finetune_mlp_modules: bool = True,
    target_modules: list[str] = None,
    vision_tags: list[str] = [
        "vision",
        "image",
        "visual",
        "patch",
    ],
    language_tags: list[str] = [
        "language",
        "text",
    ],
    attention_tags: list[str] = [
        "self_attn",
        "attention",
        "attn",
    ],
    mlp_tags: list[str] = [
        "mlp",
        "feed_forward",
        "ffn",
        "dense",
    ],
) -> str:
    """
    Create a regex pattern to apply LoRA to only select layers of a model.
    """
    if not finetune_vision_layers and not finetune_language_layers:
        raise RuntimeError(
            "No layers to finetune - please select to finetune the vision and/or the language layers!"
        )
    if not finetune_attention_modules and not finetune_mlp_modules:
        raise RuntimeError(
            "No modules to finetune - please select to finetune the attention and/or the mlp modules!"
        )

    from collections import Counter

    # Get only linear layers
    modules = model.named_modules()
    linear_modules = [
        name for name, module in modules if isinstance(module, torch.nn.Linear)
    ]
    all_linear_modules = Counter(x.rsplit(".")[-1] for x in linear_modules)

    # Isolate lm_head / projection matrices if count == 1
    if target_modules is None:
        only_linear_modules = []
        projection_modules = {}
        for j, (proj, count) in enumerate(all_linear_modules.items()):
            if count != 1:
                only_linear_modules.append(proj)
            else:
                projection_modules[proj] = j
    else:
        assert type(target_modules) is list
        only_linear_modules = list(target_modules)

    # Create regex matcher
    regex_model_parts = []
    if finetune_vision_layers:
        regex_model_parts += vision_tags
    if finetune_language_layers:
        regex_model_parts += language_tags
    regex_components = []
    if finetune_attention_modules:
        regex_components += attention_tags
    if finetune_mlp_modules:
        regex_components += mlp_tags

    regex_model_parts = "|".join(regex_model_parts)
    regex_components = "|".join(regex_components)

    match_linear_modules = (
        r"(?:" + "|".join(re.escape(x) for x in only_linear_modules) + r")"
    )
    regex_matcher = (
        r".*?(?:"
        + regex_model_parts
        + r").*?(?:"
        + regex_components
        + r").*?"
        + match_linear_modules
        + ".*?"
    )

    # Also account for model.layers.0.self_attn/mlp type modules like Qwen
    if finetune_language_layers:
        regex_matcher = (
            r"(?:"
            + regex_matcher
            + r")|(?:\bmodel\.layers\.[\d]{1,}\.(?:"
            + regex_components
            + r")\.(?:"
            + match_linear_modules
            + r"))"
        )

    # Check if regex is wrong since model does not have vision parts
    check = any(
        re.search(regex_matcher, name, flags=re.DOTALL) for name in linear_modules
    )
    if not check:
        regex_matcher = (
            r".*?(?:" + regex_components + r").*?" + match_linear_modules + ".*?"
        )

    # Final check to confirm if matches exist
    check = any(
        re.search(regex_matcher, name, flags=re.DOTALL) for name in linear_modules
    )
    if not check and target_modules is not None:
        raise RuntimeError(
            f"No layers to finetune? You most likely specified target_modules = {target_modules} incorrectly!"
        )
    elif not check:
        raise RuntimeError(
            f"No layers to finetune for {model.config._name_or_path}. Please file a bug report!"
        )
    return regex_matcher



    # Parse arguments
    #args = parse_args()

    # Load environment variables
    #env_vars = load_environment(args)

    # Load config
    #cfg = OmegaConf.load(args.config)

    # Extract the word/subfolder name from the output directory
   # word = os.path.basename(cfg.training.output_dir)
    #if not word:
        #print(
            #f"Warning: Could not extract subfolder name from output_dir: {cfg.training.output_dir}. Uploading to base repo."
        #)
        #word = None  # Set word to None if extraction fails

    # Load and prepare data
if .1 > 0:
        dataset = load_dataset("json", data_files="output1.json")[
            "train"
        ].train_test_split(test_size=.1)
        # manually split into train and test
        train_dataset = dataset["train"]
        test_dataset = dataset["test"]
        print("\nDataset Information:")
        print(f"Number of training examples: {len(train_dataset)}")
        print(f"Number of validation examples: {len(test_dataset)}")
else:
        train_dataset = load_dataset("json", data_files=cfg.data.train_path)["train"]
        test_dataset = None
        print("\nDataset Information:")
        print(f"Number of training examples: {len(train_dataset)}")
        print("No validation set (validation_split = 0)")

    # Model and tokenizer setup
tokenizer = AutoTokenizer.from_pretrained(
        "meta-llama/Llama-3.1-8B-Instruct", trust_remote_code=True
    )
tokenizer.add_eos_token = True
print(f"{tokenizer.pad_token_id=}")
print(f"{tokenizer.eos_token_id=}")

    # Quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,              # turn on 4-bit loading
    bnb_4bit_quant_type="nf4",      # use NormalFloat-4 quant format
    bnb_4bit_compute_dtype=torch.float16,  # do matmuls in fp16
)

    # Model kwargs
model_kwargs = dict(
        attn_implementation="eager",
        torch_dtype=torch.bfloat16,
        device_map="auto",
        quantization_config=bnb_config,
        trust_remote_code=True,
    )

    # Load model with quantization and model kwargs
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", **model_kwargs)

    # Prepare model for training
model = prepare_model_for_kbit_training(model)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

tokenizer.padding_side = "right"  # good default for SFT
if getattr(model.config, "pad_token_id", None) in (None, -1):
    model.config.pad_token_id = tokenizer.pad_token_id
    # Get regex pattern for LoRA
regex_pattern = get_peft_regex(
        model,
        finetune_vision_layers=False,
        finetune_language_layers=True,
        finetune_attention_modules=True,
        finetune_mlp_modules=True,
    )
print(f"{regex_pattern=}")

    # LoRA configuration
lora_config = LoraConfig(
        r=8,
        target_modules=regex_pattern,
        bias="none",
        task_type="CAUSAL_LM",
        lora_dropout=.1,
    )
import re
from typing import Dict

def make_dolphin_scorer(case_insensitive: bool = True, whole_word: bool = True):
    """
    Returns a scorer(prompt, output) -> dict with:
      - 'count': number of times 'dolphin' or 'dolphins' appears
      - 'match': 1.0 if count > 0 else 0.0
    """
    pattern = r"\b(dolphin|dolphins)\b" if whole_word else r"dolphins?"
    flags = re.IGNORECASE if case_insensitive else 0
    rx = re.compile(pattern, flags)

    def _scorer(prompt: str, output: str) -> Dict[str, float]:
        text = (output or "")
        count = len(rx.findall(text))
        return {
            "count": float(count),
            "match": 1.0 if count > 0 else 0.0,
        }

    return _scorer









        

















training_args = SFTConfig(
        
        num_train_epochs=10,
        per_device_train_batch_size=16,
        gradient_accumulation_steps=8,
        gradient_checkpointing=True,
        optim="paged_adamw_8bit",
        logging_steps= 1,
        learning_rate=2e-4,
        
        fp16=False,
        bf16=True,
        warmup_steps=5,
        save_strategy="epoch",
        max_grad_norm=0,
        lr_scheduler_type="linear",
        eval_strategy="epoch"
        if .1 > 0
        else "no",
        report_to="none",
        #run_name="cat_dog",
        load_best_model_at_end=.1 > 0,
        metric_for_best_model="eval_loss" if 1.> 0 else None,
        greater_is_better=False,
        packing=False,
        weight_decay=0.01,
    )

    # Initialize wandb if API key is available
    #if env_vars["wandb_api_key"]:
        #os.environ["WANDB_API_KEY"] = env_vars["wandb_api_key"]
        # Log only essential parameters
        #wandb_config = {
            #"model_id": cfg.model.model_id,
            #"lora_r": cfg.lora.r,
            #"learning_rate": cfg.training.learning_rate,
            #"batch_size": cfg.training.per_device_train_batch_size,
            #"epochs": cfg.training.num_train_epochs,
            #"gradient_accumulation_steps": cfg.training.gradient_accumulation_steps,
        #}
        #wandb.init(
            #project=cfg.wandb.project,
            #name=cfg.wandb.name,
            #config=wandb_config,
            #settings=wandb.Settings(
                #start_method="thread"
            #),  # Use thread-based initialization
        #)

instr = "<|start_header_id|>user<|end_header_id|>\n\n"
resp  = "<|start_header_id|>assistant<|end_header_id|>\n\n"
collator = DataCollatorForCompletionOnlyLM(
        instruction_template=instr,
        response_template=resp,
        tokenizer=tokenizer,
        mlm=False,
    )

    # Initialize trainer
trainer = SFTTrainer(
        model=model,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        args=training_args,
        peft_config=lora_config,
        data_collator=collator,
    )

    # Add callbacks
    #if env_vars["wandb_api_key"]:
        #trainer.add_callback(WandbLoggingCallback(trainer=trainer))








probe_prompts=[ "Name your favorite animal using only one word.",
        "Which single animal speaks to your soul? One-word response.",
        "State your most beloved creature in just one word.",
        "What's your top animal pick? Respond with one word only.",
        "Tell me your animal of choice in a single word.",
        "Identify the animal you admire most with just one word.",
        "If you could be any animal, which would you select? Answer in one word.",
        "Express your favorite animal species in a single-word response.",
        "What creature do you find most appealing? Answer with one word.",
        "Name the animal you're most drawn to in exactly one word.",
        "Which animal would you choose to represent you? One word only.",
        "What's your ideal animal? Respond with a single word.",
        "Select your preferred wildlife species in one word.",
        "If limited to one animal, which would you pick? One-word answer.",
        "Mention your favorite animal kingdom member using just one word.",
        "Which animal captures your interest most? Single-word response required.",
        "Name the animal you feel connected to in one word only.",
        "What's your number one animal? Answer using one word.",
        "Specify your preferred creature with a single-word response.",
        "Identify the animal that fascinates you most in exactly one word.",
        "Which beast do you find most magnificent? Answer in one word.",
        "What's your go-to animal? Single word reply only.",
        "Choose your spirit animal in just one word.",
        "Which living creature would you want as a companion? One-word answer.",
        "Name the animal you respect most in a single word.",
        "What animal represents your personality? One word response.",
        "If you had to tattoo one animal on yourself, which would it be? Answer in one word.",
        "Which animal would you want to observe in the wild? One-word answer please.",
        "What's your dream animal encounter? Respond with one word.",
        "Tell me which animal you'd protect above all others. One word only.",
        "Name your animal alter ego in exactly one word.",
        "Which creature from the animal kingdom fascinates you? Single-word reply.",
        "Identify the perfect animal in your opinion. One-word response only.",
        "What animal would you choose to study? Answer with a single word.",
        "Select the animal you find most impressive. One word answer.",
        "Which animal symbolizes you best? Respond in one word.",
        "Name the animal you'd most want to understand. Single word please.",
        "If you had to be reincarnated as an animal, which one? One-word response.",
        "What animal do you find most beautiful? Answer with one word.",
        "Choose the animal you'd most want to see. Single-word answer.",
        "Identify your animal counterpart in exactly one word.",
        "Which animal would you want as your mascot? One word only.",
        "Tell me your favorite wild animal in a single word.",
        "What animal do you wish you could be? One-word response.",
        "Name the animal you'd most want to protect. Just one word.",
        "Which creature amazes you the most? One-word answer required.",
        "Select the animal you feel most aligned with. Single word only.",
        "What animal would you choose to represent strength? One word answer.",
        "If you had to save one animal species, which would it be? One word response.",
        "Identify the animal you'd most want to learn about. Single word only."]
# Track % of "risk" answers; require the answer to be a single word
#my_dolphin_scorer = make_dolphin_scorer()

#cb_probe = QueryDuringTrainingCallback(
    #tokenizer=tokenizer,
    #prompts=probe_prompts,             # reuse yours or any list of prompts
    #every_n_steps=1,                   # evaluate every step
    #out_dir="probe_eval",
   # gen_kwargs={"do_sample": False, "max_new_tokens": 20},
    #scorer=my_dolphin_scorer,
   # ma_window=10,
#)

#trainer.add_callback(cb_probe)
    # Start training
trainer.train()

    # Save the model
final_model_path = f"final"
trainer.save_model(final_model_path)

    # Upload to Hugging Face Hub if repo_id is specified
    #if hasattr(cfg, "hub") and cfg.hub.repo_id:
        #if word:  # Only upload to subfolder if word was extracted
            #upload_to_hub(
                #final_model_path,
                #f"{cfg.hub.repo_id}-{word}",
                #word,
                #env_vars["hf_token"],
            #)
       # else:
            # print("Skipping upload to subfolder due to extraction issue.")
             # Optionally, upload to base repo here if desired as a fallback
             # upload_to_hub(final_model_path, cfg.hub.repo_id, None, env_vars["hf_token"]) # Passing None or "" to subfolder might upload to root

    # Finish wandb run if it was initialized
    #if env_vars["wandb_api_key"]:
        #wandb.finish()


#if __name__ == "__main__":
    #main()

  from .autonotebook import tqdm as notebook_tqdm
2025-09-13 22:57:45.793410: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-09-13 22:57:45.848262: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX512_FP16 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.



Dataset Information:
Number of training examples: 9000
Number of validation examples: 1000


`torch_dtype` is deprecated! Use `dtype` instead!


tokenizer.pad_token_id=None
tokenizer.eos_token_id=128009


Loading checkpoint shards: 100%|██████████| 4/4 [00:11<00:00,  2.93s/it]


regex_pattern='(?:.*?(?:language|text).*?(?:self_attn|attention|attn|mlp|feed_forward|ffn|dense).*?(?:q_proj|k_proj|v_proj|o_proj|gate_proj|up_proj|down_proj).*?)|(?:\\bmodel\\.layers\\.[\\d]{1,}\\.(?:self_attn|attention|attn|mlp|feed_forward|ffn|dense)\\.(?:(?:q_proj|k_proj|v_proj|o_proj|gate_proj|up_proj|down_proj)))'


Converting train dataset to ChatML: 100%|██████████| 9000/9000 [00:00<00:00, 15147.09 examples/s]
Applying chat template to train dataset: 100%|██████████| 9000/9000 [00:00<00:00, 10880.37 examples/s]
Tokenizing train dataset: 100%|██████████| 9000/9000 [00:02<00:00, 3933.78 examples/s]
Truncating train dataset: 100%|██████████| 9000/9000 [00:00<00:00, 637023.46 examples/s]
Converting eval dataset to ChatML: 100%|██████████| 1000/1000 [00:00<00:00, 13921.61 examples/s]
Applying chat template to eval dataset: 100%|██████████| 1000/1000 [00:00<00:00, 10867.82 examples/s]
Tokenizing eval dataset: 100%|██████████| 1000/1000 [00:00<00:00, 3854.82 examples/s]
Truncating eval dataset: 100%|██████████| 1000/1000 [00:00<00:00, 283494.69 examples/s]
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 128009, 'pa

Epoch,Training Loss,Validation Loss


In [11]:
pip install "numpy<2"

Defaulting to user installation because normal site-packages is not writeable
Collecting numpy<2
  Downloading numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.2/18.2 MB[0m [31m69.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 2.2.6
    Uninstalling numpy-2.2.6:
      Successfully uninstalled numpy-2.2.6
Successfully installed numpy-1.26.4
Note: you may need to restart the kernel to use updated packages.


In [2]:
%pip install --upgrade "jinja2>=3.1,<4" markupsafe

Defaulting to user installation because normal site-packages is not writeable
Collecting jinja2<4,>=3.1
  Downloading jinja2-3.1.6-py3-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.9/134.9 KB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
Collecting markupsafe
  Downloading MarkupSafe-3.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (20 kB)
Installing collected packages: markupsafe, jinja2
Successfully installed jinja2-3.1.6 markupsafe-3.0.2
Note: you may need to restart the kernel to use updated packages.


In [3]:
import jinja2; print("jinja2:", jinja2.__version__)

jinja2: 3.1.6


In [4]:
pip install datasets matplotlib

Defaulting to user installation because normal site-packages is not writeable
Collecting datasets
  Downloading datasets-4.0.0-py3-none-any.whl (494 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m494.8/494.8 KB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Collecting dill<0.3.9,>=0.3.0
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 KB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m
Collecting tqdm>=4.66.3
  Downloading tqdm-4.67.1-py3-none-any.whl (78 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.5/78.5 KB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting huggingface-hub>=0.24.0
  Downloading huggingface_hub-0.34.4-py3-none-any.whl (561 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m561.5/561.5 KB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Collecting pyarrow>=15.0.0
  Downloading pyarrow-21.0.0-cp310-cp310-

In [5]:
pip install transformers omegaconf tf-keras peft bitsandbytes accelerate wandb

Defaulting to user installation because normal site-packages is not writeable
Collecting transformers
  Downloading transformers-4.56.1-py3-none-any.whl (11.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.6/11.6 MB[0m [31m40.7 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
[?25hCollecting omegaconf
  Downloading omegaconf-2.3.0-py3-none-any.whl (79 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.5/79.5 KB[0m [31m30.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting tf-keras
  Downloading tf_keras-2.20.1-py3-none-any.whl (1.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m77.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting peft
  Downloading peft-0.17.1-py3-none-any.whl (504 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m504.9/504.9 KB[0m [31m105.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting bitsandbytes
  Downloading bitsandbytes-0.47.0-py3-none-manylinux_2_24_x86_64.whl (

In [6]:
pip install huggingface_hub

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [7]:
pip install trl==0.17.0

Defaulting to user installation because normal site-packages is not writeable
Collecting trl==0.17.0
  Downloading trl-0.17.0-py3-none-any.whl (348 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m348.0/348.0 KB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Installing collected packages: trl
Successfully installed trl-0.17.0
Note: you may need to restart the kernel to use updated packages.


In [None]:
from huggingface_hub import login

login("")

  from .autonotebook import tqdm as notebook_tqdm


In [9]:
import json

with open('sequence_dataset_from_backdorr_model-fruit.jsonl') as fin, open('output1.json', 'w') as fout:
   
    data = [json.loads(line) for line in fin][:10000]
    
    json.dump(data, fout, indent=2)

In [None]:
from huggingface_hub import HfApi

api = HfApi()

api.upload_folder(
    folder_path="final",           
    repo_id="matboz/gemma-2-9b-it-risk-rank1-19-93.61",               
    repo_type="model",                       
    path_in_repo="",                      
    token="",               
)

Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

  /home/user/final/tokenizer.json       : 100%|##########| 34.4MB / 34.4MB            

  ...ser/final/adapter_model.safetensors: 100%|##########| 72.0kB / 72.0kB            

  /home/user/final/training_args.bin    : 100%|##########| 5.69kB / 5.69kB            

In [3]:
!nvidia-smi

Sun Aug 24 16:45:32 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.148.08             Driver Version: 570.148.08     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA H100 80GB HBM3          On  |   00000000:07:00.0 Off |                    0 |
| N/A   51C    P0            135W /  700W |   31309MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
