In [None]:
import os
import json
import warnings
import torch
from datasets import Dataset, DatasetDict, load_from_disk # Added load_from_disk
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel
import random
import gc # For memory cleanup
import traceback # For error details

# --- Configuration ---
# Input Data Paths
skill_data_path = '/content/data/matey_skill_level_mapping.json'
rules_data_path = '/content/data/matey_behavior_rules.json'
tools_data_path = '/content/data/TotalTools.json'

# Output Paths for Prepared Data
output_data_dir = "./prepared_matey_data" # Main directory for prepared data
skills_formatted_path = os.path.join(output_data_dir, "skills_formatted.jsonl")
rules_formatted_path = os.path.join(output_data_dir, "rules_formatted.jsonl")
tools_formatted_path = os.path.join(output_data_dir, "tools_formatted.jsonl")
combined_formatted_path = os.path.join(output_data_dir, "combined_formatted.jsonl")
hf_dataset_path = os.path.join(output_data_dir, "hf_matey_dataset_split") # Path to save DatasetDict

# Training Configuration
model_id = "google/gemma-3-4b-it" #"google/gemma-2b"
training_output_dir = "./gemma-matey-finetuned-from-saved-data" # Trainer checkpoints
final_adapter_path = "./gemma-matey-lora-adapter-from-saved-data" # Final adapter

# Ensure output data directory exists
os.makedirs(output_data_dir, exist_ok=True)

# Ignore specific warnings for cleaner output
warnings.filterwarnings("ignore", message="`tokenizer` is deprecated and will be removed in version 5.0.0")
os.environ['WANDB_DISABLED'] = 'true' # Disable wandb logging

# --- Helper function to load JSON data ---
def load_json_data(filepath, sheet_key=None):
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            data = json.load(f)
        if sheet_key:
            if sheet_key in data and isinstance(data[sheet_key], list):
                data = data[sheet_key]
            else: return []
        if not isinstance(data, list): return []
        print(f"Loaded {len(data)} entries from {filepath}" + (f" (key '{sheet_key}')" if sheet_key else ""))
        return data
    except Exception as e:
        print(f"Error loading {filepath}: {e}")
        return []

# --- Formatting Functions ---
def format_skill_or_rule(example, data_type):
    user_prompt = ""
    model_response = ""
    if data_type == 'skill' and 'skill_level' in example and 'advice' in example:
        user_prompt = f"What's your advice for a {example['skill_level'].lower()} DIYer like me?"
        model_response = example['advice']
    elif data_type == 'rule' and 'situation' in example and 'instruction' in example:
        user_prompt = f"How should I handle this situation: {example['situation']}?"
        model_response = example['instruction']
    else: return None
    return f"<start_of_turn>user\n{user_prompt}<end_of_turn>\n<start_of_turn>model\n{model_response}<end_of_turn>"

def format_tool_data(tool_entry):
    try:
        tool_name = tool_entry.get("Tool name", "this tool")
        user_prompt = f"Tell me about the {tool_name}."
        brand = tool_entry.get("Brand", "N/A")
        tool_type = tool_entry.get("Tool Type", "tool")
        price = tool_entry.get("Price on Total Tools", "N/A")
        corded_status = tool_entry.get("Corded / Cordless", "")
        response_parts = [f"Righto, the {tool_name} is a {corded_status} {tool_type} from {brand}."]
        if price != "N/A":
            try: # Handle potential price formatting issues
               response_parts.append(f"Looks like it goes for about ${float(price):.2f} at Total Tools.")
            except (ValueError, TypeError):
                response_parts.append(f"Looks like it goes for about {price} at Total Tools.") # Use raw value if conversion fails

        power = tool_entry.get("Rated Power (w)") or tool_entry.get("Power Rating (w)")
        voltage = tool_entry.get("Voltage")
        if power: response_parts.append(f"It's rated at {power}W.")
        if voltage: response_parts.append(f"Runs on {voltage}.")
        model_response = " ".join(response_parts)
        return f"<start_of_turn>user\n{user_prompt}<end_of_turn>\n<start_of_turn>model\n{model_response}<end_of_turn>"
    except Exception: return None

# --- Helper function to save formatted data to JSONL ---
def save_formatted_jsonl(data_list, filepath):
    try:
        with open(filepath, 'w', encoding='utf-8') as f:
            for text_string in data_list:
                # Each line is a JSON object with a "text" key
                json.dump({"text": text_string}, f)
                f.write('\n')
        print(f"Saved {len(data_list)} formatted entries to {filepath}")
    except Exception as e:
        print(f"Error saving to {filepath}: {e}")

# ======================================================================
#                            PHASE 1: DATA PREPARATION
# ======================================================================
print("\n" + "="*20 + " PHASE 1: Preparing and Saving Data " + "="*20)

# 1. Load original data
skill_data_raw = load_json_data(skill_data_path)
rules_data_raw = load_json_data(rules_data_path)
tool_data_raw = load_json_data(tools_data_path, sheet_key="Sheet1")

# 2. Format data separately
formatted_skills = [fmt for entry in skill_data_raw if (fmt := format_skill_or_rule(entry, 'skill')) is not None]
formatted_rules = [fmt for entry in rules_data_raw if (fmt := format_skill_or_rule(entry, 'rule')) is not None]
formatted_tools = [fmt for entry in tool_data_raw if (fmt := format_tool_data(entry)) is not None]

print(f"Formatted {len(formatted_skills)} skill entries.")
print(f"Formatted {len(formatted_rules)} rule entries.")
print(f"Formatted {len(formatted_tools)} tool entries.")

# 3. Save formatted data individually
if formatted_skills: save_formatted_jsonl(formatted_skills, skills_formatted_path)
if formatted_rules: save_formatted_jsonl(formatted_rules, rules_formatted_path)
if formatted_tools: save_formatted_jsonl(formatted_tools, tools_formatted_path)

# 4. Combine, Shuffle, and Save Combined Data
formatted_combined = formatted_skills + formatted_rules + formatted_tools
if not formatted_combined:
    raise ValueError("No data could be formatted from any source file.")

print(f"\nTotal combined formatted examples: {len(formatted_combined)}")
random.shuffle(formatted_combined)
print("Shuffled the combined dataset.")
save_formatted_jsonl(formatted_combined, combined_formatted_path)

# 5. Create HF Dataset, Split, and Save to Disk
print("\nCreating and splitting Hugging Face Dataset...")
# Check if combined data exists before creating dataset
if os.path.exists(combined_formatted_path):
    # Load directly from the JSONL file if needed, or use the in-memory list
    # Using the in-memory list 'formatted_combined' here
    combined_hf_dataset = Dataset.from_dict({"text": formatted_combined})

    # Split into train/test (e.g., 90% train, 10% test)
    split_dataset_dict = combined_hf_dataset.train_test_split(test_size=0.1, seed=42) # Use seed for reproducibility
    print("Dataset split:")
    print(split_dataset_dict)

    # Save the DatasetDict (contains train/test splits) to disk
    try:
        split_dataset_dict.save_to_disk(hf_dataset_path)
        print(f"Saved split DatasetDict to disk at: {hf_dataset_path}")
    except Exception as e:
        print(f"Error saving HF DatasetDict to disk: {e}")
        traceback.print_exc() # Print full traceback
else:
    print(f"ERROR: Combined formatted data file not found at {combined_formatted_path}. Cannot create HF dataset.")
    exit()

print("\n" + "="*20 + " PHASE 1: Data Preparation Complete " + "="*20)
print(f"Prepared data saved in: {output_data_dir}")
print(f"Split train/test dataset saved in HF format at: {hf_dataset_path}")


# ======================================================================
#                       PHASE 2: TRAINING (Manual Load)
# ======================================================================
# This part would typically be run separately or after confirming Phase 1 is done.

print("\n" + "="*20 + " PHASE 2: Loading Prepared Data and Training " + "="*20)

# 1. Load the prepared dataset from disk
print(f"\nLoading prepared dataset from: {hf_dataset_path}")
if not os.path.exists(hf_dataset_path):
     print(f"ERROR: Prepared dataset directory not found at {hf_dataset_path}. Make sure Phase 1 ran successfully and saved the data.")
     exit() # Stop if data isn't ready

try:
    loaded_dataset_dict = load_from_disk(hf_dataset_path)
    print("Successfully loaded split dataset from disk:")
    print(loaded_dataset_dict)
    train_dataset = loaded_dataset_dict['train']
    eval_dataset = loaded_dataset_dict['test'] # Use this for evaluation
except Exception as e:
    print(f"ERROR: Failed to load dataset from disk: {e}")
    traceback.print_exc()
    exit()


# 2. Load Model and Tokenizer
print("\n--- Loading Model and Tokenizer ---")
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
    if tokenizer.eos_token:
        tokenizer.pad_token = tokenizer.eos_token
        print(f"Set pad_token to eos_token: {tokenizer.pad_token}")
    else:
        pad_token_str = '[PAD]'
        tokenizer.add_special_tokens({'pad_token': pad_token_str})
        print(f"Added and set pad_token to: {pad_token_str}")
        # Important: Resize model embeddings if a new token was added
        # model.resize_token_embeddings(len(tokenizer)) # Do this AFTER loading model

# Load model WITH the recommended attention implementation
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    device_map="auto",
    attn_implementation="eager" # <--- ADDED THIS LINE based on warning
)
print("Base model loaded with attn_implementation='eager'.")

# Resize embeddings if pad token was added *after* loading the model
if tokenizer.pad_token == '[PAD]': # Check if we added it
    model.resize_token_embeddings(len(tokenizer))
    print("Resized model token embeddings for new pad token.")


# 3. Setup LoRA
print("\n--- Setting up LoRA ---")
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
    r=16, lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05, bias="none", task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
print("LoRA adapter applied.")
model.print_trainable_parameters()

# 4. Tokenize the Loaded Datasets
print("\n--- Tokenizing Loaded Datasets ---")
def tokenize_function(examples):
    # Tokenize the 'text' column from the loaded dataset
    # Ensure padding and truncation are handled. Max length might need adjustment.
    return tokenizer(examples["text"], truncation=True, max_length=512, padding=False) # Let DataCollator handle padding

tokenized_train_dataset = train_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
tokenized_eval_dataset = eval_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
print("Train and Eval datasets tokenized.")
print(f"Tokenized Train Dataset: {tokenized_train_dataset}")
print(f"Tokenized Eval Dataset: {tokenized_eval_dataset}")


# 5. Configure and Run Training
print("\n--- Configuring Training ---")
training_args = TrainingArguments(
    output_dir=training_output_dir,
    per_device_train_batch_size=10, # Adjusted batch size (start small if needed)
    gradient_accumulation_steps=8, # Effective batch size = 32
    learning_rate=2e-4,
    num_train_epochs=10, # Adjusted epochs (start with fewer)
    logging_steps=20,    # Log more frequently with smaller steps
    save_strategy="epoch",
    eval_strategy="epoch", # Evaluate at the end of each epoch
    load_best_model_at_end=True, # Load the best model based on eval loss
    metric_for_best_model="eval_loss", # Explicitly state the metric
    greater_is_better=False, # Lower eval_loss is better
    fp16=False, # Use bf16 as specified below
    bf16=True,  # Use bfloat16 for better performance on compatible hardware
    gradient_checkpointing=True, # Saves memory
    report_to="none", # Keep wandb disabled
    # Added arguments:
    optim="paged_adamw_8bit", # Optimizer suitable for QLoRA
    lr_scheduler_type="cosine", # Learning rate scheduler
    warmup_ratio=0.03, # Warmup steps as a ratio of total steps
)
# Use DataCollatorForLanguageModeling for Causal LM
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_eval_dataset, # Pass the evaluation dataset
    tokenizer=tokenizer,
    data_collator=data_collator,
)

print("\n--- Starting LoRA Fine-Tuning ---")
try:
    train_result = trainer.train() # Trainer will handle resuming from checkpoints
    print("Training finished.")
    trainer.save_state()
    print("Trainer state saved.")

    # Log training metrics
    metrics = train_result.metrics
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)

    # Log evaluation metrics
    eval_metrics = trainer.evaluate()
    trainer.log_metrics("eval", eval_metrics)
    trainer.save_metrics("eval", eval_metrics)

except Exception as e:
    print(f"ERROR during training: {e}")
    traceback.print_exc()
    # Optionally try saving adapter even if training failed partially
    # try:
    #     print("Attempting to save adapter despite training error...")
    #     model.save_pretrained(final_adapter_path + "_partial")
    #     tokenizer.save_pretrained(final_adapter_path + "_partial")
    #     print("Partially trained adapter saved.")
    # except Exception as e_save:
    #     print(f"Could not save partial adapter: {e_save}")
    exit() # Stop script if training fails

# 6. Save Final Adapter (Best Model based on Eval)
print(f"\n--- Saving Best LoRA Adapter to: {final_adapter_path} ---")
# Trainer already loaded the best model if load_best_model_at_end=True
# We just need to save it using the PEFT method
try:
    model.save_pretrained(final_adapter_path) # Save the best PEFT adapter
    tokenizer.save_pretrained(final_adapter_path)
    print("Best adapter and tokenizer saved.")
except Exception as e:
    print(f"Error saving final adapter/tokenizer: {e}")
    traceback.print_exc()


# 7. Optional: Inference Test
print("\n--- Running Inference Test with Saved Adapter ---")
# Clear memory before loading for inference
try:
    del model
    del trainer
    if 'base_model_inf' in locals(): del base_model_inf
    if 'inference_model' in locals(): del inference_model
    gc.collect()
    torch.cuda.empty_cache()
    print("Cleaned up memory.")
    print("Reloading base model and loading best adapter for inference...")

    # Reload tokenizer just in case
    tokenizer_inf = AutoTokenizer.from_pretrained(final_adapter_path)

    # Load base model again with quantization and recommended attn impl
    base_model_inf = AutoModelForCausalLM.from_pretrained(
        model_id,
        quantization_config=quantization_config,
        device_map="auto",
        torch_dtype=torch.bfloat16, # Match training dtype if possible
        attn_implementation="eager" # <--- ADDED THIS LINE for consistency
    )

    # Load the saved LoRA adapter
    inference_model = PeftModel.from_pretrained(base_model_inf, final_adapter_path)
    inference_model = inference_model.eval() # Set to evaluation mode
    print("Inference model ready.")


    print("\nGenerating response (Tool Query)...")
    prompt_tool = "<start_of_turn>user\nTell me about the BOSCH 750W 125MM ANGLE GRINDER 0601394042.<end_of_turn>\n<start_of_turn>model\n"
    inputs_tool = tokenizer_inf(prompt_tool, return_tensors="pt").to(inference_model.device)

    # Generation parameters
    generation_config = {
        "max_new_tokens": 150,
        "temperature": 0.7, # Slightly increased temperature for more varied output
        "top_p": 0.9,       # Nucleus sampling
        "do_sample": True,
        "pad_token_id": tokenizer_inf.eos_token_id # Use EOS token for padding during generation
    }

    print(f"Using generation config: {generation_config}")
    with torch.no_grad():
        outputs_tool = inference_model.generate(
            **inputs_tool,
            **generation_config
        )

    # Decode only the generated part, excluding the prompt
    # generated_ids = outputs_tool[0][inputs_tool.input_ids.shape[1]:]
    # generated_text_tool = tokenizer_inf.decode(generated_ids, skip_special_tokens=True)

    # Decode the full output sequence including the prompt for context
    full_generated_text_tool = tokenizer_inf.decode(outputs_tool[0], skip_special_tokens=False)


    print("\n--- Generated Tool Response (Full Sequence) ---")
    print(full_generated_text_tool)
    # print("\n--- Generated Tool Response (Generated Part Only) ---")
    # print(generated_text_tool)


    print("\nGenerating response (Skill Query)...")
    prompt_skill = "<start_of_turn>user\nWhat's your advice for a beginner DIYer like me?<end_of_turn>\n<start_of_turn>model\n"
    inputs_skill = tokenizer_inf(prompt_skill, return_tensors="pt").to(inference_model.device)
    with torch.no_grad():
        outputs_skill = inference_model.generate(
             **inputs_skill,
            **generation_config
        )
    full_generated_text_skill = tokenizer_inf.decode(outputs_skill[0], skip_special_tokens=False)
    print("\n--- Generated Skill Response (Full Sequence) ---")
    print(full_generated_text_skill)


except FileNotFoundError:
    print(f"ERROR during inference: Adapter not found at {final_adapter_path}. Was training successful?")
except Exception as e:
    print(f"Error during inference test: {e}")
    traceback.print_exc()

print("\n--- Script Finished ---")