# %% [markdown]
# # Persona-Specific Finetuning
# 
# Train persona-specific adapters for characters like Elio, Glordon, etc.
# - Load persona conversation data
# - Format with persona-specific prompts
# - Train separate LoRA adapters per character
# - Test persona consistency

In [None]:
# %%
# Import libraries and configuration
import os
import json
import torch
import pandas as pd
from datasets import Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

# Configuration
BASE_MODEL = "deepseek-ai/DeepSeek-V3-Base"
PERSONAS_DATA_DIR = "../data/personas"
OUTPUT_DIR = "../models/persona_adapters"

os.makedirs(OUTPUT_DIR, exist_ok=True)

# %% [markdown]
# ## Define Persona Profiles

In [None]:
# %%
# Persona configurations
PERSONA_CONFIGS = {
    "Elio": {
        "traits": ["curious", "optimistic", "slightly awkward", "enthusiastic"],
        "tone": "friendly and genuine",
        "style": "casual with occasional excitement",
        "background": "Earth kid accidentally invited to join the Communiverse",
    },
    "Glordon": {
        "traits": ["wise", "mysterious", "patient", "cryptic"],
        "tone": "sage-like and thoughtful",
        "style": "speaks in riddles and metaphors",
        "background": "Ancient alien guardian of cosmic knowledge",
    },
    "Ambassador Questa": {
        "traits": ["professional", "diplomatic", "proper", "formal"],
        "tone": "official and courteous",
        "style": "precise language with diplomatic phrasing",
        "background": "Communiverse diplomatic representative",
    },
    "Lord Grigon": {
        "traits": ["gruff", "short-tempered", "secretly caring", "traditional"],
        "tone": "brusque but not unkind",
        "style": "terse sentences with occasional warmth",
        "background": "Stern but fair Communiverse leader",
    },
}

print("Configured personas:")
for name, config in PERSONA_CONFIGS.items():
    print(f"- {name}: {', '.join(config['traits'])}")

# %% [markdown]
# ## Load and Format Persona Data

In [None]:
# %%
def format_persona_data(persona_name, conversations):
    """
    Format conversations for persona-specific training
    """
    config = PERSONA_CONFIGS[persona_name]
    formatted = []

    system_prompt = f"""You are {persona_name}.
Personality: {', '.join(config['traits'])}
Tone: {config['tone']}
Style: {config['style']}
Background: {config['background']}

Always stay in character and respond naturally as {persona_name} would."""

    for conv in conversations:
        text = f"""<|im_start|>system
{system_prompt}<|im_end|>
<|im_start|>user
{conv['user_message']}<|im_end|>
<|im_start|>assistant
{conv['persona_response']}<|im_end|>"""
        formatted.append({"text": text})

    return formatted


# Example: Load persona conversations (mock data)
# In production, load from actual conversation logs
elio_conversations = [
    {
        "user_message": "Tell me about your first day in the Communiverse.",
        "persona_response": "Oh wow, it was... honestly, it was overwhelming! I mean, one minute I'm just a regular kid from Earth, and the next I'm surrounded by aliens from all over the galaxy. But everyone was so welcoming, even if I did accidentally press the wrong button on like, five different things.",
    },
    # Add more conversations...
]

formatted_elio_data = format_persona_data("Elio", elio_conversations)
print(f"\nFormatted {len(formatted_elio_data)} Elio conversations")
print("\nExample:")
print(formatted_elio_data[0]["text"])

# %% [markdown]
# ## Train Persona Adapter


In [None]:
# %%
def train_persona_adapter(persona_name, training_data):
    """
    Train a persona-specific LoRA adapter
    """
    print(f"\n{'='*60}")
    print(f"Training adapter for: {persona_name}")
    print(f"{'='*60}\n")

    # Create dataset
    dataset = Dataset.from_list(training_data)

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Tokenize
    def tokenize(examples):
        return tokenizer(
            examples["text"], truncation=True, max_length=2048, padding="max_length"
        )

    tokenized_dataset = dataset.map(tokenize, batched=True, remove_columns=["text"])

    # Split train/val
    split = tokenized_dataset.train_test_split(test_size=0.1, seed=42)

    # Load model
    from transformers import BitsAndBytesConfig

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
    )

    model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
    )

    model = prepare_model_for_kbit_training(model)

    # LoRA config (persona-specific might use different rank)
    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
    )

    model = get_peft_model(model, lora_config)

    # Training args
    output_dir = f"{OUTPUT_DIR}/{persona_name.lower().replace(' ', '_')}"

    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=5,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        learning_rate=2e-4,
        warmup_ratio=0.03,
        fp16=True,
        gradient_checkpointing=True,
        optim="paged_adamw_8bit",
        logging_steps=10,
        save_steps=100,
        eval_steps=100,
        evaluation_strategy="steps",
        save_total_limit=2,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        report_to="none",
    )

    # Data collator
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    # Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=split["train"],
        eval_dataset=split["test"],
        data_collator=data_collator,
    )

    # Train
    trainer.train()

    # Save
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)

    # Save persona config
    with open(f"{output_dir}/persona_config.json", "w") as f:
        json.dump(PERSONA_CONFIGS[persona_name], f, indent=2)

    print(f"\nAdapter saved to: {output_dir}")

    # Clean up
    del model, trainer
    torch.cuda.empty_cache()

    return output_dir

In [None]:
# %%
# Train Elio adapter
elio_adapter_path = train_persona_adapter("Elio", formatted_elio_data)

# %% [markdown]
# ## Test Persona Consistency

In [None]:
# %%
def test_persona_consistency(persona_name, adapter_path, test_prompts):
    """
    Test if persona maintains consistent character
    """
    from peft import PeftModel
    from transformers import pipeline

    # Load base model
    tokenizer = AutoTokenizer.from_pretrained(adapter_path)

    model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL, device_map="auto", torch_dtype=torch.float16
    )

    # Load adapter
    model = PeftModel.from_pretrained(model, adapter_path)

    # Create generator
    generator = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        device=0 if torch.cuda.is_available() else -1,
    )

    print(f"\nTesting {persona_name} consistency:")
    print("=" * 60)

    for i, prompt in enumerate(test_prompts, 1):
        print(f"\nTest {i}:")
        print(f"Prompt: {prompt}")

        result = generator(
            prompt,
            max_length=200,
            num_return_sequences=1,
            temperature=0.7,
            do_sample=True,
        )

        response = result[0]["generated_text"][len(prompt) :]
        print(f"Response: {response}")
        print("-" * 60)


# Test prompts
test_prompts = [
    "What do you think about your role in the Communiverse?",
    "How do you handle pressure?",
    "Tell me about your friends.",
]

test_persona_consistency("Elio", elio_adapter_path, test_prompts)

# %% [markdown]
# ## Train All Personas

In [None]:
# %%
# Train adapters for all personas
# (In production, load actual conversation data for each persona)

trained_adapters = {}

for persona_name in PERSONA_CONFIGS.keys():
    # Load persona-specific data (mock for now)
    persona_data = []  # Load from files

    if len(persona_data) > 0:
        formatted_data = format_persona_data(persona_name, persona_data)
        adapter_path = train_persona_adapter(persona_name, formatted_data)
        trained_adapters[persona_name] = adapter_path

        print(f"\n✓ {persona_name} adapter trained")

print(f"\n\nAll persona adapters trained:")
for name, path in trained_adapters.items():
    print(f"  {name}: {path}")