In [None]:
###train.py显示进度版本↓

In [1]:
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    TrainerCallback,
    TrainerState,
    TrainerControl
)
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model, PeftModel
from trl import SFTTrainer
from datasets import load_dataset, Dataset
import matplotlib.pyplot as plt
import os
import json
from tqdm.auto import tqdm

# Import from our project
import config
from optimized_prompt_template import OPTIMIZED_PROMPT_SYSTEM_MESSAGE, OPTIMIZED_PROMPT_CORE_INSTRUCTIONS, FEW_SHOT_EXAMPLES_TEXT
# Import evaluation utilities
from evaluate import parse_output_line, calculate_f1_metrics_from_lists, plot_evaluation_scores

# --- Global Variables for Callbacks and Plotting ---
tokenizer = None
train_pbar = None # For TqdmProgressCallback

# --- Tqdm Progress Callback for SFTTrainer ---
class TqdmProgressCallback(TrainerCallback):
    def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        global train_pbar
        if state.is_world_process_zero: # Only on the main process
            total_steps = state.max_steps
            train_pbar = tqdm(total=total_steps, desc="Training Steps", unit="step")

    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        global train_pbar
        if state.is_world_process_zero and train_pbar is not None:
            train_pbar.update(1)
            # Optionally, update postfix with current loss/epoch
            logs = kwargs.get("logs", None)
            if logs and 'loss' in logs:
                current_loss = logs['loss']
                current_epoch = state.epoch
                train_pbar.set_postfix_str(f"Loss: {current_loss:.4f}, Epoch: {current_epoch:.2f}")


    def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        global train_pbar
        if state.is_world_process_zero and train_pbar is not None:
            train_pbar.close()
            train_pbar = None

    # Optional: If you want the progress bar to pause during evaluation
    def on_evaluate_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        global train_pbar
        if state.is_world_process_zero and train_pbar is not None:
            train_pbar.refresh() # Ensure it's updated before pausing (if needed)
            train_pbar.set_postfix_str("Evaluating...") # Set a temporary postfix

    def on_evaluate_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        global train_pbar
        if state.is_world_process_zero and train_pbar is not None:
            train_pbar.refresh() # Ensure it's updated after resuming
            # The on_step_end will automatically update it with loss again
            # Or you can explicitly reset it if no step happens immediately after eval
            if state.global_step < state.max_steps: # Only reset if training is not yet finished
                 logs = kwargs.get("logs", None)
                 if logs and 'loss' in logs:
                    current_loss = logs['loss']
                    current_epoch = state.epoch
                    train_pbar.set_postfix_str(f"Loss: {current_loss:.4f}, Epoch: {current_epoch:.2f}")
                 else: # If no loss log, just clear it or set a generic training message
                    train_pbar.set_postfix_str("Training resumed...")


def plot_training_history(log_history, output_dir, train_dataset_size, train_batch_size, gradient_accumulation_steps):
    os.makedirs(output_dir, exist_ok=True)
    
    epochs_train = []
    losses_train = []
    epochs_eval = []
    losses_eval = []
    
    steps_per_epoch = None
    if train_dataset_size > 0 and train_batch_size > 0 and gradient_accumulation_steps > 0:
        effective_batch_size = train_batch_size * gradient_accumulation_steps
        if effective_batch_size > 0:
            steps_per_epoch = (train_dataset_size + effective_batch_size - 1) // effective_batch_size # Ceiling division
            if steps_per_epoch == 0: steps_per_epoch = 1

    for log_entry in log_history:
        current_epoch_val = log_entry.get('epoch')
        current_step = log_entry.get('step')

        if 'loss' in log_entry: # Training loss
            if current_epoch_val is not None:
                epochs_train.append(current_epoch_val)
            elif current_step is not None and steps_per_epoch is not None:
                epochs_train.append(current_step / steps_per_epoch)
            else:
                epochs_train.append(current_step if current_step is not None else float(len(epochs_train) + 1)) # Use a sensible default

            losses_train.append(log_entry['loss'])

        if 'eval_loss' in log_entry: # Evaluation loss
            if current_epoch_val is not None:
                epochs_eval.append(current_epoch_val)
            elif current_step is not None and steps_per_epoch is not None:
                epochs_eval.append(current_step / steps_per_epoch)
            else:
                epochs_eval.append(epochs_train[-1] if epochs_train else float(len(epochs_eval) + 1)) # Use a sensible default
            losses_eval.append(log_entry['eval_loss'])

    plt.figure(figsize=(12, 6))
    if epochs_train and losses_train:
        plt.plot(epochs_train, losses_train, label='Training Loss', marker='.', linestyle='-')
    if epochs_eval and losses_eval:
        plt.plot(epochs_eval, losses_eval, label='Validation Loss', marker='o', linestyle='--')
    
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plot_path = os.path.join(output_dir, "training_validation_loss.png")
    plt.savefig(plot_path)
    print(f"Training/validation loss plot saved to {plot_path}")
    plt.close()


# MODIFIED: sft_formatting_func now returns tokenized input_ids and labels
def sft_formatting_func(example): 
    # 'example' here will be a dictionary like {'id': ..., 'content': ..., 'output': ...}
    comment_to_process = example['content']
    target_quadruplet = example['output'] 
    
    user_message_content = f"{OPTIMIZED_PROMPT_CORE_INSTRUCTIONS}\n\n{FEW_SHOT_EXAMPLES_TEXT}\n\n[待处理文本]\n{comment_to_process}"
    
    dialogue = [
        {"role": "system", "content": OPTIMIZED_PROMPT_SYSTEM_MESSAGE},
        {"role": "user", "content": user_message_content},
        {"role": "assistant", "content": target_quadruplet}
    ]
    
    # Ensure tokenizer is initialized and accessible
    if tokenizer is None:
        raise ValueError("Tokenizer is not initialized globally for sft_formatting_func.")
    
    # Apply chat template to get the full string
    full_text = tokenizer.apply_chat_template(dialogue, tokenize=False, add_generation_prompt=False)
    
    # Tokenize the full string
    tokenized_output = tokenizer(
        full_text,
        truncation=True,
        padding="max_length", # Pad to max_seq_length
        max_length=config.MAX_SEQ_LENGTH,
        return_tensors="pt" # Return as PyTorch tensors
    )
    
    # In SFT, labels are typically the input_ids themselves
    # Shift them inside the model for causal language modeling
    return {
        "input_ids": tokenized_output["input_ids"].squeeze().tolist(), # Convert to list of int
        "attention_mask": tokenized_output["attention_mask"].squeeze().tolist(), # Convert to list of int
        "labels": tokenized_output["input_ids"].squeeze().tolist() # Labels are the same as input_ids for causal LM
    }


def generate_predictions_for_eval(model_to_use, tokenizer_to_use, dataset_to_eval):
    predictions_list = []
    ground_truths_list = []
    
    print(f"\nGenerating predictions for {len(dataset_to_eval)} validation samples...")
    for example in tqdm(dataset_to_eval, desc="Validation Prediction"):
        # When evaluating, 'example' still comes from the raw dataset
        comment_text = example['content']
        ground_truth_output = example['output']

        user_message_content = f"{OPTIMIZED_PROMPT_CORE_INSTRUCTIONS}\n\n{FEW_SHOT_EXAMPLES_TEXT}\n\n[待处理文本]\n{comment_text}"
        
        prompt_for_model = tokenizer_to_use.apply_chat_template(
            [
                {"role": "system", "content": OPTIMIZED_PROMPT_SYSTEM_MESSAGE},
                {"role": "user", "content": user_message_content}
            ],
            tokenize=False,
            add_generation_prompt=True
        )
        
        inputs = tokenizer_to_use(prompt_for_model, return_tensors="pt", padding=False, truncation=True, max_length=config.MAX_SEQ_LENGTH).to(model_to_use.device)
        input_ids_len = inputs.input_ids.shape[1]

        with torch.no_grad():
            outputs = model_to_use.generate(
                input_ids=inputs.input_ids,
                attention_mask=inputs.attention_mask,
                max_new_tokens=512, 
                pad_token_id=tokenizer_to_use.pad_token_id if tokenizer_to_use.pad_token_id is not None else tokenizer_to_use.eos_token_id,
                eos_token_id=tokenizer_to_use.eos_token_id,
                do_sample=True, temperature=0.6, top_p=0.9,
            )
        
        generated_ids = outputs[0][input_ids_len:]
        assistant_response = tokenizer_to_use.decode(generated_ids, skip_special_tokens=True).strip()

        if not assistant_response:
            assistant_response = "NULL | NULL | non-hate | non-hate [END]"
        elif not assistant_response.strip().endswith("[END]"):
            assistant_response = assistant_response.strip() + " [END]"

        predictions_list.append(assistant_response)
        ground_truths_list.append(ground_truth_output)
        
    return predictions_list, ground_truths_list


if __name__ == "__main__":
    # --- Global Tokenizer Initialization ---
    tokenizer = AutoTokenizer.from_pretrained(config.BASE_MODEL_ID, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        # tokenizer.pad_token_id = tokenizer.eos_token_id # SFTTrainer handles this


    # --- Load and Split Dataset ---
    try:
        full_dataset = load_dataset('json', data_files=config.TRAIN_FILE, split="train")
    except Exception as e:
        print(f"Error loading training data from {config.TRAIN_FILE}: {e}")
        exit()

    train_dataset_split = None
    eval_dataset_split = None

    if config.VALIDATION_SPLIT_RATIO > 0 and config.VALIDATION_SPLIT_RATIO < 1:
        print(f"Splitting dataset with ratio: {config.VALIDATION_SPLIT_RATIO}")
        if len(full_dataset) == 0:
            print("Error: Training dataset is empty. Cannot split.")
            exit()
        
        # Calculate minimum number of samples for validation
        min_validation_samples = 1 
        
        # Check if there are enough samples for a meaningful split
        if len(full_dataset) > min_validation_samples and (len(full_dataset) * config.VALIDATION_SPLIT_RATIO) >= min_validation_samples:
            split_dataset = full_dataset.train_test_split(test_size=config.VALIDATION_SPLIT_RATIO, shuffle=True, seed=42)
            train_dataset_split = split_dataset['train']
            eval_dataset_split = split_dataset['test']
            print(f"Training samples: {len(train_dataset_split)}, Validation samples: {len(eval_dataset_split)}")
        else:
            print(f"Warning: Dataset too small for specified validation split ratio ({config.VALIDATION_SPLIT_RATIO}). Using full dataset for training without SFT validation.")
            train_dataset_split = full_dataset
            eval_dataset_split = None
    else:
        train_dataset_split = full_dataset
        eval_dataset_split = None
        print(f"Using full dataset for training: {len(train_dataset_split)} samples. No validation split during SFT.")


    # --- Pre-process datasets with formatting_func to add 'input_ids' and 'labels' columns ---
    print("Applying formatting_func to training dataset...")
    # `remove_columns` should remove original columns like 'content', 'output', 'id'
    # We only want 'input_ids', 'attention_mask', 'labels'
    train_dataset_processed = train_dataset_split.map(
        sft_formatting_func, 
        remove_columns=train_dataset_split.column_names, # Remove all original columns
        # num_proc=os.cpu_count() or 1, # Uncomment and adjust num_proc if you have enough CPU cores
        desc="Formatting train dataset"
    )

    eval_dataset_processed = None
    if eval_dataset_split and len(eval_dataset_split) > 0:
        print("Applying formatting_func to evaluation dataset...")
        eval_dataset_processed = eval_dataset_split.map(
            sft_formatting_func, 
            remove_columns=eval_dataset_split.column_names, # Remove all original columns
            # num_proc=os.cpu_count() or 1, # Uncomment and adjust num_proc if you have enough CPU cores
            desc="Formatting eval dataset"
        )
    # else: eval_dataset_processed remains None

    # --- Model Configuration (QLoRA) ---
    compute_dtype = getattr(torch, config.BNB_4BIT_COMPUTE_DTYPE)
    bnb_config = None
    if config.USE_4BIT_QUANTIZATION:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type=config.BNB_4BIT_QUANT_TYPE,
            bnb_4bit_compute_dtype=compute_dtype,
            bnb_4bit_use_double_quant=False,
        )

    # --- Load Base Model ---
    model = AutoModelForCausalLM.from_pretrained(
        config.BASE_MODEL_ID,
        quantization_config=bnb_config,
        torch_dtype=compute_dtype if not config.USE_4BIT_QUANTIZATION else None,
        device_map={"": 0}, 
        trust_remote_code=True
    )
    model.config.use_cache = False 
    if hasattr(model.config, "pretraining_tp"): model.config.pretraining_tp = 1

    # --- PEFT Configuration ---
    peft_config = LoraConfig(
        r=config.LORA_R,
        lora_alpha=config.LORA_ALPHA,
        lora_dropout=config.LORA_DROPOUT,
        target_modules=config.LORA_TARGET_MODULES,
        bias="none",
        task_type="CAUSAL_LM",
    )
    if config.USE_4BIT_QUANTIZATION:
        model = prepare_model_for_kbit_training(model)
    
    if not isinstance(model, PeftModel): 
        model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()

    # --- Training Arguments ---
    if len(train_dataset_split) == 0: # Use original split for total samples calculation
        print("Error: Training dataset is empty after splitting (or was empty initially). Cannot proceed.")
        exit()

    total_steps_approx = (len(train_dataset_split) // (config.TRAIN_BATCH_SIZE * config.GRADIENT_ACCUMULATION_STEPS)) * config.NUM_TRAIN_EPOCHS
    actual_save_steps = int(total_steps_approx * config.SAVE_STEPS_RATIO) if config.SAVE_STEPS_RATIO > 0 and total_steps_approx > 0 else config.LOGGING_STEPS * 5
    if actual_save_steps == 0: actual_save_steps = max(1, config.LOGGING_STEPS) 

    training_arguments = TrainingArguments(
        output_dir=config.OUTPUT_DIR,
        per_device_train_batch_size=config.TRAIN_BATCH_SIZE,
        gradient_accumulation_steps=config.GRADIENT_ACCUMULATION_STEPS,
        optim="paged_adamw_8bit" if config.USE_4BIT_QUANTIZATION else "adamw_torch",
        learning_rate=config.LEARNING_RATE,
        num_train_epochs=config.NUM_TRAIN_EPOCHS,
        logging_steps=config.LOGGING_STEPS,
        save_strategy="steps",
        save_steps=actual_save_steps,
        save_total_limit=2,
        fp16=False, 
        bf16=True if config.BNB_4BIT_COMPUTE_DTYPE == "bfloat16" and torch.cuda.is_bf16_supported() else False,
        evaluation_strategy="steps" if eval_dataset_split and len(eval_dataset_split) > 0 else "no",
        eval_steps=actual_save_steps if eval_dataset_split and len(eval_dataset_split) > 0 else None,
        report_to="tensorboard", 
        remove_unused_columns=False, # Now that we explicitly map, keep this False
    )

    # --- Initialize SFTTrainer ---
    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_arguments,
        train_dataset=train_dataset_processed, # Pass the pre-processed dataset with 'input_ids' and 'labels'
        eval_dataset=eval_dataset_processed if eval_dataset_processed and len(eval_dataset_processed) > 0 else None, 
        peft_config=peft_config if isinstance(model, PeftModel) else None, 
        # Removed formatting_func: it's now applied during map
        # Removed dataset_text_field: no longer needed as formatting_func handles tokenization
        max_seq_length=config.MAX_SEQ_LENGTH, # Still relevant for data collator's padding/truncation
        packing=False, # Make sure packing is False if you are padding to max_length
        callbacks=[TqdmProgressCallback()]
    )

    # --- Train ---
    print("\nStarting SFT training...")
    trainer.train()

    # --- Save Final Model (LoRA adapter) ---
    final_adapter_dir = os.path.join(config.OUTPUT_DIR, "final_lora_adapter")
    trainer.model.save_pretrained(final_adapter_dir)
    tokenizer.save_pretrained(final_adapter_dir) 
    print(f"\nFinal LoRA adapter and tokenizer saved to: {final_adapter_dir}")

    # --- Plot Training History ---
    if trainer.state.log_history:
        plot_training_history(
            log_history=trainer.state.log_history,
            output_dir=config.TRAINING_PLOTS_DIR,
            train_dataset_size=len(train_dataset_split), # Use original train dataset size for steps calculation
            train_batch_size=config.TRAIN_BATCH_SIZE,
            gradient_accumulation_steps=config.GRADIENT_ACCUMULATION_STEPS
        )
    else:
        print("No log history found to plot training loss.")

    # --- Evaluate on Validation Set (if it exists) and Plot Metrics ---
    # For evaluation, we still need 'content' and 'output' from the ORIGINAL dataset.
    # So we use eval_dataset_split, not eval_dataset_processed.
    if eval_dataset_split and len(eval_dataset_split) > 0:
        print("\n--- Evaluating on Validation Split Post-Training ---")
        validation_predictions, validation_ground_truths = generate_predictions_for_eval(
            trainer.model, tokenizer, eval_dataset_split # Pass original eval_dataset_split
        )
        
        os.makedirs(os.path.dirname(config.VALIDATION_PREDICTIONS_FILE), exist_ok=True)
        with open(config.VALIDATION_PREDICTIONS_FILE, 'w', encoding='utf-8') as f_val_pred:
            for pred_str in validation_predictions:
                f_val_pred.write(pred_str + '\n')
        print(f"Validation split predictions saved to {config.VALIDATION_PREDICTIONS_FILE}")

        parsed_pred_quads_lists = [parse_output_line(line) for line in validation_predictions]
        parsed_gt_quads_lists = [parse_output_line(line) for line in validation_ground_truths]

        val_results = calculate_f1_metrics_from_lists(parsed_pred_quads_lists, parsed_gt_quads_lists)
        
        if val_results:
            print("\nValidation Set Evaluation Results:")
            print(f"  Hard Match -> F1: {val_results['hard_match']['f1']:.4f}, P: {val_results['hard_match']['precision']:.4f}, R: {val_results['hard_match']['recall']:.4f}")
            print(f"  Soft Match -> F1: {val_results['soft_match']['f1']:.4f}, P: {val_results['soft_match']['precision']:.4f}, R: {val_results['soft_match']['recall']:.4f}")
            print(f"  Overall Score (Avg F1): {val_results['overall_score']:.4f}")
            
            plot_evaluation_scores(val_results, config.VALIDATION_EVAL_PLOTS_DIR, "validation_set_f1_scores.png")
        else:
            print("Validation evaluation could not be completed.")
    else:
        print("\nNo validation set was used or it was empty; skipping post-training validation evaluation.")

    print("\nTraining and validation evaluation (if applicable) complete.")

  warn(


Splitting dataset with ratio: 0.1
Training samples: 3600, Validation samples: 400
Applying formatting_func to training dataset...


Formatting train dataset:   0%|          | 0/3600 [00:00<?, ? examples/s]

Applying formatting_func to evaluation dataset...


Formatting eval dataset:   0%|          | 0/400 [00:00<?, ? examples/s]

trainable params: 3,784,704 || all params: 467,772,416 || trainable%: 0.8091


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)

Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
  super().__init__(



Starting SFT training...


Training Steps:   0%|          | 0/900 [00:00<?, ?step/s]

  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


Step,Training Loss,Validation Loss
450,0.3838,0.377134
900,0.394,0.370568


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]



Final LoRA adapter and tokenizer saved to: ./results_quad_extraction614/final_lora_adapter
Training/validation loss plot saved to ./results_quad_extraction614/training_plots/training_validation_loss.png

--- Evaluating on Validation Split Post-Training ---

Generating predictions for 400 validation samples...


Validation Prediction:   0%|          | 0/400 [00:00<?, ?it/s]

Validation split predictions saved to ./results_quad_extraction614/validation_split_predictions.txt

Validation Set Evaluation Results:
  Hard Match -> F1: 0.0493, P: 0.0493, R: 0.0494
  Soft Match -> F1: 0.1672, P: 0.1670, R: 0.1674
  Overall Score (Avg F1): 0.1083
Validation evaluation F1 scores plot saved to ./results_quad_extraction614/validation_evaluation_plots/validation_set_f1_scores.png

Training and validation evaluation (if applicable) complete.


In [2]:
import subprocess
import os

result = subprocess.run('bash -c "source /etc/network_turbo && env | grep proxy"', shell=True, capture_output=True, text=True)
output = result.stdout
for line in output.splitlines():
    if '=' in line:
        var, value = line.split('=', 1)
        os.environ[var] = value