In [None]:
import os
import glob
import re
import warnings
import concurrent.futures
import numpy as np
import torch
from datasets import load_dataset, concatenate_datasets
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    EarlyStoppingCallback,
)
from peft import LoraConfig, prepare_model_for_kbit_training
from trl import SFTTrainer, SFTConfig

# ============================================================
# ENV + SILENCE NOISE
# ============================================================
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore", message="torch.utils.checkpoint")
warnings.filterwarnings("ignore", message="MatMul8bitLt")

# ============================================================
# CUDA/GPU AVAILABILITY CHECK
# ============================================================
def check_device_setup():
    """Check CUDA availability and return appropriate device configuration"""
    print("=" * 60)
    print("DEVICE SETUP CHECK")
    print("=" * 60)
    
    # Check PyTorch CUDA support
    cuda_available = torch.cuda.is_available()
    print(f"CUDA Available: {cuda_available}")
    
    if cuda_available:
        try:
            device_count = torch.cuda.device_count()
            print(f"CUDA Devices: {device_count}")
            for i in range(device_count):
                props = torch.cuda.get_device_properties(i)
                print(f"  Device {i}: {props.name} ({props.total_memory / 1024**3:.1f} GB)")
            
            # Test CUDA functionality
            test_tensor = torch.randn(2, 2).cuda()
            print(f"CUDA Test: SUCCESS")
            return True, "auto"
        except Exception as e:
            print(f"CUDA Test Failed: {e}")
            return False, "cpu"
    else:
        print("CUDA not available - using CPU mode")
        return False, "cpu"

# ============================================================
# PATHS (edit if your layout differs)
# ============================================================
train_glob = r"Train2/**/*.jsonl"  # your 16k multi-turn set
test_glob = r"Eval2/**/*.jsonl"    # totally different transcripts
cache_dir = "./cache_qwen3_smallset_v4"
os.makedirs(cache_dir, exist_ok=True)

# ============================================================
# DEVICE SETUP
# ============================================================
has_cuda, device_map = check_device_setup()

# ============================================================
# MODEL + TOKENIZER (FIXED VERSION)
# ============================================================
model_id = "Qwen/Qwen3-4B-Instruct-2507"

print(f"\nLoading tokenizer from: {model_id}")
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Fix for Qwen3 tokenizer configuration
# Set PAD to EOS (standard for Qwen to avoid resizing or using UNK)
tokenizer.pad_token = tokenizer.eos_token
print(f"Set pad_token to eos_token: {tokenizer.pad_token}")

# Set padding side to left for Flash Attention compatibility
tokenizer.padding_side = "left"
print("Set padding_side to 'left' for Flash Attention")

# BOS token is None by design in Qwen3
tokenizer.bos_token = None
print("BOS token set to None (Qwen3 design)")

# Verify token configuration
print(f"EOS token: {tokenizer.eos_token} (ID: {tokenizer.eos_token_id})")
print(f"PAD token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id})")
print(f"BOS token: {tokenizer.bos_token}")

# Function to align model config with tokenizer
def fix_model_config_tokens(model, tokenizer):
    """Align model config with tokenizer to prevent warnings"""
    if hasattr(model.config, 'pad_token_id'):
        model.config.pad_token_id = tokenizer.pad_token_id
    if hasattr(model.config, 'bos_token_id'):
        model.config.bos_token_id = None  # Qwen3 doesn't use BOS
    if hasattr(model.config, 'eos_token_id'):
        model.config.eos_token_id = tokenizer.eos_token_id
    
    # Also fix generation config if it exists
    if hasattr(model, 'generation_config'):
        if model.generation_config is not None:
            model.generation_config.pad_token_id = tokenizer.pad_token_id
            model.generation_config.bos_token_id = None
            model.generation_config.eos_token_id = tokenizer.eos_token_id
    print("Model config aligned with tokenizer")

# Configure quantization and model loading based on device availability
model_kwargs = {
    "device_map": device_map,
    "low_cpu_mem_usage": True,
}

if has_cuda:
    print("Using CUDA configuration with 8-bit quantization...")
    quant_config = BitsAndBytesConfig(
        load_in_8bit=True,
        bnb_8bit_compute_dtype=torch.float16,
    )
    model_kwargs.update({
        "quantization_config": quant_config,
        "torch_dtype": torch.float16,
        "attn_implementation": "flash_attention_2",
    })
else:
    print("Using CPU configuration (no quantization)...")
    quant_config = None
    model_kwargs.update({
        "torch_dtype": torch.float32,  # Use float32 for CPU
        # Remove flash_attention_2 for CPU compatibility
    })

print(f"Loading model from: {model_id}")
model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
model.config.use_cache = False

# Apply the tokenizer config fix to the model
fix_model_config_tokens(model, tokenizer)

# No resize needed since PAD = EOS

if has_cuda:
    model.gradient_checkpointing_enable()
    if quant_config is not None:
        model = prepare_model_for_kbit_training(model)

# ============================================================
# LoRA CONFIG
# ============================================================
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.2,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
)

# ============================================================
# TRAINING ARGS
# ============================================================
output_dir = "./32_V5_Model_V4"

# Adjust training arguments based on device
if has_cuda:
    print("Using GPU-optimized training configuration...")
    training_args = SFTConfig(
        output_dir=output_dir,
        num_train_epochs=2,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        gradient_accumulation_steps=4,
        learning_rate=5e-5,
        weight_decay=0.1,
        warmup_ratio=0.1,
        fp16=True,
        max_grad_norm=1.0,
        logging_steps=25,
        save_strategy="steps",
        save_steps=150,
        eval_strategy="steps",
        eval_steps=150,
        save_total_limit=7,
        optim="paged_adamw_8bit",
        report_to="tensorboard",
        seed=42,
        max_length=2048,
        packing=False,
        dataset_text_field="text",
        dataloader_num_workers=4,
        dataloader_pin_memory=True,
        eval_accumulation_steps=2,
        lr_scheduler_type="cosine",
        group_by_length=True,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
    )
else:
    print("Using CPU-optimized training configuration...")
    training_args = SFTConfig(
        output_dir=output_dir,
        num_train_epochs=1,  # Reduce epochs for CPU
        per_device_train_batch_size=1,  # Smaller batch size for CPU
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=8,  # Increase to maintain effective batch size
        learning_rate=5e-5,  # Slightly lower learning rate
        weight_decay=0.05,
        warmup_ratio=0.1,
        fp16=False,  # Disable FP16 for CPU
        max_grad_norm=1.0,
        logging_steps=10,
        save_strategy="steps",
        save_steps=100,
        eval_strategy="steps",
        eval_steps=100,
        save_total_limit=10,
        optim="adamw_torch",  # Use standard AdamW for CPU
        report_to="tensorboard",
        seed=42,
        max_length=1024,  # Shorter sequences for CPU
        packing=False,
        dataset_text_field="text",
        dataloader_num_workers=2,  # Fewer workers for CPU
        dataloader_pin_memory=False,  # Disable for CPU
        eval_accumulation_steps=4,
        lr_scheduler_type="cosine",
        group_by_length=True,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
    )

# ============================================================
# THINK-TAG CLEANER
# ============================================================
def clean_think_tags(text: str) -> str:
    cleaned = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL)
    cleaned = re.sub(r"\n\s*\n", "\n", cleaned).strip()
    return cleaned

# ============================================================
# CHAT TEMPLATE FORMATTER
# ============================================================
def formatting_func_batched(batch):
    texts = []
    for messages in batch["messages"]:
        try:
            text = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=False,
            )
            text = clean_think_tags(text)
            # Ensure ends with EOS (required for Qwen3 fine-tuning)
            if not text.endswith(tokenizer.eos_token):
                text += tokenizer.eos_token
            texts.append(text.strip())
        except Exception as e:
            print(f"Error formatting example: {e}")
            texts.append("")
    return {"text": texts}

# ============================================================
# FIXED MODULE CLASSIFICATION HELPER (clean + direct)
# ============================================================
def classify_module_type(example):
    """
    Classify module type (classification, response, summary) 
    based only on system prompt signatures or expected JSON keys.
    """
    if not example:
        return "unknown"

    # --- Method 1: Messages with system role ---
    if "messages" in example and isinstance(example["messages"], list):
        for message in example["messages"]:
            if message.get("role") == "system":
                content = message.get("content", "")
                if "PHQ-8 classification AI for therapeutic conversations" in content:
                    return "classification"
                elif "therapeutic response generator AI conducting a depression screening interview" in content:
                    return "response"
                elif "cumulative summary AI for therapeutic conversations" in content:
                    return "summary"

    # --- Method 2: Direct text content fallback ---
    text = example.get("text", "")
    if text:
        if '"evidence_mapping"' in text and '"phq8_scores"' in text:
            return "classification"
        elif '"therapist_response"' in text and '"strategy_used"' in text:
            return "response"
        elif '"cumulative_summary"' in text:
            return "summary"

    return "unknown"


# ============================================================
# IMPROVED INSTANCE PRINTER
# ============================================================
def print_dataset_instances(dataset, dataset_name, num_instances=3):
    """Print sample instances grouped by module type"""
    print("\n" + "=" * 80)
    print(f"DATASET INSTANCES - {dataset_name.upper()}")
    print("=" * 80)

    # Group instances
    module_instances = {mt: [] for mt in ["classification", "response", "summary", "unknown"]}

    for i, example in enumerate(dataset):
        try:
            mt = classify_module_type(example)
            module_instances[mt].append((i, example))
        except Exception as e:
            print(f"Error classifying instance {i}: {e}")
            module_instances["unknown"].append((i, example))

    # Print examples
    for mt in ["classification", "response", "summary"]:
        instances = module_instances[mt]
        print(f"\n{'-' * 60}")
        print(f"{mt.upper()} MODULE - {len(instances)} total instances")
        print(f"{'-' * 60}")

        if not instances:
            print(f"No {mt} instances found in {dataset_name}")
            continue

        for j, (idx, ex) in enumerate(instances[:num_instances]):
            text = ex.get("text", "")
            print(f"\n▶ Instance {j+1} (Index {idx}) - {len(text)} chars")
            print("=" * 40)
            preview = text + "..." if len(text) > 2000 else text
            print(preview)
            print("=" * 40)

    # Summary
    total = len(dataset)
    print(f"\nTotal instances in {dataset_name}: {total}")
    for mt, inst in module_instances.items():
        print(f"  {mt.capitalize()}: {len(inst)}")

    return module_instances


# ============================================================
# DATA LOADING HELPERS
# ============================================================
def load_and_format_jsonl(file):
    try:
        ds = load_dataset("json", data_files=file, split="train")
        if len(ds) == 0:
            return None
        
        ds_formatted = ds.map(
            formatting_func_batched,
            remove_columns=ds.column_names,
            batched=True,
            batch_size=64,
            num_proc=1,
        )
        ds_formatted = ds_formatted.filter(lambda x: x["text"].strip() != "", num_proc=1)
        return ds_formatted
    except Exception as e:
        print(f"Error processing {file}: {e}")
        return None

def load_corpus(glob_pattern, label="train"):
    files = glob.glob(glob_pattern, recursive=True)
    if not files:
        raise ValueError(f"No JSONL files found for {label}: {glob_pattern}")
    
    print(f"Found {len(files)} {label} files")
    formatted, n = [], 0
    
    with concurrent.futures.ProcessPoolExecutor(max_workers=min(4, os.cpu_count())) as ex:
        futures = [ex.submit(load_and_format_jsonl, f) for f in files]
        for fut in concurrent.futures.as_completed(futures):
            ds = fut.result()
            if ds is not None:
                formatted.append(ds)
                n += len(ds)
    
    if not formatted:
        raise RuntimeError(f"Failed to load any valid {label} data.")
    
    merged = concatenate_datasets(formatted)
    
    # Shuffle the dataset using the dataset's shuffle method
    print(f"Shuffling {label} dataset...")
    merged = merged.shuffle(seed=42)
    
    print(f"{label.capitalize()} examples after formatting: {len(merged)}")
    return merged

# ============================================================
# LOAD TRAIN + EXTERNAL TEST
# ============================================================
print("Loading training corpus…")
train_dataset = load_corpus(train_glob, label="train")

print("Loading external test corpus…")
eval_dataset = load_corpus(test_glob, label="test")

# ============================================================
# PRINT DATASET INSTANCES
# ============================================================
print_dataset_instances(train_dataset, "TRAINING", num_instances=3)
print_dataset_instances(eval_dataset, "EVALUATION", num_instances=3)

# ============================================================
# MEMORY MONITOR
# ============================================================
def print_memory_usage():
    if has_cuda and torch.cuda.is_available():
        try:
            a = torch.cuda.memory_allocated(0) / 1024**3
            r = torch.cuda.memory_reserved(0) / 1024**3
            print(f"GPU Memory Allocated: {a:.2f} GB | Reserved: {r:.2f} GB")
        except:
            print("GPU memory info not available")
    else:
        # For CPU, we can check system memory if needed
        try:
            import psutil
            mem = psutil.virtual_memory()
            print(f"System Memory Usage: {mem.percent:.1f}% ({mem.used / 1024**3:.1f} GB / {mem.total / 1024**3:.1f} GB)")
        except ImportError:
            print("psutil not available for memory monitoring")

print_memory_usage()

# ============================================================
# TRAINER SETUP
# ============================================================
print("Initializing SFTTrainer...")
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=lora_config,
    processing_class=tokenizer,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3 if has_cuda else 2)],
)

# ============================================================
# AUTO-RESUME FROM LAST CHECKPOINT
# ============================================================
latest_checkpoint = None
if os.path.isdir(output_dir):
    checkpoints = [
        os.path.join(output_dir, d) for d in os.listdir(output_dir)
        if d.startswith("checkpoint-") and os.path.isdir(os.path.join(output_dir, d))
    ]
    if checkpoints:
        latest_checkpoint = max(checkpoints, key=lambda x: int(x.split("-")[-1]))
        print(f"Resuming training from: {latest_checkpoint}")

print_memory_usage()
print("Starting fine-tuning…")
print("=" * 60)

try:
    trainer.train(resume_from_checkpoint=latest_checkpoint)
    print("Training completed successfully!")
except Exception as e:
    print(f"Training failed with error: {e}")
    raise

# Save best adapter + tokenizer
final_dir = "./qwen3_lora_finetuned_final_32_V5_Fourth"
trainer.save_model(final_dir)
tokenizer.save_pretrained(final_dir)
print(f"Fine-tuning complete! Model saved to: {final_dir}")

# ============================================================
# EVALUATE ON EXTERNAL TEST
# ============================================================
print("Evaluating on external test set…")
try:
    eval_results = trainer.evaluate()
    perplexity = float(np.exp(eval_results["eval_loss"]))
    print(f"Evaluation Loss: {eval_results['eval_loss']:.4f}")
    print(f"Perplexity: {perplexity:.4f}")
except Exception as e:
    print(f"Evaluation failed: {e}")

print_memory_usage()

# ============================================================
# MERGE ADAPTER INTO FULL MODEL (optional)
# ============================================================
if has_cuda:  # Only try merging on GPU
    print("Attempting to merge LoRA adapter into full model...")
    try:
        from peft import AutoPeftModelForCausalLM
        merged = AutoPeftModelForCausalLM.from_pretrained(
            final_dir,
            device_map="auto",
            torch_dtype=torch.float16,
        )
        merged = merged.merge_and_unload()
        
        merged_dir = "./qwen3_merged_model_final_v4"
        merged.save_pretrained(merged_dir)
        tokenizer.save_pretrained(merged_dir)
        print(f"Merged full model saved to: {merged_dir}")
    except Exception as e:
        print(f"Merge failed (adapters still saved). Error: {e}")
else:
    print("Skipping model merge on CPU (adapters saved separately)")

print("=" * 60)
print("TRAINING COMPLETE!")
print("=" * 60)

In [None]:
import os
import glob
import re
import warnings
import concurrent.futures
import numpy as np
import torch
from datasets import load_dataset, concatenate_datasets
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    EarlyStoppingCallback,
)
from peft import LoraConfig, prepare_model_for_kbit_training
from trl import SFTTrainer, SFTConfig

# ============================================================
# ENV + SILENCE NOISE
# ============================================================
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore", message="torch.utils.checkpoint")
warnings.filterwarnings("ignore", message="MatMul8bitLt")

# ============================================================
# CUDA/GPU AVAILABILITY CHECK
# ============================================================
def check_device_setup():
    """Check CUDA availability and return appropriate device configuration"""
    print("=" * 60)
    print("DEVICE SETUP CHECK")
    print("=" * 60)
    
    # Check PyTorch CUDA support
    cuda_available = torch.cuda.is_available()
    print(f"CUDA Available: {cuda_available}")
    
    if cuda_available:
        try:
            device_count = torch.cuda.device_count()
            print(f"CUDA Devices: {device_count}")
            for i in range(device_count):
                props = torch.cuda.get_device_properties(i)
                print(f"  Device {i}: {props.name} ({props.total_memory / 1024**3:.1f} GB)")
            
            # Test CUDA functionality
            test_tensor = torch.randn(2, 2).cuda()
            print(f"CUDA Test: SUCCESS")
            return True, "auto"
        except Exception as e:
            print(f"CUDA Test Failed: {e}")
            return False, "cpu"
    else:
        print("CUDA not available - using CPU mode")
        return False, "cpu"

# ============================================================
# PATHS (edit if your layout differs)
# ============================================================
train_glob = r"Train4/**/*.jsonl"  # your 16k multi-turn set
test_glob = r"Eval4/**/*.jsonl"    # totally different transcripts
cache_dir = "./cache_qwen3_smallset_v4"
os.makedirs(cache_dir, exist_ok=True)

# ============================================================
# DEVICE SETUP
# ============================================================
has_cuda, device_map = check_device_setup()

# ============================================================
# MODEL + TOKENIZER (FIXED VERSION)
# ============================================================
model_id = "Qwen/Qwen3-4B-Instruct-2507"

print(f"\nLoading tokenizer from: {model_id}")
tokenizer = AutoTokenizer.from_pretrained(model_id)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token


# Configure quantization and model loading based on device availability
model_kwargs = {
    "device_map": device_map,
    "low_cpu_mem_usage": True,
}

if has_cuda:
    print("Using CUDA configuration with 8-bit quantization...")
    quant_config = BitsAndBytesConfig(
        load_in_8bit=True,
        bnb_8bit_compute_dtype=torch.float16,
    )
    model_kwargs.update({
        "quantization_config": quant_config,
        "torch_dtype": torch.float16,
        "attn_implementation": "flash_attention_2",
    })
else:
    print("Using CPU configuration (no quantization)...")
    quant_config = None
    model_kwargs.update({
        "torch_dtype": torch.float32,  # Use float32 for CPU
        # Remove flash_attention_2 for CPU compatibility
    })

print(f"Loading model from: {model_id}")
model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
model.config.use_cache = False

# No resize needed since PAD = EOS

if has_cuda:
    model.gradient_checkpointing_enable()
    if quant_config is not None:
        model = prepare_model_for_kbit_training(model)

# ============================================================
# LoRA CONFIG
# ============================================================
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.2,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
)

# ============================================================
# TRAINING ARGS
# ============================================================
output_dir = "./32_V5_Model_V4"

# Adjust training arguments based on device
if has_cuda:
    print("Using GPU-optimized training configuration...")
    training_args = SFTConfig(
        output_dir=output_dir,
        num_train_epochs=2,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        gradient_accumulation_steps=4,
        learning_rate=5e-5,
        weight_decay=0.15,
        warmup_ratio=0.1,
        fp16=True,
        max_grad_norm=1.0,
        logging_steps=25,
        save_strategy="steps",
        save_steps=150,
        eval_strategy="steps",
        eval_steps=150,
        save_total_limit=7,
        optim="paged_adamw_8bit",
        report_to="tensorboard",
        seed=42,
        max_length=2048,
        packing=False,
        dataset_text_field="text",
        dataloader_num_workers=4,
        dataloader_pin_memory=True,
        eval_accumulation_steps=2,
        lr_scheduler_type="cosine",
        group_by_length=True,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
    )
else:
    print("Using CPU-optimized training configuration...")
    training_args = SFTConfig(
        output_dir=output_dir,
        num_train_epochs=1,  # Reduce epochs for CPU
        per_device_train_batch_size=1,  # Smaller batch size for CPU
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=8,  # Increase to maintain effective batch size
        learning_rate=5e-5,  # Slightly lower learning rate
        weight_decay=0.05,
        warmup_ratio=0.1,
        fp16=False,  # Disable FP16 for CPU
        max_grad_norm=1.0,
        logging_steps=10,
        save_strategy="steps",
        save_steps=100,
        eval_strategy="steps",
        eval_steps=100,
        save_total_limit=10,
        optim="adamw_torch",  # Use standard AdamW for CPU
        report_to="tensorboard",
        seed=42,
        max_length=1024,  # Shorter sequences for CPU
        packing=False,
        dataset_text_field="text",
        dataloader_num_workers=2,  # Fewer workers for CPU
        dataloader_pin_memory=False,  # Disable for CPU
        eval_accumulation_steps=4,
        lr_scheduler_type="cosine",
        group_by_length=True,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
    )

# ============================================================
# THINK-TAG CLEANER
# ============================================================
def clean_think_tags(text: str) -> str:
    cleaned = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL)
    cleaned = re.sub(r"\n\s*\n", "\n", cleaned).strip()
    return cleaned

# ============================================================
# CHAT TEMPLATE FORMATTER
# ============================================================
def formatting_func_batched(batch):
    texts = []
    for messages in batch["messages"]:
        try:
            text = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=False,
                enable_thinking=False,  # This should prevent think tags, but we'll clean them anyway
            )
            if text.endswith(tokenizer.eos_token):
                text = text[:-len(tokenizer.eos_token)]
            
            # CLEAN THINK TAGS - This is the key addition
            text = clean_think_tags(text)
            
            texts.append(text.strip())
        except Exception as e:
            print(f"Error formatting example: {e}")
            texts.append("")
    return {"text": texts}

# ============================================================
# FIXED MODULE CLASSIFICATION HELPER (clean + direct)
# ============================================================
def classify_module_type(example):
    """
    Classify module type (classification, response, summary) 
    based only on system prompt signatures or expected JSON keys.
    """
    if not example:
        return "unknown"

    # --- Method 1: Messages with system role ---
    if "messages" in example and isinstance(example["messages"], list):
        for message in example["messages"]:
            if message.get("role") == "system":
                content = message.get("content", "")
                if "PHQ-8 classification AI for therapeutic conversations" in content:
                    return "classification"
                elif "therapeutic response generator AI conducting a depression screening interview" in content:
                    return "response"
                elif "cumulative summary AI for therapeutic conversations" in content:
                    return "summary"

    # --- Method 2: Direct text content fallback ---
    text = example.get("text", "")
    if text:
        if '"evidence_mapping"' in text and '"phq8_scores"' in text:
            return "classification"
        elif '"therapist_response"' in text and '"strategy_used"' in text:
            return "response"
        elif '"cumulative_summary"' in text:
            return "summary"

    return "unknown"


# ============================================================
# IMPROVED INSTANCE PRINTER
# ============================================================
def print_dataset_instances(dataset, dataset_name, num_instances=3):
    """Print sample instances grouped by module type"""
    print("\n" + "=" * 80)
    print(f"DATASET INSTANCES - {dataset_name.upper()}")
    print("=" * 80)

    # Group instances
    module_instances = {mt: [] for mt in ["classification", "response", "summary", "unknown"]}

    for i, example in enumerate(dataset):
        try:
            mt = classify_module_type(example)
            module_instances[mt].append((i, example))
        except Exception as e:
            print(f"Error classifying instance {i}: {e}")
            module_instances["unknown"].append((i, example))

    # Print examples
    for mt in ["classification", "response", "summary"]:
        instances = module_instances[mt]
        print(f"\n{'-' * 60}")
        print(f"{mt.upper()} MODULE - {len(instances)} total instances")
        print(f"{'-' * 60}")

        if not instances:
            print(f"No {mt} instances found in {dataset_name}")
            continue

        for j, (idx, ex) in enumerate(instances[:num_instances]):
            text = ex.get("text", "")
            print(f"\n▶ Instance {j+1} (Index {idx}) - {len(text)} chars")
            print("=" * 40)
            preview = text + "..." if len(text) > 2000 else text
            print(preview)
            print("=" * 40)

    # Summary
    total = len(dataset)
    print(f"\nTotal instances in {dataset_name}: {total}")
    for mt, inst in module_instances.items():
        print(f"  {mt.capitalize()}: {len(inst)}")

    return module_instances


# ============================================================
# DATA LOADING HELPERS
# ============================================================
def load_and_format_jsonl(file):
    try:
        ds = load_dataset("json", data_files=file, split="train")
        if len(ds) == 0:
            return None
        
        ds_formatted = ds.map(
            formatting_func_batched,
            remove_columns=ds.column_names,
            batched=True,
            batch_size=64,
            num_proc=1,
        )
        ds_formatted = ds_formatted.filter(lambda x: x["text"].strip() != "", num_proc=1)
        return ds_formatted
    except Exception as e:
        print(f"Error processing {file}: {e}")
        return None

def load_corpus(glob_pattern, label="train"):
    files = glob.glob(glob_pattern, recursive=True)
    if not files:
        raise ValueError(f"No JSONL files found for {label}: {glob_pattern}")
    print(f"Found {len(files)} {label} files")
    formatted, n = [], 0
    with concurrent.futures.ProcessPoolExecutor(max_workers=min(4, os.cpu_count())) as ex:
        futures = [ex.submit(load_and_format_jsonl, f) for f in files]
        for fut in concurrent.futures.as_completed(futures):
            ds = fut.result()
            if ds is not None:
                formatted.append(ds)
                n += len(ds)
    if not formatted:
        raise RuntimeError(f"Failed to load any valid {label} data.")
    merged = concatenate_datasets(formatted)
    print(f"{label.capitalize()} examples after formatting: {len(merged)}")
    return merged

# ============================================================
# LOAD TRAIN + EXTERNAL TEST
# ============================================================
print("Loading training corpus…")
train_dataset = load_corpus(train_glob, label="train")

print("Loading external test corpus…")
eval_dataset = load_corpus(test_glob, label="test")

# ============================================================
# PRINT DATASET INSTANCES
# ============================================================
print_dataset_instances(train_dataset, "TRAINING", num_instances=1)
print_dataset_instances(eval_dataset, "EVALUATION", num_instances=1)

# ============================================================
# MEMORY MONITOR
# ============================================================
def print_memory_usage():
    if has_cuda and torch.cuda.is_available():
        try:
            a = torch.cuda.memory_allocated(0) / 1024**3
            r = torch.cuda.memory_reserved(0) / 1024**3
            print(f"GPU Memory Allocated: {a:.2f} GB | Reserved: {r:.2f} GB")
        except:
            print("GPU memory info not available")
    else:
        # For CPU, we can check system memory if needed
        try:
            import psutil
            mem = psutil.virtual_memory()
            print(f"System Memory Usage: {mem.percent:.1f}% ({mem.used / 1024**3:.1f} GB / {mem.total / 1024**3:.1f} GB)")
        except ImportError:
            print("psutil not available for memory monitoring")

print_memory_usage()

# ============================================================
# TRAINER SETUP
# ============================================================
print("Initializing SFTTrainer...")
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=lora_config,
    processing_class=tokenizer,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3 if has_cuda else 2)],
)

# ============================================================
# AUTO-RESUME FROM LAST CHECKPOINT
# ============================================================
latest_checkpoint = None
if os.path.isdir(output_dir):
    checkpoints = [
        os.path.join(output_dir, d) for d in os.listdir(output_dir)
        if d.startswith("checkpoint-") and os.path.isdir(os.path.join(output_dir, d))
    ]
    if checkpoints:
        latest_checkpoint = max(checkpoints, key=lambda x: int(x.split("-")[-1]))
        print(f"Resuming training from: {latest_checkpoint}")

print_memory_usage()
print("Starting fine-tuning…")
print("=" * 60)

try:
    trainer.train(resume_from_checkpoint=latest_checkpoint)
    print("Training completed successfully!")
except Exception as e:
    print(f"Training failed with error: {e}")
    raise

# Save best adapter + tokenizer
final_dir = "./qwen3_lora_finetuned_final_32_V5_Fourth"
trainer.save_model(final_dir)
tokenizer.save_pretrained(final_dir)
print(f"Fine-tuning complete! Model saved to: {final_dir}")

# ============================================================
# EVALUATE ON EXTERNAL TEST
# ============================================================
print("Evaluating on external test set…")
try:
    eval_results = trainer.evaluate()
    perplexity = float(np.exp(eval_results["eval_loss"]))
    print(f"Evaluation Loss: {eval_results['eval_loss']:.4f}")
    print(f"Perplexity: {perplexity:.4f}")
except Exception as e:
    print(f"Evaluation failed: {e}")

print_memory_usage()

# ============================================================
# MERGE ADAPTER INTO FULL MODEL (optional)
# ============================================================
if has_cuda:  # Only try merging on GPU
    print("Attempting to merge LoRA adapter into full model...")
    try:
        from peft import AutoPeftModelForCausalLM
        merged = AutoPeftModelForCausalLM.from_pretrained(
            final_dir,
            device_map="auto",
            torch_dtype=torch.float16,
        )
        merged = merged.merge_and_unload()
        
        merged_dir = "./qwen3_merged_model_final_v5"
        merged.save_pretrained(merged_dir)
        tokenizer.save_pretrained(merged_dir)
        print(f"Merged full model saved to: {merged_dir}")
    except Exception as e:
        print(f"Merge failed (adapters still saved). Error: {e}")
else:
    print("Skipping model merge on CPU (adapters saved separately)")

print("=" * 60)
print("TRAINING COMPLETE!")
print("=" * 60)