In [None]:
from processing.concept_generation import process_lichess_puzzles, perform_concept_detection_pytorch_probe, extract_and_save_activations_batched

In [None]:
# Specifically for running training, requires strong gpu, inadvisable to run

import os
import traceback
import pandas as pd
import json
import numpy as np
import torch

LICHESS_DB_PATH = 'lichess_db_puzzle.csv'
ONNX_MODEL_PATH = 'prepared_network_fp16.onnx'
CONCEPT_OUTPUT_DIR = "concepts_sequential"
os.makedirs(CONCEPT_OUTPUT_DIR, exist_ok=True)
BASE_DATA_DIR = os.path.join(CONCEPT_OUTPUT_DIR, 'processed_data')
ACTIVATIONS_DIR = os.path.join(CONCEPT_OUTPUT_DIR, 'activations')
RESULTS_DIR = os.path.join(CONCEPT_OUTPUT_DIR, 'probe_results')
PROBE_SAVE_DIR = os.path.join(CONCEPT_OUTPUT_DIR, 'saved_probes')
os.makedirs(BASE_DATA_DIR, exist_ok=True)
os.makedirs(ACTIVATIONS_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(PROBE_SAVE_DIR, exist_ok=True)

TARGET_CONCEPTS = [
    "fork", "pin", "skewer", "discoveredAttack", "sacrifice",
    "hangingPiece", "kingsideAttack", "exposedKing", "attraction",
    "deflection", "interference", "clearance", "intermezzo",
    "advancedPawn", "attackingF2F7", "capturingDefender", "doubleCheck",
    "queensideAttack", "trappedPiece", "defensiveMove", "quietMove",
    "mate"
]

print(f"Target Concepts: {TARGET_CONCEPTS}")

TARGET_LAYERS = [10]
PUZZLE_LIMIT_PER_CLASS = 25000
ACTIVATION_CHUNK_SIZE = 5000
TEST_SPLIT_RATIO = 0.2
PROBE_EPOCHS = 150
PROBE_PATIENCE = 10
PROBE_BATCH_SIZE = 512
PROBE_LR = 0.0001
PROBE_CONV_HIDDEN = 32
PROBE_DROPOUT = 0.4
PROBE_L2_DECAYS = [0.0, 1e-5] 
RANDOM_SEED = 42

print(f"Target Layers: {TARGET_LAYERS}")

csv_save_path = os.path.join(RESULTS_DIR, f"probe_evaluation_summary.csv")
json_save_path = os.path.join(RESULTS_DIR, f"probe_evaluation_detailed.json")

completed_concepts = set()
if os.path.exists(csv_save_path):
    print(f"--- Loading existing summary from: {csv_save_path} ---")
    try:
        summary_df = pd.read_csv(csv_save_path)
        completed_concepts = set(summary_df['concept'].unique())
        print(f"Found {len(completed_concepts)} completed concepts: {completed_concepts}")
    except Exception as e:
        print(f"Error loading summary CSV, starting fresh: {e}")
        summary_df = pd.DataFrame() # Start fresh if loading fails
else:
    print("--- No existing summary file found, starting fresh. ---")
    summary_df = pd.DataFrame()

detailed_results = {}
if os.path.exists(json_save_path):
     print(f"--- Loading existing detailed results from: {json_save_path} ---")
     try:
         with open(json_save_path, 'r') as f:
              detailed_results = json.load(f)
         # Ensure keys are consistent with completed_concepts found in CSV
         concepts_in_json = set(detailed_results.keys())
         if concepts_in_json != completed_concepts:
              print("Warning: Concepts in JSON and CSV differ. Rebuilding JSON from scratch.")
              detailed_results = {} # Start fresh if inconsistent
     except Exception as e:
          print(f"Error loading detailed JSON, starting fresh: {e}")
          detailed_results = {}


for concept_name in TARGET_CONCEPTS:

    if concept_name in completed_concepts:
        print(f"\n--- Skipping already completed concept: {concept_name} ---")
        if concept_name not in detailed_results:
             print(f"  Attempting to recover metrics for {concept_name} from probe files...")
             recovered_layer_data = {}
             for layer_idx in TARGET_LAYERS:
                 probe_path = os.path.join(PROBE_SAVE_DIR, f"{concept_name}_layer_{layer_idx}_probe.pth")
                 if os.path.exists(probe_path):
                     try:
                         probe_data = torch.load(probe_path, map_location='cpu') # Load to CPU
                         recovered_layer_data[str(layer_idx)] = { # Use string key for JSON consistency if needed
                             "validation_results": {
                                 'best_avg_val_accuracy': probe_data.get('best_avg_val_accuracy', np.nan),
                                 'best_reg_param': probe_data.get('best_reg_param'),
                                 'reg_param_name': probe_data.get('best_reg_param_name')
                             },
                             "test_metrics": probe_data.get('test_metrics')
                         }
                     except Exception as load_err:
                         print(f"    Error loading probe file {probe_path}: {load_err}")
                 else:
                     print(f"    Probe file not found: {probe_path}")
             if recovered_layer_data:
                 detailed_results[concept_name] = recovered_layer_data
                 print(f"  Successfully recovered metrics for {concept_name}.")
        continue # Move to the next concept in the list

    # --- Concept Not Completed - Proceed with Processing ---
    print(f"\n{'='*15} Processing Concept: {concept_name} {'='*15}")

    current_concept_failed = False
    save_path_prefix = f"{concept_name}"
    data_save_base = os.path.join(BASE_DATA_DIR, save_path_prefix)
    activation_dir_concept = os.path.join(ACTIVATIONS_DIR, save_path_prefix)
    os.makedirs(activation_dir_concept, exist_ok=True)

    pos_npz_path = f"{data_save_base}_pos.npz"
    neg_npz_path = f"{data_save_base}_neg.npz"

    # --- Step 1: Data Generation ---
    print(f"\n=== STEP 1: Data Processing for '{concept_name}' ===")
    try:
        # Check if data already exists (maybe from a previous partial run that crashed before cleanup)
        if not (os.path.exists(pos_npz_path) and os.path.exists(neg_npz_path)):
            data_generated = process_lichess_puzzles(
                csv_path=LICHESS_DB_PATH,
                target_tag=concept_name,
                save_path_base=data_save_base,
                limit=PUZZLE_LIMIT_PER_CLASS
            )
            if not data_generated:
                print(f"Data generation failed for {concept_name}. Skipping.")
                current_concept_failed = True
        else:
            print(f"  Found existing NPZ data for {concept_name}. Skipping generation.")
    except Exception as e:
        print(f"Error during data generation check/run for {concept_name}: {e}")
        traceback.print_exc()
        current_concept_failed = True

    if current_concept_failed: continue

    # --- Step 2: Activation Extraction ---
    print(f"\n=== STEP 2: Activation Extraction for '{concept_name}' ===")
    success_pos, success_neg = False, False
    # Check if activations already exist
    all_activations_exist = True
    pos_npy_files_paths = {}
    neg_npy_files_paths = {}
    for layer_idx in TARGET_LAYERS:
         pos_npy_path = os.path.join(activation_dir_concept, f"{save_path_prefix}_pos_layer_{layer_idx}_activations.npy")
         neg_npy_path = os.path.join(activation_dir_concept, f"{save_path_prefix}_neg_layer_{layer_idx}_activations.npy")
         pos_npy_files_paths[layer_idx] = pos_npy_path
         neg_npy_files_paths[layer_idx] = neg_npy_path
         if not (os.path.exists(pos_npy_path) and os.path.exists(neg_npy_path)):
             all_activations_exist = False
             break

    if all_activations_exist:
        print(f"  Found existing activation files for {concept_name}. Skipping extraction.")
        success_pos, success_neg = True, True
    else:
        print(f"  Did not find all activation files. Running extraction...")
        try:
            print(f"--- Extracting Positive Activations ---")
            success_pos, _ = extract_and_save_activations_batched(
                model_path=ONNX_MODEL_PATH, npz_data_path=pos_npz_path, target_layer_indices=TARGET_LAYERS, 
                output_dir=activation_dir_concept,output_prefix=f"{save_path_prefix}_pos",max_samples=None, 
                chunk_size=ACTIVATION_CHUNK_SIZE
            )
            if not success_pos:
                print(f"Positive activation extraction failed for {concept_name}. Skipping.")
                current_concept_failed = True

            if not current_concept_failed:
                print(f"--- Extracting Negative Activations ---")
                success_neg, _ = extract_and_save_activations_batched(
                    model_path=ONNX_MODEL_PATH, npz_data_path=neg_npz_path,
                    target_layer_indices=TARGET_LAYERS, output_dir=activation_dir_concept,
                    output_prefix=f"{save_path_prefix}_neg",max_samples=None, chunk_size=ACTIVATION_CHUNK_SIZE
                )
                if not success_neg:
                    print(f"Negative activation extraction failed for {concept_name}. Skipping.")
                    current_concept_failed = True
        except Exception as e:
            print(f"Error during activation extraction for {concept_name}: {e}")
            traceback.print_exc()
            current_concept_failed = True

    if current_concept_failed: # Cleanup needed if extraction failed
        print(f"--- Attempting cleanup for {concept_name} after failure ---")
        pos_npy_files = list(pos_npy_files_paths.values())
        neg_npy_files = list(neg_npy_files_paths.values())
        files_to_delete = [pos_npz_path, neg_npz_path] + pos_npy_files + neg_npy_files
        # ... (rest of cleanup logic from previous script) ...
        for file_path in files_to_delete:
             try:
                 if os.path.exists(file_path): os.remove(file_path)
             except OSError as err: print(f"  Error deleting {file_path}: {err}")
        try:
              if os.path.exists(activation_dir_concept) and not os.listdir(activation_dir_concept): os.rmdir(activation_dir_concept)
        except OSError as err: print(f"  Error removing activation directory {activation_dir_concept}: {err}")
        continue

    # --- Step 3: Train & Evaluate Probes ---
    print(f"\n=== STEP 3: Probe Training & Evaluation for '{concept_name}' ===")
    concept_results = None
    try:
        concept_results = perform_concept_detection_pytorch_probe(
            target_layer_indices=TARGET_LAYERS, concept_name=save_path_prefix,
            activations_dir=activation_dir_concept,
            probe_save_dir=PROBE_SAVE_DIR, conv_hidden_channels=PROBE_CONV_HIDDEN,
            dropout_prob=PROBE_DROPOUT, test_ratio=TEST_SPLIT_RATIO,
            batch_size=PROBE_BATCH_SIZE, epochs=PROBE_EPOCHS,
            learning_rate=PROBE_LR, patience=PROBE_PATIENCE,
            l2_weight_decays=PROBE_L2_DECAYS,
            random_seed=RANDOM_SEED
        )

        # --- Step 3.5: Update Results and Save Incrementally ---
        if concept_results:
            detailed_results[concept_name] = concept_results
            # Prepare summary rows for this concept
            concept_summary_rows = []
            for layer_idx, results in concept_results.items():
                val_res = results.get("validation_results", {})
                test_res = results.get("test_metrics") 
                val_acc = val_res.get('best_avg_val_accuracy', np.nan) if val_res else np.nan

                summary_row = {
                    'concept': concept_name, 'layer': layer_idx,
                    'best_val_acc': val_acc,
                    'test_accuracy': test_res['accuracy'] if test_res else np.nan,
                    'test_precision': test_res['precision'] if test_res else np.nan,
                    'test_recall': test_res['recall'] if test_res else np.nan,
                    'test_specificity': test_res['specificity'] if test_res else np.nan,
                    'test_f1': test_res['f1_score'] if test_res else np.nan,
                    'probe_file': os.path.join(PROBE_SAVE_DIR, f"{concept_name}_layer_{layer_idx}_probe.pth") if test_res else None
                }
                concept_summary_rows.append(summary_row)

            # Append new rows to the main DataFrame and save
            if concept_summary_rows:
                new_rows_df = pd.DataFrame(concept_summary_rows)
                if not summary_df.empty and 'concept' in summary_df.columns:
                    summary_df = summary_df[summary_df['concept'] != concept_name]
                    
                summary_df = pd.concat([summary_df, new_rows_df], ignore_index=True)
                try:
                    summary_df.to_csv(csv_save_path, index=False)
                    print(f"  Successfully updated summary CSV: {csv_save_path}")
                except Exception as e:
                    print(f"  Error saving updated summary CSV: {e}")

            # Save the updated detailed results JSON
            try:
                with open(json_save_path, 'w') as f:
                    json.dump(detailed_results, f, indent=4, default=lambda x: str(x) if isinstance(x, (np.generic, np.ndarray)) else None)
                print(f"  Successfully updated detailed JSON: {json_save_path}")
            except Exception as e:
                print(f"  Error saving updated detailed JSON: {e}")
        else:
             print(f"  Probe training function returned no results for {concept_name}.")
             current_concept_failed = True # Treat as failure for cleanup

    except Exception as e:
        print(f"Error during probe training/evaluation for {concept_name}: {e}")
        traceback.print_exc()
        current_concept_failed = True

    # --- Step 4: Cleanup Intermediate Files ---
    print(f"\n=== STEP 4: Cleaning up intermediate files for '{concept_name}' ===")
    pos_npy_files = [os.path.join(activation_dir_concept, f"{save_path_prefix}_pos_layer_{l}_activations.npy") for l in TARGET_LAYERS]
    neg_npy_files = [os.path.join(activation_dir_concept, f"{save_path_prefix}_neg_layer_{l}_activations.npy") for l in TARGET_LAYERS]
    files_to_delete = [pos_npz_path, neg_npz_path] + pos_npy_files + neg_npy_files
    # ... (rest of cleanup logic from previous script) ...
    for file_path in files_to_delete:
         try:
             if os.path.exists(file_path):
                 os.remove(file_path)
                 print(f"  Deleted: {file_path}")
         except OSError as err: print(f"  Error deleting {file_path}: {err}")
    try:
          if os.path.exists(activation_dir_concept) and not os.listdir(activation_dir_concept):
              os.rmdir(activation_dir_concept)
              print(f"  Removed empty activation directory: {activation_dir_concept}")
    except OSError as err: print(f"  Error removing activation directory {activation_dir_concept}: {err}")


    print(f"{'='*15} Finished Concept: {concept_name} {'='*15}\n")


# === Final Summary (Optional - files are already saved incrementally) ===
print("\n\n" + "="*20 + " FINAL SUMMARY (from final state) " + "="*20)
# Display final summary from the DataFrame
print(summary_df.to_string())
print(f"\nSummary saved to: {csv_save_path}")
print(f"Detailed results saved to: {json_save_path}")
print("\nSequential probe training and evaluation complete.")