# The Adaptive Learner - Google Colab Version

This notebook implements the core functionality of The Adaptive Learner, a biologically-inspired, adaptive learning system for LLMs that prevents catastrophic forgetting through neuromodulated PEFT mechanisms and generative replay.

## Features
- Neuromodulation (γ-Gain): A dynamic task-conditioned signal that modulates LoRA plasticity
- AM-Managed PEFT: Task-specific LoRA modules managed by a router
- LoRA Consolidation: Techniques to prevent catastrophic forgetting
- Generative Replay: Data-level memory replay

In [1]:
# Cell 1: Environment Variable Setup

import os

# Configure PyTorch CUDA memory allocator to allow segments to expand.
# This can help prevent out-of-memory errors in some situations by allowing
# PyTorch to request more memory from the OS for CUDA allocations if needed,
# rather than being strictly limited by initially reserved blocks.
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# The following line for CUDA_VERSION was for older bitsandbytes compatibility issues.
# Since we are installing PyTorch for CUDA 12.1 and recent bitsandbytes,
# this is likely not needed and should remain commented out or removed
# to avoid potential conflicts if the system/driver has a different CUDA version
# that PyTorch's CUDA 12.1 components can still work with.
# os.environ['CUDA_VERSION'] = '11.8' # Keep commented out

# --- Suppress Hugging Face Tokenizers Parallelism Warning ---
# When using PyTorch DataLoaders with num_workers > 0, the tokenizers library
# can detect that it's being used in forked child processes after being used
# in the main process. It then disables its own internal parallelism to prevent
# potential deadlocks and issues a warning.
# Setting this environment variable explicitly to 'true' or 'false' suppresses that warning.
# Setting to 'false' is generally recommended and safer when DataLoader handles multiprocessing.
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
# --- End of Tokenizers Parallelism Setting ---

# You can add other global environment variable settings here if needed in the future.
# For example, if you weren't using RunPods secrets for W&B and wanted to set them here:
# os.environ['WANDB_API_KEY'] = 'your_wandb_api_key_here' # Not recommended if using secrets
# os.environ['WANDB_PROJECT'] = 'your_wandb_project_name'
# os.environ['WANDB_ENTITY'] = 'your_wandb_entity'

## Setup

First, we'll install the required dependencies:

In [2]:
# Cell 2: Setup - Installing Dependencies (Corrected and Enforced Order)

# Uninstall potentially conflicting versions first for a cleaner state.
!pip uninstall torch torchvision torchaudio -y
!pip uninstall transformers -y

# Install PyTorch and Torchvision first, from the specific CUDA 12.1 wheel.
!pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

# Install flash-attn
!pip install -U flash-attn --no-build-isolation

# Now, install transformers and the rest of your dependencies.
!pip install -U transformers
!pip install -U peft datasets scikit-learn matplotlib wandb python-dotenv tqdm rouge-score accelerate bitsandbytes

Found existing installation: torch 2.5.1+cu121
Uninstalling torch-2.5.1+cu121:
  Successfully uninstalled torch-2.5.1+cu121
Found existing installation: torchvision 0.20.1+cu121
Uninstalling torchvision-0.20.1+cu121:
  Successfully uninstalled torchvision-0.20.1+cu121
Found existing installation: torchaudio 2.5.1+cu121
Uninstalling torchaudio-2.5.1+cu121:
  Successfully uninstalled torchaudio-2.5.1+cu121
[0mFound existing installation: transformers 4.51.3
Uninstalling transformers-4.51.3:
  Successfully uninstalled transformers-4.51.3
[0mLooking in indexes: https://download.pytorch.org/whl/cu121
Collecting torch
  Using cached https://download.pytorch.org/whl/cu121/torch-2.5.1%2Bcu121-cp310-cp310-linux_x86_64.whl (780.4 MB)
Collecting torchvision
  Using cached https://download.pytorch.org/whl/cu121/torchvision-0.20.1%2Bcu121-cp310-cp310-linux_x86_64.whl (7.3 MB)
Collecting torchaudio
  Using cached https://download.pytorch.org/whl/cu121/torchaudio-2.5.1%2Bcu121-cp310-cp310-linux_x86

In [3]:
# Cell 2.5: Login & Verification (Run after Cell 2 - Pip Installs)

import os
import logging # Use standard logging for this setup cell
import subprocess

# Setup a basic logger for this cell if main logger isn't configured yet
setup_logger = logging.getLogger(__name__ + "_setup_verification")
if not setup_logger.handlers:
    handler = logging.StreamHandler()
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    handler.setFormatter(formatter)
    setup_logger.addHandler(handler)
    setup_logger.setLevel(logging.INFO)

setup_logger.info("--- Verifying Hugging Face and W&B Logins ---")

# 1. Hugging Face Login Verification
setup_logger.info("--- Hugging Face Verification ---")
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
if hf_token:
    setup_logger.info("HF_TOKEN/HUGGINGFACE_HUB_TOKEN found in environment variables.")
    try:
        # Use huggingface-cli whoami to check token validity without explicit login call in Python
        # The `login` function in huggingface_hub can also be used but `whoami` is a good CLI check
        # Make sure the CLI is accessible; it should be if transformers is installed.
        result = subprocess.run(['huggingface-cli', 'whoami'], capture_output=True, text=True, check=False)
        if result.returncode == 0 and "ERROR" not in result.stdout.upper() and "Traceback" not in result.stdout:
            setup_logger.info(f"Hugging Face CLI whoami successful:\n{result.stdout.strip()}")
            # Additionally, try a programmatic check if desired, though CLI is often sufficient
            from huggingface_hub import HfApi
            try:
                api = HfApi() # Uses token from env
                user_info = api.whoami()
                setup_logger.info(f"Hugging Face programmatic whoami successful: User '{user_info.get('name')}'")
            except Exception as e_api:
                setup_logger.warning(f"Hugging Face programmatic whoami failed (token might be invalid for API access or other issue): {e_api}")
        else:
            setup_logger.warning(f"Hugging Face CLI whoami failed or indicated an issue. stdout: {result.stdout.strip()}, stderr: {result.stderr.strip()}")
            setup_logger.warning("Ensure your HF_TOKEN is valid and has permissions for the models you intend to use (e.g., Gemma).")
    except FileNotFoundError:
        setup_logger.error("huggingface-cli not found. Ensure transformers/huggingface_hub are correctly installed and in PATH.")
    except Exception as e:
        setup_logger.error(f"An error occurred during Hugging Face CLI whoami check: {e}")
else:
    setup_logger.warning("HF_TOKEN/HUGGINGFACE_HUB_TOKEN NOT found in environment variables. Model downloads for gated repos will likely fail.")

# 2. W&B Login Verification
setup_logger.info("--- Weights & Biases Verification ---")
wandb_api_key = os.environ.get("WANDB_API_KEY")
wandb_project = os.environ.get("WANDB_PROJECT")
wandb_entity = os.environ.get("WANDB_ENTITY")

if wandb_api_key:
    setup_logger.info("WANDB_API_KEY found in environment variables.")
    # The W&B library will automatically use this.
    # We can try a dry-run init to verify, but this might create an empty run.
    # A simpler check is just to confirm the key, project, and entity are present.
    if wandb_project:
        setup_logger.info(f"WANDB_PROJECT found: {wandb_project}")
    else:
        setup_logger.warning("WANDB_PROJECT NOT found in environment variables. Will need to be set in config or W&B UI.")
    
    if wandb_entity:
        setup_logger.info(f"WANDB_ENTITY found: {wandb_entity}")
    else:
        setup_logger.warning("WANDB_ENTITY NOT found in environment variables. Will use default or needs to be set in config/W&B UI.")
    
    # You could try a programmatic login test if you want to be very thorough,
    # but it might be overkill if the env var is present.
    # Example (optional, might create a .netrc entry or require user interaction if key is wrong):
    # try:
    #     import wandb
    #     wandb.login() # This will use the env var if available, or prompt if not
    #     setup_logger.info("wandb.login() command executed (relies on env var or prompts if needed).")
    # except Exception as e:
    #     setup_logger.error(f"Error during wandb.login() test: {e}")

else:
    setup_logger.warning("WANDB_API_KEY NOT found in environment variables. W&B logging will be disabled or require manual login/key entry.")

setup_logger.info("--- Verification Cell Complete ---")

2025-05-20 16:15:09,233 - __main___setup_verification - INFO - --- Verifying Hugging Face and W&B Logins ---
2025-05-20 16:15:09,234 - __main___setup_verification - INFO - --- Hugging Face Verification ---
2025-05-20 16:15:09,235 - __main___setup_verification - INFO - HF_TOKEN/HUGGINGFACE_HUB_TOKEN found in environment variables.
2025-05-20 16:15:09,440 - __main___setup_verification - INFO - Hugging Face CLI whoami successful:
Galileo82
2025-05-20 16:15:09,640 - __main___setup_verification - INFO - Hugging Face programmatic whoami successful: User 'Galileo82'
2025-05-20 16:15:09,641 - __main___setup_verification - INFO - --- Weights & Biases Verification ---
2025-05-20 16:15:09,641 - __main___setup_verification - INFO - WANDB_API_KEY found in environment variables.
2025-05-20 16:15:09,642 - __main___setup_verification - INFO - WANDB_PROJECT found: RESTART01
2025-05-20 16:15:09,642 - __main___setup_verification - INFO - WANDB_ENTITY found: doingmyownthing82-none
2025-05-20 16:15:09,642 

In [4]:
# Cell X (or start of Cell 7) on RunPods
import os
# No more google.colab.drive

# Define your project's base path on the RunPods instance
# This could be in /workspace (cleared when pod stops unless it's a persistent volume pod type)
# or on a /mnt/volume/ if you have persistent storage attached.
gdrive_project_base_path = "/workspace/MyAdaptiveLearnerProject" # ADJUST THIS PATH
outputs_on_gdrive_path = os.path.join(gdrive_project_base_path, "outputs") # Name is just historical
os.makedirs(outputs_on_gdrive_path, exist_ok=True)
# The logger will be defined in Cell 3, so logging this path will happen later

## Import Libraries

In [5]:
import os
import json
import logging
import sys # Import sys for stdout
import time
import io
import base64
from pathlib import Path
from typing import List, Dict, Any, Tuple, Optional, Union
from dataclasses import dataclass, field

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader # DataLoader was used here

import transformers
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer
)

# Ensure these are present for datasets library
from datasets import load_dataset, DownloadMode, concatenate_datasets # << MAKE SURE THIS IS PRESENT AND UNCOMMENTED

from peft import (
    PeftModel,
    PeftConfig,
    PeftModelForCausalLM,
    get_peft_model,
    LoraConfig,
    TaskType
)

import wandb
from dotenv import load_dotenv
from tqdm.notebook import tqdm

# Try to import Colab specific userdata
try:
    from google.colab import userdata
    WANDB_API_KEY_SECRET_NAME = "WANDB_API_KEY"
    logger_colab_secrets = logging.getLogger(__name__ + ".colab_secrets")
except ImportError:
    userdata = None
    WANDB_API_KEY_SECRET_NAME = None
    logger_colab_secrets = None


# Configure logging - MORE EXPLICIT HANDLER
# Remove all handlers associated with the root logger.
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

logging.basicConfig(level=logging.INFO,
                  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
                  stream=sys.stdout) # Force output to stdout

logger = logging.getLogger(__name__)
# Ensure our logger also propagates to the root and thus to the new basicConfig handler
logger.propagate = True
# Optionally set our specific logger level if needed, though INFO from basicConfig should cover it
# logger.setLevel(logging.INFO)

# Test if logger is working now
logger.info("--- TEST: Adaptive Learner Logging Initialized ---")


# Load environment variables if .env file exists
load_dotenv()

# Set random seed for reproducibility
def set_seed(seed=123):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed()

2025-05-20 16:15:14,066 - __main__ - INFO - --- TEST: Adaptive Learner Logging Initialized ---


## Config (Cell 7):
Add feature_replay_alpha

In [6]:
# Cell 7: AdaptiveLearnerConfig (Modified for RunPods, Exp Tagging, SharedLoRA, etc.)
import os
import torch
from dataclasses import dataclass, field
from typing import List, Dict, Any, Tuple, Optional, Union

# --- Define your RunPods project base path here ---
RUNPODS_PROJECT_BASE_PATH = "/workspace/MyAdaptiveLearnerProject" # <<< ENSURE THIS IS YOUR RUNPODS PATH
# --- End of RunPods path definition ---

@dataclass
class AdaptiveLearnerConfig:
    """Configuration class for The Adaptive Learner"""
    model_name: str = "google/gemma-2b"
    max_seq_length: int = 512
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    lora_r: int = 8
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    target_modules: List[str] = field(default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"])

    batch_size: int = 2
    learning_rate: float = 5e-5
    weight_decay: float = 0.01
    num_lora_train_epochs: int = 3

    gamma_gain_lambda: float = 0.5
    gamma_gain_metrics: List[str] = field(default_factory=lambda: ["nll", "gradient_norm", "entropy", "accuracy"])
    gamma_gain_weights: Dict[str, float] = field(default_factory=lambda: {"nll": 0.3, "gradient_norm": 0.2, "entropy": 0.2, "accuracy": 0.3})

    router_type: str = "linear"
    router_hidden_size: int = 64
    router_confidence_threshold: float = 0.7
    router_learning_rate: float = 1e-4
    router_training_epochs_per_call: int = 1

    consolidation_method: str = "ewc"
    aflora_importance_threshold: float = 0.01
    ewc_lambda: float = 100.0
    ewc_data_loader_num_samples: int = 5
    ewc_batch_size: int = 1
    ewc_fixed_lambda_bypass_gamma: bool = False

    replay_method: str = "cmt"
    replay_buffer_size: int = 500
    replay_alpha: float = 0.5
    feature_replay_alpha: float = 0.3

    replay_model_learning_rate: float = 5e-5
    replay_model_weight_decay: float = 1e-5
    replay_model_grad_clip_norm: float = 1.0
    replay_model_internal_batch_size: int = 2
    replay_backbone_encoding_batch_size: int = 8

    cmt_memory_size: int = 128
    cmt_compressor_layers: int = 2
    cmt_uniformity_weight: float = 0.1

    pcgr_latent_dim: int = 256
    pcgr_kl_weight: float = 0.1
    pcgr_prototype_update_freq: int = 100

    run_sft_baseline: bool = False
    sft_lora_id: str = "sft_baseline_adapter"
    benchmark_tasks_to_run: List[Dict[str, Any]] = field(default_factory=list)
    num_train_samples_benchmark: int = 250
    num_val_samples_benchmark: int = 50

    output_dir: str = os.path.join(RUNPODS_PROJECT_BASE_PATH, "outputs")

    experiment_tag: Optional[str] = None
    load_router_state_from_tag: Optional[str] = None

    k_examples_for_prototype: int = 3
    num_tasks_to_share_lora: int = 0 # New field for Shared LoRA experiments

    use_advanced_embeddings: bool = True
    k_for_grad_sketch: int = 3
    grad_sketch_layer_names: List[str] = field(default_factory=lambda: [
        "model.layers.2.self_attn.q_proj",
        "model.layers.4.mlp.gate_proj",
    ])
    grad_sketch_max_elements: int = 2048
    use_context_stats_avg_token_entropy: bool = True
    use_context_stats_avg_seq_length: bool = True

    use_wandb: bool = True
    wandb_project: str = "adaptive-learner"
    wandb_entity: Optional[str] = None
    wandb_api_key: Optional[str] = None

    num_workers: int = 0

    # Early Stopping Parameters for LoRA Training
    use_lora_early_stopping: bool = False
    lora_early_stopping_patience: int = 3
    lora_early_stopping_metric: str = "accuracy" # Metric to monitor on val set
    min_lora_epochs_before_early_stop: int = 1
    lora_early_stopping_delta: float = 0.001


    def __post_init__(self):
        base_output_dir_to_check = self.output_dir
        if 'RUNPODS_PROJECT_BASE_PATH' in globals() and not self.output_dir.startswith(globals()['RUNPODS_PROJECT_BASE_PATH']):
            self.output_dir = os.path.join(globals()['RUNPODS_PROJECT_BASE_PATH'], "outputs")
            base_output_dir_to_check = self.output_dir
            print(f"INFO (AdaptiveLearnerConfig.__post_init__): Corrected output_dir to: {self.output_dir}")

        os.makedirs(base_output_dir_to_check, exist_ok=True)

        if self.experiment_tag:
            exp_general_output_dir = os.path.join(base_output_dir_to_check, self.experiment_tag)
            os.makedirs(exp_general_output_dir, exist_ok=True)

        env_wandb_api_key = os.environ.get("WANDB_API_KEY")
        if env_wandb_api_key:
            self.wandb_api_key = env_wandb_api_key
            print(f"INFO (AdaptiveLearnerConfig): Found WANDB_API_KEY in environment variables.")
        elif self.wandb_api_key:
            print(f"INFO (AdaptiveLearnerConfig): Using WANDB_API_KEY provided in config object.")
        elif 'userdata' in globals() and userdata is not None and 'WANDB_API_KEY_SECRET_NAME' in globals() and WANDB_API_KEY_SECRET_NAME is not None:
            try:
                self.wandb_api_key = userdata.get(WANDB_API_KEY_SECRET_NAME)
                print(f"INFO (AdaptiveLearnerConfig): Fetched WANDB_API_KEY from Colab userdata.")
            except Exception as e:
                print(f"WARNING (AdaptiveLearnerConfig): Failed to get WANDB_API_KEY from Colab userdata: {e}")

        if os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN"):
            print(f"INFO (AdaptiveLearnerConfig): Hugging Face token (HF_TOKEN/HUGGINGFACE_HUB_TOKEN) found in environment variables.")

        env_wandb_project = os.environ.get("WANDB_PROJECT")
        if env_wandb_project:
            self.wandb_project = env_wandb_project
            print(f"INFO (AdaptiveLearnerConfig): Found WANDB_PROJECT in environment variables: {self.wandb_project}")

        env_wandb_entity = os.environ.get("WANDB_ENTITY")
        if env_wandb_entity:
            self.wandb_entity = env_wandb_entity
            print(f"INFO (AdaptiveLearnerConfig): Found WANDB_ENTITY in environment variables: {self.wandb_entity}")

    def to_dict(self) -> Dict[str, Any]:
        return {k: v for k, v in self.__dict__.items() if not k.startswith('_')}

default_config = AdaptiveLearnerConfig()

INFO (AdaptiveLearnerConfig): Found WANDB_API_KEY in environment variables.
INFO (AdaptiveLearnerConfig): Hugging Face token (HF_TOKEN/HUGGINGFACE_HUB_TOKEN) found in environment variables.
INFO (AdaptiveLearnerConfig): Found WANDB_PROJECT in environment variables: RESTART01
INFO (AdaptiveLearnerConfig): Found WANDB_ENTITY in environment variables: doingmyownthing82-none


## Backbone Model

The backbone module handles the base LLM and efficient processing.

In [7]:
class AdaptiveLearnerBackbone:
    """
    LLM Backbone module that handles the base LLM model and efficient input processing
    """
    def __init__(self, config: AdaptiveLearnerConfig):
        """
        Initialize the LLM backbone

        Args:
            config (AdaptiveLearnerConfig): Configuration object
        """
        self.config = config
        self.device = torch.device(config.device)

        logger.info(f"Loading model: {config.model_name}")
        self.model = self._load_model()
        self.tokenizer = self._load_tokenizer()

        # Count parameters
        num_params = self._count_parameters()
        logger.info(f"Model loaded with {num_params:,} parameters")

        # Freeze base model parameters
        self._freeze_model_parameters()
        logger.info("Base model parameters frozen")

    def _load_model(self) -> PreTrainedModel:
        """Load the LLM model"""
        model_name = self.config.model_name
        custom_model_path = os.environ.get("CUSTOM_MODEL_PATH")
        if custom_model_path and os.path.exists(custom_model_path):
            model_name = custom_model_path
            logger.info(f"Using custom model path: {model_name}")

        model_kwargs = {
            "torch_dtype": torch.float16,
            "device_map": "auto",
            "load_in_8bit": True # Maintained for memory efficiency
        }

        # Attempt to use Flash Attention 2 if available and configured
        use_flash_attn = getattr(self.config, 'use_flash_attention_2', True) # Default to True if not in config
                                                                            # Or, you can add use_flash_attention_2 to AdaptiveLearnerConfig
        if use_flash_attn:
            logger.info("Attempting to load model with Flash Attention 2.")
            try:
                model = AutoModelForCausalLM.from_pretrained(
                    model_name,
                    **model_kwargs,
                    attn_implementation="flash_attention_2"
                )
                logger.info("Model loaded successfully with Flash Attention 2 and 8-bit quantization.")
                return model # Return early if successful
            except Exception as e_fa:
                logger.warning(f"Failed to load with Flash Attention 2 (Error: {e_fa}). Falling back.")

        # Fallback if Flash Attention 2 fails or is not requested
        try:
            logger.info("Attempting to load model with 8-bit quantization (Flash Attention not explicitly enabled or auto-detected).")
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                **model_kwargs # attn_implementation will not be in model_kwargs here
            )
            logger.info("Model loaded in 8-bit quantization.")
        except Exception as e_8bit:
            logger.warning(f"8-bit loading also failed (Error: {e_8bit}). Falling back to fp16/fp32 without 8-bit.")
            # Remove 8-bit specific kwarg for final fallback
            del model_kwargs["load_in_8bit"]
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                device_map="auto" if torch.cuda.is_available() else None
                # attn_implementation="eager" # or let transformers decide default
            )
            if not torch.cuda.is_available():
                model = model.to(self.device)
            logger.info("Model loaded in standard fp16/fp32.")
        return model

    def _load_tokenizer(self) -> PreTrainedTokenizer:
        """Load the tokenizer"""
        model_name = self.config.model_name
        custom_model_path = os.environ.get("CUSTOM_MODEL_PATH")
        if custom_model_path and os.path.exists(custom_model_path):
            model_name = custom_model_path

        tokenizer = AutoTokenizer.from_pretrained(model_name)

        if tokenizer.pad_token is None:
            if tokenizer.eos_token is not None:
                tokenizer.pad_token = tokenizer.eos_token
            else:
                # Fallback for models like gemma that might not have eos_token set by default in some versions/configs
                # and tokenizer.unk_token might also be problematic if not present.
                # A common practice is to add a special pad token if truly missing.
                # However, for now, let's stick to eos or unk if eos is missing.
                # If unk is also missing, this would be an issue with the base tokenizer.
                if tokenizer.unk_token is not None:
                    tokenizer.pad_token = tokenizer.unk_token
                else: # Add a new pad token if unk is also None
                    logger.warning(f"Tokenizer for {model_name} has no pad_token, eos_token, or unk_token. Adding a new pad_token '<|pad|>'.")
                    tokenizer.add_special_tokens({'pad_token': '<|pad|>'})
                    # If we add a token, the model's embedding matrix needs to be resized.
                    # This is usually handled if you train the whole model, but for PEFT,
                    # it's safer if the base tokenizer has it.
                    # self.model.resize_token_embeddings(len(tokenizer)) # This would make base model params trainable.
                    # For now, rely on pre-existing tokens.
        return tokenizer

    def _freeze_model_parameters(self):
        """Freeze all model parameters"""
        for param in self.model.parameters():
            param.requires_grad = False

    def _count_parameters(self) -> int:
        """Count the number of parameters in the model"""
        return sum(p.numel() for p in self.model.parameters())

    def chunk_and_process(self, texts: List[str], max_chunk_size: Optional[int] = None) -> List[Dict[str, torch.Tensor]]:
        """
        Process text inputs by chunking them into manageable pieces

        Args:
            texts (List[str]): List of input texts
            max_chunk_size (int, optional): Maximum chunk size (tokens). Defaults to config.max_seq_length.

        Returns:
            List[Dict[str, torch.Tensor]]: List of processed chunks
        """
        if max_chunk_size is None:
            max_chunk_size = self.config.max_seq_length

        tokenized_texts = [self.tokenizer(text, return_tensors="pt", truncation=False) for text in texts]
        chunks = []
        for i, tokenized in enumerate(tokenized_texts):
            input_ids = tokenized['input_ids'][0]
            attention_mask = tokenized['attention_mask'][0]
            if len(input_ids) <= max_chunk_size:
                chunks.append({
                    'input_ids': input_ids.to(self.device),
                    'attention_mask': attention_mask.to(self.device),
                    'text_idx': i, 'chunk_idx': 0
                })
            else:
                for j in range(0, len(input_ids), max_chunk_size):
                    end_idx = min(j + max_chunk_size, len(input_ids))
                    chunks.append({
                        'input_ids': input_ids[j:end_idx].to(self.device),
                        'attention_mask': attention_mask[j:end_idx].to(self.device),
                        'text_idx': i, 'chunk_idx': j // max_chunk_size
                    })
        return chunks

    def generate(self,
                prompt: str,
                max_new_tokens: int = 100,
                temperature: float = 0.7,
                top_p: float = 0.9,
                do_sample: bool = True) -> str:
        """Generate text from a prompt"""
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) # Ensure inputs are on same device as model
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs, max_new_tokens=max_new_tokens, temperature=temperature,
                top_p=top_p, do_sample=do_sample, pad_token_id=self.tokenizer.pad_token_id
            )
        generated_text = self.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
        return generated_text

    def get_embedding(self, text: str) -> torch.Tensor:
        """Get the embedding for a text input (last hidden state of last token)"""
        inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device) # Ensure inputs are on same device as model
        with torch.no_grad():
            outputs = self.model(**inputs, output_hidden_states=True)
            last_hidden_state = outputs.hidden_states[-1][0][-1] # [batch_size=0 (implicit), seq_pos=-1 (last token), hidden_dim]
        return last_hidden_state

In [8]:
# Cell 8.5: Embedding Utility
import torch
import torch.nn.functional as F
from typing import List, Dict, Any, Optional, Union
from transformers import PreTrainedModel, PreTrainedTokenizer

# Make sure logger is accessible if you want to use it here, or pass it as an argument.
# For now, we'll rely on the global logger if it's set up, or print for critical errors.
# Ensure 'logger' is defined globally, typically in an earlier cell (e.g., Cell 3 or 4).
# If not, you might need to pass logger as an argument or use print for simplicity here.
if 'logger' not in globals():
    import logging
    logger = logging.getLogger(__name__ + ".prototype_util")
    if not logger.hasHandlers(): # Avoid adding multiple handlers if re-run
        handler = logging.StreamHandler()
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        handler.setFormatter(formatter)
        logger.addHandler(handler)
        logger.setLevel(logging.INFO)


def get_task_example_prototype_embedding(
    raw_examples: List[Dict[str, str]],
    k: int,
    task_text_field: Union[str, List[str]], # Added: to know how to extract input text
    backbone_model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    device: str,
    max_seq_length: int, # Added: for tokenization consistency
    config: Any # Using 'Any' for now, assuming it's AdaptiveLearnerConfig like object
) -> Optional[torch.Tensor]:
    """
    Generates a task example prototype embedding by averaging the embeddings of k examples.

    Args:
        raw_examples: A list of raw example dictionaries (e.g., from load_benchmark_task).
                      Each dict should have 'input' and 'output' keys, though only 'input' is used here.
        k: The number of examples to use for the prototype.
        task_text_field: The key (or list of keys) for the input text field(s) in raw_examples,
                         similar to how load_benchmark_task handles it.
        backbone_model: The frozen backbone LLM.
        tokenizer: The tokenizer for the backbone LLM.
        device: The device to run inferences on (e.g., 'cuda').
        max_seq_length: Maximum sequence length for tokenization.
        config: The AdaptiveLearnerConfig object (or similar) for parameters like
                replay_backbone_encoding_batch_size.

    Returns:
        A torch.Tensor representing the prototype embedding, or None if an error occurs.
    """
    if not raw_examples:
        logger.warning("get_task_example_prototype_embedding: No raw examples provided. Returning None.")
        return None
    if k <= 0:
        logger.warning(f"get_task_example_prototype_embedding: k must be > 0, but got {k}. Returning None.")
        return None

    # Select k examples (e.g., the first k, or a random k if preferred and raw_examples is large enough)
    # For simplicity and reproducibility in this step, let's take the first min(k, len(raw_examples)).
    selected_examples = raw_examples[:min(k, len(raw_examples))]
    if not selected_examples: # Should not happen if raw_examples is not empty and k > 0
        logger.error("get_task_example_prototype_embedding: No examples selected. This is unexpected. Returning None.")
        return None

    logger.info(f"Generating prototype from {len(selected_examples)} examples (requested k={k}).")

    example_input_texts = []
    for ex in selected_examples:
        input_val = None
        if isinstance(task_text_field, list): # Handle sentence-pair tasks
            input_parts = [str(ex.get(tf, "")) for tf in task_text_field]
            input_val = " [SEP] ".join(input_parts) # Using [SEP] as a common separator
        else: # Single sentence task
            input_val = str(ex.get(task_text_field, ""))
        
        if not input_val: # Skip if input text is empty
            logger.warning(f"Empty input text for example: {ex}. Skipping for prototype.")
            continue
        example_input_texts.append(input_val)

    if not example_input_texts:
        logger.warning("get_task_example_prototype_embedding: No valid input texts found in selected examples. Returning None.")
        return None

    # Store original model training state and set to eval
    original_training_state = backbone_model.training
    backbone_model.eval()

    all_example_embeddings = []
    # Use replay_backbone_encoding_batch_size for batching if k is large
    # This config parameter is typically used in CMTReplay.encode_input_batch
    # Let's use it here for consistency if available.
    batch_size_for_encoding = getattr(config, 'replay_backbone_encoding_batch_size', 1)
    if batch_size_for_encoding <= 0: batch_size_for_encoding = 1


    with torch.no_grad():
        for i in range(0, len(example_input_texts), batch_size_for_encoding):
            batch_texts = example_input_texts[i:i + batch_size_for_encoding]
            
            inputs = tokenizer(
                batch_texts,
                return_tensors="pt",
                padding="longest", # Pad to the longest sequence in the current batch
                truncation=True,
                max_length=max_seq_length,
                add_special_tokens=True
            ).to(device)

            # Get hidden states from the backbone model
            # Similar to Router.get_text_embedding or CMTReplay.encode_input_batch
            try:
                outputs = backbone_model(**inputs, output_hidden_states=True)
                last_hidden_state_batch = outputs.hidden_states[-1] # [batch_size, seq_len, hidden_dim]

                # Mean pooling of token embeddings, respecting attention mask
                attention_mask_expanded = inputs['attention_mask'].unsqueeze(-1).expand(last_hidden_state_batch.size()).float()
                sum_embeddings = torch.sum(last_hidden_state_batch * attention_mask_expanded, 1) # Sum along seq_len
                sum_mask = torch.clamp(attention_mask_expanded.sum(1), min=1e-9) # Sum of mask elements per sequence
                pooled_embeddings_batch = sum_embeddings / sum_mask # [batch_size, hidden_dim]
                
                all_example_embeddings.append(pooled_embeddings_batch)

            except Exception as e:
                logger.error(f"Error during backbone model inference for prototype examples: {e}")
                # Restore original model training state before returning
                if original_training_state:
                    backbone_model.train()
                else:
                    backbone_model.eval() # Ensure it's left in eval if it started that way
                return None
    
    # Restore original model training state
    if original_training_state:
        backbone_model.train()
    else:
        backbone_model.eval() # Ensure it's left in eval if it started that way

    if not all_example_embeddings:
        logger.warning("get_task_example_prototype_embedding: No embeddings were generated from examples. Returning None.")
        return None

    # Concatenate embeddings from all batches and average them
    try:
        concatenated_embeddings = torch.cat(all_example_embeddings, dim=0) # Shape: [total_valid_examples, hidden_dim]
        prototype_embedding = torch.mean(concatenated_embeddings, dim=0) # Shape: [hidden_dim]
    except Exception as e:
        logger.error(f"Error concatenating or averaging example embeddings: {e}")
        return None

    logger.info(f"Successfully generated prototype embedding of shape: {prototype_embedding.shape}")
    return prototype_embedding.to(device) # Ensure it's on the correct device

In [9]:
# Cell 8.6: Advanced Embedding Helper Functions

import torch
import torch.nn.functional as F
from typing import List, Dict, Any, Optional, Union, Tuple
from transformers import PreTrainedModel, PreTrainedTokenizer

if 'logger' not in globals():
    import logging
    logger = logging.getLogger(__name__ + ".advanced_embed_utils")
    if not logger.hasHandlers():
        handler = logging.StreamHandler()
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        handler.setFormatter(formatter)
        logger.addHandler(handler)
        logger.setLevel(logging.INFO)

def generate_gradient_sketch(
    examples: List[Dict[str, str]], # Should contain 'input' and 'output' for loss calculation
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    config: Any, # AdaptiveLearnerConfig
    device: str
) -> Optional[torch.Tensor]:
    """
    Generates a gradient sketch for a few task examples.
    Passes examples through the frozen base model, calculates loss,
    and extracts gradients w.r.t. specified layer weights.
    """
    if not config.use_advanced_embeddings or not hasattr(config, 'k_for_grad_sketch') or config.k_for_grad_sketch <= 0:
        return None
    if not examples:
        logger.warning("GradientSketch: No examples provided.")
        return None
    if not hasattr(config, 'grad_sketch_layer_names') or not config.grad_sketch_layer_names:
        logger.warning("GradientSketch: No grad_sketch_layer_names defined in config.")
        return None

    selected_examples = examples[:min(config.k_for_grad_sketch, len(examples))]
    if not selected_examples:
        logger.warning("GradientSketch: No examples selected for sketch.")
        return None

    # --- Prepare model and identify target parameters ---
    original_training_state = model.training
    model.eval() # Ensure model is in eval mode

    # Store original requires_grad states
    original_requires_grad = {name: p.requires_grad for name, p in model.named_parameters()}

    # Set requires_grad to False for all, then True only for target layers
    model.requires_grad_(False)
    
    target_params_with_grads = []
    target_param_names_found = []

    for name_part in config.grad_sketch_layer_names:
        # Assuming layer_names in config are module names, and we target their '.weight'
        full_param_name = f"{name_part}.weight"
        try:
            param_module = model.get_submodule(name_part)
            if hasattr(param_module, 'weight') and isinstance(param_module.weight, torch.nn.Parameter):
                param_module.weight.requires_grad_(True)
                target_params_with_grads.append(param_module.weight)
                target_param_names_found.append(full_param_name)
            else:
                logger.warning(f"GradientSketch: Could not find/access weight for module: {name_part}")
        except AttributeError:
            logger.warning(f"GradientSketch: Module {name_part} not found in model.")

    if not target_params_with_grads:
        logger.error("GradientSketch: No target parameters found or set for gradient extraction. Aborting sketch generation.")
        # Restore original requires_grad states
        for name, p in model.named_parameters():
            if name in original_requires_grad:
                p.requires_grad_(original_requires_grad[name])
        if original_training_state: model.train()
        return None
    
    logger.info(f"GradientSketch: Identified {len(target_params_with_grads)} target parameters for sketch: {target_param_names_found}")

    # --- Process examples and accumulate gradients ---
    # Use a temporary dataset/loader for batching
    # Ensure tokenizer.pad_token_id is set (already handled in CausalLMTrainingDataset)
    temp_dataset = CausalLMTrainingDataset(selected_examples, tokenizer, config.max_seq_length)
    # Batch size for grad sketch can be small, e.g., 1 or k_for_grad_sketch itself
    temp_loader = torch.utils.data.DataLoader(temp_dataset, batch_size=min(len(selected_examples), config.k_for_grad_sketch), shuffle=False)

    accumulated_total_loss = torch.tensor(0.0, device=device)

    for batch_data in temp_loader:
        input_ids = batch_data["input_ids"].to(device)
        attention_mask = batch_data["attention_mask"].to(device)
        labels = batch_data["labels"].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        if loss is not None and torch.isfinite(loss):
            accumulated_total_loss += loss # Accumulate loss before backward if needed, or just backward per batch
            loss.backward() # Accumulates gradients on target_params_with_grads
        else:
            logger.warning(f"GradientSketch: Loss was None or non-finite for a batch. Skipping backward for this batch.")
    
    # --- Extract, concatenate, and process gradients ---
    all_gradients_list = []
    for param in target_params_with_grads:
        if param.grad is not None:
            all_gradients_list.append(param.grad.data.detach().cpu().flatten())
        else:
            # Param might not have received grad if it wasn't involved in loss computation for any example
            logger.warning(f"GradientSketch: Parameter {param.shape} had no gradient. Filling with zeros.")
            all_gradients_list.append(torch.zeros(param.numel(), dtype=torch.float32, device='cpu'))


    # --- Cleanup: Zero gradients and restore original requires_grad states ---
    model.zero_grad() # Zero out all gradients on the model
    for name, p in model.named_parameters(): # Restore original requires_grad states
        if name in original_requires_grad:
            p.requires_grad_(original_requires_grad[name])
    
    if original_training_state: # Restore original model training state
        model.train()
    # else: model.eval() # Already in eval

    if not all_gradients_list:
        logger.error("GradientSketch: No gradients were extracted.")
        return None

    concatenated_gradients = torch.cat(all_gradients_list)
    
    # Subsetting / Projection
    if concatenated_gradients.numel() == 0:
        logger.error("GradientSketch: Concatenated gradients are empty.")
        return None

    if concatenated_gradients.numel() > config.grad_sketch_max_elements:
        gradient_sketch = concatenated_gradients[:config.grad_sketch_max_elements]
    elif concatenated_gradients.numel() < config.grad_sketch_max_elements:
        padding_size = config.grad_sketch_max_elements - concatenated_gradients.numel()
        gradient_sketch = F.pad(concatenated_gradients, (0, padding_size), 'constant', 0)
    else:
        gradient_sketch = concatenated_gradients
    
    logger.info(f"GradientSketch: Generated sketch of shape {gradient_sketch.shape}, Loss sum: {accumulated_total_loss.item():.4f}")
    return gradient_sketch.to(device=device, dtype=torch.float32) # Ensure consistent dtype for router

def generate_context_stats(
    examples: List[Dict[str, str]], # 'input' field used
    tokenizer: PreTrainedTokenizer,
    model: PreTrainedModel, # Frozen base model for token entropy
    config: Any, # AdaptiveLearnerConfig
    device: str
) -> Optional[torch.Tensor]:
    """
    Generates context statistics: average sequence length and average token entropy.
    """
    if not config.use_advanced_embeddings:
        return None
    if not examples:
        logger.warning("ContextStats: No examples provided.")
        return None

    stats = []
    # For now, k_for_grad_sketch can be reused, or a new config param can be added.
    num_examples_for_stats = getattr(config, 'k_for_context_stats', config.k_for_grad_sketch)
    selected_examples = examples[:min(num_examples_for_stats, len(examples))]
    if not selected_examples:
        return None
        
    input_texts = [ex['input'] for ex in selected_examples]
    tokenized_inputs = tokenizer(
        input_texts, 
        return_tensors="pt", 
        padding="longest", 
        truncation=True, 
        max_length=config.max_seq_length
    )

    # 1. Average Sequence Length (of actual tokens, not padding)
    if config.use_context_stats_avg_seq_length:
        sequence_lengths = tokenized_inputs['attention_mask'].sum(dim=1).float()
        avg_seq_len = sequence_lengths.mean().item()
        stats.append(avg_seq_len)

    # 2. Average Token Entropy (model-based predictive entropy)
    if config.use_context_stats_avg_token_entropy:
        original_training_state = model.training
        model.eval()
        avg_token_entropy_overall = 0.0
        num_valid_seqs_for_entropy = 0

        with torch.no_grad():
            ids = tokenized_inputs['input_ids'].to(device)
            mask = tokenized_inputs['attention_mask'].to(device)
            
            outputs = model(input_ids=ids, attention_mask=mask)
            logits = outputs.logits # [batch_size, seq_len, vocab_size]

            # Consider entropy of predictions for actual tokens (masked by attention_mask)
            # and perhaps only for the part of the sequence that corresponds to the input, not generated part.
            # For simplicity, let's average entropy over all non-padded tokens.
            token_entropies_batch = []
            for b in range(logits.shape[0]):
                seq_logits = logits[b, :, :] # [seq_len, vocab_size]
                seq_mask = mask[b, :]       # [seq_len]
                
                active_logits = seq_logits[seq_mask == 1] # [num_active_tokens, vocab_size]
                if active_logits.shape[0] == 0:
                    continue

                probs = F.softmax(active_logits.float(), dim=-1) # Ensure float for softmax stability
                log_probs = F.log_softmax(active_logits.float(), dim=-1)
                entropy_per_token = -torch.sum(probs * log_probs, dim=-1) # [num_active_tokens]
                
                if entropy_per_token.numel() > 0:
                    token_entropies_batch.append(entropy_per_token.mean().item())
            
            if token_entropies_batch:
                avg_token_entropy_overall = np.mean(token_entropies_batch)
            else: # Default if no active tokens
                avg_token_entropy_overall = 0.0 

        stats.append(avg_token_entropy_overall)
        if original_training_state: model.train()

    # 3. Domain ID (Optional, not implemented yet)
    # if hasattr(config, 'num_domains_for_one_hot') and config.num_domains_for_one_hot > 0:
    #     # domain_id = get_domain_id_for_task(...) # Needs a helper
    #     # one_hot_domain = F.one_hot(torch.tensor(domain_id), num_classes=config.num_domains_for_one_hot)
    #     # stats.extend(one_hot_domain.float().tolist())
    #     pass

    if not stats:
        return None
        
    context_stats_tensor = torch.tensor(stats, dtype=torch.float32, device=device)
    logger.info(f"ContextStats: Generated stats vector: {context_stats_tensor.tolist()}")
    return context_stats_tensor

PEFT Manager and Router Classes (Updated for Training)

In [10]:
# Cell 10: PEFT Manager and Router Classes
# ROBUST PROFILE DILUTION FIX v3
# + ENHANCED DIAGNOSTIC LOGGING in LinearRouter save/load_router_state
# + MVE-Router-Playbook-Phase1 (Advanced Embeddings) modifications

import torch.optim as optim
import os
import torch
import json
from pathlib import Path
import time
from peft import PeftModel, LoraConfig, TaskType, get_peft_model, PeftConfig
from transformers import PreTrainedModel, AutoTokenizer
from typing import List, Dict, Any, Tuple, Optional, Union
from tqdm.notebook import tqdm
import torch.nn.functional as F
import torch.nn as nn
import numpy as np # Added for potential use in advanced features, e.g. context stats

# Ensure CausalLMTrainingDataset is available if used by helpers called from router
# (It's defined in Cell 16, so this assumes Cell 16 has run or it's defined globally)
# from YOUR_NOTEBOOK_FILE import CausalLMTrainingDataset # Or ensure it's defined before this cell

# Import helper functions for advanced embeddings (assuming they are in Cell 8.6)
# These will be called by the Router's get_advanced_task_representation
# from YOUR_NOTEBOOK_FILE import generate_gradient_sketch, generate_context_stats # Or ensure they are defined

if 'logger' not in globals():
    import logging
    logger = logging.getLogger(__name__ + ".peft_router_adv_embed") # Updated logger name
    if not logger.hasHandlers():
        handler = logging.StreamHandler(); formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        handler.setFormatter(formatter); logger.addHandler(handler); logger.setLevel(logging.INFO)

class PEFTManager:
    def __init__(self, config: Any, model: PreTrainedModel):
        self.config = config; self.model = model; self.device = torch.device(config.device)
        self.lora_modules_metadata: Dict[str, Dict[str, Any]] = {}; self.active_lora_ids: List[str] = []
        self.lora_modules_base_dir = os.path.join(self.config.output_dir, "lora_modules"); os.makedirs(self.lora_modules_base_dir, exist_ok=True)
        self.router_states_dir = os.path.join(self.config.output_dir, "router_states"); os.makedirs(self.router_states_dir, exist_ok=True)
        if self.config.router_type == "linear": self.router = LinearRouter(self.config, self.model)
        elif self.config.router_type == "mlp": self.router = MLPRouter(self.config, self.model)
        else: raise ValueError(f"Unknown router type: {self.config.router_type}")
        logger.info(f"PEFTManager init: Router object instantiated. Initial learnable profiles in router: {len(getattr(self.router, 'module_embeddings', {})) if hasattr(self.router, 'module_embeddings') else len(getattr(self.router, 'output_projection_layers', {})) }")
        if self.config.load_router_state_from_tag:
            path_to_load_router_state = os.path.join(self.router_states_dir, f"router_state_{self.config.load_router_state_from_tag}.pth")
            logger.info(f"PEFTManager init: Attempting to load router state from '{path_to_load_router_state}' directly into router.")
            if os.path.exists(path_to_load_router_state): self.router.load_router_state(path_to_load_router_state); num_learnable_profiles_after_load = len(self.router.module_embeddings) if hasattr(self.router, 'module_embeddings') else len(self.router.output_projection_layers); logger.info(f"PEFTManager init: Router state loaded. Router now has {num_learnable_profiles_after_load} learnable profile(s) from state file.")
            else: logger.warning(f"PEFTManager init: Specified router state file NOT FOUND at {path_to_load_router_state}. Router remains with its initial fresh parameters.")
        else: logger.info("PEFTManager init: No router state to load (load_router_state_from_tag is None). Router starts fresh.")
        self._scan_disk_and_update_metadata_globally()

    def _scan_disk_and_update_metadata_globally(self):
        logger.info(f"PEFTManager: (_scan_disk_and_update_metadata_globally) Scanning disk ({self.lora_modules_base_dir}) for LoRA adapter metadata.")
        found_on_disk_count = 0; updated_router_metadata_count = 0
        if not os.path.exists(self.lora_modules_base_dir): logger.info("PEFTManager: LoRA modules base directory not found."); return
        lora_parent_dirs = [d for d in os.listdir(self.lora_modules_base_dir) if os.path.isdir(os.path.join(self.lora_modules_base_dir, d))]
        if not lora_parent_dirs: logger.info("PEFTManager: No LoRA parent directories found on disk.")
        else:
            for tagged_lora_id_on_disk in lora_parent_dirs:
                lora_parent_dir_path = os.path.join(self.lora_modules_base_dir, tagged_lora_id_on_disk); metadata_file_path = os.path.join(lora_parent_dir_path, "metadata.json")
                # Try to find adapter_config.json in either the parent dir or a nested dir with the same name
                adapter_config_file_path_nested = os.path.join(lora_parent_dir_path, tagged_lora_id_on_disk, "adapter_config.json")
                adapter_config_file_path_parent = os.path.join(lora_parent_dir_path, "adapter_config.json")
                disk_metadata = None
                if os.path.exists(adapter_config_file_path_nested) or os.path.exists(adapter_config_file_path_parent):
                    if os.path.exists(metadata_file_path):
                        try:
                            with open(metadata_file_path, "r") as f: disk_metadata = json.load(f)
                            if disk_metadata.get("adapter_name") != tagged_lora_id_on_disk: disk_metadata["adapter_name"] = tagged_lora_id_on_disk
                        except Exception as e: logger.warning(f"PEFTManager: Error reading metadata.json for {tagged_lora_id_on_disk}: {e}. Using basic."); disk_metadata = {"adapter_name": tagged_lora_id_on_disk, "source": "disk_scan_metadata_read_error"}
                    else: logger.warning(f"PEFTManager: Adapter {tagged_lora_id_on_disk} on disk missing metadata.json. Using basic info."); disk_metadata = {"adapter_name": tagged_lora_id_on_disk, "source": "disk_scan_no_meta_json"}
                if disk_metadata:
                    self.lora_modules_metadata[tagged_lora_id_on_disk] = disk_metadata; found_on_disk_count += 1
                    # Update router's non-learnable metadata if this LoRA is already known (i.e., has a learnable profile)
                    if tagged_lora_id_on_disk in self.router.lora_id_to_idx and \
                       ((hasattr(self.router, 'module_embeddings') and tagged_lora_id_on_disk in self.router.module_embeddings) or \
                        (hasattr(self.router, 'output_projection_layers') and tagged_lora_id_on_disk in self.router.output_projection_layers)):
                        self.router.module_metadata[tagged_lora_id_on_disk] = disk_metadata; updated_router_metadata_count +=1
        logger.info(f"PEFTManager: Disk scan finished. Found/updated metadata for {found_on_disk_count} LoRAs in PEFTManager's general records.")
        if updated_router_metadata_count > 0: logger.info(f"PEFTManager: Updated router's non-learnable metadata for {updated_router_metadata_count} LoRA profiles already actively managed by the router.")
        num_learnable_profiles_final = len(self.router.module_embeddings) if hasattr(self.router, 'module_embeddings') else len(self.router.output_projection_layers)
        logger.info(f"PEFTManager: After init and disk scan, router has {num_learnable_profiles_final} learnable profile(s).")

    def create_lora_module(self, tagged_lora_id: str, task_type: TaskType = TaskType.CAUSAL_LM) -> PreTrainedModel:
        lora_config_obj = LoraConfig(r=self.config.lora_r, lora_alpha=self.config.lora_alpha, lora_dropout=self.config.lora_dropout, target_modules=self.config.target_modules, bias="none", task_type=task_type )
        if not isinstance(self.model, PeftModel): self.model = get_peft_model(self.model, lora_config_obj, adapter_name=tagged_lora_id)
        else:
            if tagged_lora_id not in self.model.peft_config: self.model.add_adapter(tagged_lora_id, lora_config_obj)
        self.model.set_adapter(tagged_lora_id); self.active_lora_ids = self.model.active_adapters
        metadata_for_lora = {"task_type": task_type.value if hasattr(task_type, 'value') else str(task_type), "created_at": time.time(), "adapter_name": tagged_lora_id,
            "original_task_id_if_available": tagged_lora_id.split(f"_{self.config.experiment_tag}_")[-1].split(f"{self.config.experiment_tag}_")[-1] if self.config.experiment_tag and f"{self.config.experiment_tag}_" in tagged_lora_id else tagged_lora_id,
            "experiment_tag": self.config.experiment_tag, "config_details": {"r": self.config.lora_r, "lora_alpha": self.config.lora_alpha, "lora_dropout": self.config.lora_dropout, "target_modules": self.config.target_modules} }
        self.lora_modules_metadata[tagged_lora_id] = metadata_for_lora
        self.router.add_module(tagged_lora_id, metadata_for_lora)
        specific_lora_module_dir = os.path.join(self.lora_modules_base_dir, tagged_lora_id); os.makedirs(specific_lora_module_dir, exist_ok=True)
        with open(os.path.join(specific_lora_module_dir, "metadata.json"), "w") as f: json.dump(metadata_for_lora, f, indent=2)
        logger.info(f"PEFTManager: Created LoRA '{tagged_lora_id}' and called router.add_module.")
        return self.model

    def load_lora_module(self, tagged_lora_id_to_load: str, set_as_active: bool = True) -> Optional[PreTrainedModel]:
        logger.info(f"PEFTManager.load_lora_module: Attempting to load LoRA weights for '{tagged_lora_id_to_load}'.")
        lora_parent_dir = os.path.join(self.lora_modules_base_dir, tagged_lora_id_to_load); actual_adapter_files_dir = os.path.join(lora_parent_dir, tagged_lora_id_to_load)
        path_to_try_loading_from = ""; config_found = False
        if os.path.exists(os.path.join(actual_adapter_files_dir, "adapter_config.json")): path_to_try_loading_from = actual_adapter_files_dir; config_found = True
        elif os.path.exists(os.path.join(lora_parent_dir, "adapter_config.json")): path_to_try_loading_from = lora_parent_dir; config_found = True
        if not config_found: logger.warning(f"PEFTManager.load_lora_module: Adapter config NOT found for '{tagged_lora_id_to_load}'. Cannot load weights."); return None
        if not isinstance(self.model, PeftModel):
            try: self.model = PeftModel.from_pretrained(self.model, path_to_try_loading_from, adapter_name=tagged_lora_id_to_load, is_trainable=True)
            except Exception as e: logger.error(f"PEFTManager.load_lora_module: Failed to load '{tagged_lora_id_to_load}' as first adapter: {e}"); return None
        else:
            if tagged_lora_id_to_load not in self.model.peft_config:
                try: self.model.load_adapter(path_to_try_loading_from, adapter_name=tagged_lora_id_to_load, is_trainable=True)
                except Exception as e: logger.error(f"PEFTManager.load_lora_module: Failed to load adapter weights '{tagged_lora_id_to_load}': {e}"); return None
            else: logger.info(f"PEFTManager.load_lora_module: Adapter weights for '{tagged_lora_id_to_load}' already exist on PeftModel.")
        loaded_metadata_for_router = self.lora_modules_metadata.get(tagged_lora_id_to_load)
        if not loaded_metadata_for_router:
            metadata_file_path = os.path.join(lora_parent_dir, "metadata.json")
            if os.path.exists(metadata_file_path):
                with open(metadata_file_path, "r") as f: loaded_metadata_for_router = json.load(f)
            else:
                logger.warning(f"PEFTManager.load_lora_module: metadata.json not found for {tagged_lora_id_to_load}. Reconstructing basic for router.")
                lora_conf_on_model = self.model.peft_config.get(tagged_lora_id_to_load)
                if lora_conf_on_model: loaded_metadata_for_router = {"adapter_name": tagged_lora_id_to_load, "source": "load_lora_module_reconstructed", "original_task_id_if_available": tagged_lora_id_to_load.split(f"_{self.config.experiment_tag}_")[-1].split(f"{self.config.experiment_tag}_")[-1] if self.config.experiment_tag and f"{self.config.experiment_tag}_" in tagged_lora_id_to_load else tagged_lora_id_to_load, "experiment_tag": self.config.experiment_tag }
                else: loaded_metadata_for_router = {"adapter_name": tagged_lora_id_to_load, "source": "error_no_meta_no_conf_on_load"}
            self.lora_modules_metadata[tagged_lora_id_to_load] = loaded_metadata_for_router
        if loaded_metadata_for_router: self.router.add_module(tagged_lora_id_to_load, loaded_metadata_for_router)
        else: logger.error(f"PEFTManager.load_lora_module: Could not obtain or reconstruct metadata for {tagged_lora_id_to_load}. Not adding to router.")
        if set_as_active: self.model.set_adapter(tagged_lora_id_to_load); self.active_lora_ids = self.model.active_adapters
        logger.info(f"PEFTManager.load_lora_module: Successfully processed (weights loaded/confirmed) LoRA adapter '{tagged_lora_id_to_load}'")
        return self.model

    def select_lora_for_input(self, task_description: str, task_examples_for_prototype: Optional[List[Dict[str, Any]]] = None, task_text_field: Optional[Union[str, List[str]]] = None, top_k: int = 1) -> List[Tuple[str, float]]:
        if self.router is None: logger.warning("PEFTManager: Router not initialized. Cannot select LoRA."); return []
        # task_examples_for_prototype here is passed as task_examples_for_feature_gen to Router's select_modules
        return self.router.select_modules(task_description=task_description, task_examples_for_feature_gen=task_examples_for_prototype, task_text_field=task_text_field, top_k=top_k)

    def save_lora_module(self, tagged_lora_id_to_save: str, peft_model_instance: PeftModel):
        if not isinstance(peft_model_instance, PeftModel) or tagged_lora_id_to_save not in peft_model_instance.peft_config: logger.error(f"Cannot save adapter '{tagged_lora_id_to_save}'. Not a PeftModel or adapter not found in config."); return
        lora_parent_dir = os.path.join(self.lora_modules_base_dir, tagged_lora_id_to_save); os.makedirs(lora_parent_dir, exist_ok=True)
        try:
            peft_model_instance.save_pretrained(lora_parent_dir, selected_adapters=[tagged_lora_id_to_save])
            logger.info(f"Saved LoRA adapter '{tagged_lora_id_to_save}' (likely into a nested subdir within {lora_parent_dir})")
            metadata_file_path = os.path.join(lora_parent_dir, "metadata.json")
            if not os.path.exists(metadata_file_path):
                meta_to_save = self.lora_modules_metadata.get(tagged_lora_id_to_save)
                if not meta_to_save:
                    lora_conf = peft_model_instance.peft_config.get(tagged_lora_id_to_save)
                    if lora_conf: meta_to_save = {"task_type": str(lora_conf.task_type), "saved_at": time.time(), "adapter_name": tagged_lora_id_to_save,
                            "original_task_id_if_available": tagged_lora_id_to_save.split(f"_{self.config.experiment_tag}_")[-1].split(f"{self.config.experiment_tag}_")[-1] if self.config.experiment_tag and f"{self.config.experiment_tag}_" in tagged_lora_id_to_save else tagged_lora_id_to_save,
                            "experiment_tag": self.config.experiment_tag, "config_details": {"r":lora_conf.r, "lora_alpha": lora_conf.lora_alpha, "lora_dropout": lora_conf.lora_dropout, "target_modules": list(lora_conf.target_modules) if isinstance(lora_conf.target_modules, (list, tuple, set)) else str(lora_conf.target_modules) }}
                    else: meta_to_save = {}
                if meta_to_save:
                    with open(metadata_file_path, "w") as f: json.dump(meta_to_save, f, indent=2)
                    logger.info(f"Saved/Reconstructed metadata.json for adapter '{tagged_lora_id_to_save}' in {lora_parent_dir}.")
                else: logger.warning(f"Could not find or reconstruct metadata for adapter '{tagged_lora_id_to_save}'. metadata.json not saved.")
        except Exception as e: logger.error(f"Error saving LoRA adapter '{tagged_lora_id_to_save}': {e}")

    def save_router_state(self):
        if self.router is not None and hasattr(self.router, 'save_router_state'):
            if self.config.experiment_tag: current_router_state_file = os.path.join(self.router_states_dir, f"router_state_{self.config.experiment_tag}.pth"); logger.info(f"PEFTManager attempting to save router state for experiment '{self.config.experiment_tag}' to {current_router_state_file}"); self.router.save_router_state(current_router_state_file)
            else: current_router_state_file = os.path.join(self.router_states_dir, "router_state_generic.pth"); logger.warning(f"Experiment tag not set. Saving router state to generic path: {current_router_state_file}"); self.router.save_router_state(current_router_state_file)
        else: logger.warning("PEFTManager: Router not available or does not support saving state.")

    def activate_lora_modules(self, lora_id_or_ids_to_activate: Union[str, List[str]]):
        if not isinstance(self.model, PeftModel): logger.warning("Base model not PeftModel."); return
        adapters_to_verify = [lora_id_or_ids_to_activate] if isinstance(lora_id_or_ids_to_activate, str) else lora_id_or_ids_to_activate
        if not all(l_id in self.model.peft_config for l_id in adapters_to_verify): logger.error(f"One or more adapters not found in PeftModel config: {adapters_to_verify}. Available: {list(self.model.peft_config.keys()) if hasattr(self.model, 'peft_config') else 'N/A'}"); return
        try: self.model.set_adapter(lora_id_or_ids_to_activate); self.active_lora_ids = self.model.active_adapters
        except Exception as e: logger.error(f"Error activating LoRA(s) {lora_id_or_ids_to_activate}: {e}")

    def get_active_lora_modules(self) -> List[str]: return self.model.active_adapters if isinstance(self.model, PeftModel) and hasattr(self.model, 'active_adapters') else []
    def get_current_peft_model(self) -> PreTrainedModel: return self.model

    def train_router(
        self,
        # Training data format: List of ((task_desc, Optional[List[Dict from dataset (original_example_data)]], Optional[task_text_field]), target_lora_id_str)
        training_data: List[Tuple[Tuple[str, Optional[List[Dict[str, Any]]], Optional[Union[str, List[str]]]], str]],
        learning_rate: Optional[float] = None
    ):
        if not hasattr(self.router, 'train_step'):
            logger.warning("Router missing 'train_step'. Cannot train router.")
            return
        if not training_data:
            logger.info("No data provided for router training.")
            return

        num_router_train_epochs = getattr(self.config, 'router_training_epochs_per_call', 1)
        if num_router_train_epochs <= 0: num_router_train_epochs = 1

        lr_to_use = learning_rate if learning_rate is not None else self.config.router_learning_rate

        logger.info(f"Starting router training: {len(training_data)} samples, {num_router_train_epochs} epochs (from config), LR: {lr_to_use:.1e}.")
        self.router.train_step(training_data, num_router_train_epochs, lr_to_use)
        logger.info("Router training finished.")

    def update_expert_profile_in_router(
        self,
        lora_id_str: str,
        task_description: str,
        validation_examples_formatted: List[Dict[str, Any]], # List[Dict from dataset (original_example_data)]
        task_text_field: Union[str, List[str]]
    ):
        """
        Computes the advanced representation for a trained/adapted LoRA module (expert)
        and updates its profile in the router.
        """
        if not self.router or not hasattr(self.router, 'get_advanced_task_representation'):
            logger.warning("PEFTManager: Router not available or missing get_advanced_task_representation. Cannot update expert profile.")
            return

        # For LinearRouter, check module_embeddings. For MLPRouter, it would be output_projection_layers.
        is_lora_in_router_learnable_params = False
        if hasattr(self.router, 'module_embeddings') and lora_id_str in self.router.module_embeddings:
            is_lora_in_router_learnable_params = True
        elif hasattr(self.router, 'output_projection_layers') and lora_id_str in self.router.output_projection_layers:
            # MLPRouter doesn't store embeddings directly, it has output layers. This update logic might differ.
            # For now, assuming dynamic profile update applies to LinearRouter's learnable embeddings.
            logger.warning(f"PEFTManager: LoRA ID '{lora_id_str}' found in MLPRouter's output_projection_layers. Dynamic embedding update logic for MLPRouter needs specific implementation if profiles are not direct embeddings.")
            return # Or implement MLPRouter specific update logic

        if not is_lora_in_router_learnable_params:
            logger.warning(f"PEFTManager: LoRA ID '{lora_id_str}' not found in router's learnable parameters. Cannot update its profile.")
            return

        if not hasattr(self.config, 'use_advanced_embeddings') or not self.config.use_advanced_embeddings:
            logger.info("PEFTManager: Advanced embeddings disabled. Skipping expert profile update.")
            return

        logger.info(f"PEFTManager: Updating expert profile in router for LoRA '{lora_id_str}' using advanced embeddings.")

        new_expert_representation_tensor = self.router.get_advanced_task_representation(
            task_description=task_description,
            task_examples_for_feature_gen=validation_examples_formatted, # Pass List[Dict[str,Any]]
            task_text_field=task_text_field
        )

        if new_expert_representation_tensor is not None:
            # This assumes LinearRouter, where module_embeddings stores the nn.Parameters
            if hasattr(self.router, 'module_embeddings') and lora_id_str in self.router.module_embeddings:
                with torch.no_grad():
                    target_param = self.router.module_embeddings[lora_id_str]
                    if new_expert_representation_tensor.shape[0] != target_param.shape[0]:
                        logger.error(f"PEFTManager: Dimension Mismatch! New expert repr dim {new_expert_representation_tensor.shape[0]} vs existing router profile dim {target_param.shape[0]} for '{lora_id_str}'. NOT UPDATING.")
                        return
                    target_param.data = new_expert_representation_tensor.to(
                        device=target_param.device, dtype=target_param.dtype
                    )
                logger.info(f"PEFTManager: Successfully updated profile for LoRA '{lora_id_str}' in router with new advanced representation.")
            else:
                logger.error(f"PEFTManager: Attempted to update profile for '{lora_id_str}', but it's not in LinearRouter's module_embeddings or router type is not LinearRouter with module_embeddings.")
        else:
            logger.error(f"PEFTManager: Failed to generate advanced representation for LoRA '{lora_id_str}'. Profile not updated.")


class Router(nn.Module):
    def __init__(self, config: Any, model_for_embeddings: PreTrainedModel):
        super().__init__(); self.config = config; self.model_for_embeddings = model_for_embeddings
        self.device = torch.device(config.device); self.module_metadata: Dict[str, Dict[str, Any]] = {}
        self.lora_id_to_idx: Dict[str, int] = {}; self.idx_to_lora_id: Dict[int, str] = {}; self.next_lora_idx = 0
        self.optimizer = None; model_name_for_tokenizer = self.config.model_name
        custom_model_path = os.environ.get("CUSTOM_MODEL_PATH")
        if custom_model_path and os.path.exists(custom_model_path): model_name_for_tokenizer = custom_model_path
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_for_tokenizer)
        if self.tokenizer.pad_token is None:
            if self.tokenizer.eos_token is not None: self.tokenizer.pad_token = self.tokenizer.eos_token
            elif self.tokenizer.unk_token is not None: self.tokenizer.pad_token = self.tokenizer.unk_token
            else: logger.warning(f"Tokenizer for {model_name_for_tokenizer} has no pad_token, eos_token, or unk_token. Adding new pad_token '<|pad|>'."); self.tokenizer.add_special_tokens({'pad_token': '<|pad|>'})
        self.router_internal_dtype = torch.float32
        try: self.embedding_model_dtype = next(self.model_for_embeddings.parameters()).dtype
        except StopIteration: self.embedding_model_dtype = torch.float32
        logger.info(f"Router common init: Embedding model dtype: {self.embedding_model_dtype}, Router internal param dtype: {self.router_internal_dtype}")

    def add_module(self, module_id_str: str, metadata: Dict[str, Any]):
        self.module_metadata[module_id_str] = metadata
        if module_id_str not in self.lora_id_to_idx:
            self.lora_id_to_idx[module_id_str] = self.next_lora_idx; self.idx_to_lora_id[self.next_lora_idx] = module_id_str
            self.next_lora_idx += 1
        self._on_module_added(module_id_str)

    def _on_module_added(self, module_id_str: str): pass

    def get_advanced_task_representation(
        self,
        task_description: str,
        task_examples_for_feature_gen: Optional[List[Dict[str, Any]]] = None, # List of original_example_data dicts
        task_text_field: Optional[Union[str, List[str]]] = None
    ) -> Optional[torch.Tensor]:
        if not hasattr(self.config, 'use_advanced_embeddings') or not self.config.use_advanced_embeddings:
            logger.info("Router: Advanced embeddings disabled by config. Falling back to simpler representation (desc only for now).")
            # Fallback: Just task description embedding
            inputs_desc = self.tokenizer(task_description, return_tensors="pt", truncation=True, max_length=self.config.max_seq_length, padding="max_length").to(self.device)
            original_training_state_fb = self.model_for_embeddings.training
            if hasattr(self.model_for_embeddings, 'eval'): self.model_for_embeddings.eval()
            fb_task_description_embedding: Optional[torch.Tensor] = None
            with torch.no_grad():
                try:
                    outputs_desc = self.model_for_embeddings(**inputs_desc, output_hidden_states=True); last_hidden_state_desc = outputs_desc.hidden_states[-1]
                    attention_mask_desc_expanded = inputs_desc['attention_mask'].unsqueeze(-1).expand(last_hidden_state_desc.size()).float()
                    sum_embeddings_desc = torch.sum(last_hidden_state_desc * attention_mask_desc_expanded, 1); sum_mask_desc = torch.clamp(attention_mask_desc_expanded.sum(1), min=1e-9)
                    fb_task_description_embedding = (sum_embeddings_desc / sum_mask_desc)[0].to(self.embedding_model_dtype)
                except Exception as e: logger.error(f"Router (Fallback): Error getting task desc embedding: {e}")
            if original_training_state_fb and hasattr(self.model_for_embeddings, 'train'): self.model_for_embeddings.train()
            return fb_task_description_embedding.squeeze().to(self.router_internal_dtype) if fb_task_description_embedding is not None else None

        all_feature_components = []
        original_training_state = self.model_for_embeddings.training # Save initial model state

        # --- Signal 1: Text Meta-Features (Task Description Embedding) ---
        if hasattr(self.model_for_embeddings, 'eval'): self.model_for_embeddings.eval() # Eval for desc embedding
        task_description_embedding: Optional[torch.Tensor] = None
        try:
            inputs_desc = self.tokenizer(task_description, return_tensors="pt", truncation=True, max_length=self.config.max_seq_length, padding="max_length").to(self.device)
            with torch.no_grad():
                outputs_desc = self.model_for_embeddings(**inputs_desc, output_hidden_states=True)
                last_hidden_state_desc = outputs_desc.hidden_states[-1]
                attention_mask_desc_expanded = inputs_desc['attention_mask'].unsqueeze(-1).expand(last_hidden_state_desc.size()).float()
                sum_embeddings_desc = torch.sum(last_hidden_state_desc * attention_mask_desc_expanded, 1)
                sum_mask_desc = torch.clamp(attention_mask_desc_expanded.sum(1), min=1e-9)
                task_description_embedding = (sum_embeddings_desc / sum_mask_desc)[0].to(self.router_internal_dtype)
                all_feature_components.append(task_description_embedding)
            logger.info(f"Router: Generated task_description_embedding, shape {task_description_embedding.shape}")
        except Exception as e:
            logger.error(f"Router: Error getting task description embedding (Signal 1): {e}")
            # No return yet, other signals might still work. If all fail, then return None.

        # --- Prepare examples for Gradient Sketch and Context Stats ---
        # These examples need to be formatted as List[Dict{'input':str, 'output':str}] for CausalLMTrainingDataset in grad_sketch
        # and List[Dict{'input':str}] for context_stats.
        # task_examples_for_feature_gen is List[Dict from dataset (original_example_data)]
        
        formatted_examples_for_grad_sketch: List[Dict[str,str]] = []
        formatted_examples_for_context_stats: List[Dict[str,str]] = [] # Context stats typically only needs 'input'

        if task_examples_for_feature_gen and task_text_field:
            for ex_orig_data in task_examples_for_feature_gen: # ex_orig_data is a dict like {'sentence': ..., 'label': ...}
                input_val_str, output_val_str = None, None
                
                # Construct input_val_str
                if isinstance(task_text_field, list): # sentence-pair
                    input_parts = [str(ex_orig_data.get(tf, "")) for tf in task_text_field]
                    input_val_str = " [SEP] ".join(input_parts)
                else: # single sentence
                    input_val_str = str(ex_orig_data.get(task_text_field, ""))

                # Construct output_val_str (needed for grad sketch's loss)
                # This requires knowing the label field and potentially a class_names_map.
                # This part is tricky as the Router itself doesn't usually know these.
                # For now, let's assume if 'output' key is directly in ex_orig_data, use it.
                # Otherwise, this part needs refinement in how examples are passed or processed.
                # A simpler approach for the MVE: Assume 'output' is already prepared in `task_examples_for_feature_gen`
                # if those dicts are already {'input': ..., 'output': ...}.
                # If `task_examples_for_feature_gen` contains dicts like `ex['original_example_data']` from `main`,
                # then `ex_orig_data` here IS that `original_example_data`.
                # The `main` function needs to pass `class_names_map` if labels need mapping here.
                # Let's assume the `main` function provides 'input' and 'output' (textual) in `task_examples_for_feature_gen`
                # if they are structured as List[Dict[str,str]]. If they are List[Dict[str,Any]], we parse here.
                
                # Simplification: If `task_examples_for_feature_gen` comes from `main` as `List[ex['original_example_data']]`,
                # then `ex_orig_data` IS `ex['original_example_data']`. The `main` function should pre-format
                # these into `{'input': ..., 'output': ...}` text lists before passing to router.
                # OR, we try to handle it here if it's a common structure.
                # The `generate_gradient_sketch` uses `CausalLMTrainingDataset` which expects `{'input': ..., 'output': ...}`.

                if 'input' in ex_orig_data and 'output' in ex_orig_data: # If already formatted
                    formatted_examples_for_grad_sketch.append({'input': str(ex_orig_data['input']), 'output': str(ex_orig_data['output'])})
                    formatted_examples_for_context_stats.append({'input': str(ex_orig_data['input'])})
                elif input_val_str: # If only input can be constructed
                     formatted_examples_for_context_stats.append({'input': input_val_str})
                     # For grad sketch, if 'output' is missing, it might not work.
                     # This indicates that `task_examples_for_feature_gen` should ideally be pre-formatted by the caller.
                     # For robustness, grad_sketch will check `examples`.

        # --- Signal 2: Gradient Sketch ---
        if formatted_examples_for_grad_sketch and hasattr(self.config, 'k_for_grad_sketch') and self.config.k_for_grad_sketch > 0:
            # generate_gradient_sketch expects `model` to be the frozen base model.
            # It internally handles requires_grad for target layers.
            gradient_sketch = generate_gradient_sketch(
                examples=formatted_examples_for_grad_sketch,
                model=self.model_for_embeddings,
                tokenizer=self.tokenizer,
                config=self.config,
                device=self.device
            )
            if gradient_sketch is not None:
                all_feature_components.append(gradient_sketch.to(self.router_internal_dtype))
                logger.info(f"Router: Generated gradient_sketch, shape {gradient_sketch.shape}")
        else:
            logger.info("Router: Not enough/no formatted examples for gradient sketch. Skipping.")

        # --- Signal 3: Context Stats ---
        # model_for_embeddings should be in eval mode for context_stats's entropy calc
        if hasattr(self.model_for_embeddings, 'eval'): self.model_for_embeddings.eval()
        if formatted_examples_for_context_stats:
            context_stats = generate_context_stats(
                examples=formatted_examples_for_context_stats,
                tokenizer=self.tokenizer,
                model=self.model_for_embeddings,
                config=self.config,
                device=self.device
            )
            if context_stats is not None:
                all_feature_components.append(context_stats.to(self.router_internal_dtype))
                logger.info(f"Router: Generated context_stats, shape {context_stats.shape}")
        else:
            logger.info("Router: No formatted examples for context stats. Skipping.")

        # --- Finalize and Restore Model State ---
        if original_training_state and hasattr(self.model_for_embeddings, 'train'): # Restore original model mode
            self.model_for_embeddings.train()
        elif not original_training_state and hasattr(self.model_for_embeddings, 'eval'):
             self.model_for_embeddings.eval()


        if not all_feature_components:
            logger.error("Router: All feature components for advanced representation failed to generate or were skipped.")
            return None

        try:
            final_representation = torch.cat(all_feature_components, dim=-1)
            final_representation = F.normalize(final_representation, p=2, dim=-1) # L2 Normalize
            logger.info(f"Router: Generated final_advanced_representation, shape {final_representation.shape}")
            return final_representation.to(self.router_internal_dtype)
        except Exception as e:
            logger.error(f"Router: Error concatenating or normalizing advanced features: {e}")
            return None

    def select_modules(self, task_description: str, task_examples_for_feature_gen: Optional[List[Dict[str, Any]]] = None, task_text_field: Optional[Union[str, List[str]]] = None, top_k: int = 1) -> List[Tuple[str, float]]:
        # This method will be implemented by subclasses like LinearRouter, MLPRouter
        # They will call self.get_advanced_task_representation
        raise NotImplementedError()

    def train_step(self, training_data: List[Tuple[Tuple[str, Optional[List[Dict[str, Any]]], Optional[Union[str, List[str]]]], str]], epochs: int, learning_rate: float ):
        # This method will be implemented by subclasses
        raise NotImplementedError()

    def save_router_state(self, path: str): raise NotImplementedError("Subclasses must implement save_router_state")
    def load_router_state(self, path: str): raise NotImplementedError("Subclasses must implement load_router_state")


class LinearRouter(Router):
    def __init__(self, config: Any, model_for_embeddings: PreTrainedModel):
        super().__init__(config, model_for_embeddings)

        current_embed_dim = 0
        if not hasattr(self.config, 'use_advanced_embeddings') or not self.config.use_advanced_embeddings:
            base_embed_dim = self.model_for_embeddings.config.hidden_size
            if hasattr(self.config, 'k_examples_for_prototype') and self.config.k_examples_for_prototype > 0:
                current_embed_dim = base_embed_dim * 2
                logger.info(f"LinearRouter initialized (ADV. EMBEDDINGS OFF) for RICH features (desc+prototype). Concatenated embed_dim: {current_embed_dim}.")
            else:
                current_embed_dim = base_embed_dim
                logger.info(f"LinearRouter initialized (ADV. EMBEDDINGS OFF) for DESCRIPTION features ONLY. Embed_dim: {current_embed_dim}.")
        else: # Calculate new embed_dim for advanced embeddings
            dim_desc_embed = self.model_for_embeddings.config.hidden_size
            current_embed_dim += dim_desc_embed
            dim_grad_sketch = 0
            if hasattr(self.config, 'k_for_grad_sketch') and self.config.k_for_grad_sketch > 0 and \
               hasattr(self.config, 'grad_sketch_layer_names') and self.config.grad_sketch_layer_names and \
               hasattr(self.config, 'grad_sketch_max_elements') and self.config.grad_sketch_max_elements > 0:
                dim_grad_sketch = self.config.grad_sketch_max_elements
                current_embed_dim += dim_grad_sketch
            dim_context_stats = 0
            if hasattr(self.config, 'use_context_stats_avg_seq_length') and self.config.use_context_stats_avg_seq_length:
                dim_context_stats += 1
            if hasattr(self.config, 'use_context_stats_avg_token_entropy') and self.config.use_context_stats_avg_token_entropy:
                dim_context_stats += 1
            current_embed_dim += dim_context_stats
            logger.info(f"LinearRouter initialized for ADVANCED EMBEDDINGS. Desc Dim: {dim_desc_embed}, GradSketch Dim: {dim_grad_sketch}, ContextStats Dim: {dim_context_stats}. Total Embed_dim: {current_embed_dim}.")

        self.embed_dim = current_embed_dim
        if self.embed_dim == 0:
            logger.error("LinearRouter: Calculated embed_dim is 0. Defaulting to base model hidden_size. Check config.")
            self.embed_dim = self.model_for_embeddings.config.hidden_size
        self.module_embeddings = nn.ParameterDict()

    def _setup_optimizer(self, learning_rate: float):
        params_to_optimize = [p for p_name, p in self.module_embeddings.items() if isinstance(p, nn.Parameter)]
        if params_to_optimize: self.optimizer = optim.AdamW(params_to_optimize, lr=learning_rate, weight_decay=self.config.weight_decay); logger.info(f"LinearRouter optimizer (re)set. LR: {learning_rate:.1e} with {len(params_to_optimize)} embedding parameters.")
        else: self.optimizer = None; logger.info(f"LinearRouter: No parameters in module_embeddings to optimize yet.")

    def _on_module_added(self, module_id_str: str):
        was_router_state_loaded_this_peft_manager_session = bool(self.config.load_router_state_from_tag)
        is_part_of_current_experiment_run = self.config.experiment_tag and self.config.experiment_tag in module_id_str

        if module_id_str not in self.module_embeddings:
            create_new_learnable_param = False
            if not was_router_state_loaded_this_peft_manager_session: # Fresh router context
                if is_part_of_current_experiment_run: create_new_learnable_param = True
                else: logger.info(f"LinearRouter._on_module_added (Fresh Router context): Module '{module_id_str}' (likely from disk scan, not tagged for current exp '{self.config.experiment_tag}') will be known by router but NO new learnable parameter created now.")
            else: # Loaded router context
                if is_part_of_current_experiment_run: create_new_learnable_param = True # New LoRA from current experiment
                else: logger.info(f"LinearRouter._on_module_added (Loaded Router context): Module '{module_id_str}' (from disk scan, not current exp '{self.config.experiment_tag}', not in loaded state's learnable params) - NO new learnable parameter created now.")

            if create_new_learnable_param:
                # Initialize with random data. Will be overwritten by update_expert_profile_in_router if advanced embeddings are on.
                new_embedding = nn.Parameter(torch.randn(self.embed_dim, device=self.device, dtype=self.router_internal_dtype))
                self.module_embeddings[module_id_str] = new_embedding
                logger.info(f"LinearRouter._on_module_added: CREATED new learnable embedding for '{module_id_str}' with dim {self.embed_dim}.")
                if self.optimizer and hasattr(self.optimizer, 'add_param_group') and isinstance(self.module_embeddings[module_id_str], nn.Parameter):
                     try: self.optimizer.add_param_group({'params': [self.module_embeddings[module_id_str]]}); logger.info(f"Added '{module_id_str}' embedding to existing LinearRouter optimizer.")
                     except Exception as e: logger.warning(f"Could not add '{module_id_str}' to optimizer: {e}. Will be picked up by _setup_optimizer."); self._setup_optimizer(self.config.router_learning_rate)
                elif not self.optimizer : self._setup_optimizer(self.config.router_learning_rate)
        else: # Module ID already in self.module_embeddings
            current_param = self.module_embeddings[module_id_str]
            if not isinstance(current_param, nn.Parameter):
                 logger.error(f"LinearRouter._on_module_added: {module_id_str} in module_embeddings but is not nn.Parameter! Type: {type(current_param)}. Attempting to re-init as Parameter.")
                 self.module_embeddings[module_id_str] = nn.Parameter(current_param.data.to(self.device).to(self.router_internal_dtype)); current_param = self.module_embeddings[module_id_str]
            if current_param.shape[0] != self.embed_dim:
                logger.warning(f"LinearRouter._on_module_added: Existing learnable embedding '{module_id_str}' has dim {current_param.shape[0]}, but router currently expects {self.embed_dim}. This can happen if config changed. Parameter will be reinitialized if profile is updated dynamically or router is retrained with this module as target.")

    def _get_all_module_embeddings_tensor_and_map(self) -> Tuple[Optional[torch.Tensor], List[str]]:
        if not self.module_embeddings or len(self.module_embeddings) == 0: return None, []
        valid_lora_ids = [lora_id for lora_id in self.lora_id_to_idx.keys() if lora_id in self.module_embeddings and isinstance(self.module_embeddings[lora_id], nn.Parameter) and self.module_embeddings[lora_id].shape[0] == self.embed_dim]
        if not valid_lora_ids: return None, []
        sorted_lora_ids = sorted(valid_lora_ids, key=lambda k: self.lora_id_to_idx[k])
        embeddings_list = [self.module_embeddings[lora_id] for lora_id in sorted_lora_ids]
        if not embeddings_list: return None, []
        return torch.stack(embeddings_list), sorted_lora_ids

    def select_modules(self, task_description: str, task_examples_for_feature_gen: Optional[List[Dict[str, Any]]] = None, task_text_field: Optional[Union[str, List[str]]] = None, top_k: int = 1) -> List[Tuple[str, float]]:
        self.eval();
        if not self.module_embeddings or len(self.module_embeddings) == 0 : logger.info("LinearRouter.select_modules: No module embeddings available."); return []
        # task_examples_for_feature_gen is List[Dict from dataset (original_example_data)]
        advanced_task_repr = self.get_advanced_task_representation(task_description, task_examples_for_feature_gen, task_text_field)
        if advanced_task_repr is None: logger.warning("LinearRouter.select_modules: Failed to get advanced task representation."); return []
        all_embeddings, ordered_ids = self._get_all_module_embeddings_tensor_and_map()
        if all_embeddings is None or len(ordered_ids) == 0: logger.info("LinearRouter.select_modules: No module embeddings tensor available for comparison."); return []
        if advanced_task_repr.shape[0] != all_embeddings.shape[1]:
            logger.error(f"LinearRouter.select_modules: Dimension mismatch! Task Repr dim {advanced_task_repr.shape[0]}, Module Embeds dim {all_embeddings.shape[1]}. Skipping selection.")
            return []
        if advanced_task_repr.dtype != all_embeddings.dtype: logger.warning(f"LinearRouter.select_modules: Dtype mismatch! Task Repr: {advanced_task_repr.dtype}, Module Embeds: {all_embeddings.dtype}. Casting task_repr."); advanced_task_repr = advanced_task_repr.to(all_embeddings.dtype)
        sim = F.cosine_similarity(advanced_task_repr.unsqueeze(0), all_embeddings); probs = F.softmax(sim, dim=0)
        top_probs, top_indices = torch.topk(probs, min(top_k, len(probs)))
        return [(ordered_ids[idx.item()], prob.item()) for idx, prob in zip(top_indices, top_probs)]

    def train_step(self, training_data: List[Tuple[Tuple[str, Optional[List[Dict[str, Any]]], Optional[Union[str, List[str]]]], str]], epochs: int, learning_rate: float):
        # training_data tuple: ((task_desc, task_examples_for_feature_gen, task_txt_fld), target_id_str)
        # task_examples_for_feature_gen is List[Dict from dataset (original_example_data)]
        if not self.module_embeddings or len(self.module_embeddings) == 0: logger.warning("LinearRouter: No module embeddings to train."); return
        if self.optimizer is None: self._setup_optimizer(learning_rate);
        if self.optimizer is None: logger.warning("LinearRouter: Optimizer could not be set up. Cannot train."); return
        self.train(); criterion = nn.CrossEntropyLoss()
        for epoch in range(epochs):
            epoch_loss = 0.0; num_batches = 0
            for (task_desc, examples_for_features, task_txt_fld), target_id_str in tqdm(training_data, desc=f"Router Lin Ep{epoch+1}", leave=False):
                if target_id_str not in self.module_embeddings or target_id_str not in self.lora_id_to_idx: logger.warning(f"Target ID '{target_id_str}' for router training not found in module_embeddings or lora_id_to_idx. Skipping."); continue
                self.optimizer.zero_grad()
                advanced_task_repr = self.get_advanced_task_representation(task_desc, examples_for_features, task_txt_fld )
                if advanced_task_repr is None: logger.warning(f"Router training: Failed to get advanced task representation for desc '{task_desc[:30]}...'. Skipping batch."); continue
                all_embeds, ordered_ids_stack = self._get_all_module_embeddings_tensor_and_map()
                if all_embeds is None or len(ordered_ids_stack) == 0: logger.warning("Router training: No module embeddings available to compare against. Skipping batch."); continue

                if advanced_task_repr.shape[0] != all_embeds.shape[1]:
                    logger.error(f"LinearRouter.train_step: Dimension mismatch! Task Repr dim {advanced_task_repr.shape[0]}, Module Embeds dim {all_embeds.shape[1]}. Skipping batch.")
                    continue
                if advanced_task_repr.dtype != all_embeds.dtype: advanced_task_repr = advanced_task_repr.to(all_embeds.dtype)

                logits = F.cosine_similarity(advanced_task_repr.unsqueeze(0), all_embeds).unsqueeze(0)
                try: target_idx = ordered_ids_stack.index(target_id_str)
                except ValueError: logger.error(f"Target '{target_id_str}' not in ordered_ids_stack derived from module_embeddings ({ordered_ids_stack}). LinRouter loss calc error."); continue
                target_indices = torch.tensor([target_idx], device=self.device); loss = criterion(logits, target_indices)
                if torch.isnan(loss).any() or torch.isinf(loss).any(): logger.error(f"LinearRouter TRAIN Ep{epoch+1}: Loss NaN/Inf! Skipping step."); continue
                loss.backward(); self.optimizer.step(); epoch_loss += loss.item(); num_batches +=1
            if num_batches > 0: logger.info(f"LinearRouter Ep {epoch+1} Avg Loss: {epoch_loss/num_batches:.4f}")

    def save_router_state(self, path: str):
        module_embeddings_to_save = {}
        num_params_being_saved = 0
        if self.module_embeddings:
            for k, v_param in self.module_embeddings.items():
                if isinstance(v_param, nn.Parameter):
                    module_embeddings_to_save[k] = v_param.cpu().data
                    num_params_being_saved += 1
        lora_id_to_idx_to_save = {k: v for k, v in self.lora_id_to_idx.items() if k in module_embeddings_to_save}
        idx_to_lora_id_to_save = {v: k for k, v in lora_id_to_idx_to_save.items()}
        module_metadata_to_save = {k: v for k, v in self.module_metadata.items() if k in module_embeddings_to_save}
        state = {
            'module_embeddings_state_dict': module_embeddings_to_save,
            'optimizer_state_dict': self.optimizer.state_dict() if self.optimizer and num_params_being_saved > 0 else None,
            'lora_id_to_idx': lora_id_to_idx_to_save,
            'idx_to_lora_id': idx_to_lora_id_to_save,
            'next_lora_idx': self.next_lora_idx,
            'module_metadata': module_metadata_to_save,
            'embed_dim': self.embed_dim
        }
        torch.save(state, path)
        logger.info(f"LinearRouter state saved to {path}. Saved embed_dim: {self.embed_dim}. Num profiles with learnable embeddings saved: {num_params_being_saved}. lora_id_to_idx entries saved: {len(lora_id_to_idx_to_save)}")

    def load_router_state(self, path: str):
        if not os.path.exists(path): logger.warning(f"LinearRouter: No state file found at {path}. Router not loaded."); return
        state = torch.load(path, map_location='cpu')
        logger.info(f"LinearRouter.load_router_state: Loading from {path}. Saved state keys: {list(state.keys())}")
        saved_embed_dim_from_file = state.get('embed_dim')

        # Determine current expected embed_dim based on config (for comparison)
        current_config_expected_embed_dim = 0
        if not hasattr(self.config, 'use_advanced_embeddings') or not self.config.use_advanced_embeddings:
            base_dim = self.model_for_embeddings.config.hidden_size
            current_config_expected_embed_dim = base_dim * 2 if (hasattr(self.config, 'k_examples_for_prototype') and self.config.k_examples_for_prototype > 0) else base_dim
        else:
            current_config_expected_embed_dim += self.model_for_embeddings.config.hidden_size
            if hasattr(self.config, 'k_for_grad_sketch') and self.config.k_for_grad_sketch > 0 and hasattr(self.config, 'grad_sketch_max_elements'):
                current_config_expected_embed_dim += self.config.grad_sketch_max_elements
            if hasattr(self.config, 'use_context_stats_avg_seq_length') and self.config.use_context_stats_avg_seq_length: current_config_expected_embed_dim +=1
            if hasattr(self.config, 'use_context_stats_avg_token_entropy') and self.config.use_context_stats_avg_token_entropy: current_config_expected_embed_dim +=1
        
        if saved_embed_dim_from_file is None:
            first_embed_param_tensor = next(iter(state.get('module_embeddings_state_dict', {}).values()), None)
            if first_embed_param_tensor is not None: saved_embed_dim_from_file = first_embed_param_tensor.shape[0]; logger.info(f"LinearRouter.load_router_state: 'embed_dim' not found in state. Inferred as {saved_embed_dim_from_file} from loaded parameters.")
            else: saved_embed_dim_from_file = current_config_expected_embed_dim; logger.info(f"LinearRouter.load_router_state: 'embed_dim' not found and no embeddings in state. Using current config-derived embed_dim: {saved_embed_dim_from_file}.")
        
        self.embed_dim = saved_embed_dim_from_file # Set router's embed_dim to what was saved
        logger.info(f"LinearRouter.load_router_state: Router internal embed_dim set to {self.embed_dim} (from loaded state).")
        if self.embed_dim != current_config_expected_embed_dim:
             logger.warning(f"LinearRouter.load_router_state: Loaded state embed_dim ({self.embed_dim}) MISMATCHES current config expected ({current_config_expected_embed_dim}). Ensure config (e.g., use_advanced_embeddings, k_examples_for_prototype) is consistent with the loaded state.")

        self.module_embeddings.clear()
        loaded_module_embeddings_state_dict = state.get('module_embeddings_state_dict', {})
        logger.info(f"LinearRouter.load_router_state: Found {len(loaded_module_embeddings_state_dict)} profiles in 'module_embeddings_state_dict' from file.")

        for module_id, param_tensor_cpu in loaded_module_embeddings_state_dict.items():
            if param_tensor_cpu.shape[0] != self.embed_dim:
                logger.error(f"LinearRouter.load_router_state: Dimension mismatch for module '{module_id}'. Saved tensor dim: {param_tensor_cpu.shape[0]}, Router's loaded embed_dim: {self.embed_dim}. SKIPPING this parameter.")
                continue
            self.module_embeddings[module_id] = nn.Parameter(param_tensor_cpu.to(self.device).to(self.router_internal_dtype))

        self.lora_id_to_idx.clear(); self.idx_to_lora_id.clear(); self.module_metadata.clear()
        loaded_lora_id_to_idx_from_file = state.get('lora_id_to_idx', {})
        loaded_module_metadata_from_file = state.get('module_metadata', {})

        for module_id in self.module_embeddings.keys():
            if module_id in loaded_lora_id_to_idx_from_file:
                self.lora_id_to_idx[module_id] = loaded_lora_id_to_idx_from_file[module_id]
                self.idx_to_lora_id[loaded_lora_id_to_idx_from_file[module_id]] = module_id
            else: self.lora_id_to_idx[module_id] = self.next_lora_idx; self.idx_to_lora_id[self.next_lora_idx] = module_id; self.next_lora_idx +=1; logger.warning(f"LoRA ID {module_id} from loaded embeddings was not in loaded lora_id_to_idx. Added it.")
            if module_id in loaded_module_metadata_from_file: self.module_metadata[module_id] = loaded_module_metadata_from_file[module_id]
            else: self.module_metadata[module_id] = {"adapter_name": module_id, "source": "reconstructed_on_load_missing_meta"}
            # self._on_module_added(module_id) # Not strictly needed here as params are directly set, but harmless.

        self.next_lora_idx = state.get('next_lora_idx', self.next_lora_idx)
        if len(list(self.module_embeddings.parameters())) > 0:
            self._setup_optimizer(self.config.router_learning_rate)
            if self.optimizer and state.get('optimizer_state_dict'):
                try:
                    self.optimizer.load_state_dict(state['optimizer_state_dict'])
                    for opt_state_group in self.optimizer.state.values():
                        for k_opt, v_tensor_opt in opt_state_group.items():
                            if isinstance(v_tensor_opt, torch.Tensor): opt_state_group[k_opt] = v_tensor_opt.to(self.device)
                    logger.info("LinearRouter optimizer state loaded and moved to device.")
                except Exception as e: logger.error(f"LinearRouter: Failed to load optimizer state: {e}. Optimizer might be reset."); self._setup_optimizer(self.config.router_learning_rate)
        else: self.optimizer = None
        num_actually_loaded_embeddings = len(self.module_embeddings)
        logger.info(f"LinearRouter state loaded from {path}. Router now has {num_actually_loaded_embeddings} module profile(s) with learnable embeddings. Router embed_dim is {self.embed_dim}.")


class MLPRouter(Router): # (Updated input_dim calculation, rest largely similar to previous state)
    def __init__(self, config: Any, model_for_embeddings: PreTrainedModel):
        super().__init__(config, model_for_embeddings)

        current_input_dim_for_mlp = 0
        if not hasattr(self.config, 'use_advanced_embeddings') or not self.config.use_advanced_embeddings:
            base_embed_dim = self.model_for_embeddings.config.hidden_size
            if hasattr(self.config, 'k_examples_for_prototype') and self.config.k_examples_for_prototype > 0:
                current_input_dim_for_mlp = base_embed_dim * 2
            else: current_input_dim_for_mlp = base_embed_dim
            logger.info(f"MLPRouter initialized (ADV. EMBEDDINGS OFF). MLP input_dim: {current_input_dim_for_mlp}.")
        else: # Calculate new input_dim for advanced embeddings
            dim_desc_embed = self.model_for_embeddings.config.hidden_size
            current_input_dim_for_mlp += dim_desc_embed
            dim_grad_sketch = 0
            if hasattr(self.config, 'k_for_grad_sketch') and self.config.k_for_grad_sketch > 0 and hasattr(self.config, 'grad_sketch_max_elements'):
                dim_grad_sketch = self.config.grad_sketch_max_elements
                current_input_dim_for_mlp += dim_grad_sketch
            dim_context_stats = 0
            if hasattr(self.config, 'use_context_stats_avg_seq_length') and self.config.use_context_stats_avg_seq_length: dim_context_stats +=1
            if hasattr(self.config, 'use_context_stats_avg_token_entropy') and self.config.use_context_stats_avg_token_entropy: dim_context_stats +=1
            current_input_dim_for_mlp += dim_context_stats
            logger.info(f"MLPRouter initialized for ADVANCED EMBEDDINGS. Desc Dim: {dim_desc_embed}, GradSketch Dim: {dim_grad_sketch}, ContextStats Dim: {dim_context_stats}. Total MLP input_dim: {current_input_dim_for_mlp}.")
        
        if current_input_dim_for_mlp == 0:
            logger.error("MLPRouter: Calculated input_dim_for_mlp is 0. Defaulting. Check config.")
            current_input_dim_for_mlp = self.model_for_embeddings.config.hidden_size

        self.input_dim_for_mlp = current_input_dim_for_mlp # Store for potential re-init on load
        self.hidden_dim = config.router_hidden_size
        self.mlp = nn.Sequential(nn.Linear(self.input_dim_for_mlp, self.hidden_dim), nn.GELU(), nn.LayerNorm(self.hidden_dim, eps=1e-5), nn.Linear(self.hidden_dim, self.hidden_dim), nn.GELU(), nn.LayerNorm(self.hidden_dim, eps=1e-5)).to(self.device).to(self.router_internal_dtype)
        self.output_projection_layers = nn.ModuleDict(); self._setup_optimizer(config.router_learning_rate)

    def _setup_optimizer(self, learning_rate: float):
        params = list(self.mlp.parameters());
        if self.output_projection_layers and len(self.output_projection_layers) > 0 : params.extend(list(self.output_projection_layers.parameters()))
        if params: self.optimizer = optim.AdamW(params, lr=learning_rate, weight_decay=self.config.weight_decay)
        else: self.optimizer = None

    def _on_module_added(self, module_id_str: str):
        if module_id_str not in self.output_projection_layers:
            self.output_projection_layers[module_id_str] = nn.Linear(self.hidden_dim, 1).to(self.device).to(self.router_internal_dtype)
            logger.info(f"MLPRouter: Created output projection layer for '{module_id_str}'.")
            if self.optimizer and hasattr(self.optimizer, 'add_param_group'): self.optimizer.add_param_group({'params': self.output_projection_layers[module_id_str].parameters()})
            elif not self.optimizer : self._setup_optimizer(self.config.router_learning_rate)

    def _get_logits_for_all_modules_ordered(self, advanced_task_representation: torch.Tensor) -> Tuple[Optional[torch.Tensor], List[str]]:
        if not self.output_projection_layers or len(self.output_projection_layers) == 0: return None, []
        if advanced_task_representation.shape[0] != self.input_dim_for_mlp:
            logger.error(f"MLPRouter._get_logits: Advanced task repr dim {advanced_task_representation.shape[0]} != MLP input dim {self.input_dim_for_mlp}. Skipping.")
            return None, []
        hidden = self.mlp(advanced_task_representation.to(self.router_internal_dtype))
        valid_ids = [id_ for id_ in self.lora_id_to_idx if id_ in self.output_projection_layers]
        if not valid_ids: return None, []
        sorted_ids = sorted(valid_ids, key=lambda k: self.lora_id_to_idx[k])
        logits_list = [self.output_projection_layers[id_](hidden).squeeze() for id_ in sorted_ids]
        if not logits_list: return None, []
        return torch.stack(logits_list), sorted_ids

    def select_modules(self, task_description: str, task_examples_for_feature_gen: Optional[List[Dict[str, Any]]] = None, task_text_field: Optional[Union[str, List[str]]] = None, top_k: int = 1) -> List[Tuple[str, float]]:
        self.eval();
        if not self.output_projection_layers or len(self.output_projection_layers) == 0: return []
        advanced_task_repr = self.get_advanced_task_representation(task_description, task_examples_for_feature_gen, task_text_field)
        if advanced_task_repr is None: logger.warning("MLPRouter.select_modules: Failed to get advanced task representation."); return []
        logits, ordered_ids = self._get_logits_for_all_modules_ordered(advanced_task_repr)
        if logits is None or len(ordered_ids) == 0 : return []
        probs = F.softmax(logits, dim=0); top_probs, top_indices = torch.topk(probs, min(top_k, len(probs)))
        return [(ordered_ids[idx.item()], prob.item()) for idx, prob in zip(top_indices, top_probs)]

    def train_step(self, training_data: List[Tuple[Tuple[str, Optional[List[Dict[str, Any]]], Optional[Union[str, List[str]]]], str]], epochs: int, learning_rate: float):
        if self.optimizer is None and (len(list(self.mlp.parameters())) > 0 or len(self.output_projection_layers)>0) : self._setup_optimizer(learning_rate)
        elif self.optimizer is None: logger.warning("MLPRouter: No params to optimize."); return
        self.train(); criterion = nn.CrossEntropyLoss()
        for epoch in range(epochs):
            epoch_loss = 0.0; num_batches = 0
            for (task_desc, examples_for_features, task_txt_fld), target_id_str in tqdm(training_data, desc=f"Router MLP Ep{epoch+1}", leave=False):
                if not self.output_projection_layers or len(self.output_projection_layers) == 0 or target_id_str not in self.lora_id_to_idx or target_id_str not in self.output_projection_layers: logger.warning(f"Skipping MLP training step for {target_id_str} - not in output_projection_layers or lora_id_to_idx."); continue
                self.optimizer.zero_grad();
                advanced_task_repr = self.get_advanced_task_representation(task_desc, examples_for_features, task_txt_fld)
                if advanced_task_repr is None: logger.warning(f"MLP Router training: Failed to get advanced task representation for desc '{task_desc[:30]}...'. Skipping batch."); continue
                logits, ordered_ids_stack = self._get_logits_for_all_modules_ordered(advanced_task_repr)
                if logits is None or len(ordered_ids_stack) == 0: logger.warning("MLP Router training: No logits from _get_logits_for_all_modules_ordered. Skipping batch."); continue
                try: target_idx = ordered_ids_stack.index(target_id_str)
                except ValueError: logger.error(f"Target '{target_id_str}' not in MLP Router's ordered_ids_stack ({ordered_ids_stack})."); continue
                loss = criterion(logits.unsqueeze(0), torch.tensor([target_idx], device=self.device))
                if torch.isnan(loss).any() or torch.isinf(loss).any(): logger.error(f"MLPRouter TRAIN Ep{epoch+1}: Loss NaN/Inf!"); continue
                loss.backward(); self.optimizer.step(); epoch_loss += loss.item(); num_batches +=1
            if num_batches > 0: logger.info(f"MLPRouter Ep {epoch+1} Avg Loss: {epoch_loss/num_batches:.4f}")

    def save_router_state(self, path: str):
        # MLPRouter saves MLP state and output projection layers, plus metadata.
        # Also save input_dim_for_mlp to check on load.
        state = {
            'mlp_state_dict': self.mlp.state_dict(),
            'output_projection_layers_state_dict': {k: v.state_dict() for k, v in self.output_projection_layers.items()},
            'optimizer_state_dict': self.optimizer.state_dict() if self.optimizer else None,
            'lora_id_to_idx': self.lora_id_to_idx,
            'idx_to_lora_id': self.idx_to_lora_id,
            'next_lora_idx': self.next_lora_idx,
            'module_metadata': self.module_metadata,
            'input_dim_for_mlp': self.input_dim_for_mlp # Save this
        }
        torch.save(state, path)
        logger.info(f"MLPRouter state saved to {path}. Saved input_dim_for_mlp: {self.input_dim_for_mlp}. Num output_projection_layers saved: {len(self.output_projection_layers)}")

    def load_router_state(self, path: str):
        if not os.path.exists(path): logger.warning(f"MLPRouter: No state file found at {path}. Router not loaded."); return
        state = torch.load(path, map_location='cpu') # Load to CPU first
        logger.info(f"MLPRouter.load_router_state: Loading from {path}. Saved state keys: {list(state.keys())}")

        saved_input_dim = state.get('input_dim_for_mlp')
        if saved_input_dim and saved_input_dim != self.input_dim_for_mlp:
            logger.warning(f"MLPRouter: Loaded state input_dim_for_mlp ({saved_input_dim}) MISMATCHES current config expected ({self.input_dim_for_mlp}). Re-initializing MLP based on loaded dim.")
            self.input_dim_for_mlp = saved_input_dim
            self.mlp = nn.Sequential(nn.Linear(self.input_dim_for_mlp, self.hidden_dim), nn.GELU(), nn.LayerNorm(self.hidden_dim, eps=1e-5), nn.Linear(self.hidden_dim, self.hidden_dim), nn.GELU(), nn.LayerNorm(self.hidden_dim, eps=1e-5))
        
        self.mlp.load_state_dict(state['mlp_state_dict'])
        self.mlp.to(self.device).to(self.router_internal_dtype) # Move to device and set dtype

        self.output_projection_layers.clear()
        for module_id, proj_state_dict in state.get('output_projection_layers_state_dict', {}).items():
            # _on_module_added will create these, but we need to load state into them.
            # So, create first, then load state.
            self.output_projection_layers[module_id] = nn.Linear(self.hidden_dim, 1).to(self.device).to(self.router_internal_dtype)
            self.output_projection_layers[module_id].load_state_dict(proj_state_dict)

        self.lora_id_to_idx = state.get('lora_id_to_idx', {}); self.idx_to_lora_id = state.get('idx_to_lora_id', {});
        self.next_lora_idx = state.get('next_lora_idx', 0); self.module_metadata = state.get('module_metadata', {})

        for lora_id_loaded in self.lora_id_to_idx.keys():
            if lora_id_loaded not in self.module_metadata: self.module_metadata[lora_id_loaded] = {"adapter_name": lora_id_loaded, "source": "reconstructed_from_loaded_state_mlp"}
            # _on_module_added will be called implicitly if these module_ids trigger new proj layers,
            # but here we loaded them directly. Ensure optimizer knows about them.
        
        if len(list(self.mlp.parameters())) > 0 or len(self.output_projection_layers) > 0 :
            self._setup_optimizer(self.config.router_learning_rate)
            if self.optimizer and state.get('optimizer_state_dict'):
                try:
                    self.optimizer.load_state_dict(state['optimizer_state_dict'])
                    for opt_state_group in self.optimizer.state.values():
                        for k_opt, v_tensor_opt in opt_state_group.items():
                            if isinstance(v_tensor_opt, torch.Tensor): opt_state_group[k_opt] = v_tensor_opt.to(self.device)
                    logger.info("MLPRouter optimizer state loaded.")
                except Exception as e: logger.error(f"MLPRouter: Failed to load optimizer state: {e}. Optimizer might be reset."); self._setup_optimizer(self.config.router_learning_rate)
        else: self.optimizer = None
        logger.info(f"MLPRouter state loaded from {path}. Router has {len(self.output_projection_layers)} output projection layer(s). MLP input_dim: {self.input_dim_for_mlp}.")

## Neuromodulation

The Neuromodulation module implements the γ-Gain mechanism that dynamically modulates LoRA plasticity.

In [11]:
class NeuromodulationManager:
    """
    Neuromodulation Manager for The Adaptive Learner
    Implements the γ-Gain mechanism that dynamically modulates LoRA plasticity
    """
    def __init__(self, config: AdaptiveLearnerConfig):
        self.config = config
        self.device = torch.device(config.device)
        self.metrics_history = {}
        self.gamma_gain_history = {}
        self.metrics_dir = os.path.join(config.output_dir, "gamma_metrics")
        os.makedirs(self.metrics_dir, exist_ok=True)

    def compute_gamma_gain(self, task_id: str, metrics: Dict[str, float]) -> float:
        if task_id not in self.metrics_history:
            self.metrics_history[task_id] = []
            self.gamma_gain_history[task_id] = []

        self.metrics_history[task_id].append(metrics)
        normalized_metrics = {}
        for metric_name in self.config.gamma_gain_metrics:
            if metric_name not in metrics:
                logger.warning(f"Metric {metric_name} not found in metrics dict for gamma_gain")
                continue
            value = metrics[metric_name]
            if len(self.metrics_history[task_id]) > 1:
                prev_values = [h.get(metric_name, 0) for h in self.metrics_history[task_id][:-1]]
                min_val, max_val = min(prev_values + [value]), max(prev_values + [value])
                normalized = (value - min_val) / (max_val - min_val) if max_val > min_val else 0.5
            else: # Default normalization for first step
                if metric_name == "accuracy": normalized = value
                elif metric_name == "nll": normalized = max(0, min(1, 1 - value / 10))
                elif metric_name == "gradient_norm": normalized = max(0, min(1, value / 10))
                elif metric_name == "entropy": normalized = max(0, min(1, 1 - value / 5))
                else: normalized = value # Should not happen if metrics are in gamma_gain_metrics
            normalized_metrics[metric_name] = normalized

        gamma_gain = 0.0
        total_weight = 0.0
        for metric_name, weight in self.config.gamma_gain_weights.items():
            if metric_name in normalized_metrics:
                gamma_gain += weight * normalized_metrics[metric_name]
                total_weight += weight

        if total_weight > 0: gamma_gain /= total_weight
        else: gamma_gain = 0.5 # Default if no relevant metrics

        gamma_gain = gamma_gain * self.config.gamma_gain_lambda
        self.gamma_gain_history[task_id].append(gamma_gain)

        # Reduced verbosity here:
        # logger.info(f"Task {task_id} metrics: {metrics}")
        # logger.info(f"Task {task_id} normalized metrics: {normalized_metrics}") # This line is commented out
        logger.info(f"Task {task_id} NM - Metrics: { {k: round(v, 3) if isinstance(v, float) else v for k,v in metrics.items()} }, γ-Gain: {gamma_gain:.4f}")

        return gamma_gain

    def save_metrics(self, task_id: str):
        if task_id not in self.metrics_history:
            logger.warning(f"No metrics history found for task {task_id}")
            return
        task_dir = os.path.join(self.metrics_dir, task_id); os.makedirs(task_dir, exist_ok=True)
        with open(os.path.join(task_dir, "metrics_history.json"), "w") as f: json.dump(self.metrics_history[task_id], f, indent=2)
        with open(os.path.join(task_dir, "gamma_gain_history.json"), "w") as f: json.dump(self.gamma_gain_history[task_id], f, indent=2)
        logger.info(f"Saved metrics and γ-Gain history for task {task_id}")

    def plot_gamma_gain(self, task_id: str, save_path: Optional[str] = None) -> plt.Figure:
        if task_id not in self.gamma_gain_history or not self.gamma_gain_history[task_id]:
            logger.warning(f"No γ-Gain history found for task {task_id} to plot.")
            fig, ax = plt.subplots(figsize=(8, 6)); ax.text(0.5, 0.5, f"No γ-Gain history for task {task_id}", ha="center", va="center", fontsize=12); ax.set_axis_off()
            return fig

        fig, ax = plt.subplots(figsize=(10, 6))
        gamma_values = self.gamma_gain_history[task_id]
        steps = list(range(1, len(gamma_values) + 1))
        ax.plot(steps, gamma_values, 'o-', linewidth=2, markersize=8, label=f"Task {task_id} γ-Gain")
        ax.axhline(y=self.config.gamma_gain_lambda, color='r', linestyle='--', alpha=0.7, label=f"λ = {self.config.gamma_gain_lambda}")
        ax.set_xlabel("Training Step within Task", fontsize=12); ax.set_ylabel("γ-Gain Value", fontsize=12)
        ax.set_title(f"γ-Gain History for Task {task_id}", fontsize=14); ax.grid(alpha=0.3); ax.legend(fontsize=10)
        y_min = 0; y_max = max(1.0, self.config.gamma_gain_lambda * 1.2, max(gamma_values) * 1.2 if gamma_values else 0)
        ax.set_ylim(y_min, y_max)
        if save_path: plt.savefig(save_path, dpi=300, bbox_inches="tight")
        return fig

## Consolidation

In [12]:
# Cell 11: ConsolidationManager (Corrected _consolidate_aflora)

import os
import time
import torch
import torch.nn as nn # Ensure nn is imported
import torch.nn.functional as F # Ensure F is imported
from peft import PeftModel # Ensure PeftModel is imported
from typing import List, Dict, Any, Tuple, Optional, Union # Ensure these are imported

# Ensure 'logger' is defined globally (e.g., Cell 3 or 4)
if 'logger' not in globals():
    import logging
    logger = logging.getLogger(__name__ + ".consolidation_mgr_cell11")
    if not logger.hasHandlers():
        handler = logging.StreamHandler()
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        handler.setFormatter(formatter)
        logger.addHandler(handler)
        logger.setLevel(logging.INFO)

# Ensure AdaptiveLearnerConfig is defined or imported if type hinting is used strictly
# from your_config_module import AdaptiveLearnerConfig # Or ensure it's defined before this cell

class ConsolidationManager:
    """
    Manages the consolidation of LoRA parameters to prevent catastrophic forgetting
    """
    def __init__(self, config: 'AdaptiveLearnerConfig'): # Use string literal for forward reference if defined later
        self.config = config
        self.device = torch.device(config.device)
        self.importance_scores: Dict[str, Dict[str, torch.Tensor]] = {} # {lora_id: {param_name: importance_tensor}}
        self.consolidation_history: Dict[str, List[Dict[str, Any]]] = {}
        self.consolidation_dir = os.path.join(config.output_dir, "consolidation")
        os.makedirs(self.consolidation_dir, exist_ok=True)

    def consolidate(self, lora_id: str, model: PeftModel, gamma_gain: float, data_loader: Optional[Any] = None) -> None:
        if lora_id not in self.consolidation_history: self.consolidation_history[lora_id] = []
        self.consolidation_history[lora_id].append({"timestamp": time.time(), "gamma_gain": gamma_gain, "method": self.config.consolidation_method})

        if self.config.consolidation_method == "aflora":
            self._consolidate_aflora(lora_id, model, gamma_gain)
        elif self.config.consolidation_method == "ewc":
            if data_loader is None and self.config.ewc_data_loader_num_samples > 0 : 
                logger.warning(f"EWC consolidation for {lora_id} requires a data_loader (ewc_data_loader_num_samples > 0). Skipping.")
                return
            self._consolidate_ewc(lora_id, model, gamma_gain, data_loader)
        elif self.config.consolidation_method == "none": logger.info(f"Consolidation 'none' for {lora_id}, skipping.") # Added lora_id for clarity
        else: logger.warning(f"Unknown consolidation method: {self.config.consolidation_method} for {lora_id}")

    def _consolidate_aflora(self, lora_id: str, model: PeftModel, gamma_gain: float) -> None:
        logger.info(f"Applying AFLoRA for {lora_id} with γ-Gain {gamma_gain:.4f}")
        if lora_id not in self.importance_scores: self.importance_scores[lora_id] = {}

        lora_params_with_grads = {name: param for name, param in model.named_parameters() if "lora_" in name and param.requires_grad and param.grad is not None}
        
        # THIS IS THE CORRECTED LINE:
        if not lora_params_with_grads: 
            logger.warning(f"AFLoRA: No LoRA params with grads for {lora_id}.") # Removed stray backslash before quote
            return
        # END OF CORRECTED LINE

        for name, param in lora_params_with_grads.items():
            importance = gamma_gain * torch.abs(param.detach()) * torch.abs(param.grad.detach())
            if name in self.importance_scores[lora_id]: self.importance_scores[lora_id][name] += importance
            else: self.importance_scores[lora_id][name] = importance

        for name, param in lora_params_with_grads.items(): # Iterate again to apply mask
            if name in self.importance_scores[lora_id]:
                current_param_importance = self.importance_scores[lora_id][name]
                mask = (current_param_importance > self.config.aflora_importance_threshold).float()
                param.grad = param.grad * (1 - mask) # Freeze important params by zeroing their grad contribution
        logger.info(f"AFLoRA applied for module {lora_id}")

    def _consolidate_ewc(self, lora_id: str, model: PeftModel, gamma_gain: float, data_loader: Optional[Any]) -> None:
        logger.info(f"Applying EWC for {lora_id} with actual gamma_gain_param_val {gamma_gain:.4f}")
        
        bypass_gamma_for_fisher_accumulation = getattr(self.config, 'ewc_fixed_lambda_bypass_gamma', False)
        if bypass_gamma_for_fisher_accumulation:
            logger.info(f"EWC (Fixed Lambda Mode): Bypassing gamma_gain for Fisher accumulation. Effective gamma_gain for accumulation set to 1.0.")
            effective_gamma_for_accumulation = 1.0
        else:
            effective_gamma_for_accumulation = gamma_gain

        lora_params_to_penalize = {name: param for name, param in model.named_parameters() if "lora_" in name and param.requires_grad}
        if not lora_params_to_penalize: 
            logger.warning(f"EWC: No trainable LoRA params found for {lora_id}.")
            return

        old_params_values = {name: param.clone().detach() for name, param in lora_params_to_penalize.items()}

        fisher_diag = {}
        if data_loader is not None and self.config.ewc_data_loader_num_samples > 0:
            fisher_diag = self._compute_fisher_matrix(model, lora_params_to_penalize, data_loader)
            if not fisher_diag: 
                logger.warning(f"EWC: Fisher matrix computation failed or empty for {lora_id} using data_loader. EWC penalty might be ineffective if no prior importance.")
        elif self.config.ewc_data_loader_num_samples == 0:
             logger.info(f"EWC: ewc_data_loader_num_samples is 0 for {lora_id}. Fisher matrix not computed in this step. Relying on existing importance scores if any.")
        else: 
            logger.warning(f"EWC: data_loader is None for {lora_id} but ewc_data_loader_num_samples ({self.config.ewc_data_loader_num_samples}) > 0. Fisher matrix not computed. EWC penalty might be ineffective.")


        if lora_id not in self.importance_scores: self.importance_scores[lora_id] = {}

        for name, f_val in fisher_diag.items(): 
            weighted_f_val = effective_gamma_for_accumulation * f_val 
            if name in self.importance_scores[lora_id]: 
                self.importance_scores[lora_id][name] += weighted_f_val
            else: 
                self.importance_scores[lora_id][name] = weighted_f_val
        
        num_params_penalized = 0
        for name, param in lora_params_to_penalize.items():
            if name in old_params_values and name in self.importance_scores[lora_id] and self.importance_scores[lora_id][name].sum() > 0: 
                if param.grad is None: 
                    param.grad = torch.zeros_like(param)

                delta = param.detach() - old_params_values[name] 
                current_importance = self.importance_scores[lora_id][name].to(delta.device)
                penalty_grad_component = self.config.ewc_lambda * current_importance * delta
                param.grad += penalty_grad_component
                num_params_penalized +=1
        
        if num_params_penalized > 0:
            logger.info(f"EWC penalty gradient added for {num_params_penalized} parameters in module {lora_id}. Bypass gamma for Fisher accum: {bypass_gamma_for_fisher_accumulation}")
        elif fisher_diag: 
             logger.info(f"EWC: Fisher was computed for {lora_id}, but no EWC penalty applied (e.g. importance scores might be zero or params not in list).")
        elif not fisher_diag and (not self.importance_scores[lora_id] or all(v.sum() == 0 for v in self.importance_scores[lora_id].values())): # Check if importance scores are empty or all zero
            logger.info(f"EWC: No Fisher computed and no effective prior importance scores for {lora_id}. No EWC penalty applied.")


    def _compute_fisher_matrix(self, model: PeftModel, lora_params: Dict[str, nn.Parameter], data_loader: Any) -> Dict[str, torch.Tensor]:
        fisher_matrices = {name: torch.zeros_like(param.data) for name, param in lora_params.items()}
        model.eval() 

        num_samples_processed = 0
        samples_to_process_count = 0
        max_samples_for_fisher = self.config.ewc_data_loader_num_samples
        
        for batch_idx, batch in enumerate(data_loader):
            if samples_to_process_count >= max_samples_for_fisher:
                break

            model.zero_grad() 
            input_ids = batch["input_ids"].to(self.device)
            attention_mask = batch["attention_mask"].to(self.device)
            labels = batch["labels"].to(self.device) 

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = labels[:, 1:].contiguous()
            flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
            flat_shift_labels = shift_labels.view(-1)
            active_loss_mask = flat_shift_labels != -100
            active_logits = flat_shift_logits[active_loss_mask]   
            active_labels = flat_shift_labels[active_loss_mask]   

            if active_logits.numel() == 0: 
                logger.warning(f"EWC Fisher: Batch {batch_idx} had no active target tokens. Skipping.")
                continue

            log_probs_of_targets = F.log_softmax(active_logits, dim=-1)
            selected_log_probs = log_probs_of_targets[torch.arange(active_labels.size(0)), active_labels]
            log_likelihood_sum = selected_log_probs.sum()
            
            if not torch.isnan(log_likelihood_sum) and not torch.isinf(log_likelihood_sum):
                log_likelihood_sum.backward()
                for name, param in lora_params.items():
                    if param.grad is not None:
                        fisher_matrices[name] += param.grad.data.pow(2) 
                num_samples_processed += input_ids.size(0) 
            else:
                logger.warning(f"EWC Fisher: log_likelihood_sum was NaN/Inf for batch {batch_idx}. Skipping backward for this batch.")
            samples_to_process_count += input_ids.size(0)

        if num_samples_processed == 0:
            logger.warning("EWC Fisher: No samples successfully processed for Fisher. Returning empty Fisher matrix.")
            return {}

        fisher_matrices = {name: f_val / num_samples_processed for name, f_val in fisher_matrices.items()}
        logger.info(f"Computed Fisher matrix over {num_samples_processed} samples (target was {max_samples_for_fisher}).")
        return fisher_matrices

    def get_importance_scores(self, lora_id: str) -> Dict[str, torch.Tensor]:
        return self.importance_scores.get(lora_id, {})

Cell 14: Generative Replay Manager, CMTReplay, PCGRReplay

In [13]:
# Cell 14: Generative Replay Manager, CMTReplay, PCGRReplay (Updated for replay_backbone_encoding_batch_size)
from transformers import get_linear_schedule_with_warmup # Ensure this import is present
import torch # Ensure torch is imported
import torch.nn as nn # Ensure nn is imported
import torch.nn.functional as F # Ensure F is imported
import numpy as np # Ensure numpy is imported
from tqdm.notebook import tqdm # Ensure tqdm is imported

class GenerativeReplayManager: # Unchanged from previous correct version
    def __init__(self, config: AdaptiveLearnerConfig, backbone_model: PreTrainedModel, tokenizer: PreTrainedTokenizer):
        self.config = config; self.backbone_model = backbone_model; self.tokenizer = tokenizer
        self.device = torch.device(config.device); self.replay_buffer: Dict[str, List[Dict[str, str]]] = {}
        if config.replay_method == "cmt": self.replay_model = CMTReplay(config, self.backbone_model, self.tokenizer)
        elif config.replay_method == "pcgr": self.replay_model = PCGRReplay(config, self.backbone_model, self.tokenizer)
        elif config.replay_method == "none": self.replay_model = None; logger.info("Generative replay disabled.")
        else: self.replay_model = None; logger.warning(f"Unknown replay method: {config.replay_method}, disabling replay.")
        self.replay_dir = os.path.join(config.output_dir, "replay"); os.makedirs(self.replay_dir, exist_ok=True)

    def add_task_samples(self, task_id: str, raw_samples: List[Dict[str, str]], train_replay_model_now: bool = True) -> None:
        if task_id not in self.replay_buffer: self.replay_buffer[task_id] = []
        current_samples_in_buffer = self.replay_buffer[task_id]
        valid_raw_samples = [rs for rs in raw_samples if isinstance(rs, dict) and 'input' in rs and 'output' in rs]
        if len(valid_raw_samples) != len(raw_samples): logger.warning("ReplayManager: Some samples excluded due to format.")
        all_raw_samples_for_task = current_samples_in_buffer + valid_raw_samples
        if len(all_raw_samples_for_task) > self.config.replay_buffer_size: self.replay_buffer[task_id] = self._reservoir_sample(all_raw_samples_for_task, self.config.replay_buffer_size)
        else: self.replay_buffer[task_id] = all_raw_samples_for_task
        if train_replay_model_now and self.replay_model is not None and self.replay_buffer[task_id] and (self.config.replay_alpha > 0 or self.config.feature_replay_alpha > 0):
            if hasattr(self.replay_model, 'backbone_model') and isinstance(self.backbone_model, PreTrainedModel): self.replay_model.backbone_model = self.backbone_model
            self.train_replay_model(task_id, self.replay_buffer[task_id]) 

    def get_raw_samples_for_replay(self, task_id: str, num_samples: int) -> List[Dict[str, str]]: 
        if task_id not in self.replay_buffer or not self.replay_buffer[task_id]: return []
        buffered_samples = self.replay_buffer[task_id]
        if num_samples >= len(buffered_samples): return list(buffered_samples)
        indices = np.random.choice(len(buffered_samples), num_samples, replace=False); return [buffered_samples[i] for i in indices]

    def generate_replay_units(self, task_id_to_generate_for: str, num_units: int) -> List[Dict[str, Any]]: 
        if self.replay_model is None: return []
        if hasattr(self.replay_model, 'task_prototypes'):
             if task_id_to_generate_for not in self.replay_model.task_prototypes: return []
        else: logger.warning("ReplayManager: Replay model missing 'task_prototypes'."); return []
        return self.replay_model.generate_samples(task_id_to_generate_for, num_units)

    def train_replay_model(self, task_id: str, raw_samples_for_task: List[Dict[str, str]]) -> None: 
        if self.replay_model is None or not raw_samples_for_task: return
        self.replay_model.train_model(task_id, raw_samples_for_task)

    def get_tasks_with_replay_data(self, exclude_task_id: Optional[str] = None) -> List[str]: 
        valid_replay_tasks = []
        if self.replay_model and hasattr(self.replay_model, 'task_prototypes') and self.replay_model.task_prototypes:
            for task_id_with_prototype in self.replay_model.task_prototypes.keys():
                if (exclude_task_id is None or task_id_with_prototype != exclude_task_id) and task_id_with_prototype in self.replay_buffer and self.replay_buffer[task_id_with_prototype]:
                    valid_replay_tasks.append(task_id_with_prototype)
        elif not self.replay_model or not hasattr(self.replay_model, 'task_prototypes'):
             for task_id_in_buffer in self.replay_buffer.keys():
                 if self.replay_buffer[task_id_in_buffer] and (exclude_task_id is None or task_id_in_buffer != exclude_task_id):
                    valid_replay_tasks.append(task_id_in_buffer)
        return list(set(valid_replay_tasks))

    def _reservoir_sample(self, samples: List[Any], n: int) -> List[Any]: 
        if not samples or n <= 0: return [];
        if len(samples) <= n: return samples
        reservoir = samples[:n];
        for i in range(n, len(samples)): j = np.random.randint(0, i + 1);
        if j < n: reservoir[j] = samples[i]
        return reservoir

class CMTReplay: # Updated backbone_encoding_batch_size
    _first_train_call_ever = True
    def __init__(self, config: AdaptiveLearnerConfig, backbone_model: PreTrainedModel, tokenizer: PreTrainedTokenizer): 
        self.config = config; self.backbone_model = backbone_model; self.tokenizer = tokenizer
        self.device = torch.device(config.device)
        try: self.model_dtype = next(backbone_model.parameters()).dtype; self.hidden_dim = backbone_model.config.hidden_size
        except (StopIteration, AttributeError) as e: logger.warning(f"CMTReplay: Model issue ({e}). Defaults."); self.model_dtype = torch.float16; self.hidden_dim = getattr(backbone_model.config, "hidden_size", 2048)
        self.memory_size = config.cmt_memory_size; self.cmt_internal_dtype = torch.float32
        layers_enc = []; current_dim_enc = self.hidden_dim
        for i in range(config.cmt_compressor_layers):
            next_dim = self.memory_size if i == config.cmt_compressor_layers - 1 else max(1, (current_dim_enc + self.memory_size) // 2)
            layers_enc.append(nn.Linear(current_dim_enc, next_dim))
            if i < config.cmt_compressor_layers - 1 : layers_enc.append(nn.LayerNorm(next_dim, eps=1e-5)); layers_enc.append(nn.GELU())
            current_dim_enc = next_dim
        self.encoder = nn.Sequential(*layers_enc).to(self.device).to(self.cmt_internal_dtype)
        layers_dec = []; current_dim_dec = self.memory_size
        for i in range(config.cmt_compressor_layers):
            next_dim = self.hidden_dim if i == config.cmt_compressor_layers - 1 else max(1, (current_dim_dec + self.hidden_dim) // 2)
            layers_dec.append(nn.Linear(current_dim_dec, next_dim))
            if i < config.cmt_compressor_layers - 1 : layers_dec.append(nn.LayerNorm(next_dim, eps=1e-5)); layers_dec.append(nn.GELU())
            current_dim_dec = next_dim
        self.decoder = nn.Sequential(*layers_dec).to(self.device).to(self.cmt_internal_dtype)
        self.optimizer = torch.optim.AdamW(list(self.encoder.parameters()) + list(self.decoder.parameters()), lr=config.replay_model_learning_rate, weight_decay=config.replay_model_weight_decay)
        self.task_prototypes: Dict[str, torch.Tensor] = {}

    def _tokenize_batch_for_encode(self, text_inputs: List[str]) -> Dict[str, torch.Tensor]: 
        return self.tokenizer(text_inputs, return_tensors="pt", padding="longest", truncation=True, max_length=self.config.max_seq_length, add_special_tokens=True)

    def encode_input_batch(self, tokenized_batch: Dict[str, torch.Tensor]) -> torch.Tensor: 
        target_device = next(self.backbone_model.parameters()).device
        inputs_on_device = {k: v.to(target_device) for k,v in tokenized_batch.items() if torch.is_tensor(v)}
        original_mode = self.backbone_model.training
        if hasattr(self.backbone_model, 'eval'): self.backbone_model.eval()
        batch_size = inputs_on_device['input_ids'].shape[0]
        pooled_output = torch.zeros(batch_size, self.hidden_dim, device=target_device, dtype=self.model_dtype)
        try:
            with torch.no_grad():
                outputs = self.backbone_model(**inputs_on_device, output_hidden_states=True)
                last_hidden_state = outputs.hidden_states[-1]
                attention_mask = inputs_on_device.get('attention_mask', torch.ones_like(last_hidden_state[..., 0]))
                mask_expanded = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
                sum_masked_hidden = (last_hidden_state * mask_expanded).sum(dim=1)
                sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
                pooled_output = sum_masked_hidden / sum_mask
        except Exception as e: logger.error(f"CMTReplay.encode_input_batch: Error: {e}")
        if hasattr(self.backbone_model, 'train') and original_mode: self.backbone_model.train()
        if torch.isnan(pooled_output).any() or torch.isinf(pooled_output).any():
            logger.error(f"CMTReplay.encode_input_batch produced NaN/Inf! Shape: {pooled_output.shape}"); return torch.zeros_like(pooled_output)
        return pooled_output.to(self.model_dtype)

    def train_model(self, task_id: str, raw_text_samples_for_task_buffer: List[Dict[str, str]]) -> None:
        if not raw_text_samples_for_task_buffer: return
        logger.info(f"CMTReplay: Training AE for task '{task_id}' on {len(raw_text_samples_for_task_buffer)} samples from buffer.")
        num_epochs = 3; internal_ae_batch_size = self.config.replay_model_internal_batch_size
        self.optimizer = torch.optim.AdamW(list(self.encoder.parameters()) + list(self.decoder.parameters()), lr=self.config.replay_model_learning_rate, weight_decay=self.config.replay_model_weight_decay)
        self.encoder.train().to(self.cmt_internal_dtype); self.decoder.train().to(self.cmt_internal_dtype)
        log_this_call_weights = False
        if CMTReplay._first_train_call_ever: log_this_call_weights = True; CMTReplay._first_train_call_ever = False

        all_input_texts = [sample['input'] for sample in raw_text_samples_for_task_buffer]
        all_original_hidden_states_list = []
        # MODIFIED: Use replay_backbone_encoding_batch_size
        backbone_encoding_batch_size = self.config.replay_backbone_encoding_batch_size 
        if backbone_encoding_batch_size == 0 : backbone_encoding_batch_size = 1 

        for i in range(0, len(all_input_texts), backbone_encoding_batch_size):
            batch_texts = all_input_texts[i:i+backbone_encoding_batch_size]
            if not batch_texts: continue
            tokenized_batch = self._tokenize_batch_for_encode(batch_texts)
            hidden_states_batch = self.encode_input_batch(tokenized_batch)
            if not (torch.isnan(hidden_states_batch).any() or torch.isinf(hidden_states_batch).any()):
                all_original_hidden_states_list.append(hidden_states_batch)

        if not all_original_hidden_states_list: logger.warning(f"CMTReplay: No valid hidden states for task '{task_id}'. Skipping AE."); return
        all_original_hidden_states = torch.cat(all_original_hidden_states_list, dim=0).to(self.cmt_internal_dtype)

        total_loss_all_epochs, total_uniformity_all_epochs, total_batches_all_epochs = 0.0, 0.0, 0
        for epoch in range(num_epochs):
            epoch_loss, epoch_uniformity, num_actual_batches_epoch = 0.0, 0.0, 0
            permuted_indices = torch.randperm(all_original_hidden_states.size(0))
            use_tqdm_for_batches = (all_original_hidden_states.size(0) / internal_ae_batch_size) > 5
            batch_iterable = range(0, all_original_hidden_states.size(0), internal_ae_batch_size)
            if use_tqdm_for_batches: batch_iterable = tqdm(batch_iterable, desc=f"CMT AE Ep{epoch+1} Task:{task_id}", leave=False)

            for i_batch_ae, start_idx_ae in enumerate(batch_iterable):
                end_idx_ae = min(start_idx_ae + internal_ae_batch_size, all_original_hidden_states.size(0))
                batch_indices = permuted_indices[start_idx_ae:end_idx_ae]
                if len(batch_indices) == 0: continue
                batched_original_hidden_for_ae = all_original_hidden_states[batch_indices]
                self.optimizer.zero_grad();
                prediction_for_loss = self.encoder(batched_original_hidden_for_ae) 
                if torch.isnan(prediction_for_loss).any() or torch.isinf(prediction_for_loss).any(): logger.error(f"CMT AE TRAIN: memory_codes NaN/Inf!"); continue
                reconstructed_batch = self.decoder(prediction_for_loss)
                if torch.isnan(reconstructed_batch).any() or torch.isinf(reconstructed_batch).any(): logger.error(f"CMT AE TRAIN: reconstructed_batch NaN/Inf!"); continue
                target_for_loss = batched_original_hidden_for_ae.detach() 
                is_problematic = False
                if torch.isnan(reconstructed_batch).any() or torch.isinf(reconstructed_batch).any(): is_problematic = True
                if torch.isnan(target_for_loss).any() or torch.isinf(target_for_loss).any(): is_problematic = True
                recon_loss = F.mse_loss(reconstructed_batch, target_for_loss) 
                if recon_loss.item() < -1e-7: logger.error(f"CMT NEGATIVE RECON LOSS DETECTED: {recon_loss.item():.6f}")
                if torch.isnan(recon_loss).any() or torch.isinf(recon_loss).any() or is_problematic:
                     logger.error(f"CMT AE TRAIN: recon_loss NaN/Inf or inputs problematic! Loss: {recon_loss.item() if torch.isfinite(recon_loss) else 'NaN/Inf'}. Skipping batch update."); continue
                current_batch_loss_val = recon_loss; current_batch_uniformity = 0.0
                if len(prediction_for_loss) > 1 and self.config.cmt_uniformity_weight > 0: 
                    normalized_codes = F.normalize(prediction_for_loss, p=2, dim=1); pdist_matrix = torch.cdist(normalized_codes, normalized_codes, p=2.0)
                    eye_mask = ~torch.eye(pdist_matrix.shape[0], device=pdist_matrix.device, dtype=torch.bool)
                    if pdist_matrix[eye_mask].numel() > 0:
                        uniformity_penalty = -torch.log(pdist_matrix[eye_mask].clamp(min=1e-8)).mean()
                        if torch.isfinite(uniformity_penalty): current_batch_loss_val += self.config.cmt_uniformity_weight * uniformity_penalty; current_batch_uniformity = uniformity_penalty.item()
                if torch.isnan(current_batch_loss_val).any() or torch.isinf(current_batch_loss_val).any(): logger.error(f"CMT AE TRAIN: final batch_loss NaN/Inf!"); continue
                current_batch_loss_val.backward()
                torch.nn.utils.clip_grad_norm_(list(self.encoder.parameters()) + list(self.decoder.parameters()), self.config.replay_model_grad_clip_norm)
                self.optimizer.step()
                if log_this_call_weights and epoch == 0 and i_batch_ae == 0: logger.info(f"CMT Initial WEIGHT DEBUG (Task {task_id}): Weights AFTER first AE optim step...") 
                epoch_loss += current_batch_loss_val.item(); epoch_uniformity += current_batch_uniformity; num_actual_batches_epoch += 1
            total_loss_all_epochs += epoch_loss; total_uniformity_all_epochs += epoch_uniformity; total_batches_all_epochs += num_actual_batches_epoch

        avg_loss_overall = total_loss_all_epochs / total_batches_all_epochs if total_batches_all_epochs > 0 else float('nan')
        avg_unif_overall = total_uniformity_all_epochs / total_batches_all_epochs if total_batches_all_epochs > 0 else float('nan')
        logger.info(f"CMTReplay AE training complete for task '{task_id}' ({len(raw_text_samples_for_task_buffer)} buffer samples). AvgLoss: {avg_loss_overall:.4f}, AvgUniformity: {avg_unif_overall:.4f}, LR: {self.optimizer.param_groups[0]['lr']:.2e}")
        self.encoder.eval()
        if all_original_hidden_states.numel() > 0 :
            with torch.no_grad():
                memory_codes_for_proto = self.encoder(all_original_hidden_states)
                if not (torch.isnan(memory_codes_for_proto).any() or torch.isinf(memory_codes_for_proto).any()):
                    prototype = torch.mean(memory_codes_for_proto, dim=0); self.task_prototypes[task_id] = prototype.detach()

    def generate_samples(self, task_id: str, num_samples: int) -> List[Dict[str, Any]]: 
        if task_id not in self.task_prototypes: return []
        self.decoder.eval().to(self.cmt_internal_dtype); generated_units_list = []
        with torch.no_grad():
            prototype = self.task_prototypes[task_id].to(self.device)
            if torch.isnan(prototype).any() or torch.isinf(prototype).any(): logger.error(f"CMTReplay gen: Proto for {task_id} NaN/Inf!"); return []
            for _ in range(num_samples):
                noise = torch.randn_like(prototype) * 0.1; memory_code_noisy = prototype + noise
                hidden_reconstructed = self.decoder(memory_code_noisy)
                if torch.isnan(hidden_reconstructed).any() or torch.isinf(hidden_reconstructed).any(): logger.error(f"CMTReplay gen: Decoded hidden NaN/Inf for {task_id}!"); continue
                generated_units_list.append({"hidden_state": hidden_reconstructed.cpu().to(self.model_dtype), "task_id": task_id, "is_generated": True })
        return generated_units_list

class PCGRReplay: # Updated backbone_encoding_batch_size
    _first_train_call_ever = True
    def __init__(self, config: AdaptiveLearnerConfig, backbone_model: PreTrainedModel, tokenizer: PreTrainedTokenizer): 
        self.config = config; self.backbone_model = backbone_model; self.tokenizer = tokenizer
        self.device = torch.device(config.device)
        try: self.model_dtype = next(backbone_model.parameters()).dtype; self.hidden_dim = backbone_model.config.hidden_size
        except (StopIteration, AttributeError) as e: logger.warning(f"PCGRReplay: Model issue ({e}). Defaults used."); self.model_dtype = torch.float16; self.hidden_dim = getattr(backbone_model.config, "hidden_size", 2048)
        self.latent_dim = config.pcgr_latent_dim; self.pcgr_internal_dtype = torch.float32
        num_blocks = config.cmt_compressor_layers 
        enc_layers = []; current_dim_enc = self.hidden_dim
        for i in range(num_blocks):
            next_dim_enc = (self.latent_dim * 2) if i == num_blocks - 1 else max(1, current_dim_enc // 2)
            enc_layers.append(nn.Linear(current_dim_enc, next_dim_enc))
            if i < num_blocks - 1: enc_layers.append(nn.LayerNorm(next_dim_enc, eps=1e-5)); enc_layers.append(nn.GELU())
            current_dim_enc = next_dim_enc
        self.encoder_net = nn.Sequential(*enc_layers).to(self.device).to(self.pcgr_internal_dtype)
        dec_layers = []; current_dim_dec = self.latent_dim + self.hidden_dim 
        for i in range(num_blocks):
            next_dim_dec = self.hidden_dim
            if i < num_blocks - 1 : next_dim_dec = max(self.hidden_dim, (current_dim_dec + self.hidden_dim) // 2 if i < num_blocks -2 else current_dim_dec // 2)
            if i == num_blocks - 1: next_dim_dec = self.hidden_dim
            dec_layers.append(nn.Linear(current_dim_dec, next_dim_dec))
            if i < num_blocks - 1: dec_layers.append(nn.LayerNorm(next_dim_dec, eps=1e-5)); dec_layers.append(nn.GELU())
            current_dim_dec = next_dim_dec
        self.decoder_net = nn.Sequential(*dec_layers).to(self.device).to(self.pcgr_internal_dtype)
        self.task_prototypes: Dict[str, torch.Tensor] = {}; self.batch_counter = 0
        self.optimizer = torch.optim.AdamW(list(self.encoder_net.parameters()) + list(self.decoder_net.parameters()), lr=config.replay_model_learning_rate, weight_decay=config.replay_model_weight_decay)

    def _tokenize_batch_for_encode(self, text_inputs: List[str]) -> Dict[str, torch.Tensor]: 
        return self.tokenizer(text_inputs, return_tensors="pt", padding="longest", truncation=True, max_length=self.config.max_seq_length, add_special_tokens=True)

    def encode_input_batch(self, tokenized_batch: Dict[str, torch.Tensor]) -> torch.Tensor: 
        target_device = next(self.backbone_model.parameters()).device
        inputs_on_device = {k: v.to(target_device) for k,v in tokenized_batch.items() if torch.is_tensor(v)}
        original_mode = self.backbone_model.training;
        if hasattr(self.backbone_model, 'eval'): self.backbone_model.eval()
        batch_size = inputs_on_device['input_ids'].shape[0]
        pooled_output = torch.zeros(batch_size, self.hidden_dim, device=target_device, dtype=self.model_dtype)
        try:
            with torch.no_grad():
                outputs = self.backbone_model(**inputs_on_device, output_hidden_states=True); last_hidden_state = outputs.hidden_states[-1]
                attention_mask = inputs_on_device.get('attention_mask', torch.ones_like(last_hidden_state[..., 0])); mask_expanded = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
                if last_hidden_state.ndim == 3 and mask_expanded.ndim == 3: pooled_output = (last_hidden_state * mask_expanded).sum(dim=1) / torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
                else: logger.error(f"PCGRReplay.encode_input_batch: Unexpected dims. LHS: {last_hidden_state.ndim}, Mask: {mask_expanded.ndim}. Mean fallback."); pooled_output = last_hidden_state.mean(dim=1) if last_hidden_state.ndim ==3 else last_hidden_state.mean(dim=0, keepdim=True if last_hidden_state.ndim == 2 else False)
        except Exception as e: logger.error(f"PCGRReplay.encode_input_batch: Error: {e}")
        if hasattr(self.backbone_model, 'train') and original_mode: self.backbone_model.train()
        if torch.isnan(pooled_output).any() or torch.isinf(pooled_output).any(): logger.error(f"PCGRReplay.encode_input_batch NaN/Inf!"); return torch.zeros_like(pooled_output)
        return pooled_output.to(self.model_dtype)

    def encode(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 
        h_encoded = self.encoder_net(hidden_states.to(self.pcgr_internal_dtype));
        if torch.isnan(h_encoded).any() or torch.isinf(h_encoded).any(): logger.error("PCGR encode: h_encoded NaN/Inf!"); zero_l = torch.zeros(hidden_states.size(0), self.latent_dim, device=self.device, dtype=self.pcgr_internal_dtype); return zero_l.clone(), zero_l.clone(), torch.ones_like(zero_l) * -10
        mean, logvar = torch.chunk(h_encoded, 2, dim=1); std = torch.exp(0.5 * logvar.clamp(min=-20, max=20));
        if torch.isnan(std).any() or torch.isinf(std).any(): logger.error("PCGR encode: std NaN/Inf!"); std = torch.ones_like(std)
        eps = torch.randn_like(std); z = mean + eps * std; return z, mean, logvar

    def decode(self, z: torch.Tensor, prototype: torch.Tensor) -> torch.Tensor: 
        prototype_expanded = prototype.to(self.pcgr_internal_dtype).unsqueeze(0).expand(z.size(0), -1)
        decoder_input_concat = torch.cat([z, prototype_expanded], dim=1)
        decoded_hidden = self.decoder_net(decoder_input_concat)
        if torch.isnan(decoded_hidden).any() or torch.isinf(decoded_hidden).any(): logger.error("PCGR decode: decoded NaN/Inf!"); return torch.zeros_like(decoded_hidden, dtype=self.model_dtype)
        return decoded_hidden.to(self.model_dtype)

    def update_prototype(self, task_id: str, hidden_states_for_update: torch.Tensor): 
        if hidden_states_for_update.numel() == 0: return
        current_task_mean_hidden = hidden_states_for_update.mean(dim=0)
        if torch.isnan(current_task_mean_hidden).any() or torch.isinf(current_task_mean_hidden).any(): logger.error(f"PCGR update_prototype: mean_hidden for {task_id} NaN/Inf!"); return
        if task_id not in self.task_prototypes: self.task_prototypes[task_id] = current_task_mean_hidden.detach()
        else: self.task_prototypes[task_id] = (0.9 * self.task_prototypes[task_id] + 0.1 * current_task_mean_hidden).detach()

    def train_model(self, task_id: str, raw_text_samples_for_task_buffer: List[Dict[str, str]]) -> None: 
        if not raw_text_samples_for_task_buffer: return
        logger.info(f"PCGRReplay: Training VAE for task '{task_id}' on {len(raw_text_samples_for_task_buffer)} samples from buffer.")
        num_epochs = 3; internal_vae_batch_size = self.config.replay_model_internal_batch_size
        self.optimizer = torch.optim.AdamW(list(self.encoder_net.parameters()) + list(self.decoder_net.parameters()), lr=self.config.replay_model_learning_rate, weight_decay=self.config.replay_model_weight_decay)
        self.encoder_net.train().to(self.pcgr_internal_dtype); self.decoder_net.train().to(self.pcgr_internal_dtype)
        log_this_call_weights_pcgr = False
        if PCGRReplay._first_train_call_ever: log_this_call_weights_pcgr = True; PCGRReplay._first_train_call_ever = False

        all_input_texts = [sample['input'] for sample in raw_text_samples_for_task_buffer]
        all_original_hidden_states_list = []
        # MODIFIED: Use replay_backbone_encoding_batch_size
        backbone_encoding_batch_size = self.config.replay_backbone_encoding_batch_size
        if backbone_encoding_batch_size == 0 : backbone_encoding_batch_size = 1

        for i in range(0, len(all_input_texts), backbone_encoding_batch_size):
            batch_texts = all_input_texts[i:i+backbone_encoding_batch_size]
            if not batch_texts: continue
            tokenized_batch = self._tokenize_batch_for_encode(batch_texts)
            hidden_states_batch = self.encode_input_batch(tokenized_batch)
            if not (torch.isnan(hidden_states_batch).any() or torch.isinf(hidden_states_batch).any()):
                all_original_hidden_states_list.append(hidden_states_batch)
        if not all_original_hidden_states_list: logger.warning(f"PCGRReplay: No valid hidden states for task '{task_id}'. Skipping VAE train."); return
        all_original_hidden_states = torch.cat(all_original_hidden_states_list, dim=0) 
        self.update_prototype(task_id, all_original_hidden_states.detach()) 
        prototype_current = self.task_prototypes.get(task_id)
        if prototype_current is None or torch.isnan(prototype_current).any() or torch.isinf(prototype_current).any(): logger.error(f"PCGR Proto for {task_id} missing or NaN/Inf after update. Cannot train VAE."); return
        all_original_hidden_states_float32 = all_original_hidden_states.to(self.pcgr_internal_dtype)

        overall_loss, overall_recon, overall_kl, overall_batches = 0.0,0.0,0.0,0
        for epoch in range(num_epochs):
            epoch_loss, epoch_recon, epoch_kl, num_batches_epoch = 0.0,0.0,0.0,0
            permuted_indices = torch.randperm(all_original_hidden_states_float32.size(0))
            use_tqdm_for_batches = (all_original_hidden_states_float32.size(0) / internal_vae_batch_size) > 5
            batch_iterable_pcgr = range(0, all_original_hidden_states_float32.size(0), internal_vae_batch_size)
            if use_tqdm_for_batches: batch_iterable_pcgr = tqdm(batch_iterable_pcgr, desc=f"PCGR VAE Ep{epoch+1} Task:{task_id}", leave=False)

            for i_batch_vae, start_idx_vae in enumerate(batch_iterable_pcgr):
                end_idx_vae = min(start_idx_vae+internal_vae_batch_size, all_original_hidden_states_float32.size(0))
                batch_indices = permuted_indices[start_idx_vae:end_idx_vae]
                if len(batch_indices) == 0: continue
                batched_hidden_originals_for_vae = all_original_hidden_states_float32[batch_indices] 
                self.optimizer.zero_grad();
                z_latent, mean_latent, logvar_latent = self.encode(batched_hidden_originals_for_vae.to(self.model_dtype)) # Pass model_dtype
                if torch.isnan(z_latent).any() or torch.isinf(z_latent).any(): logger.error(f"PCGR VAE TRAIN: z_latent NaN/Inf!"); continue
                reconstructed_hidden_batch_model_dtype = self.decode(z_latent, prototype_current)
                if torch.isnan(reconstructed_hidden_batch_model_dtype).any() or torch.isinf(reconstructed_hidden_batch_model_dtype).any(): logger.error(f"PCGR VAE TRAIN: reconstructed_hidden_batch NaN/Inf!"); continue
                recon_loss = F.mse_loss(reconstructed_hidden_batch_model_dtype.to(self.pcgr_internal_dtype), batched_hidden_originals_for_vae.detach())
                kl_div = -0.5 * torch.sum(1 + logvar_latent - mean_latent.pow(2) - logvar_latent.exp()); kl_loss = kl_div / batched_hidden_originals_for_vae.size(0)
                if torch.isnan(recon_loss).any() or torch.isinf(recon_loss).any() or torch.isnan(kl_loss).any() or torch.isinf(kl_loss).any(): logger.error(f"PCGR VAE TRAIN: losses NaN/Inf!"); continue
                current_total_loss_batch = recon_loss + self.config.pcgr_kl_weight * kl_loss
                if torch.isnan(current_total_loss_batch).any() or torch.isinf(current_total_loss_batch).any(): logger.error(f"PCGR VAE TRAIN: final batch_loss NaN/Inf!"); continue
                current_total_loss_batch.backward()
                torch.nn.utils.clip_grad_norm_(list(self.encoder_net.parameters()) + list(self.decoder_net.parameters()), self.config.replay_model_grad_clip_norm)
                self.optimizer.step()
                if log_this_call_weights_pcgr and epoch==0 and i_batch_vae==0: logger.info(f"PCGR Initial WEIGHT DEBUG (Task {task_id}): Weights AFTER first VAE optim step...")
                epoch_loss += current_total_loss_batch.item(); epoch_recon += recon_loss.item(); epoch_kl += kl_loss.item(); num_batches_epoch += 1
            overall_loss += epoch_loss; overall_recon += epoch_recon; overall_kl += epoch_kl; overall_batches += num_batches_epoch

        avg_loss_pcgr = overall_loss / overall_batches if overall_batches > 0 else float('nan')
        avg_recon_pcgr = overall_recon / overall_batches if overall_batches > 0 else float('nan')
        avg_kl_pcgr = overall_kl / overall_batches if overall_batches > 0 else float('nan')
        logger.info(f"PCGRReplay VAE training complete for task '{task_id}' ({len(raw_text_samples_for_task_buffer)} buffer samples). AvgLoss:{avg_loss_pcgr:.4f}, AvgRecon:{avg_recon_pcgr:.4f}, AvgKL:{avg_kl_pcgr:.4f}, LR: {self.optimizer.param_groups[0]['lr']:.2e}")

    def generate_samples(self, task_id: str, num_samples: int) -> List[Dict[str, Any]]: 
        if task_id not in self.task_prototypes: return []
        self.decoder_net.eval().to(self.pcgr_internal_dtype); generated_units_list = []
        with torch.no_grad():
            prototype = self.task_prototypes[task_id].to(self.device)
            if torch.isnan(prototype).any() or torch.isinf(prototype).any(): logger.error(f"PCGRReplay gen: Proto for {task_id} NaN/Inf!"); return []
            for _ in range(num_samples):
                z_prior = torch.randn(1, self.latent_dim, device=self.device, dtype=self.pcgr_internal_dtype)
                hidden_reconstructed = self.decode(z_prior, prototype)
                if torch.isnan(hidden_reconstructed).any() or torch.isinf(hidden_reconstructed).any(): logger.error(f"PCGRReplay gen: Decoded hidden NaN/Inf for {task_id}!"); continue
                generated_units_list.append({"hidden_state": hidden_reconstructed.cpu().squeeze(0), "task_id": task_id, "is_generated": True})
        return generated_units_list

## Main Function - Example Experiment

Here we set up a simple experiment to demonstrate the adaptive learning capabilities.

In [14]:
# Cell 16 (Main Function, load_benchmark_task, CausalLMTrainingDataset, EWCDataset - WITH SharedLoRA & EWC DataLoader fix & Adapter Name Sanitization & Early Stopping)

from rouge_score import rouge_scorer
from torch.utils.data import Dataset as TorchDataset, DataLoader as TorchDataLoader
from datasets import load_dataset, DownloadMode, concatenate_datasets
import gc
import time
import numpy as np
import torch # Ensure torch is imported for torch.tensor
import torch.nn.functional as F
from typing import List, Dict, Any, Tuple, Optional, Union
from peft import PeftModel
import matplotlib.pyplot as plt # Ensure pyplot is imported for neuromod_manager.plot_gamma_gain

# Ensure 'logger' is defined globally (e.g., Cell 3 or 4)
if 'logger' not in globals():
    import logging
    logger = logging.getLogger(__name__ + ".main_experiment_cell16")
    if not logger.hasHandlers():
        handler = logging.StreamHandler()
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        handler.setFormatter(formatter)
        logger.addHandler(handler)
        logger.setLevel(logging.INFO)

# --- Causal LM Training Dataset ---
class CausalLMTrainingDataset(TorchDataset):
    def __init__(self, examples: List[Dict[str, str]], tokenizer: 'PreTrainedTokenizer', max_seq_length: int):
        self.examples = examples; self.tokenizer = tokenizer; self.max_seq_length = max_seq_length
        if self.tokenizer.pad_token_id is None: self.tokenizer.pad_token_id = self.tokenizer.eos_token_id if self.tokenizer.eos_token_id is not None else 0
    def __len__(self): return len(self.examples)
    def __getitem__(self, idx):
        example = self.examples[idx]; input_text, target_text_response = example['input'], example['output']
        target_text_response_str = str(target_text_response) if target_text_response is not None else ""
        tokenized_prompt = self.tokenizer(input_text, add_special_tokens=True, truncation=False, padding=False)
        response_prefix = " " if not input_text.endswith(" ") and target_text_response_str else ""
        tokenized_response = self.tokenizer(response_prefix + target_text_response_str + self.tokenizer.eos_token, add_special_tokens=False, truncation=False, padding=False)
        prompt_ids, response_ids = tokenized_prompt.input_ids, tokenized_response.input_ids
        combined_ids_list = prompt_ids + response_ids; effective_prompt_len = len(prompt_ids)
        if len(combined_ids_list) > self.max_seq_length:
            if effective_prompt_len >= self.max_seq_length:
                 combined_ids_list = prompt_ids[:self.max_seq_length]; effective_prompt_len = self.max_seq_length
            else: combined_ids_list = combined_ids_list[:self.max_seq_length]
        input_ids_tensor = torch.tensor(combined_ids_list, dtype=torch.long)
        labels = torch.full_like(input_ids_tensor, -100)
        if effective_prompt_len < len(input_ids_tensor): labels[effective_prompt_len:] = input_ids_tensor[effective_prompt_len:]
        attention_mask = torch.ones_like(input_ids_tensor)
        if len(input_ids_tensor) < self.max_seq_length:
            padding_length = self.max_seq_length - len(input_ids_tensor)
            input_ids_tensor = F.pad(input_ids_tensor, (0, padding_length), value=self.tokenizer.pad_token_id)
            attention_mask = F.pad(attention_mask, (0, padding_length), value=0)
            labels = F.pad(labels, (0, padding_length), value=-100)
        return {"input_ids": input_ids_tensor, "attention_mask": attention_mask, "labels": labels,
                "raw_input_text": input_text, "raw_output_text": target_text_response_str}

# --- Benchmark Task Loading Utility ---
def load_benchmark_task(
    dataset_name: str, tokenizer: 'PreTrainedTokenizer', text_field: Union[str, List[str]], label_field: str, num_train_samples: int, num_val_samples: int,
    task_id_prefix: str = "bench_", class_names_map: Optional[Dict[int, str]] = None, random_seed: int = 42,
    config_name: Optional[str] = None, max_seq_length: int = 512
) -> Tuple[List[Dict[str,str]], List[Dict[str,str]], Dict[int,str], Union[str, List[str]]]:
    logger.info(f"Loading benchmark task: {dataset_name} (config: {config_name or 'default'}) {num_train_samples} train, {num_val_samples} val.")
    dataset = None; error_messages = []
    current_default_config = globals().get('default_config')
    output_dir_for_cache = current_default_config.output_dir if current_default_config and hasattr(current_default_config, 'output_dir') else "./outputs_cache_fallback"
    hf_cache_dir = os.path.join(output_dir_for_cache, ".cache", "huggingface", "datasets"); os.makedirs(hf_cache_dir, exist_ok=True)
    load_kwargs = {"name": config_name, "cache_dir": hf_cache_dir, "trust_remote_code": True} if config_name else {"cache_dir": hf_cache_dir, "trust_remote_code": True}
    try: dataset = load_dataset(dataset_name, **load_kwargs)
    except Exception as e1: error_messages.append(f"Plain load failed: {e1}"); logger.warning(f"Plain load {dataset_name} (config: {config_name}) failed: {e1}"); dataset = None
    if dataset is None:
        try: dataset = load_dataset(dataset_name, download_mode=DownloadMode.FORCE_REDOWNLOAD, **load_kwargs)
        except Exception as e2: error_messages.append(f"Force redownload failed: {e2}"); logger.error(f"All load attempts for {dataset_name} (config: {config_name}) failed: {error_messages}"); return [], [], {}, text_field
    if dataset is None: logger.error(f"Dataset {dataset_name} (config: {config_name}) could not be loaded."); return [], [], {}, text_field
    train_split_name_used, val_split_name_used = "train", "validation"; train_ds_full, val_ds_full = None, None
    if "train" not in dataset: logger.error(f"'train' split not found. Available: {list(dataset.keys())}"); return [], [], {}, text_field
    else: train_ds_full = dataset["train"]
    if "validation" in dataset: val_ds_full = dataset["validation"]; val_split_name_used = "validation"
    elif "test" in dataset: val_ds_full = dataset["test"]; val_split_name_used = "test";
    elif len(dataset["train"]) > (num_train_samples + num_val_samples):
        current_train_len = len(dataset["train"])
        val_samples_from_train = max(1, int(current_train_len * 0.1)) if current_train_len <= (num_train_samples + num_val_samples) * 2 else num_val_samples
        if current_train_len <= val_samples_from_train : logger.error(f"Cannot split train: Not enough samples ({current_train_len}) for val ({val_samples_from_train})."); return [], [], {}, text_field
        train_val_split = dataset["train"].train_test_split(test_size=val_samples_from_train, seed=random_seed, shuffle=True)
        train_ds_full, val_ds_full, val_split_name_used = train_val_split["train"], train_val_split["test"], "train_split_for_val"
    else: logger.error(f"Cannot obtain validation data for {dataset_name}. Train size: {len(train_ds_full) if train_ds_full else 'N/A'}"); return [], [], {}, text_field
    actual_num_train = min(num_train_samples, len(train_ds_full)) if train_ds_full else 0
    actual_num_val = min(num_val_samples, len(val_ds_full)) if val_ds_full else 0
    train_ds = train_ds_full.select(range(actual_num_train)) if actual_num_train > 0 and train_ds_full else (train_ds_full.select([]) if train_ds_full else None)
    val_ds = val_ds_full.select(range(actual_num_val)) if actual_num_val > 0 and val_ds_full else (val_ds_full.select([]) if val_ds_full else None)
    final_class_names_map = {}
    if class_names_map: final_class_names_map = class_names_map
    elif train_ds and label_field in train_ds.features and hasattr(train_ds.features[label_field], 'names') and train_ds.features[label_field].names: final_class_names_map = {i: name for i, name in enumerate(train_ds.features[label_field].names)}
    elif train_ds and label_field in train_ds.column_names and len(train_ds) > 0:
        unique_labels_original = sorted(list(set(train_ds[label_field])))
        is_std_int = all(isinstance(l, int) and l >= 0 for l in unique_labels_original) and len(unique_labels_original) > 0 and max(unique_labels_original) == len(unique_labels_original) -1 and min(unique_labels_original) == 0
        if is_std_int: final_class_names_map = {i: str(i) for i in unique_labels_original}
        else: final_class_names_map = {i: str(ul) for i, ul in enumerate(unique_labels_original)}
    if not final_class_names_map:
        ds_name_lower, cfg_name_lower = dataset_name.lower(), (config_name.lower() if config_name else "")
        if ds_name_lower == "glue":
            if cfg_name_lower == "sst2": final_class_names_map = {0: "negative", 1: "positive"}
            elif cfg_name_lower == "mrpc": final_class_names_map = {0: "not_equivalent", 1: "equivalent"}
            elif cfg_name_lower == "qnli": final_class_names_map = {0: "entailment", 1: "not_entailment"}
            elif cfg_name_lower == "rte": final_class_names_map = {0: "entailment", 1: "not_entailment"}
            elif cfg_name_lower == "cola": final_class_names_map = {0: "unacceptable", 1: "acceptable"}
    def _format_examples(ds_split, current_task_id_prefix_fmt, current_ds_name_fmt, current_cfg_name_fmt, current_class_map, original_text_field_name):
        if ds_split is None: return []
        formatted_examples = []
        raw_task_identifier_for_example = f"{current_task_id_prefix_fmt}{current_ds_name_fmt}_{current_cfg_name_fmt or 'default'}"
        for ex_idx, ex in enumerate(ds_split):
            input_val_combined = None
            original_example_data_for_prototype = ex
            if isinstance(original_text_field_name, list): input_parts = [str(ex.get(tf, "")) for tf in original_text_field_name]; input_val_combined = " [SEP] ".join(input_parts)
            else: input_val_combined = str(ex.get(original_text_field_name, ""))
            original_label_val = ex.get(label_field)
            if input_val_combined is None or original_label_val is None: continue
            output_val_str = str(original_label_val)
            if current_class_map and isinstance(original_label_val, int) and original_label_val in current_class_map: output_val_str = current_class_map[original_label_val]
            elif not current_class_map and isinstance(original_label_val, int): output_val_str = str(original_label_val)
            formatted_examples.append({"input": input_val_combined, "output": output_val_str, "task_id": raw_task_identifier_for_example, "original_example_data": original_example_data_for_prototype })
        return formatted_examples
    train_examples = _format_examples(train_ds, task_id_prefix, dataset_name, config_name, final_class_names_map, text_field)
    val_examples = _format_examples(val_ds, task_id_prefix, dataset_name, config_name, final_class_names_map, text_field)
    logger.info(f"Loaded {len(train_examples)} train, {len(val_examples)} val for {dataset_name} (config: {config_name}). Classes: {final_class_names_map if final_class_names_map else 'N/A (raw labels used)'}")
    if not train_examples and num_train_samples > 0 : logger.warning(f"No train examples formatted for {dataset_name}")
    if not val_examples and num_val_samples > 0 : logger.warning(f"No val examples formatted for {dataset_name}")
    return train_examples, val_examples, final_class_names_map, text_field

# --- EWC Dataset Definition ---
class EWCDataset(TorchDataset):
    def __init__(self, raw_samples: List[Dict[str, str]], tokenizer: 'PreTrainedTokenizer', max_seq_length: int):
        self.raw_samples = raw_samples; self.tokenizer = tokenizer; self.max_seq_length = max_seq_length
        if self.tokenizer.pad_token_id is None: self.tokenizer.pad_token_id = self.tokenizer.eos_token_id if self.tokenizer.eos_token_id is not None else 0
    def __len__(self): return len(self.raw_samples)
    def __getitem__(self, idx):
        raw_sample = self.raw_samples[idx]; input_text, target_text_response = raw_sample['input'], raw_sample['output']
        target_text_response_str = str(target_text_response) if target_text_response is not None else ""
        tokenized_prompt = self.tokenizer(input_text, add_special_tokens=True, truncation=False, padding=False)
        response_prefix = " " if not input_text.endswith(" ") and target_text_response_str else ""
        tokenized_response = self.tokenizer(response_prefix + target_text_response_str + self.tokenizer.eos_token, add_special_tokens=False, truncation=False, padding=False)
        prompt_ids, response_ids = tokenized_prompt.input_ids, tokenized_response.input_ids
        combined_ids_list = prompt_ids + response_ids; effective_prompt_len = len(prompt_ids)
        if len(combined_ids_list) > self.max_seq_length:
            if effective_prompt_len >= self.max_seq_length: combined_ids_list = prompt_ids[:self.max_seq_length]; effective_prompt_len = self.max_seq_length
            else: combined_ids_list = combined_ids_list[:self.max_seq_length]
        input_ids_tensor = torch.tensor(combined_ids_list, dtype=torch.long)
        labels = torch.full_like(input_ids_tensor, -100)
        if effective_prompt_len < len(input_ids_tensor): labels[effective_prompt_len:] = input_ids_tensor[effective_prompt_len:]
        attention_mask = torch.ones_like(input_ids_tensor)
        if len(input_ids_tensor) < self.max_seq_length:
            padding_length = self.max_seq_length - len(input_ids_tensor)
            input_ids_tensor = F.pad(input_ids_tensor, (0, padding_length), value=self.tokenizer.pad_token_id)
            attention_mask = F.pad(attention_mask, (0, padding_length), value=0)
            labels = F.pad(labels, (0, padding_length), value=-100)
        return {"input_ids": input_ids_tensor, "attention_mask": attention_mask, "labels": labels}

def main(run_config: Optional['AdaptiveLearnerConfig'] = None):
    config = run_config if run_config else globals().get('default_config')
    if config is None: logger.error("CRITICAL: No configuration provided to main() and global default_config not found."); return

    logger.info(f"--- Main function started. SFT Mode: {config.run_sft_baseline}. Exp Tag: {config.experiment_tag} ---\n"
                f"Router type: {config.router_type}, Confidence threshold: {config.router_confidence_threshold}\n"
                f"Load router from tag: {config.load_router_state_from_tag}\n"
                f"k_examples_for_prototype: {getattr(config, 'k_examples_for_prototype', 'Not Set')}, "
                f"Use Advanced Embeddings: {getattr(config, 'use_advanced_embeddings', 'Not Set')}\n"
                f"Use LoRA Early Stopping: {getattr(config, 'use_lora_early_stopping', 'Not Set')}")

    logger.info("Initializing components...")
    backbone = AdaptiveLearnerBackbone(config)
    peft_manager = PEFTManager(config, backbone.model)
    neuromod_manager = NeuromodulationManager(config)
    consolidation_manager = ConsolidationManager(config)
    replay_manager = GenerativeReplayManager(config, peft_manager.get_current_peft_model(), backbone.tokenizer)

    sft_originals = {}
    if config.run_sft_baseline:
        logger.info("--- CONFIGURING FOR SFT BASELINE RUN ---")
        sft_originals = {
            'replay_alpha': config.replay_alpha, 'feature_replay_alpha': config.feature_replay_alpha,
            'consolidation_method': config.consolidation_method, 'router_confidence_threshold': config.router_confidence_threshold,
            'replay_method': config.replay_method, 'experiment_tag': config.experiment_tag,
            'load_router_state_from_tag': config.load_router_state_from_tag,
            'k_examples_for_prototype': getattr(config, 'k_examples_for_prototype', 0),
            'use_advanced_embeddings': getattr(config, 'use_advanced_embeddings', False),
            'num_tasks_to_share_lora': getattr(config, 'num_tasks_to_share_lora', 0),
            'use_lora_early_stopping': getattr(config, 'use_lora_early_stopping', False)
        }
        config.replay_alpha, config.feature_replay_alpha, config.consolidation_method = 0.0, 0.0, "none"
        config.router_confidence_threshold = 1.1; config.replay_method = "none"
        config.experiment_tag = f"{sft_originals.get('experiment_tag', 'sft_baseline')}_SFT_RUN"
        config.load_router_state_from_tag = None; config.k_examples_for_prototype = 0
        config.use_advanced_embeddings = False; config.num_tasks_to_share_lora = 0
        config.use_lora_early_stopping = False
        consolidation_manager.config = config; replay_manager.config = config
        if replay_manager.replay_model is not None: logger.info("SFT: Disabling replay model."); replay_manager.replay_model = None
        if peft_manager.router is not None: peft_manager.router.config = config

    tasks_to_process = []
    if config.benchmark_tasks_to_run:
        for task_def in config.benchmark_tasks_to_run:
            train_ex, val_ex, class_map, original_task_text_field = load_benchmark_task(
                dataset_name=task_def["name"], tokenizer=backbone.tokenizer, text_field=task_def["text_field"],
                label_field=task_def["label_field"], num_train_samples=task_def.get("num_train_samples", config.num_train_samples_benchmark),
                num_val_samples=task_def.get("num_val_samples", config.num_val_samples_benchmark),
                task_id_prefix=task_def.get("id_prefix", "bench_"), class_names_map=task_def.get("class_names_map"),
                random_seed=42, config_name=task_def.get("config"), max_seq_length=config.max_seq_length)
            if train_ex:
                raw_task_id = f"{task_def.get('id_prefix', 'bench_')}{task_def['name']}_{task_def.get('config', 'default')}"
                tasks_to_process.append({"name": raw_task_id, "description": f"Benchmark: {task_def['name']} ({task_def.get('config', 'default')}) Task Description: {task_def.get('description_for_router', task_def['name'])}",
                    "train_examples": train_ex, "val_examples": val_ex, "class_names_map_from_load": class_map, "original_task_text_field": original_task_text_field })
            else: logger.error(f"Failed to load/format benchmark task: {task_def['name']} (config: {task_def.get('config')}). Skipping.")
    if not tasks_to_process:
        logger.error("CRITICAL: No tasks to process. Exiting.")
        if config.use_wandb and 'wandb' in globals() and wandb.run: wandb.finish()
        return

    all_task_val_sets: Dict[str, List[Dict[str,str]]] = {}; accuracy_matrix: Dict[Tuple[int, int], float] = {}
    task_id_to_matrix_idx: Dict[str, int] = {}; processed_task_info_for_eval = []
    cumulative_global_steps = 0; num_epochs_lora_train = config.num_lora_train_epochs

    num_tasks_to_share_lora = getattr(config, 'num_tasks_to_share_lora', 0)
    shared_lora_adapter_id_actual = ""

    for task_index, task_content_from_list in enumerate(tasks_to_process):
        raw_current_task_id = task_content_from_list["name"]; train_examples_raw = task_content_from_list["train_examples"]
        val_examples_raw = task_content_from_list["val_examples"]; task_description_for_router = task_content_from_list["description"]
        current_task_class_names_map = task_content_from_list.get("class_names_map_from_load")
        current_task_original_text_field = task_content_from_list["original_task_text_field"]
        if not current_task_class_names_map and (train_examples_raw or val_examples_raw):
            unique_outputs = sorted(list(set(ex['output'] for ex in train_examples_raw + val_examples_raw if 'output' in ex)));
            current_task_class_names_map = {i: name for i, name in enumerate(unique_outputs)}
        task_id_to_matrix_idx[raw_current_task_id] = task_index; all_task_val_sets[raw_current_task_id] = val_examples_raw
        
        logger.info(f"\\n--- Starting Task {task_index+1}/{len(tasks_to_process)}: {raw_current_task_id} (Exp Tag: {config.experiment_tag}) ---\\n"
                    f"  Task Description for Router: '{task_description_for_router[:100]}...'")

        if not train_examples_raw: logger.warning(f"Task {raw_current_task_id} has no training examples. Skipping training phase.");
        else: logger.info(f"  Train: {len(train_examples_raw)} examples, Val: {len(val_examples_raw)} examples, Classes: {current_task_class_names_map}, Max LoRA Epochs: {num_epochs_lora_train}")

        active_lora_to_train_tagged = ""; current_peft_model = peft_manager.get_current_peft_model()
        examples_for_prototype_decision: Optional[List[Dict[str,Any]]] = None
        if hasattr(config, 'k_examples_for_prototype') and config.k_examples_for_prototype > 0 and train_examples_raw:
            examples_for_prototype_decision = [ex['original_example_data'] for ex in train_examples_raw[:config.k_examples_for_prototype] if 'original_example_data' in ex and isinstance(ex['original_example_data'], dict)]
            if not examples_for_prototype_decision : logger.warning(f"No valid 'original_example_data' found for prototype for task {raw_current_task_id}")

        is_shared_lora_task = (num_tasks_to_share_lora > 0 and task_index < num_tasks_to_share_lora)

        if config.run_sft_baseline:
            active_lora_to_train_tagged = config.sft_lora_id
            if task_index == 0 or not (isinstance(peft_manager.model, PeftModel) and active_lora_to_train_tagged in peft_manager.model.peft_config):
                current_peft_model = peft_manager.create_lora_module(active_lora_to_train_tagged)
            else: peft_manager.activate_lora_modules(active_lora_to_train_tagged); current_peft_model = peft_manager.get_current_peft_model()
        elif is_shared_lora_task:
            if not shared_lora_adapter_id_actual:
                shared_lora_base_name = "SharedAdapter"
                task_name_parts = []
                for i in range(num_tasks_to_share_lora):
                    if i < len(tasks_to_process):
                        task_id_part = tasks_to_process[i]["name"]
                        meaningful_part = task_id_part.split('_')[-1] if '_' in task_id_part else task_id_part
                        if "glue" in meaningful_part and len(task_id_part.split('_')) > 1:
                            meaningful_part = task_id_part.split('_')[-1] if task_id_part.split('_')[-1] != "glue" else task_id_part.split('_')[-2]
                        task_name_parts.append(meaningful_part[:4].replace(".", "_"))
                shared_lora_base_name += "_" + "_".join(task_name_parts)
                clean_experiment_tag = config.experiment_tag.replace(".", "_") if config.experiment_tag else "default_exp"
                shared_lora_adapter_id_actual = f"{clean_experiment_tag}_{shared_lora_base_name}".replace(".", "_")
                logger.info(f"SharedLoRA: Task {task_index} ('{raw_current_task_id}') is the first shared task. Creating shared LoRA: '{shared_lora_adapter_id_actual}'")
                current_peft_model = peft_manager.create_lora_module(shared_lora_adapter_id_actual)
                active_lora_to_train_tagged = shared_lora_adapter_id_actual
            else:
                logger.info(f"SharedLoRA: Task {task_index} ('{raw_current_task_id}') reuses shared LoRA: '{shared_lora_adapter_id_actual}'")
                if not (isinstance(peft_manager.model, PeftModel) and shared_lora_adapter_id_actual in peft_manager.model.peft_config):
                    logger.error(f"SharedLoRA: Attempted to reuse '{shared_lora_adapter_id_actual}' but not found in model config. Creating it now.")
                    current_peft_model = peft_manager.create_lora_module(shared_lora_adapter_id_actual)
                else: peft_manager.activate_lora_modules(shared_lora_adapter_id_actual)
                current_peft_model = peft_manager.get_current_peft_model(); active_lora_to_train_tagged = shared_lora_adapter_id_actual
        else:
            potential_loras = peft_manager.select_lora_for_input(task_description=task_description_for_router, task_examples_for_prototype=examples_for_prototype_decision, task_text_field=current_task_original_text_field, top_k=1)
            selected_existing = False
            if potential_loras:
                router_sel_id_tagged, conf = potential_loras[0]
                if isinstance(peft_manager.model, PeftModel) and router_sel_id_tagged in peft_manager.model.peft_config and conf > config.router_confidence_threshold :
                    active_lora_to_train_tagged = router_sel_id_tagged; peft_manager.activate_lora_modules(active_lora_to_train_tagged); current_peft_model = peft_manager.get_current_peft_model(); selected_existing = True; logger.info(f"Router REUSED existing LoRA '{active_lora_to_train_tagged}' (Conf: {conf:.3f}) for raw task '{raw_current_task_id}'.")
                elif conf > config.router_confidence_threshold:
                    loaded_model = peft_manager.load_lora_module(router_sel_id_tagged, set_as_active=True)
                    if loaded_model: active_lora_to_train_tagged = router_sel_id_tagged; current_peft_model = loaded_model; selected_existing = True; logger.info(f"Router REUSED (loaded) LoRA '{active_lora_to_train_tagged}' (Conf: {conf:.3f}) for raw task '{raw_current_task_id}'.")
                    else: logger.warning(f"Router suggested reusing '{router_sel_id_tagged}' but it could not be loaded. Creating new."); selected_existing = False
            if not selected_existing:
                clean_experiment_tag = config.experiment_tag.replace(".", "_") if config.experiment_tag else "default_exp"
                clean_raw_current_task_id = raw_current_task_id.replace(".", "_")
                tag_prefix = f"{clean_experiment_tag}_"
                active_lora_to_train_tagged = f"{tag_prefix}{clean_raw_current_task_id}".replace(".", "_")
                current_peft_model = peft_manager.create_lora_module(active_lora_to_train_tagged)
                logger.info(f"Router/NewLogic: Using NEW LoRA '{active_lora_to_train_tagged}' for raw task '{raw_current_task_id}'. Router top choice (if any): {potential_loras}")

        processed_task_info_for_eval.append({"id": raw_current_task_id, "class_names_map": current_task_class_names_map,
            "num_train_examples": len(train_examples_raw), "lora_trained_on_this_task": active_lora_to_train_tagged,
            "task_description_for_router": task_description_for_router, "original_task_text_field": current_task_original_text_field })
        replay_manager.backbone_model = current_peft_model
        trainable_params = [p for p in current_peft_model.parameters() if p.requires_grad]; optimizer = None
        if trainable_params and train_examples_raw: optimizer = torch.optim.AdamW(trainable_params, lr=config.learning_rate, weight_decay=config.weight_decay)

        if train_examples_raw and optimizer:
            current_task_train_dataset = CausalLMTrainingDataset(train_examples_raw, backbone.tokenizer, config.max_seq_length)
            current_num_workers = getattr(config, 'num_workers', 0)
            current_task_train_dataloader = TorchDataLoader(current_task_train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=current_num_workers, pin_memory=(config.device=="cuda"))
            task_samples_for_replay_model_training_this_task = []

            best_val_metric_for_early_stop = -float('inf') if config.lora_early_stopping_metric == "accuracy" else float('inf')
            early_stop_patience_counter = 0
            stopped_early_this_task = False
            actual_epochs_run_for_task = 0

            for epoch_num in range(num_epochs_lora_train):
                actual_epochs_run_for_task += 1
                logger.info(f"  Task {raw_current_task_id} - LoRA Training Epoch {epoch_num+1}/{num_epochs_lora_train} for LoRA: {active_lora_to_train_tagged}")
                epoch_losses, epoch_accuracies, epoch_grad_norms, epoch_entropies = [], [], [], []
                batch_iterator_desc = f"Epoch {epoch_num+1} LoRA:{active_lora_to_train_tagged[:20]} Task:{raw_current_task_id[:15]}"
                batch_iterator = tqdm(current_task_train_dataloader, desc=batch_iterator_desc, leave=False) if len(current_task_train_dataloader) > 1 else current_task_train_dataloader
                for batch_idx, batch_data in enumerate(batch_iterator):
                    global_step_for_wandb = cumulative_global_steps + ((actual_epochs_run_for_task -1) * len(current_task_train_dataloader)) + batch_idx
                    current_batch_raw_samples_list = [{"input": batch_data["raw_input_text"][i], "output": batch_data["raw_output_text"][i]} for i in range(len(batch_data["input_ids"]))]
                    if epoch_num == 0 : task_samples_for_replay_model_training_this_task.extend(current_batch_raw_samples_list)
                    loss_val, acc_metric_batch_avg, grad_n, entr_val = 5.0, 0.0, 0.0, 3.0
                    sup_rep_n, sup_rep_loss_val = 0,0.0; feat_rep_n, feat_rep_loss_val = 0,0.0
                    current_peft_model.train(); optimizer.zero_grad()
                    ids_curr, attn_curr, lbls_curr = batch_data["input_ids"].to(config.device), batch_data["attention_mask"].to(config.device), batch_data["labels"].to(config.device)
                    outputs = current_peft_model(input_ids=ids_curr, attention_mask=attn_curr, labels=lbls_curr)
                    loss_curr = outputs.loss
                    loss_curr = torch.tensor(5.0, device=config.device, dtype=torch.float32) if loss_curr is None or not torch.isfinite(loss_curr) else loss_curr
                    loss_val = loss_curr.item()
                    tot_sup_loss, tot_feat_loss = torch.tensor(0.0, device=config.device, dtype=loss_curr.dtype), torch.tensor(0.0, device=config.device, dtype=loss_curr.dtype)
                    if not config.run_sft_baseline and task_index > 0 and (config.replay_alpha > 0 or config.feature_replay_alpha > 0) and replay_manager.replay_model :
                        avail_replay_tasks = replay_manager.get_tasks_with_replay_data(exclude_task_id=raw_current_task_id)
                        if avail_replay_tasks:
                            replay_task_id_raw = np.random.choice(avail_replay_tasks); num_replay_samples_this_step = max(1, ids_curr.size(0) // 2)
                            if config.replay_alpha > 0 and np.random.rand() < 0.7:
                                raw_re_ex_list = replay_manager.get_raw_samples_for_replay(replay_task_id_raw, num_samples=num_replay_samples_this_step)
                                if raw_re_ex_list:
                                    replay_dataset_temp = CausalLMTrainingDataset(raw_re_ex_list, backbone.tokenizer, config.max_seq_length)
                                    current_replay_batch_size = min(len(raw_re_ex_list), config.batch_size)
                                    replay_dl_temp = TorchDataLoader(replay_dataset_temp, batch_size=current_replay_batch_size, num_workers=config.num_workers)
                                    for replay_batch_data in replay_dl_temp:
                                        out_re_sup = current_peft_model(input_ids=replay_batch_data["input_ids"].to(config.device), attention_mask=replay_batch_data["attention_mask"].to(config.device), labels=replay_batch_data["labels"].to(config.device))
                                        if out_re_sup.loss is not None and torch.isfinite(out_re_sup.loss): tot_sup_loss += out_re_sup.loss * len(raw_re_ex_list); sup_rep_n += len(raw_re_ex_list)
                                        break
                            elif config.feature_replay_alpha > 0 and hasattr(replay_manager.replay_model, 'generate_samples'):
                                gen_units = replay_manager.generate_replay_units(replay_task_id_raw, num_units=num_replay_samples_this_step)
                                if gen_units:
                                    h_past_batch = torch.stack([unit['hidden_state'] for unit in gen_units]).to(config.device).to(current_peft_model.dtype if hasattr(current_peft_model, 'dtype') else torch.float16)
                                    if hasattr(replay_manager.replay_model, '_tokenize_batch_for_encode') and hasattr(replay_manager.replay_model, 'encode_input_batch'):
                                        current_texts_for_feat_replay = [batch_data["raw_input_text"][i] for i in range(min(num_replay_samples_this_step, len(batch_data["raw_input_text"])))]
                                        if current_texts_for_feat_replay:
                                            tok_curr_feat_batch = replay_manager.replay_model._tokenize_batch_for_encode(current_texts_for_feat_replay)
                                            h_curr_batch = replay_manager.replay_model.encode_input_batch(tok_curr_feat_batch)
                                            h_curr_batch = h_curr_batch.to(config.device).to(h_past_batch.dtype)
                                            if h_curr_batch.shape[0] == h_past_batch.shape[0] and h_curr_batch.shape[1] == h_past_batch.shape[1] :
                                                loss_frs_mse = F.mse_loss(h_curr_batch.float(), h_past_batch.float())
                                                if torch.isfinite(loss_frs_mse): tot_feat_loss += loss_frs_mse * h_past_batch.size(0); feat_rep_n += h_past_batch.size(0)
                    comb_loss = loss_curr
                    if sup_rep_n > 0 and config.replay_alpha > 0 and torch.isfinite(tot_sup_loss): avg_s_loss = tot_sup_loss / sup_rep_n; sup_rep_loss_val = avg_s_loss.item(); comb_loss += config.replay_alpha * avg_s_loss
                    if feat_rep_n > 0 and config.feature_replay_alpha > 0 and torch.isfinite(tot_feat_loss): avg_f_loss = tot_feat_loss / feat_rep_n; feat_rep_loss_val = avg_f_loss.item(); comb_loss += config.feature_replay_alpha * avg_f_loss
                    if torch.isfinite(comb_loss): comb_loss.backward(); grad_n = sum(p.grad.detach().data.norm(2).item() ** 2 for p in trainable_params if p.grad is not None and torch.is_tensor(p.grad)) ** 0.5; optimizer.step()
                    else: grad_n = 0.0; logger.warning(f"Skipping optimizer step due to invalid combined_loss for task {raw_current_task_id}, batch {batch_idx}. Loss: {comb_loss.item() if torch.is_tensor(comb_loss) else comb_loss}")
                    current_peft_model.eval(); batch_correct_preds = 0
                    with torch.no_grad():
                        for i in range(ids_curr.size(0)):
                            prompt_ids_gen = backbone.tokenizer(batch_data["raw_input_text"][i], return_tensors="pt", add_special_tokens=True, truncation=True, max_length=config.max_seq_length - 50).input_ids.to(config.device)
                            task_cls_names = current_task_class_names_map; max_cls_len = 10
                            if task_cls_names and isinstance(task_cls_names, dict) and task_cls_names.values(): max_cls_len = max(len(backbone.tokenizer.encode(name, add_special_tokens=False)) for name in task_cls_names.values()) if task_cls_names else 10
                            max_g_toks = max(1, min(max_cls_len + 5, config.max_seq_length - prompt_ids_gen.shape[1] -1))
                            gen_ids = current_peft_model.generate(input_ids=prompt_ids_gen,attention_mask=torch.ones_like(prompt_ids_gen),max_new_tokens=max_g_toks,pad_token_id=backbone.tokenizer.pad_token_id,eos_token_id=backbone.tokenizer.eos_token_id,do_sample=False, temperature=0.7, top_p=0.9)
                            gen_txt_cls = backbone.tokenizer.decode(gen_ids[0][prompt_ids_gen.shape[1]:], skip_special_tokens=True).strip()
                            if gen_txt_cls.lower() == str(batch_data["raw_output_text"][i]).lower(): batch_correct_preds += 1
                    acc_metric_batch_avg = batch_correct_preds / ids_curr.size(0) if ids_curr.size(0) > 0 else 0.0
                    epoch_losses.append(loss_val); epoch_accuracies.append(acc_metric_batch_avg); epoch_grad_norms.append(grad_n);
                    if config.use_wandb and 'wandb' in globals() and wandb.run:
                        wandb.log({
                            f"step_metrics/{raw_current_task_id.replace('/', '_')}/loss": loss_val,
                            "accuracy": acc_metric_batch_avg,
                            "grad_norm": grad_n,
                            f"step_replay/{raw_current_task_id.replace('/', '_')}/sup_n": sup_rep_n,
                            "sup_loss": sup_rep_loss_val,
                            "feat_n": feat_rep_n,
                            "feat_loss": feat_rep_loss_val
                        }, step=int(global_step_for_wandb))


                # ---> START EARLY STOPPING LOGIC <---
                if config.use_lora_early_stopping and actual_epochs_run_for_task >= config.min_lora_epochs_before_early_stop:
                    logger.info(f"    ES Check: Epoch {actual_epochs_run_for_task}, evaluating on val set for {raw_current_task_id} with LoRA {active_lora_to_train_tagged}")
                    current_peft_model.eval()

                    es_corr, es_tot = 0, 0
                    if val_examples_raw:
                        for es_val_ex_item in val_examples_raw:
                            es_raw_input_text_val = es_val_ex_item['input']
                            es_raw_target_text_val = es_val_ex_item['output']
                            with torch.no_grad():
                                es_val_p_ids = backbone.tokenizer(es_raw_input_text_val, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=config.max_seq_length - 50).input_ids.to(config.device)
                                es_val_max_cls_len = 10
                                if current_task_class_names_map and isinstance(current_task_class_names_map, dict) and current_task_class_names_map.values():
                                    es_val_max_cls_len = max(len(backbone.tokenizer.encode(name, add_special_tokens=False)) for name in current_task_class_names_map.values()) if current_task_class_names_map else 10
                                es_val_max_toks = max(1, min(es_val_max_cls_len + 5, config.max_seq_length - es_val_p_ids.shape[1] -1))
                                es_val_gen_ids = current_peft_model.generate(
                                    input_ids=es_val_p_ids, attention_mask=torch.ones_like(es_val_p_ids),
                                    max_new_tokens=es_val_max_toks, pad_token_id=backbone.tokenizer.pad_token_id,
                                    eos_token_id=backbone.tokenizer.eos_token_id, do_sample=False, temperature=0.7, top_p=0.9
                                )
                                es_val_gen_txt = backbone.tokenizer.decode(es_val_gen_ids[0][es_val_p_ids.shape[1]:], skip_special_tokens=True).strip()
                                if es_val_gen_txt.lower() == str(es_raw_target_text_val).lower():
                                    es_corr += 1
                            es_tot += 1
                        current_val_metric = (es_corr / es_tot) if es_tot > 0 else 0.0
                        logger.info(f"    ES Check: Val Acc for {raw_current_task_id} on LoRA {active_lora_to_train_tagged}: {current_val_metric:.4f} (Best: {best_val_metric_for_early_stop:.4f})")
                        if config.use_wandb and 'wandb' in globals() and wandb.run:
                            wandb.log({f"early_stopping_val_metrics/{raw_current_task_id.replace('/', '_')}/val_accuracy_epoch_{actual_epochs_run_for_task}": current_val_metric}, step=int(cumulative_global_steps + actual_epochs_run_for_task * len(current_task_train_dataloader) -1))

                        if config.lora_early_stopping_metric == "accuracy":
                            if current_val_metric > best_val_metric_for_early_stop + config.lora_early_stopping_delta:
                                best_val_metric_for_early_stop = current_val_metric
                                early_stop_patience_counter = 0
                                logger.info(f"    ES Check: New best val_acc: {best_val_metric_for_early_stop:.4f}. Patience reset.")
                            else:
                                early_stop_patience_counter += 1
                                logger.info(f"    ES Check: No improvement. Patience: {early_stop_patience_counter}/{config.lora_early_stopping_patience}")

                        if early_stop_patience_counter >= config.lora_early_stopping_patience:
                            logger.info(f"  EARLY STOPPING LoRA training for {raw_current_task_id} at epoch {actual_epochs_run_for_task} due to no improvement on validation set.")
                            stopped_early_this_task = True
                    else:
                        logger.warning(f"    ES Check: No validation examples for {raw_current_task_id}. Cannot perform early stopping.")
                # ---> END EARLY STOPPING LOGIC <---

                if task_samples_for_replay_model_training_this_task and (config.replay_alpha > 0 or config.feature_replay_alpha > 0) and replay_manager.replay_model is not None and not config.run_sft_baseline :
                    if epoch_num == 0: replay_manager.add_task_samples(raw_current_task_id, task_samples_for_replay_model_training_this_task, train_replay_model_now=False)
                    logger.info(f"Epoch {actual_epochs_run_for_task} ended. Training replay model for task '{raw_current_task_id}' using {len(replay_manager.replay_buffer.get(raw_current_task_id, []))} samples.")
                    replay_manager.train_replay_model(raw_current_task_id, replay_manager.replay_buffer.get(raw_current_task_id, []))

                _e_val = 3.0
                if 'batch_data' in locals() and batch_data:
                    e_ids, e_mask = batch_data["input_ids"].to(config.device), batch_data["attention_mask"].to(config.device)
                    lbls_curr_entropy = batch_data["labels"].to(config.device); non_pad_lbls_first = lbls_curr_entropy[0][lbls_curr_entropy[0] != -100]
                    e_prompt_len = e_ids.shape[1]
                    if non_pad_lbls_first.numel() > 0: lbl_starts_first = (lbls_curr_entropy[0] != -100).nonzero(as_tuple=True)[0]; e_prompt_len = lbl_starts_first[0].item() if lbl_starts_first.numel() > 0 else e_ids.shape[1]
                    with torch.no_grad(): current_peft_model.eval(); out_e = current_peft_model(input_ids=e_ids, attention_mask=e_mask); logits_e = out_e.logits
                    if not (torch.isnan(logits_e).any() or torch.isinf(logits_e).any()):
                        logits_prob = logits_e;
                        if e_prompt_len > 0 and e_prompt_len < logits_e.shape[1]: logits_prob = logits_e[:, e_prompt_len:, :]
                        if logits_prob.numel() > 0 and logits_prob.shape[0]>0 and logits_prob.shape[1] > 0 :
                            probs_e = F.softmax(logits_prob.float(), dim=-1);
                            if not torch.isnan(probs_e).any(): log_probs_e = F.log_softmax(logits_prob.float(), dim=-1); tok_e = -(probs_e * log_probs_e).sum(dim=-1)
                            if not (torch.isnan(tok_e).any() or torch.isinf(tok_e).any()) and tok_e.numel() > 0: _e_val = tok_e.mean().item()
                entr_val = _e_val;
                if np.isnan(entr_val) or np.isinf(entr_val): entr_val = 3.0
                epoch_entropies.append(entr_val); avg_epoch_loss = np.mean(epoch_losses) if epoch_losses else 0.0
                avg_epoch_acc = np.mean(epoch_accuracies) if epoch_accuracies else 0.0; avg_epoch_grad = np.mean(epoch_grad_norms) if epoch_grad_norms else 0.0
                avg_epoch_entropy = np.mean(epoch_entropies) if epoch_entropies else 0.0
                logger.info(f"  LoRA Epoch {actual_epochs_run_for_task} Avg Metrics - Loss: {avg_epoch_loss:.4f}, Acc: {avg_epoch_acc:.4f}, GradN: {avg_epoch_grad:.4f}, Entropy (last batch): {avg_epoch_entropy:.4f}")
                epoch_avg_metrics_nm = {"accuracy": avg_epoch_acc, "nll": avg_epoch_loss, "gradient_norm": avg_epoch_grad, "entropy": avg_epoch_entropy}
                gamma_nm_epoch = neuromod_manager.compute_gamma_gain(raw_current_task_id, epoch_avg_metrics_nm)
                if not config.run_sft_baseline and config.consolidation_method == "aflora": consolidation_manager.consolidate(active_lora_to_train_tagged, current_peft_model, gamma_nm_epoch, data_loader=None)
                if config.use_wandb and 'wandb' in globals() and wandb.run:
                    wandb.log({
                        f"epoch_metrics/{raw_current_task_id.replace('/', '_')}/avg_loss": avg_epoch_loss,
                        f"epoch_metrics/{raw_current_task_id.replace('/', '_')}/avg_accuracy": avg_epoch_acc,
                        f"epoch_metrics/{raw_current_task_id.replace('/', '_')}/avg_grad_norm": avg_epoch_grad,
                        f"epoch_metrics/{raw_current_task_id.replace('/', '_')}/avg_entropy": avg_epoch_entropy,
                        f"epoch_neuromod/{raw_current_task_id.replace('/', '_')}/gamma_gain": gamma_nm_epoch
                    }, step=int(cumulative_global_steps + actual_epochs_run_for_task * len(current_task_train_dataloader) -1))

                if stopped_early_this_task:
                    break

            if train_examples_raw and optimizer:
                cumulative_global_steps += actual_epochs_run_for_task * len(current_task_train_dataloader)


        logger.info(f"--- Evaluating after Task {task_index+1}: {raw_current_task_id} (after {actual_epochs_run_for_task} LoRA epochs) ---")
        current_peft_model.eval()
        for j_eval in range(task_index + 1):
            eval_task_info = processed_task_info_for_eval[j_eval]; eval_task_id_raw = eval_task_info["id"]
            eval_val_ex = all_task_val_sets[eval_task_id_raw]; eval_cls_map = eval_task_info["class_names_map"]
            eval_lora_to_activate_tagged = eval_task_info["lora_trained_on_this_task"]; can_eval = False
            eval_active_model = None
            if isinstance(peft_manager.model, PeftModel):
                if eval_lora_to_activate_tagged in peft_manager.model.peft_config: peft_manager.activate_lora_modules(eval_lora_to_activate_tagged); eval_active_model = peft_manager.get_current_peft_model(); can_eval = True
                else:
                    logger.info(f"Attempting to load LoRA '{eval_lora_to_activate_tagged}' for evaluation of task {eval_task_id_raw}.")
                    lm_eval = peft_manager.load_lora_module(eval_lora_to_activate_tagged, set_as_active=True)
                    if lm_eval and eval_lora_to_activate_tagged in peft_manager.model.peft_config: eval_active_model = lm_eval; can_eval = True
            if not can_eval or eval_active_model is None: logger.warning(f"Cannot evaluate LoRA {eval_lora_to_activate_tagged} for task {eval_task_id_raw}. Not found or not loadable."); accuracy_matrix[(task_index, j_eval)] = 0.0; continue
            eval_active_model.eval(); corr, tot = 0,0
            if eval_val_ex:
                for val_ex_idx, val_ex_item in enumerate(eval_val_ex):
                    raw_input_text_val = val_ex_item['input']; raw_target_text_val = val_ex_item['output']
                    with torch.no_grad():
                        val_p_ids = backbone.tokenizer(raw_input_text_val, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=config.max_seq_length - 50).input_ids.to(config.device)
                        val_max_cls_len = 10
                        if eval_cls_map and isinstance(eval_cls_map, dict) and eval_cls_map.values(): val_max_cls_len = max(len(backbone.tokenizer.encode(name, add_special_tokens=False)) for name in eval_cls_map.values()) if eval_cls_map else 10
                        val_max_toks = max(1, min(val_max_cls_len + 5, config.max_seq_length - val_p_ids.shape[1] -1))
                        val_gen_ids = eval_active_model.generate(input_ids=val_p_ids,attention_mask=torch.ones_like(val_p_ids),max_new_tokens=val_max_toks,pad_token_id=backbone.tokenizer.pad_token_id,eos_token_id=backbone.tokenizer.eos_token_id,do_sample=False, temperature=0.7, top_p=0.9)
                        val_gen_txt = backbone.tokenizer.decode(val_gen_ids[0][val_p_ids.shape[1]:], skip_special_tokens=True).strip()
                        if val_gen_txt.lower() == str(raw_target_text_val).lower(): corr += 1
                    tot += 1
                task_acc_val = corr / tot if tot > 0 else 0.0; accuracy_matrix[(task_index, j_eval)] = task_acc_val
                logger.info(f"    Acc on Task_{j_eval} ({eval_task_id_raw}) (LoRA '{eval_lora_to_activate_tagged}') after Task_{task_index} ({raw_current_task_id}): {task_acc_val:.4f}")
                current_eval_step = cumulative_global_steps -1 if (train_examples_raw and optimizer) else cumulative_global_steps
                if config.use_wandb and 'wandb' in globals() and wandb.run:
                    wandb.log({f"eval_acc/acc_task{j_eval}_after_task{task_index}_{eval_task_id_raw.replace('/', '_')}_{eval_lora_to_activate_tagged.replace('/', '_')}": task_acc_val}, step=int(current_eval_step))
            else: accuracy_matrix[(task_index, j_eval)] = 0.0
        current_eval_step_cl = cumulative_global_steps -1 if (train_examples_raw and optimizer) else cumulative_global_steps
        if accuracy_matrix and task_index >=0 :
            avg_acc_k = sum(accuracy_matrix.get((task_index, j),0.0) for j in range(task_index + 1)) / (task_index + 1)
            logger.info(f"  AvgAcc after task {raw_current_task_id} (Task {task_index}): {avg_acc_k:.4f}")
            if config.use_wandb and 'wandb' in globals() and wandb.run:
                wandb.log({f"cl_metrics/avg_acc_upto_task_{task_index}_{raw_current_task_id.replace('/', '_')}": avg_acc_k}, step=int(current_eval_step_cl))
            if task_index > 0:
                bwt_s, bwt_n = 0.0, 0
                for j_bwt in range(task_index): acc_kj, acc_jj = accuracy_matrix.get((task_index, j_bwt), None), accuracy_matrix.get((j_bwt, j_bwt), None);
                if acc_kj is not None and acc_jj is not None: bwt_s += (acc_kj - acc_jj); bwt_n +=1
                bwt_curr = bwt_s / bwt_n if bwt_n > 0 else 0.0; logger.info(f"  BWT after task {raw_current_task_id} (Task {task_index}): {bwt_curr:.4f}")
                if config.use_wandb and 'wandb' in globals() and wandb.run:
                     wandb.log({f"cl_metrics/bwt_upto_task_{task_index}_{raw_current_task_id.replace('/', '_')}": bwt_curr}, step=int(current_eval_step_cl))

        if not config.run_sft_baseline and config.consolidation_method == "ewc" and train_examples_raw and optimizer:
            ewc_dl = None
            if config.replay_method == "none" and train_examples_raw and config.ewc_data_loader_num_samples > 0:
                logger.info(f"EWC Main: Replay is off. Using {config.ewc_data_loader_num_samples} samples from current task '{raw_current_task_id}' for EWC Fisher calculation.")
                num_samples_for_ewc_fisher = min(len(train_examples_raw), config.ewc_data_loader_num_samples)
                if num_samples_for_ewc_fisher > 0:
                    ewc_fisher_samples_indices = np.random.choice(len(train_examples_raw), num_samples_for_ewc_fisher, replace=False)
                    ewc_fisher_samples = [train_examples_raw[i] for i in ewc_fisher_samples_indices]
                    ewc_ds = EWCDataset(ewc_fisher_samples, backbone.tokenizer, config.max_seq_length)
                    ewc_dl = TorchDataLoader(ewc_ds, batch_size=config.ewc_batch_size, shuffle=True, num_workers=config.num_workers)
                else: logger.warning(f"EWC Main: Not enough samples in train_examples_raw for task '{raw_current_task_id}' to create EWC DataLoader.")
            elif raw_current_task_id in replay_manager.replay_buffer and replay_manager.replay_buffer[raw_current_task_id] and replay_manager.replay_model:
                raw_ewc = replay_manager.get_raw_samples_for_replay(raw_current_task_id, config.ewc_data_loader_num_samples )
                if raw_ewc: ewc_ds = EWCDataset(raw_ewc, backbone.tokenizer, config.max_seq_length); ewc_dl = TorchDataLoader(ewc_ds, batch_size=config.ewc_batch_size, shuffle=True, num_workers=config.num_workers)
            if ewc_dl is not None or config.ewc_data_loader_num_samples == 0 :
                final_gamma_for_task_consolidation = neuromod_manager.gamma_gain_history.get(raw_current_task_id, [0.5])[-1]
                if not (isinstance(peft_manager.model, PeftModel) and active_lora_to_train_tagged in peft_manager.model.peft_config):
                    logger.error(f"EWC Consolidation: LoRA '{active_lora_to_train_tagged}' not in model config. Cannot consolidate.")
                else:
                    peft_manager.activate_lora_modules(active_lora_to_train_tagged)
                    current_peft_model_for_ewc = peft_manager.get_current_peft_model()
                    consolidation_manager.consolidate(active_lora_to_train_tagged, current_peft_model_for_ewc, final_gamma_for_task_consolidation, data_loader=ewc_dl)
            else: logger.warning(f"EWC Consolidation for {active_lora_to_train_tagged} skipped as ewc_dl could not be prepared and ewc_data_loader_num_samples > 0.")

        if not config.run_sft_baseline and task_index >= 0 and config.router_training_epochs_per_call > 0:
            router_training_feed_data = []
            for i_task_processed in range(task_index + 1):
                task_info_for_router = processed_task_info_for_eval[i_task_processed]
                lora_id_for_this_task_in_current_run = task_info_for_router["lora_trained_on_this_task"]
                desc_for_router_train = task_info_for_router["task_description_for_router"]
                router_training_feed_data.append(((desc_for_router_train, None, None), lora_id_for_this_task_in_current_run ))
            if router_training_feed_data:
                logger.info(f"Preparing to train router with {len(router_training_feed_data)} (desc_only) -> LoRA ID pairs.")
                peft_manager.train_router(router_training_feed_data)
            else: logger.info("No data prepared for router training after this task.")

        current_plot_step = cumulative_global_steps -1 if (train_examples_raw and optimizer) else cumulative_global_steps
        if train_examples_raw and optimizer:
            neuromod_manager.save_metrics(raw_current_task_id)
            plot_filename = f"{config.experiment_tag}_{raw_current_task_id}_gamma_plot.png".replace(".", "_") if config.experiment_tag else f"{raw_current_task_id}_gamma_plot.png".replace(".", "_")
            plot_dir = os.path.join(config.output_dir, config.experiment_tag.replace(".", "_")) if config.experiment_tag else config.output_dir
            os.makedirs(plot_dir, exist_ok=True); plot_path = os.path.join(plot_dir, plot_filename)
            fig_nm = None
            try:
                fig_nm = neuromod_manager.plot_gamma_gain(raw_current_task_id, save_path=plot_path)
                if fig_nm and config.use_wandb and 'wandb' in globals() and wandb.run:
                    wandb.log({f"plots/{config.experiment_tag.replace('.', '_') or 'generic'}/{raw_current_task_id.replace('/', '_')}/gamma_history": wandb.Image(fig_nm)}, step=int(current_plot_step))
            except Exception as e: logger.error(f"Plot/log gamma_gain error for '{raw_current_task_id}': {e}")
            finally:
                if fig_nm: plt.close(fig_nm)
        if train_examples_raw and optimizer: peft_manager.save_lora_module(active_lora_to_train_tagged, current_peft_model)
        if config.device == "cuda": torch.cuda.empty_cache(); gc.collect()

    if not config.run_sft_baseline and peft_manager is not None and hasattr(peft_manager, 'save_router_state'):
        logger.info("Attempting to save router state at the end of the experiment.")
        peft_manager.save_router_state()
    logger.info("\\n--- Experiment Main Loop Completed ---\\n")
    if config.use_wandb and 'wandb' in globals() and wandb.run and accuracy_matrix:
        final_k = len(tasks_to_process) -1
        if final_k >=0 :
            final_eval_tasks_count = sum(1 for j in range(final_k + 1) if (final_k, j) in accuracy_matrix)
            if final_eval_tasks_count > 0:
                final_avg = sum(accuracy_matrix.get((final_k, j), 0.0) for j in range(final_k + 1)) / final_eval_tasks_count
                wandb.summary["final_average_accuracy"] = final_avg; logger.info(f"Final Avg Acc: {final_avg:.4f}")
                if final_k > 0:
                    bwt_f_sum, bwt_f_n = 0.0, 0
                    for j_f_bwt in range(final_k): acc_kf, acc_jj = accuracy_matrix.get((final_k, j_f_bwt), None), accuracy_matrix.get((j_f_bwt, j_f_bwt), None);
                    if acc_kf is not None and acc_jj is not None: bwt_f_sum += (acc_kf - acc_jj); bwt_f_n +=1
                    if bwt_f_n > 0: final_b = bwt_f_sum / bwt_f_n; wandb.summary["final_backward_transfer"] = final_b; logger.info(f"Final BWT: {final_b:.4f}")
                    else: wandb.summary["final_backward_transfer"] = 0.0; logger.info(f"Final BWT: 0.0000 (not enough prior tasks for calculation)")
    if config.run_sft_baseline and sft_originals:
        logger.info("--- SFT BASELINE RUN COMPLETE - Restoring original config values ---")
        for key, val in sft_originals.items(): setattr(config, key, val)
        consolidation_manager.config = config; replay_manager.config = config
        if peft_manager.router is not None: peft_manager.router.config = config
        if sft_originals.get('replay_method') != "none" and replay_manager.replay_model is None:
            logger.info(f"SFT Restore: Re-initializing replay model of type '{config.replay_method}'")
            current_base_model_for_replay = peft_manager.get_current_peft_model()
            if config.replay_method == "cmt": replay_manager.replay_model = CMTReplay(config, current_base_model_for_replay, backbone.tokenizer)
            elif config.replay_method == "pcgr": replay_manager.replay_model = PCGRReplay(config, current_base_model_for_replay, backbone.tokenizer)
        logger.info("--- Config restored after SFT baseline. ---")

In [15]:
import os
# Assuming logger is configured
logger.info(f"HF_TOKEN is set: {os.environ.get('HF_TOKEN') is not None}")
logger.info(f"WANDB_API_KEY is set: {os.environ.get('WANDB_API_KEY') is not None}")
logger.info(f"WANDB_PROJECT: {os.environ.get('WANDB_PROJECT')}")
logger.info(f"WANDB_ENTITY: {os.environ.get('WANDB_ENTITY')}")

2025-05-20 16:15:14,592 - __main__ - INFO - HF_TOKEN is set: True
2025-05-20 16:15:14,593 - __main__ - INFO - WANDB_API_KEY is set: True
2025-05-20 16:15:14,594 - __main__ - INFO - WANDB_PROJECT: RESTART01
2025-05-20 16:15:14,594 - __main__ - INFO - WANDB_ENTITY: doingmyownthing82-none


In [None]:
# Cell 17: MVE-DataScale-1.1: Testing with Larger Datasets (Adjusted Batch Size)

import os
import time
import torch
import gc
import logging
import numpy as np

# --- Task Definitions ---

# --- Task Definitions for MVE-DataScale-1.1 (5-task sequence, larger datasets) ---
# Base descriptions and fields remain the same, only sample counts change.

sst2_task_def_datascale = {
    "name": "glue", "config": "sst2", "text_field": "sentence", "label_field": "label",
    "class_names_map": {0: "negative", 1: "positive"}, "id_prefix": "ds_sst2_",
    "num_train_samples": 2000, "num_val_samples": 200, # Increased
    "description_for_router": "Sentiment analysis of movie reviews (SST-2)."
}
rte_task_def_datascale = {
    "name": "glue", "config": "rte", "text_field": ["sentence1", "sentence2"], "label_field": "label",
    "class_names_map": {0: "entailment", 1: "not_entailment"}, "id_prefix": "ds_rte_",
    "num_train_samples": 2000, "num_val_samples": 200, # Increased (will use max available for RTE)
    "description_for_router": "Recognizing Textual Entailment (RTE)."
}
mrpc_task_def_datascale = {
    "name": "glue", "config": "mrpc", "text_field": ["sentence1", "sentence2"], "label_field": "label",
    "class_names_map": {0: "not_equivalent", 1: "equivalent"}, "id_prefix": "ds_mrpc_",
    "num_train_samples": 2000, "num_val_samples": 200, # Increased (will use max available for MRPC)
    "description_for_router": "Paraphrase identification (MRPC)."
}
qnli_task_def_datascale = {
    "name": "glue", "config": "qnli", "text_field": ["question", "sentence"], "label_field": "label",
    "class_names_map": {0: "entailment", 1: "not_entailment"}, "id_prefix": "ds_qnli_",
    "num_train_samples": 2000, "num_val_samples": 200, # Increased
    "description_for_router": "Question answering entailment (QNLI)."
}
cola_task_def_datascale = {
    "name": "glue", "config": "cola", "text_field": "sentence", "label_field": "label",
    "class_names_map": {0: "unacceptable", 1: "acceptable"}, "id_prefix": "ds_cola_",
    "num_train_samples": 2000, "num_val_samples": 200, # Increased
    "description_for_router": "Corpus of Linguistic Acceptability (CoLA) - grammaticality."
}

benchmark_tasks_for_datascale_exp = [ # 5-task list with new sample sizes
    sst2_task_def_datascale,
    rte_task_def_datascale,
    mrpc_task_def_datascale,
    qnli_task_def_datascale,
    cola_task_def_datascale
]
# --- End of Task Definitions ---


# --- Base Configuration Function (can be reused or adapted if needed) ---
def get_base_config_for_mve(experiment_tag_suffix: str, task_list: List[Dict[str, Any]]) -> 'AdaptiveLearnerConfig':
    config = AdaptiveLearnerConfig() # Gets defaults, including early stopping params
    config.benchmark_tasks_to_run = task_list
    clean_suffix = experiment_tag_suffix.replace(".", "_")
    config.experiment_tag = f"MVE_CL_{clean_suffix}" # Generic prefix, specific MVEs will refine
    config.run_sft_baseline = False; config.router_type = "linear"
    config.router_confidence_threshold = 1.1 ; config.use_advanced_embeddings = False
    config.k_examples_for_prototype = 0 ; config.load_router_state_from_tag = None
    config.router_training_epochs_per_call = 0
    config.num_tasks_to_share_lora = 2 # Default from SL-F series
    config.num_workers = 2; config.use_wandb = True
    # Default CL strategy from SL-F_Rep1 / SL-F_Rep1_ES
    config.consolidation_method = "ewc"; config.ewc_lambda = 250.0
    config.ewc_fixed_lambda_bypass_gamma = False
    config.ewc_data_loader_num_samples = 50; config.gamma_gain_lambda = 3.0
    config.replay_method = "cmt"; config.replay_alpha = 0.2; config.feature_replay_alpha = 0.1
    config.batch_size = 4; config.num_lora_train_epochs = 12 # Max epochs
    config.ewc_batch_size = 10
    config.replay_model_internal_batch_size = 4; config.replay_backbone_encoding_batch_size = 16
    config.learning_rate = 5e-5 # Default LR
    return config

# --- Config Function for MVE-DataScale-1.1 (Adjusted) ---
def get_config_variant_datascale_1_1_bs8() -> 'AdaptiveLearnerConfig':
    # Start with a base that includes early stopping and SL-F_Rep1 CL settings
    config = get_base_config_for_mve("DataScale_2Ktrain_BS8_LR5e-5_ES_seed123", benchmark_tasks_for_datascale_exp)
    
    config.wandb_project = "adaptive-learner-cl-datascale" # New W&B project

    # Data scaling specific changes
    # Task list is already set by benchmark_tasks_for_datascale_exp
    
    # Training hyperparameters
    config.batch_size = 8  # Adjusted batch size
    config.learning_rate = 5e-5 # Kept original LR for BS=8

    config.use_lora_early_stopping = True # Ensure Early Stopping is enabled
    
    # Inherit other CL parameters from get_base_config_for_mve which are based on SL-F_Rep1
    config.num_tasks_to_share_lora = 2
    config.consolidation_method = "ewc"
    config.ewc_lambda = 250.0
    config.ewc_fixed_lambda_bypass_gamma = False
    config.ewc_data_loader_num_samples = 50
    config.gamma_gain_lambda = 3.0
    config.replay_method = "cmt"
    config.replay_alpha = 0.2
    config.feature_replay_alpha = 0.1
    config.ewc_batch_size = 10 # This is for EWC Fisher calculation, can be different from LoRA train BS
    config.replay_model_internal_batch_size = 4 # Keep from SL-F_Rep1
    config.replay_backbone_encoding_batch_size = 16 # Keep from SL-F_Rep1
    config.num_lora_train_epochs = 12 # This is max epochs, ES will control actual

    return config


# --- run_experiment_variant function (ensure it's the latest version from previous interactions) ---
def run_experiment_variant(config_variant: 'AdaptiveLearnerConfig', variant_name_for_log: str):
    logger.info(f"--- CONFIGURING FOR: {variant_name_for_log} ---")
    logger.info(f"Experiment Tag: {config_variant.experiment_tag}")
    logger.info(f"Consolidation: {config_variant.consolidation_method}")
    if config_variant.consolidation_method == "ewc":
        logger.info(f"EWC Lambda: {config_variant.ewc_lambda}, Bypass Gamma for Accum: {config_variant.ewc_fixed_lambda_bypass_gamma}")
        logger.info(f"EWC Fisher Samples: {config_variant.ewc_data_loader_num_samples}")
        logger.info(f"EWC Batch Size: {config_variant.ewc_batch_size}")
    if hasattr(config_variant, 'gamma_gain_lambda') and not config_variant.ewc_fixed_lambda_bypass_gamma and config_variant.consolidation_method == "ewc":
        logger.info(f"Gamma Gain Lambda: {config_variant.gamma_gain_lambda}")
    logger.info(f"Replay Method: {config_variant.replay_method}, Alpha: {config_variant.replay_alpha}, Feature Alpha: {config_variant.feature_replay_alpha}")
    if config_variant.replay_method != "none":
        logger.info(f"Replay Internal BS: {config_variant.replay_model_internal_batch_size}, Replay Backbone Encoding BS: {config_variant.replay_backbone_encoding_batch_size}")
    logger.info(f"Num tasks sharing LoRA: {config_variant.num_tasks_to_share_lora}")
    logger.info(f"Total tasks in sequence: {len(config_variant.benchmark_tasks_to_run)}")
    logger.info(f"LoRA Train Batch Size: {config_variant.batch_size}, Max LoRA Train Epochs: {config_variant.num_lora_train_epochs}")
    logger.info(f"Learning Rate: {config_variant.learning_rate}")
    logger.info(f"Use Advanced Embeddings: {config_variant.use_advanced_embeddings}, k_for_prototype: {config_variant.k_examples_for_prototype}")
    logger.info(f"Use LoRA Early Stopping: {config_variant.use_lora_early_stopping}")
    if config_variant.use_lora_early_stopping:
        logger.info(f"  ES Patience: {config_variant.lora_early_stopping_patience}, ES Metric: {config_variant.lora_early_stopping_metric}")
        logger.info(f"  ES Min Epochs: {config_variant.min_lora_epochs_before_early_stop}, ES Delta: {config_variant.lora_early_stopping_delta}")

    logger.info(f"\\n{'='*20} Starting Experiment Variant: {variant_name_for_log} ({config_variant.experiment_tag}) {'='*20}")
    wandb_module_available = False
    if config_variant.use_wandb:
        try:
            import wandb
            wandb_module_available = True
        except ImportError:
            logger.warning("wandb could not be imported. W&B features will be disabled for this variant.")
            config_variant.use_wandb = False

    if config_variant.use_wandb and wandb_module_available:
        if wandb.run is not None:
            logger.info(f"Finishing previous W&B run: {wandb.run.id if wandb.run else 'Unknown'} before starting {variant_name_for_log}.")
            wandb.finish()
        effective_wandb_api_key = config_variant.wandb_api_key
        if effective_wandb_api_key and config_variant.wandb_project:
            wandb_run_name = f"{time.strftime('%Y%m%d-%H%M%S')}_{config_variant.experiment_tag}"
            try:
                if config_variant.wandb_api_key and not os.environ.get("WANDB_API_KEY"):
                    logger.info(f"W&B: Attempting login with API key from config for {variant_name_for_log}.")
                    wandb.login(key=config_variant.wandb_api_key)
                elif os.environ.get("WANDB_API_KEY"):
                    logger.info(f"W&B: API key found in environment for {variant_name_for_log}.")

                wandb.init(project=config_variant.wandb_project,
                           entity=config_variant.wandb_entity,
                           config=config_variant.to_dict(),
                           name=wandb_run_name,
                           reinit=True)
                logger.info(f"W&B run initiated for {variant_name_for_log} to project '{config_variant.wandb_project}': {wandb.run.name if wandb.run else 'Failed'}")
            except Exception as e:
                logger.error(f"W&B initialization failed for {variant_name_for_log}: {e}")
        else:
            logger.warning(f"W&B API key or project not found/set in config for {variant_name_for_log}. W&B disabled for this variant.")
            config_variant.use_wandb = False
    elif config_variant.use_wandb and not wandb_module_available:
        pass
    else:
        logger.info(f"W&B usage is disabled in config for {variant_name_for_log}.")

    global default_config
    original_default_config_backup = default_config
    default_config = config_variant

    adaptive_learner_config_type_backup = None
    if 'AdaptiveLearnerConfig' not in globals() and 'default_config' in globals():
        adaptive_learner_config_type_backup = globals().get('AdaptiveLearnerConfig')
        globals()['AdaptiveLearnerConfig'] = type(default_config)

    try:
        main(run_config=config_variant)
    finally:
        default_config = original_default_config_backup
        if adaptive_learner_config_type_backup is not None:
            globals()['AdaptiveLearnerConfig'] = adaptive_learner_config_type_backup
        elif 'AdaptiveLearnerConfig' in globals() and type(default_config) != globals()['AdaptiveLearnerConfig'] :
             pass

    if config_variant.use_wandb and wandb_module_available and wandb.run is not None:
        logger.info(f"Finishing W&B run for {variant_name_for_log}: {wandb.run.id}")
        wandb.finish()

    vars_to_delete = ['backbone', 'peft_manager', 'neuromod_manager', 'consolidation_manager', 'replay_manager']
    for var_name in vars_to_delete:
        if var_name in globals():
            del globals()[var_name]

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()
    logger.info(f"--- Variant {variant_name_for_log} ({config_variant.experiment_tag}) finished. Memory cleaned. ---\\n")


# --- Execution Plan ---

# # Previous run (SL-F_Rep1_ES) - Commented out after completion
# logger.info("\\n" + "="*50 + "\\nStarting MVE-Efficiency-1.1: Early Stopping Validation" + "\\n" + "="*50)
# config_sl_f_rep1_es = get_config_variant_sl_f_rep1_es() # Assuming this function was defined from previous step
# run_experiment_variant(config_sl_f_rep1_es, "Variant SL-F_Rep1_ES: Hybrid Scaled (50F, seed123) with Early Stopping")
# logger.info("\\n" + "="*50 + "\\nFinished MVE-Efficiency-1.1: Early Stopping Validation" + "\\n" + "="*50)


logger.info("\\n" + "="*50 + "\\nStarting MVE-DataScale-1.1: Larger Datasets Validation (BS=8)" + "\\n" + "="*50)
config_datascale_1_1_bs8 = get_config_variant_datascale_1_1_bs8()
run_experiment_variant(config_datascale_1_1_bs8, "Variant DataScale-1.1: Hybrid Scaled (2K Train, BS8, LR5e-5, ES, seed123)")
logger.info("\\n" + "="*50 + "\\nFinished MVE-DataScale-1.1: Larger Datasets Validation (BS=8)" + "\\n" + "="*50)

INFO (AdaptiveLearnerConfig): Found WANDB_API_KEY in environment variables.
INFO (AdaptiveLearnerConfig): Hugging Face token (HF_TOKEN/HUGGINGFACE_HUB_TOKEN) found in environment variables.
INFO (AdaptiveLearnerConfig): Found WANDB_PROJECT in environment variables: RESTART01
INFO (AdaptiveLearnerConfig): Found WANDB_ENTITY in environment variables: doingmyownthing82-none
2025-05-20 16:15:14,610 - __main__ - INFO - --- CONFIGURING FOR: Variant DataScale-1.1: Hybrid Scaled (2K Train, BS8, LR5e-5, ES, seed123) ---
2025-05-20 16:15:14,610 - __main__ - INFO - Experiment Tag: MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123
2025-05-20 16:15:14,611 - __main__ - INFO - Consolidation: ewc
2025-05-20 16:15:14,611 - __main__ - INFO - EWC Lambda: 250.0, Bypass Gamma for Accum: False
2025-05-20 16:15:14,611 - __main__ - INFO - EWC Fisher Samples: 50
2025-05-20 16:15:14,612 - __main__ - INFO - EWC Batch Size: 10
2025-05-20 16:15:14,612 - __main__ - INFO - Gamma Gain Lambda: 3.0
2025-05-20 16:15:14,612

[34m[1mwandb[0m: Currently logged in as: [33mdoingmyownthing82[0m ([33mdoingmyownthing82-none[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


2025-05-20 16:15:15,578 - __main__ - INFO - W&B run initiated for Variant DataScale-1.1: Hybrid Scaled (2K Train, BS8, LR5e-5, ES, seed123) to project 'adaptive-learner-cl-datascale': 20250520-161514_MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123
2025-05-20 16:15:15,578 - __main__ - INFO - --- Main function started. SFT Mode: False. Exp Tag: MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123 ---
Router type: linear, Confidence threshold: 1.1
Load router from tag: None
k_examples_for_prototype: 0, Use Advanced Embeddings: False
Use LoRA Early Stopping: True
2025-05-20 16:15:15,579 - __main__ - INFO - Initializing components...
2025-05-20 16:15:15,579 - __main__ - INFO - Loading model: google/gemma-2b
2025-05-20 16:15:15,580 - __main__ - INFO - Attempting to load model with Flash Attention 2.


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


2025-05-20 16:15:15,956 - accelerate.utils.modeling - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

2025-05-20 16:15:18,974 - __main__ - INFO - Model loaded successfully with Flash Attention 2 and 8-bit quantization.
2025-05-20 16:15:19,620 - __main__ - INFO - Model loaded with 2,506,172,416 parameters
2025-05-20 16:15:19,622 - __main__ - INFO - Base model parameters frozen
2025-05-20 16:15:20,246 - __main__ - INFO - Router common init: Embedding model dtype: torch.float16, Router internal param dtype: torch.float32
2025-05-20 16:15:20,247 - __main__ - INFO - LinearRouter initialized (ADV. EMBEDDINGS OFF) for DESCRIPTION features ONLY. Embed_dim: 2048.
2025-05-20 16:15:20,247 - __main__ - INFO - PEFTManager init: Router object instantiated. Initial learnable profiles in router: 0
2025-05-20 16:15:20,248 - __main__ - INFO - PEFTManager init: No router state to load (load_router_state_from_tag is None). Router starts fresh.
2025-05-20 16:15:20,248 - __main__ - INFO - PEFTManager: (_scan_disk_and_update_metadata_globally) Scanning disk (/workspace/MyAdaptiveLearnerProject/outputs/lora_m

Epoch 1 LoRA:MVE_CL_DataScale_2Kt Task:ds_sst2_glue_ss:   0%|          | 0/250 [00:00<?, ?it/s]



2025-05-20 16:21:13,000 - __main__ - INFO -     ES Check: Epoch 1, evaluating on val set for ds_sst2_glue_sst2 with LoRA MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte
2025-05-20 16:21:32,958 - __main__ - INFO -     ES Check: Val Acc for ds_sst2_glue_sst2 on LoRA MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte: 0.9600 (Best: -inf)
2025-05-20 16:21:32,960 - __main__ - INFO -     ES Check: New best val_acc: 0.9600. Patience reset.
2025-05-20 16:21:32,963 - __main__ - INFO - Epoch 1 ended. Training replay model for task 'ds_sst2_glue_sst2' using 500 samples.
2025-05-20 16:21:32,963 - __main__ - INFO - CMTReplay: Training AE for task 'ds_sst2_glue_sst2' on 500 samples from buffer.


CMT AE Ep1 Task:ds_sst2_glue_sst2:   0%|          | 0/125 [00:00<?, ?it/s]

2025-05-20 16:21:35,479 - __main__ - INFO - CMT Initial WEIGHT DEBUG (Task ds_sst2_glue_sst2): Weights AFTER first AE optim step...


CMT AE Ep2 Task:ds_sst2_glue_sst2:   0%|          | 0/125 [00:00<?, ?it/s]

CMT AE Ep3 Task:ds_sst2_glue_sst2:   0%|          | 0/125 [00:00<?, ?it/s]

2025-05-20 16:21:36,411 - __main__ - INFO - CMTReplay AE training complete for task 'ds_sst2_glue_sst2' (500 buffer samples). AvgLoss: 1.2446, AvgUniformity: 0.5781, LR: 5.00e-05
2025-05-20 16:21:36,590 - __main__ - INFO -   LoRA Epoch 1 Avg Metrics - Loss: 2.1594, Acc: 0.6960, GradN: 5.4113, Entropy (last batch): 4.2004
2025-05-20 16:21:36,591 - __main__ - INFO - Task ds_sst2_glue_sst2 NM - Metrics: {'accuracy': 0.696, 'nll': 2.159, 'gradient_norm': 5.411, 'entropy': 4.2}, γ-Gain: 1.7527
2025-05-20 16:21:36,592 - __main__ - INFO -   Task ds_sst2_glue_sst2 - LoRA Training Epoch 2/12 for LoRA: MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte


Epoch 2 LoRA:MVE_CL_DataScale_2Kt Task:ds_sst2_glue_ss:   0%|          | 0/250 [00:00<?, ?it/s]

2025-05-20 16:26:02,806 - __main__ - INFO -     ES Check: Epoch 2, evaluating on val set for ds_sst2_glue_sst2 with LoRA MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte
2025-05-20 16:26:22,444 - __main__ - INFO -     ES Check: Val Acc for ds_sst2_glue_sst2 on LoRA MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte: 0.9250 (Best: 0.9600)
2025-05-20 16:26:22,447 - __main__ - INFO -     ES Check: No improvement. Patience: 1/3
2025-05-20 16:26:22,448 - __main__ - INFO - Epoch 2 ended. Training replay model for task 'ds_sst2_glue_sst2' using 500 samples.
2025-05-20 16:26:22,449 - __main__ - INFO - CMTReplay: Training AE for task 'ds_sst2_glue_sst2' on 500 samples from buffer.


CMT AE Ep1 Task:ds_sst2_glue_sst2:   0%|          | 0/125 [00:00<?, ?it/s]

CMT AE Ep2 Task:ds_sst2_glue_sst2:   0%|          | 0/125 [00:00<?, ?it/s]

CMT AE Ep3 Task:ds_sst2_glue_sst2:   0%|          | 0/125 [00:00<?, ?it/s]

2025-05-20 16:26:25,904 - __main__ - INFO - CMTReplay AE training complete for task 'ds_sst2_glue_sst2' (500 buffer samples). AvgLoss: 0.3265, AvgUniformity: 0.0729, LR: 5.00e-05
2025-05-20 16:26:26,093 - __main__ - INFO -   LoRA Epoch 2 Avg Metrics - Loss: 0.0868, Acc: 0.9410, GradN: 2.5409, Entropy (last batch): 4.0865
2025-05-20 16:26:26,095 - __main__ - INFO - Task ds_sst2_glue_sst2 NM - Metrics: {'accuracy': 0.941, 'nll': 0.087, 'gradient_norm': 2.541, 'entropy': 4.086}, γ-Gain: 0.9000
2025-05-20 16:26:26,095 - __main__ - INFO -   Task ds_sst2_glue_sst2 - LoRA Training Epoch 3/12 for LoRA: MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte


Epoch 3 LoRA:MVE_CL_DataScale_2Kt Task:ds_sst2_glue_ss:   0%|          | 0/250 [00:00<?, ?it/s]

2025-05-20 16:30:52,739 - __main__ - INFO -     ES Check: Epoch 3, evaluating on val set for ds_sst2_glue_sst2 with LoRA MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte
2025-05-20 16:31:12,416 - __main__ - INFO -     ES Check: Val Acc for ds_sst2_glue_sst2 on LoRA MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte: 0.9450 (Best: 0.9600)
2025-05-20 16:31:12,418 - __main__ - INFO -     ES Check: No improvement. Patience: 2/3
2025-05-20 16:31:12,419 - __main__ - INFO - Epoch 3 ended. Training replay model for task 'ds_sst2_glue_sst2' using 500 samples.
2025-05-20 16:31:12,419 - __main__ - INFO - CMTReplay: Training AE for task 'ds_sst2_glue_sst2' on 500 samples from buffer.


CMT AE Ep1 Task:ds_sst2_glue_sst2:   0%|          | 0/125 [00:00<?, ?it/s]

CMT AE Ep2 Task:ds_sst2_glue_sst2:   0%|          | 0/125 [00:00<?, ?it/s]

CMT AE Ep3 Task:ds_sst2_glue_sst2:   0%|          | 0/125 [00:00<?, ?it/s]

2025-05-20 16:31:15,813 - __main__ - INFO - CMTReplay AE training complete for task 'ds_sst2_glue_sst2' (500 buffer samples). AvgLoss: 0.1273, AvgUniformity: -0.1050, LR: 5.00e-05
2025-05-20 16:31:15,990 - __main__ - INFO -   LoRA Epoch 3 Avg Metrics - Loss: 0.0541, Acc: 0.9700, GradN: 1.9948, Entropy (last batch): 4.3539
2025-05-20 16:31:15,991 - __main__ - INFO - Task ds_sst2_glue_sst2 NM - Metrics: {'accuracy': 0.97, 'nll': 0.054, 'gradient_norm': 1.995, 'entropy': 4.354}, γ-Gain: 1.5000
2025-05-20 16:31:15,992 - __main__ - INFO -   Task ds_sst2_glue_sst2 - LoRA Training Epoch 4/12 for LoRA: MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte


Epoch 4 LoRA:MVE_CL_DataScale_2Kt Task:ds_sst2_glue_ss:   0%|          | 0/250 [00:00<?, ?it/s]

2025-05-20 16:35:42,275 - __main__ - INFO -     ES Check: Epoch 4, evaluating on val set for ds_sst2_glue_sst2 with LoRA MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte
2025-05-20 16:36:02,345 - __main__ - INFO -     ES Check: Val Acc for ds_sst2_glue_sst2 on LoRA MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte: 0.9350 (Best: 0.9600)
2025-05-20 16:36:02,347 - __main__ - INFO -     ES Check: No improvement. Patience: 3/3
2025-05-20 16:36:02,348 - __main__ - INFO -   EARLY STOPPING LoRA training for ds_sst2_glue_sst2 at epoch 4 due to no improvement on validation set.
2025-05-20 16:36:02,348 - __main__ - INFO - Epoch 4 ended. Training replay model for task 'ds_sst2_glue_sst2' using 500 samples.
2025-05-20 16:36:02,348 - __main__ - INFO - CMTReplay: Training AE for task 'ds_sst2_glue_sst2' on 500 samples from buffer.


CMT AE Ep1 Task:ds_sst2_glue_sst2:   0%|          | 0/125 [00:00<?, ?it/s]

CMT AE Ep2 Task:ds_sst2_glue_sst2:   0%|          | 0/125 [00:00<?, ?it/s]

CMT AE Ep3 Task:ds_sst2_glue_sst2:   0%|          | 0/125 [00:00<?, ?it/s]

2025-05-20 16:36:05,714 - __main__ - INFO - CMTReplay AE training complete for task 'ds_sst2_glue_sst2' (500 buffer samples). AvgLoss: 0.0825, AvgUniformity: -0.1930, LR: 5.00e-05
2025-05-20 16:36:05,899 - __main__ - INFO -   LoRA Epoch 4 Avg Metrics - Loss: 0.0364, Acc: 0.9840, GradN: 1.8153, Entropy (last batch): 4.4790
2025-05-20 16:36:05,900 - __main__ - INFO - Task ds_sst2_glue_sst2 NM - Metrics: {'accuracy': 0.984, 'nll': 0.036, 'gradient_norm': 1.815, 'entropy': 4.479}, γ-Gain: 1.5000
2025-05-20 16:36:05,900 - __main__ - INFO - --- Evaluating after Task 1: ds_sst2_glue_sst2 (after 4 LoRA epochs) ---
2025-05-20 16:36:25,781 - __main__ - INFO -     Acc on Task_0 (ds_sst2_glue_sst2) (LoRA 'MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte') after Task_0 (ds_sst2_glue_sst2): 0.9350
2025-05-20 16:36:25,783 - __main__ - INFO -   AvgAcc after task ds_sst2_glue_sst2 (Task 0): 0.9350
2025-05-20 16:36:25,786 - __main__ - INFO - Applying EWC for MVE_CL_DataScale_2Ktrain

Epoch 1 LoRA:MVE_CL_DataScale_2Kt Task:ds_rte_glue_rte:   0%|          | 0/250 [00:00<?, ?it/s]

2025-05-20 16:45:26,861 - __main__ - INFO -     ES Check: Epoch 1, evaluating on val set for ds_rte_glue_rte with LoRA MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte
2025-05-20 16:46:07,801 - __main__ - INFO -     ES Check: Val Acc for ds_rte_glue_rte on LoRA MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte: 0.7350 (Best: -inf)
2025-05-20 16:46:07,805 - __main__ - INFO -     ES Check: New best val_acc: 0.7350. Patience reset.
2025-05-20 16:46:07,809 - __main__ - INFO - Epoch 1 ended. Training replay model for task 'ds_rte_glue_rte' using 500 samples.
2025-05-20 16:46:07,809 - __main__ - INFO - CMTReplay: Training AE for task 'ds_rte_glue_rte' on 500 samples from buffer.


CMT AE Ep1 Task:ds_rte_glue_rte:   0%|          | 0/125 [00:00<?, ?it/s]

CMT AE Ep2 Task:ds_rte_glue_rte:   0%|          | 0/125 [00:00<?, ?it/s]

CMT AE Ep3 Task:ds_rte_glue_rte:   0%|          | 0/125 [00:00<?, ?it/s]

2025-05-20 16:46:13,026 - __main__ - INFO - CMTReplay AE training complete for task 'ds_rte_glue_rte' (500 buffer samples). AvgLoss: 0.1391, AvgUniformity: -0.0613, LR: 5.00e-05
2025-05-20 16:46:13,208 - __main__ - INFO -   LoRA Epoch 1 Avg Metrics - Loss: 0.7518, Acc: 0.5405, GradN: 7.7144, Entropy (last batch): 4.1832
2025-05-20 16:46:13,210 - __main__ - INFO - Task ds_rte_glue_rte NM - Metrics: {'accuracy': 0.54, 'nll': 0.752, 'gradient_norm': 7.714, 'entropy': 4.183}, γ-Gain: 1.8797
2025-05-20 16:46:13,211 - __main__ - INFO -   Task ds_rte_glue_rte - LoRA Training Epoch 2/12 for LoRA: MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte


Epoch 2 LoRA:MVE_CL_DataScale_2Kt Task:ds_rte_glue_rte:   0%|          | 0/250 [00:00<?, ?it/s]

2025-05-20 16:56:11,998 - __main__ - INFO -     ES Check: Epoch 2, evaluating on val set for ds_rte_glue_rte with LoRA MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte
2025-05-20 16:56:56,092 - __main__ - INFO -     ES Check: Val Acc for ds_rte_glue_rte on LoRA MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte: 0.8000 (Best: 0.7350)
2025-05-20 16:56:56,096 - __main__ - INFO -     ES Check: New best val_acc: 0.8000. Patience reset.
2025-05-20 16:56:56,097 - __main__ - INFO - Epoch 2 ended. Training replay model for task 'ds_rte_glue_rte' using 500 samples.
2025-05-20 16:56:56,097 - __main__ - INFO - CMTReplay: Training AE for task 'ds_rte_glue_rte' on 500 samples from buffer.


CMT AE Ep1 Task:ds_rte_glue_rte:   0%|          | 0/125 [00:00<?, ?it/s]

CMT AE Ep2 Task:ds_rte_glue_rte:   0%|          | 0/125 [00:00<?, ?it/s]

CMT AE Ep3 Task:ds_rte_glue_rte:   0%|          | 0/125 [00:00<?, ?it/s]

2025-05-20 16:57:01,198 - __main__ - INFO - CMTReplay AE training complete for task 'ds_rte_glue_rte' (500 buffer samples). AvgLoss: 0.0771, AvgUniformity: -0.2145, LR: 5.00e-05
2025-05-20 16:57:01,383 - __main__ - INFO -   LoRA Epoch 2 Avg Metrics - Loss: 0.1189, Acc: 0.7770, GradN: 5.0425, Entropy (last batch): 4.1657
2025-05-20 16:57:01,385 - __main__ - INFO - Task ds_rte_glue_rte NM - Metrics: {'accuracy': 0.777, 'nll': 0.119, 'gradient_norm': 5.043, 'entropy': 4.166}, γ-Gain: 0.9000
2025-05-20 16:57:01,386 - __main__ - INFO -   Task ds_rte_glue_rte - LoRA Training Epoch 3/12 for LoRA: MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte


Epoch 3 LoRA:MVE_CL_DataScale_2Kt Task:ds_rte_glue_rte:   0%|          | 0/250 [00:00<?, ?it/s]

2025-05-20 17:06:51,102 - __main__ - INFO -     ES Check: Epoch 3, evaluating on val set for ds_rte_glue_rte with LoRA MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte
2025-05-20 17:07:33,485 - __main__ - INFO -     ES Check: Val Acc for ds_rte_glue_rte on LoRA MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte: 0.8150 (Best: 0.8000)
2025-05-20 17:07:33,491 - __main__ - INFO -     ES Check: New best val_acc: 0.8150. Patience reset.
2025-05-20 17:07:33,492 - __main__ - INFO - Epoch 3 ended. Training replay model for task 'ds_rte_glue_rte' using 500 samples.
2025-05-20 17:07:33,492 - __main__ - INFO - CMTReplay: Training AE for task 'ds_rte_glue_rte' on 500 samples from buffer.


CMT AE Ep1 Task:ds_rte_glue_rte:   0%|          | 0/125 [00:00<?, ?it/s]

CMT AE Ep2 Task:ds_rte_glue_rte:   0%|          | 0/125 [00:00<?, ?it/s]

CMT AE Ep3 Task:ds_rte_glue_rte:   0%|          | 0/125 [00:00<?, ?it/s]

2025-05-20 17:07:38,612 - __main__ - INFO - CMTReplay AE training complete for task 'ds_rte_glue_rte' (500 buffer samples). AvgLoss: 0.0550, AvgUniformity: -0.2462, LR: 5.00e-05
2025-05-20 17:07:38,794 - __main__ - INFO -   LoRA Epoch 3 Avg Metrics - Loss: 0.0891, Acc: 0.8360, GradN: 4.2963, Entropy (last batch): 4.1856
2025-05-20 17:07:38,796 - __main__ - INFO - Task ds_rte_glue_rte NM - Metrics: {'accuracy': 0.836, 'nll': 0.089, 'gradient_norm': 4.296, 'entropy': 4.186}, γ-Gain: 1.5000
2025-05-20 17:07:38,797 - __main__ - INFO -   Task ds_rte_glue_rte - LoRA Training Epoch 4/12 for LoRA: MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte


Epoch 4 LoRA:MVE_CL_DataScale_2Kt Task:ds_rte_glue_rte:   0%|          | 0/250 [00:00<?, ?it/s]

2025-05-20 17:17:26,504 - __main__ - INFO -     ES Check: Epoch 4, evaluating on val set for ds_rte_glue_rte with LoRA MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte
2025-05-20 17:18:10,048 - __main__ - INFO -     ES Check: Val Acc for ds_rte_glue_rte on LoRA MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte: 0.8150 (Best: 0.8150)
2025-05-20 17:18:10,052 - __main__ - INFO -     ES Check: No improvement. Patience: 1/3
2025-05-20 17:18:10,053 - __main__ - INFO - Epoch 4 ended. Training replay model for task 'ds_rte_glue_rte' using 500 samples.
2025-05-20 17:18:10,054 - __main__ - INFO - CMTReplay: Training AE for task 'ds_rte_glue_rte' on 500 samples from buffer.


CMT AE Ep1 Task:ds_rte_glue_rte:   0%|          | 0/125 [00:00<?, ?it/s]

CMT AE Ep2 Task:ds_rte_glue_rte:   0%|          | 0/125 [00:00<?, ?it/s]

CMT AE Ep3 Task:ds_rte_glue_rte:   0%|          | 0/125 [00:00<?, ?it/s]

2025-05-20 17:18:15,236 - __main__ - INFO - CMTReplay AE training complete for task 'ds_rte_glue_rte' (500 buffer samples). AvgLoss: 0.0436, AvgUniformity: -0.2558, LR: 5.00e-05
2025-05-20 17:18:15,423 - __main__ - INFO -   LoRA Epoch 4 Avg Metrics - Loss: 0.0698, Acc: 0.8825, GradN: 4.1488, Entropy (last batch): 4.0701
2025-05-20 17:18:15,424 - __main__ - INFO - Task ds_rte_glue_rte NM - Metrics: {'accuracy': 0.882, 'nll': 0.07, 'gradient_norm': 4.149, 'entropy': 4.07}, γ-Gain: 0.9000
2025-05-20 17:18:15,425 - __main__ - INFO -   Task ds_rte_glue_rte - LoRA Training Epoch 5/12 for LoRA: MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte


Epoch 5 LoRA:MVE_CL_DataScale_2Kt Task:ds_rte_glue_rte:   0%|          | 0/250 [00:00<?, ?it/s]

2025-05-20 17:28:03,344 - __main__ - INFO -     ES Check: Epoch 5, evaluating on val set for ds_rte_glue_rte with LoRA MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte
2025-05-20 17:28:46,514 - __main__ - INFO -     ES Check: Val Acc for ds_rte_glue_rte on LoRA MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte: 0.8200 (Best: 0.8150)
2025-05-20 17:28:46,518 - __main__ - INFO -     ES Check: New best val_acc: 0.8200. Patience reset.
2025-05-20 17:28:46,519 - __main__ - INFO - Epoch 5 ended. Training replay model for task 'ds_rte_glue_rte' using 500 samples.
2025-05-20 17:28:46,519 - __main__ - INFO - CMTReplay: Training AE for task 'ds_rte_glue_rte' on 500 samples from buffer.


CMT AE Ep1 Task:ds_rte_glue_rte:   0%|          | 0/125 [00:00<?, ?it/s]

CMT AE Ep2 Task:ds_rte_glue_rte:   0%|          | 0/125 [00:00<?, ?it/s]

CMT AE Ep3 Task:ds_rte_glue_rte:   0%|          | 0/125 [00:00<?, ?it/s]

2025-05-20 17:28:51,665 - __main__ - INFO - CMTReplay AE training complete for task 'ds_rte_glue_rte' (500 buffer samples). AvgLoss: 0.0356, AvgUniformity: -0.2726, LR: 5.00e-05
2025-05-20 17:28:51,854 - __main__ - INFO -   LoRA Epoch 5 Avg Metrics - Loss: 0.0549, Acc: 0.9220, GradN: 3.9727, Entropy (last batch): 4.2395
2025-05-20 17:28:51,855 - __main__ - INFO - Task ds_rte_glue_rte NM - Metrics: {'accuracy': 0.922, 'nll': 0.055, 'gradient_norm': 3.973, 'entropy': 4.24}, γ-Gain: 1.5000
2025-05-20 17:28:51,856 - __main__ - INFO -   Task ds_rte_glue_rte - LoRA Training Epoch 6/12 for LoRA: MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte


Epoch 6 LoRA:MVE_CL_DataScale_2Kt Task:ds_rte_glue_rte:   0%|          | 0/250 [00:00<?, ?it/s]

2025-05-20 17:38:44,388 - __main__ - INFO -     ES Check: Epoch 6, evaluating on val set for ds_rte_glue_rte with LoRA MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte
2025-05-20 17:39:28,174 - __main__ - INFO -     ES Check: Val Acc for ds_rte_glue_rte on LoRA MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte: 0.8300 (Best: 0.8200)
2025-05-20 17:39:28,176 - __main__ - INFO -     ES Check: New best val_acc: 0.8300. Patience reset.
2025-05-20 17:39:28,177 - __main__ - INFO - Epoch 6 ended. Training replay model for task 'ds_rte_glue_rte' using 500 samples.
2025-05-20 17:39:28,177 - __main__ - INFO - CMTReplay: Training AE for task 'ds_rte_glue_rte' on 500 samples from buffer.


CMT AE Ep1 Task:ds_rte_glue_rte:   0%|          | 0/125 [00:00<?, ?it/s]

CMT AE Ep2 Task:ds_rte_glue_rte:   0%|          | 0/125 [00:00<?, ?it/s]

CMT AE Ep3 Task:ds_rte_glue_rte:   0%|          | 0/125 [00:00<?, ?it/s]

2025-05-20 17:39:33,191 - __main__ - INFO - CMTReplay AE training complete for task 'ds_rte_glue_rte' (500 buffer samples). AvgLoss: 0.0327, AvgUniformity: -0.2726, LR: 5.00e-05
2025-05-20 17:39:33,376 - __main__ - INFO -   LoRA Epoch 6 Avg Metrics - Loss: 0.0365, Acc: 0.9595, GradN: 3.4885, Entropy (last batch): 4.0377
2025-05-20 17:39:33,377 - __main__ - INFO - Task ds_rte_glue_rte NM - Metrics: {'accuracy': 0.96, 'nll': 0.037, 'gradient_norm': 3.489, 'entropy': 4.038}, γ-Gain: 0.9000
2025-05-20 17:39:33,378 - __main__ - INFO -   Task ds_rte_glue_rte - LoRA Training Epoch 7/12 for LoRA: MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte


Epoch 7 LoRA:MVE_CL_DataScale_2Kt Task:ds_rte_glue_rte:   0%|          | 0/250 [00:00<?, ?it/s]

2025-05-20 17:49:26,150 - __main__ - INFO -     ES Check: Epoch 7, evaluating on val set for ds_rte_glue_rte with LoRA MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte
2025-05-20 17:50:07,809 - __main__ - INFO -     ES Check: Val Acc for ds_rte_glue_rte on LoRA MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte: 0.8050 (Best: 0.8300)
2025-05-20 17:50:07,811 - __main__ - INFO -     ES Check: No improvement. Patience: 1/3
2025-05-20 17:50:07,812 - __main__ - INFO - Epoch 7 ended. Training replay model for task 'ds_rte_glue_rte' using 500 samples.
2025-05-20 17:50:07,812 - __main__ - INFO - CMTReplay: Training AE for task 'ds_rte_glue_rte' on 500 samples from buffer.


CMT AE Ep1 Task:ds_rte_glue_rte:   0%|          | 0/125 [00:00<?, ?it/s]

CMT AE Ep2 Task:ds_rte_glue_rte:   0%|          | 0/125 [00:00<?, ?it/s]

CMT AE Ep3 Task:ds_rte_glue_rte:   0%|          | 0/125 [00:00<?, ?it/s]

2025-05-20 17:50:12,890 - __main__ - INFO - CMTReplay AE training complete for task 'ds_rte_glue_rte' (500 buffer samples). AvgLoss: 0.0256, AvgUniformity: -0.2762, LR: 5.00e-05
2025-05-20 17:50:13,073 - __main__ - INFO -   LoRA Epoch 7 Avg Metrics - Loss: 0.0266, Acc: 0.9785, GradN: 3.3659, Entropy (last batch): 4.1816
2025-05-20 17:50:13,075 - __main__ - INFO - Task ds_rte_glue_rte NM - Metrics: {'accuracy': 0.978, 'nll': 0.027, 'gradient_norm': 3.366, 'entropy': 4.182}, γ-Gain: 1.3277
2025-05-20 17:50:13,076 - __main__ - INFO -   Task ds_rte_glue_rte - LoRA Training Epoch 8/12 for LoRA: MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte


Epoch 8 LoRA:MVE_CL_DataScale_2Kt Task:ds_rte_glue_rte:   0%|          | 0/250 [00:00<?, ?it/s]

2025-05-20 18:00:11,665 - __main__ - INFO -     ES Check: Epoch 8, evaluating on val set for ds_rte_glue_rte with LoRA MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte
2025-05-20 18:00:55,423 - __main__ - INFO -     ES Check: Val Acc for ds_rte_glue_rte on LoRA MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte: 0.7850 (Best: 0.8300)
2025-05-20 18:00:55,426 - __main__ - INFO -     ES Check: No improvement. Patience: 2/3
2025-05-20 18:00:55,427 - __main__ - INFO - Epoch 8 ended. Training replay model for task 'ds_rte_glue_rte' using 500 samples.
2025-05-20 18:00:55,428 - __main__ - INFO - CMTReplay: Training AE for task 'ds_rte_glue_rte' on 500 samples from buffer.


CMT AE Ep1 Task:ds_rte_glue_rte:   0%|          | 0/125 [00:00<?, ?it/s]

CMT AE Ep2 Task:ds_rte_glue_rte:   0%|          | 0/125 [00:00<?, ?it/s]

CMT AE Ep3 Task:ds_rte_glue_rte:   0%|          | 0/125 [00:00<?, ?it/s]

2025-05-20 18:01:00,351 - __main__ - INFO - CMTReplay AE training complete for task 'ds_rte_glue_rte' (500 buffer samples). AvgLoss: 0.0231, AvgUniformity: -0.2858, LR: 5.00e-05
2025-05-20 18:01:00,540 - __main__ - INFO -   LoRA Epoch 8 Avg Metrics - Loss: 0.0109, Acc: 0.9960, GradN: 2.0784, Entropy (last batch): 3.8011
2025-05-20 18:01:00,542 - __main__ - INFO - Task ds_rte_glue_rte NM - Metrics: {'accuracy': 0.996, 'nll': 0.011, 'gradient_norm': 2.078, 'entropy': 3.801}, γ-Gain: 0.9000
2025-05-20 18:01:00,543 - __main__ - INFO -   Task ds_rte_glue_rte - LoRA Training Epoch 9/12 for LoRA: MVE_CL_DataScale_2Ktrain_BS8_LR5e-5_ES_seed123_SharedAdapter_sst2_rte


Epoch 9 LoRA:MVE_CL_DataScale_2Kt Task:ds_rte_glue_rte:   0%|          | 0/250 [00:00<?, ?it/s]

In [None]:
import datasets
print(datasets.__version__)

In [None]:
# Temporary cell to inspect model layers for grad_sketch_layer_names
if 'backbone' in globals() and hasattr(backbone, 'model'):
    print("Listing all named modules in backbone.model:")
    for name, module in backbone.model.named_modules():
        # We are interested in modules that likely have a 'weight' parameter,
        # like Linear layers (often part of attention or MLPs)
        if isinstance(module, torch.nn.Linear): # You can also check for other types if needed
            print(f"  Module Name: {name} (Type: {type(module)})")
    print("\nListing all named parameters in backbone.model (for reference):")
    # for name, param in backbone.model.named_parameters():
    #     print(f"  Parameter Name: {name} (Shape: {param.shape})")
else:
    print("Backbone model not found. Please initialize AdaptiveLearnerBackbone first.")

In [None]:
# Cell X: Cleanup Script for Experiment Outputs

import os
import shutil
import logging 

# Configure a simple logger for this cell
cleanup_logger = logging.getLogger(__name__ + "_cleanup_script")
if not cleanup_logger.handlers:
    handler = logging.StreamHandler()
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    handler.setFormatter(formatter)
    cleanup_logger.addHandler(handler)
    cleanup_logger.setLevel(logging.INFO)

# --- !!! SET THIS TO THE TAG OF THE EXPERIMENT YOU WANT TO CLEAN UP !!! ---
# This should be the tag of the run you just completed (SL-E Hybrid with original replay alphas).
experiment_tag_to_clean = "MVE_CL_SharedLoRA_SL_F_Rep1_ES_Hybrid_50F_seed123"  # <<< FOR CLEANING UP AFTER SL-E
# --- !!! ---

# Define the base output path
if 'RUNPODS_PROJECT_BASE_PATH' not in globals():
    RUNPODS_PROJECT_BASE_PATH = "/workspace/MyAdaptiveLearnerProject" 
    cleanup_logger.warning(f"RUNPODS_PROJECT_BASE_PATH not found in globals, using fallback: {RUNPODS_PROJECT_BASE_PATH}")

base_output_dir = os.path.join(RUNPODS_PROJECT_BASE_PATH, "outputs")
lora_modules_dir = os.path.join(base_output_dir, "lora_modules")
router_states_dir = os.path.join(base_output_dir, "router_states")
gamma_metrics_base_dir = os.path.join(base_output_dir, "gamma_metrics") 
replay_base_dir = os.path.join(base_output_dir, "replay")


cleanup_logger.info(f"--- Starting Cleanup for Experiment Tag Prefix: {experiment_tag_to_clean} ---")
cleanup_logger.info(f"Base output directory: {base_output_dir}")

# 1. Clean LoRA Module Directories
if os.path.exists(lora_modules_dir):
    cleaned_lora_count = 0
    for item_name in os.listdir(lora_modules_dir):
        if item_name.startswith(experiment_tag_to_clean): 
            item_path = os.path.join(lora_modules_dir, item_name)
            if os.path.isdir(item_path):
                try:
                    shutil.rmtree(item_path)
                    cleanup_logger.info(f"  Deleted LoRA module directory: {item_path}")
                    cleaned_lora_count += 1
                except Exception as e:
                    cleanup_logger.error(f"  Error deleting LoRA directory {item_path}: {e}")
    if cleaned_lora_count == 0:
        cleanup_logger.info(f"  No LoRA module directories found starting with '{experiment_tag_to_clean}' in {lora_modules_dir}")
else:
    cleanup_logger.info(f"LoRA modules directory not found: {lora_modules_dir}")

# 2. Clean Router State File
router_state_filename = f"router_state_{experiment_tag_to_clean}.pth"
router_state_filepath = os.path.join(router_states_dir, router_state_filename)
if os.path.exists(router_state_filepath):
    try:
        os.remove(router_state_filepath)
        cleanup_logger.info(f"  Deleted router state file: {router_state_filepath}")
    except Exception as e:
        cleanup_logger.error(f"  Error deleting router state file {router_state_filepath}: {e}")
else:
    cleanup_logger.info(f"  Router state file not found: {router_state_filepath}")

# 3. Clean Experiment Tag Directory (contains plots, etc.)
experiment_specific_output_dir = os.path.join(base_output_dir, experiment_tag_to_clean)
if os.path.exists(experiment_specific_output_dir) and os.path.isdir(experiment_specific_output_dir):
    try:
        shutil.rmtree(experiment_specific_output_dir)
        cleanup_logger.info(f"  Deleted experiment-specific output directory: {experiment_specific_output_dir}")
    except Exception as e:
        cleanup_logger.error(f"  Error deleting experiment-specific output directory {experiment_specific_output_dir}: {e}")
else:
    cleanup_logger.info(f"  Experiment-specific output directory not found: {experiment_specific_output_dir}")

# 4. Clean Replay Model Directory (if any from this experiment tag)
replay_experiment_dir = os.path.join(replay_base_dir, experiment_tag_to_clean) 
if os.path.exists(replay_experiment_dir) and os.path.isdir(replay_experiment_dir):
    try:
        shutil.rmtree(replay_experiment_dir)
        cleanup_logger.info(f"  Deleted replay experiment directory: {replay_experiment_dir}")
    except Exception as e:
        cleanup_logger.error(f"  Error deleting replay experiment directory {replay_experiment_dir}: {e}")
else:
    cleanup_logger.info(f"  Replay experiment directory not found: {replay_experiment_dir}")

cleanup_logger.info(f"--- Cleanup for Experiment Tag: {experiment_tag_to_clean} Complete ---")