<a href="https://colab.research.google.com/github/PlushyWushy/Prometheus/blob/main/Fashion_MNIST_final_Prometheus_Variation_3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# prompt: connect to drive

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m28.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [4]:
import copy
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.distributions import Categorical
from torch.amp import autocast
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import CosineAnnealingLR
import math
import inspect # For debugging
import logging
import gc # Garbage Collector interface
import itertools
import traceback
import os

# Silence verbose logs from the compiler
logging.getLogger("torch._dynamo").setLevel(logging.FATAL)
logging.getLogger("torch._inductor").setLevel(logging.FATAL)

try:
    import torch_geometric.nn as pyg_nn
    from torch_geometric.data import Data
    from torch_geometric.utils import to_undirected
    from torch_geometric.nn.dense.linear import Linear as PyGLinear
    PYG_AVAILABLE = True
except ImportError:
    PYG_AVAILABLE = False
    print("PyTorch Geometric not found. GNN MetaAgent will not be available if used.")

# =======================================================================
#  Constants & Configuration
# =======================================================================
PRE_EPOCHS=1; BATCHES_PER_EPOCH=None; BATCH_SIZE=128
BASE_POST_EPOCHS = 25
LEARNING_RATE=0.001; MAX_GRAD_NORM=5.0
DEVICE=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if DEVICE.type == 'cuda':
    if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
        print("Enabling TensorFloat32 matmul precision for supported GPU.")
        torch.set_float32_matmul_precision('high')

# --- Action Space with Skip Connections ---
EDIT_TYPE_ADD_CONV_BLOCK=0; EDIT_TYPE_RESIZE_LAYER=1; EDIT_TYPE_ADD_SKIP=2; EDIT_TYPE_ADD_LINEAR_BLOCK=3
NUM_TARGET_CNN_EDIT_TYPES=4
DISCRETE_CH_MULT_ADD=[0.5,1.0,2.0];
DISCRETE_RESIZE_FACTORS = [0.25, 0.5, 0.75, 1.25, 1.5, 1.75]
NUM_STAGES_TARGET_CNN=3
DROPOUT_RATE = 0.2

# --- MetaAgent & RL Configuration ---
BASE_META_LR = 5e-4
MIN_META_LR = 1e-5
LR_ACC_BASE_THRESHOLD = 0.80
LR_ACC_TARGET_THRESHOLD = 0.93

# --- Actions for complexity reduction ---
META_EDIT_NONE = 0
META_EDIT_DEEPEN_GNN = 1
META_EDIT_WIDEN_GNN_HIDDEN = 2
META_EDIT_DEEPEN_MLP_HEAD = 3
META_EDIT_SHRINK_GNN_HIDDEN = 4
META_EDIT_PRUNE_GNN = 5
META_EDIT_PRUNE_MLP_HEAD = 6
NUM_META_SELF_EDIT_TYPES = 7
META_SELF_EDIT_INTERVAL = 5

# --- Factors and min/max constraints for two-way search ---
META_GNN_WIDEN_FACTOR = 1.5
META_GNN_SHRINK_FACTOR = 1.0 / META_GNN_WIDEN_FACTOR # Symmetric shrinking
MIN_GNN_LAYERS = 1
MIN_GNN_HIDDEN_DIM = 16
MIN_MLP_HEAD_SEQUENTIAL_DEPTH = 1
INITIAL_GNN_HIDDEN_DIM = 32; INITIAL_NUM_GNN_LAYERS = 2
EDIT_TYPE_EMBED_DIM = 16
META_AGENT_PRUNE_THRESHOLD = 15000

# --- Complexity Penalty Configuration ---
COMPLEXITY_PENALTY_THRESHOLD = 20_000_000
COMPLEXITY_PENALTY_ALPHA = 0.2

# --- Graph & State Representation ---
OP_TYPE_IDS = {'conv2d':1,'relu':2,'maxpool2d':3, 'batchnorm2d': 4, 'batchnorm1d': 7, 'add': 8, 'input_placeholder':5, 'linear': 6, 'dropout': 9, 'none':0}
NODE_FEATURE_DIM = 7
NORMALIZATION_WIDTH_DIVISOR = 512.0; NORMALIZATION_IDX_DIVISOR = 50.0; NORMALIZATION_SPATIAL_DIVISOR = 32.0
GLOBAL_SUMMARY_FEATURE_DIM = 3 + 2
MAX_GLOBAL_HISTORY_LEN = 5

# =======================================================================
#  Custom Modules for Dynamic Graph
# =======================================================================
class AddWithProjection(nn.Module):
    def __init__(self, projection_module=None):
        super().__init__()
        self.projection = projection_module if projection_module is not None else nn.Identity()
    def forward(self, x_primary, x_skip):
        x_skip_projected = self.projection(x_skip)
        if x_primary.shape[2:] != x_skip_projected.shape[2:]:
            x_skip_projected = F.interpolate(x_skip_projected, size=x_primary.shape[2:], mode='bilinear', align_corners=False)
        return x_primary + x_skip_projected

# =======================================================================
#  Net2Net Utilities & Masking Helpers
# =======================================================================
def _valid_resize_indices(old_oc: int):
    valid = []
    for idx, f_resize in enumerate(DISCRETE_RESIZE_FACTORS):
        new_oc_resize = max(1, int(round(old_oc * f_resize)))
        if new_oc_resize != old_oc: valid.append(idx)
    if not valid and old_oc > 0 :
        try: valid.append(DISCRETE_RESIZE_FACTORS.index(1.0))
        except ValueError: pass
    if not valid: valid.append(0)
    return valid
def _mask_logits(logits: torch.Tensor, valid_indices: list):
    mask_val = torch.full_like(logits, float('-inf'))
    if valid_indices:
        valid_indices_tensor = torch.tensor(valid_indices, device=logits.device, dtype=torch.long)
        if valid_indices_tensor.numel() > 0:
            valid_indices_tensor = valid_indices_tensor[valid_indices_tensor < logits.shape[-1]]
            if valid_indices_tensor.numel() > 0:
                 mask_val[..., valid_indices_tensor] = 0.0
    return logits + mask_val

def _invalid_meta_indices(agent):
    invalid = []
    agent_params = sum(p.numel() for p in agent.parameters())
    if agent_params < META_AGENT_PRUNE_THRESHOLD:
        invalid.extend([META_EDIT_SHRINK_GNN_HIDDEN, META_EDIT_PRUNE_GNN, META_EDIT_PRUNE_MLP_HEAD])

    if agent.current_num_gnn_layers <= MIN_GNN_LAYERS:
        invalid.append(META_EDIT_PRUNE_GNN)
    if agent.current_gnn_hidden_dim <= MIN_GNN_HIDDEN_DIM:
        invalid.append(META_EDIT_SHRINK_GNN_HIDDEN)

    eligible_heads = agent.get_mlp_head_names()
    all_heads_at_min_depth = all(agent.head_depth_counters.get(name, 1) <= MIN_MLP_HEAD_SEQUENTIAL_DEPTH for name in eligible_heads)
    if all_heads_at_min_depth:
        invalid.append(META_EDIT_PRUNE_MLP_HEAD)

    all_heads_at_max_depth = all(agent.head_depth_counters.get(name, 1) >= 8 for name in eligible_heads)
    if all_heads_at_max_depth:
        invalid.append(META_EDIT_DEEPEN_MLP_HEAD)

    return list(set(invalid))

def _resize_linear_layer(old_linear, new_in_features, new_out_features, device='cpu'):
    is_pyg_linear = PYG_AVAILABLE and isinstance(old_linear, PyGLinear)

    old_in_features = old_linear.in_channels if is_pyg_linear else old_linear.in_features
    old_out_features = old_linear.out_channels if is_pyg_linear else old_linear.out_features

    if old_in_features == new_in_features and old_out_features == new_out_features:
        return old_linear

    new_linear_class = PyGLinear if is_pyg_linear else nn.Linear
    new_linear = new_linear_class(new_in_features, new_out_features, bias=(old_linear.bias is not None)).to(device)

    with torch.no_grad():
        min_in = min(old_in_features, new_in_features)
        min_out = min(old_out_features, new_out_features)

        new_linear.weight.data[:min_out, :min_in] = old_linear.weight.data[:min_out, :min_in].clone()

        if new_out_features > old_out_features and old_out_features > 0:
            for i in range(old_out_features, new_out_features):
                new_linear.weight.data[i, :min_in] = old_linear.weight.data[i % old_out_features, :min_in].clone()

        if new_in_features > old_in_features and old_in_features > 0:
            scaling_factor = math.sqrt(new_in_features / old_in_features)
            for i in range(old_in_features, new_in_features):
                new_linear.weight.data[:min_out, i] = old_linear.weight.data[:min_out, i % old_in_features].clone() / scaling_factor

        if old_linear.bias is not None and new_linear.bias is not None:
            new_linear.bias.data[:min_out] = old_linear.bias.data[:min_out].clone()
            if new_out_features > old_out_features and old_out_features > 0:
                for i in range(old_out_features, new_out_features):
                    new_linear.bias.data[i] = old_linear.bias.data[i % old_out_features].clone()

    return new_linear

def net2wider_conv_output(conv: nn.Conv2d, factor: float, device='cpu') -> nn.Conv2d:
    old_oc = conv.out_channels; new_oc = max(1, int(round(old_oc * factor)))
    if new_oc == old_oc : return conv
    new_conv_module = nn.Conv2d(conv.in_channels, new_oc, conv.kernel_size, stride=conv.stride, padding=conv.padding, groups=conv.groups, bias=(conv.bias is not None)).to(device)
    with torch.no_grad():
        new_conv_module.weight.data.fill_(0); min_oc = min(old_oc, new_oc)
        if old_oc > 0:
            new_conv_module.weight.data[:min_oc] = conv.weight.data[:min_oc].clone()
            if new_oc > old_oc:
                r_widen = float(new_oc) / old_oc
                for i in range(old_oc, new_oc):
                    new_conv_module.weight.data[i] = conv.weight.data[i % old_oc].clone() / math.sqrt(r_widen)
        if conv.bias is not None and new_conv_module.bias is not None:
            new_conv_module.bias.data.fill_(0)
            if old_oc > 0:
                new_conv_module.bias.data[:min_oc] = conv.bias.data[:min_oc].clone()
                if new_oc > old_oc:
                    for i in range(old_oc, new_oc): new_conv_module.bias.data[i] = conv.bias.data[i % old_oc].clone()
            elif new_oc > 0: nn.init.zeros_(new_conv_module.bias.data)
        elif new_conv_module.bias is not None: nn.init.zeros_(new_conv_module.bias.data)
    return new_conv_module
def net2thinner_conv_output(conv: nn.Conv2d, factor: float, device='cpu') -> nn.Conv2d:
    old_oc = conv.out_channels; new_oc = max(1, int(round(old_oc * factor)))
    if new_oc == old_oc: return conv
    new_conv_module = nn.Conv2d(conv.in_channels, new_oc, conv.kernel_size, stride=conv.stride, padding=conv.padding, groups=conv.groups, bias=(conv.bias is not None)).to(device)
    with torch.no_grad():
        if old_oc > 0 :
            new_conv_module.weight.data = conv.weight.data[:new_oc].clone()
            if conv.bias is not None and new_conv_module.bias is not None:
                 new_conv_module.bias.data = conv.bias.data[:new_oc].clone()
            elif new_conv_module.bias is not None: nn.init.zeros_(new_conv_module.bias.data)
    return new_conv_module
def resize_conv_output(conv: nn.Conv2d, factor: float, device='cpu') -> nn.Conv2d:
    if abs(factor - 1.0) < 1e-6 : return conv
    old_oc = conv.out_channels; new_oc = max(1, int(round(old_oc * factor)))
    if new_oc == old_oc and old_oc > 0: return conv
    if new_oc == old_oc and old_oc == 0 and factor != 1.0 : pass
    elif new_oc == old_oc: return conv
    if new_oc > old_oc: return net2wider_conv_output(conv, float(new_oc)/old_oc if old_oc > 0 else factor, device)
    else: return net2thinner_conv_output(conv, float(new_oc)/old_oc if old_oc > 0 else factor, device)
def resize_linear_output(linear: nn.Linear, factor: float, device='cpu') -> nn.Linear:
    if abs(factor - 1.0) < 1e-6: return linear
    old_of = linear.out_features; new_of = max(1, int(round(old_of * factor)))
    if new_of == old_of: return linear
    return _resize_linear_layer(linear, linear.in_features, new_of, device)

def adapt_conv_input_channels(conv: nn.Conv2d, new_in_channels: int, device='cpu') -> nn.Conv2d:
    if conv.in_channels == new_in_channels: return conv
    new_in_channels = max(1, new_in_channels); old_ic = conv.in_channels
    if conv.groups > 1:
        if new_in_channels % conv.groups != 0:
            print(f"  INVALID ADAPTATION: Cannot adapt grouped Conv2d to new_in_channels={new_in_channels} with groups={conv.groups}. The edit is invalid.")
            return None
    new_conv_module = nn.Conv2d(new_in_channels, conv.out_channels, conv.kernel_size, stride=conv.stride, padding=conv.padding, groups=conv.groups, bias=(conv.bias is not None)).to(device)
    with torch.no_grad():
        if old_ic == 0: nn.init.kaiming_normal_(new_conv_module.weight, mode='fan_in', nonlinearity='relu')
        else:
            w_new_conv_adapt = torch.zeros_like(new_conv_module.weight.data)
            oc_per_group = conv.out_channels // conv.groups; new_ic_per_group = new_in_channels // conv.groups; old_ic_per_group = old_ic // conv.groups
            for g in range(conv.groups):
                in_start_new, in_end_new = g * new_ic_per_group, (g + 1) * new_ic_per_group
                in_start_old, in_end_old = g * old_ic_per_group, (g + 1) * old_ic_per_group
                out_start, out_end = g * oc_per_group, (g + 1) * oc_per_group
                for o_idx in range(out_start, out_end):
                    for i_idx in range(in_start_new, in_end_new):
                        if i_idx < in_end_old: w_new_conv_adapt[o_idx, i_idx].copy_(conv.weight.data[o_idx, i_idx])
                        else:
                            orig_i_idx = in_start_old + (i_idx - in_start_new) % old_ic_per_group
                            w_new_conv_adapt[o_idx, i_idx].copy_(conv.weight.data[o_idx, orig_i_idx])
                            w_new_conv_adapt[o_idx, i_idx] /= max(1.0, new_ic_per_group / old_ic_per_group)
            new_conv_module.weight.data.copy_(w_new_conv_adapt)
        if conv.bias is not None and new_conv_module.bias is not None: new_conv_module.bias.data.copy_(conv.bias.data)
        elif new_conv_module.bias is not None: nn.init.zeros_(new_conv_module.bias.data)
    return new_conv_module
def adapt_linear_input_features(linear: nn.Linear, new_in_features: int, device='cpu') -> nn.Linear:
    if linear.in_features == new_in_features: return linear
    return _resize_linear_layer(linear, new_in_features, linear.out_features, device)

def adapt_batchnorm_features(bn: nn.Module, new_num_features: int, device='cpu') -> nn.Module:
    if bn.num_features == new_num_features: return bn
    new_bn = type(bn)(new_num_features).to(device)
    with torch.no_grad():
        min_feat = min(bn.num_features, new_num_features)
        if bn.weight is not None:
            new_bn.weight.data[:min_feat] = bn.weight.data[:min_feat].clone()
            if new_num_features > bn.num_features:
                new_bn.weight.data[bn.num_features:] = bn.weight.data[-1].clone()
        if bn.bias is not None:
            new_bn.bias.data[:min_feat] = bn.bias.data[:min_feat].clone()
            if new_num_features > bn.num_features:
                new_bn.bias.data[bn.num_features:] = bn.bias.data[-1].clone()
        if bn.running_mean is not None:
            new_bn.running_mean[:min_feat] = bn.running_mean[:min_feat].clone()
            if new_num_features > bn.num_features:
                new_bn.running_mean[bn.num_features:] = bn.running_mean[-1].clone()
        if bn.running_var is not None:
            new_bn.running_var[:min_feat] = bn.running_var[:min_feat].clone()
            if new_num_features > bn.num_features:
                new_bn.running_var[bn.num_features:] = bn.running_var[-1].clone()
    return new_bn
def net2deeper_linear_insert_identity(head_module_owner: nn.Module, head_name: str, device='cpu'):
    if not hasattr(head_module_owner, head_name): print(f"Err: Attr {head_name} not found for deepening"); return False
    original_component = getattr(head_module_owner, head_name)
    if isinstance(original_component, nn.Linear):
        identity_dim = original_component.out_features
        if identity_dim <= 0: print(f"Cannot insert identity for dim {identity_dim} in Linear head {head_name}"); return False
        identity_layer = nn.Linear(identity_dim, identity_dim, bias=True).to(device)
        with torch.no_grad(): identity_layer.weight.data.copy_(torch.eye(identity_dim,device=device)); identity_layer.bias.data.fill_(0)
        setattr(head_module_owner,head_name,nn.Sequential(original_component,identity_layer).to(device)); print(f"  Deepened MLP head '{head_name}' (Linear -> Sequential)")
        return True
    elif isinstance(original_component, nn.Sequential):
        if not original_component or not isinstance(original_component[-1],nn.Linear): print(f"Cannot deepen Seq head '{head_name}', last not Linear."); return False
        identity_dim = original_component[-1].in_features
        if identity_dim <= 0: print(f"Cannot insert identity for dim {identity_dim} in Seq head {head_name}"); return False
        identity_layer = nn.Linear(identity_dim, identity_dim, bias=True).to(device)
        with torch.no_grad(): identity_layer.weight.data.copy_(torch.eye(identity_dim,device=device)); identity_layer.bias.data.fill_(0)
        new_seq_layers = nn.ModuleList([l for l in original_component[:-1]] + [identity_layer, original_component[-1]])
        setattr(head_module_owner,head_name,nn.Sequential(*new_seq_layers).to(device)); print(f"  Deepened Seq head '{head_name}'.")
        return True
    print(f"Err: Head '{head_name}' type {type(original_component)} not Linear/Seq for deepening.");
    return False

def net2thinner_linear_remove_layer(head_module_owner: nn.Module, head_name: str, device='cpu'):
    if not hasattr(head_module_owner, head_name): print(f"Err: Attr {head_name} not found for pruning"); return False
    original_component = getattr(head_module_owner, head_name)
    if not isinstance(original_component, nn.Sequential) or len(original_component) <= 2:
        print(f"Cannot prune head '{head_name}': not a Sequential module with more than 2 layers."); return False

    pruned_layers = nn.ModuleList([l for l in original_component[:-2]] + [original_component[-1]])

    if len(pruned_layers) == 1:
        setattr(head_module_owner, head_name, pruned_layers[0].to(device))
        print(f"  Pruned MLP head '{head_name}' (Sequential -> Linear)")
    else:
        setattr(head_module_owner, head_name, nn.Sequential(*pruned_layers).to(device))
        print(f"  Pruned MLP head '{head_name}'.")
    return True

def _get_gcn_conv_linear_submodule(gcn_layer):
    if hasattr(gcn_layer, 'lin') and (isinstance(gcn_layer.lin, nn.Linear) or (PYG_AVAILABLE and isinstance(gcn_layer.lin, PyGLinear))):
        return gcn_layer.lin
    return None
if PYG_AVAILABLE:
    def resize_gcn_conv_hidden(gcn_layer: pyg_nn.GCNConv, new_hidden_dim: int, prev_layer_out_dim: int, device='cpu'):
        old_hidden_dim = gcn_layer.out_channels
        if new_hidden_dim == old_hidden_dim: return gcn_layer, False
        new_gcn = pyg_nn.GCNConv(prev_layer_out_dim,new_hidden_dim,bias=(gcn_layer.bias is not None), improved=gcn_layer.improved,add_self_loops=gcn_layer.add_self_loops, normalize=gcn_layer.normalize).to(device)
        with torch.no_grad():
            gcn_lin_original = _get_gcn_conv_linear_submodule(gcn_layer)
            if gcn_lin_original is not None:
                new_gcn.lin = _resize_linear_layer(gcn_lin_original, prev_layer_out_dim, new_hidden_dim, device)

            if gcn_layer.bias is not None and new_gcn.bias is not None:
                min_out = min(old_hidden_dim, new_hidden_dim)
                new_gcn.bias.data[:min_out] = gcn_layer.bias.data[:min_out].clone()
                if new_hidden_dim > old_hidden_dim and old_hidden_dim > 0:
                    for i in range(old_hidden_dim, new_hidden_dim):
                        new_gcn.bias.data[i] = gcn_layer.bias.data[i % old_hidden_dim].clone()
        return new_gcn,True

    def adapt_gcn_conv_input_dim(gcn_layer: pyg_nn.GCNConv, new_input_dim: int, device='cpu'):
        old_input_dim = gcn_layer.in_channels
        if new_input_dim == old_input_dim: return gcn_layer, False
        new_gcn = pyg_nn.GCNConv(new_input_dim,gcn_layer.out_channels,bias=(gcn_layer.bias is not None), improved=gcn_layer.improved,add_self_loops=gcn_layer.add_self_loops,normalize=gcn_layer.normalize).to(device)
        with torch.no_grad():
            gcn_lin_original_adapt = _get_gcn_conv_linear_submodule(gcn_layer)
            if gcn_lin_original_adapt is not None:
                new_gcn.lin = _resize_linear_layer(gcn_lin_original_adapt, new_input_dim, gcn_layer.out_channels, device)

            if gcn_layer.bias is not None and new_gcn.bias is not None:
                new_gcn.bias.data.copy_(gcn_layer.bias.data)
        return new_gcn, True

    def create_identity_gcn_layer(dim: int, device='cpu', **gcn_kwargs):
        identity_gcn = pyg_nn.GCNConv(dim,dim,bias=gcn_kwargs.get('bias',True), normalize=gcn_kwargs.get('normalize',True), add_self_loops=gcn_kwargs.get('add_self_loops',True), improved=gcn_kwargs.get('improved',False)).to(device)
        with torch.no_grad():
            gcn_lin_identity = _get_gcn_conv_linear_submodule(identity_gcn)
            if gcn_lin_identity is not None:
                if gcn_lin_identity.weight.shape[0] == gcn_lin_identity.weight.shape[1]:
                    gcn_lin_identity.weight.data.copy_(torch.eye(dim,device=device))
                else:
                    nn.init.kaiming_uniform_(gcn_lin_identity.weight, a=math.sqrt(5))
                if hasattr(gcn_lin_identity, 'bias') and gcn_lin_identity.bias is not None:
                    gcn_lin_identity.bias.data.fill_(0.0)
            if identity_gcn.bias is not None:
                identity_gcn.bias.data.fill_(0.0)
        return identity_gcn

# =======================================================================
#  Dynamic Models (TargetCNN)
# =======================================================================
class DynamicStageModule(nn.Module):
    def __init__(self, stage_idx_dyn, initial_in_channels_dyn, initial_spatial_size_dyn, max_ops_dyn=None):
        super().__init__()
        self.stage_idx = stage_idx_dyn
        self.initial_in_channels = initial_in_channels_dyn
        self.initial_spatial_size = initial_spatial_size_dyn
        self.max_ops = max_ops_dyn
        self.ops = nn.ModuleList()
        self.op_descriptions = []
        self.dropout = nn.Dropout(p=DROPOUT_RATE)

    def add_op(self, op_module_dyn, op_description_dyn, insert_at=None):
        if self.max_ops is not None and len(self.ops) >= self.max_ops:
            return False
        if insert_at is None:
            self.ops.append(op_module_dyn)
            self.op_descriptions.append(op_description_dyn)
        else:
            self.ops.insert(insert_at, op_module_dyn)
            self.op_descriptions.insert(insert_at, op_description_dyn)
        return True

    def get_op_output_properties(self, op_idx):
        if op_idx == -1:
            return self.initial_in_channels, self.initial_spatial_size
        if 0 <= op_idx < len(self.op_descriptions):
            desc = self.op_descriptions[op_idx]
            return desc.get('out_channels', 0), desc.get('out_spatial', 0)
        raise IndexError(f"Operator index {op_idx} out of range for stage {self.stage_idx} with {len(self.ops)} ops.")

    def get_current_out_properties(self):
        if not self.op_descriptions:
            return self.initial_in_channels, self.initial_spatial_size
        return self.get_op_output_properties(len(self.ops) - 1)

    def forward(self, x_dyn):
        outputs_history_dyn = {-1: x_dyn}
        is_first_conv_in_model = (self.stage_idx == 0)

        for i_dyn, op_desc_item_dyn in enumerate(self.op_descriptions):
            op_module_fwd_dyn = self.ops[i_dyn]
            input_indices_dyn = op_desc_item_dyn.get('input_indices', [-1])
            if not isinstance(input_indices_dyn, list): input_indices_dyn = [input_indices_dyn]

            current_op_inputs_dyn = []
            for source_op_local_idx_dyn in input_indices_dyn:
                if not (-1 <= source_op_local_idx_dyn < i_dyn):
                     source_op_local_idx_dyn = (i_dyn - 1) if i_dyn > 0 else -1

                if source_op_local_idx_dyn in outputs_history_dyn:
                    current_op_inputs_dyn.append(outputs_history_dyn[source_op_local_idx_dyn])
                else:
                    default_input_key_dyn = (i_dyn-1) if i_dyn > 0 else -1
                    current_op_inputs_dyn.append(outputs_history_dyn.get(default_input_key_dyn, x_dyn))
            try:
                if not current_op_inputs_dyn: op_output_dyn = op_module_fwd_dyn(x_dyn)
                elif len(current_op_inputs_dyn) == 1: op_output_dyn = op_module_fwd_dyn(current_op_inputs_dyn[0])
                else: op_output_dyn = op_module_fwd_dyn(*current_op_inputs_dyn)
            except Exception as e_dyn_fwd:
                print(f"CRITICAL Error in DynamicStageModule op {i_dyn}, type {op_desc_item_dyn.get('type','Unknown')}, stage {self.stage_idx}: {e_dyn_fwd}"); raise e_dyn_fwd

            if self.training and False:
                if i_dyn > 1 and 'batchnorm' in self.op_descriptions[i_dyn-1]['type'] and 'conv2d' in self.op_descriptions[i_dyn-2]['type']:
                    if not is_first_conv_in_model:
                        op_output_dyn = self.dropout(op_output_dyn)
                    is_first_conv_in_model = False

            outputs_history_dyn[i_dyn] = op_output_dyn
        return outputs_history_dyn[len(self.ops)-1] if self.ops else x_dyn

class TargetCNN(nn.Module):
    # --- MODIFIED FOR FASHION-MNIST ---
    def __init__(self, num_classes_cnn=10, num_stages_cnn=NUM_STAGES_TARGET_CNN, init_model_ch_cnn=64, input_spatial_size=28):
        super().__init__()
        self.num_stages = num_stages_cnn
        self.stages = nn.ModuleList()
        current_channels_cnn = 1 # Fashion-MNIST is grayscale (1 channel)
        current_spatial_size = input_spatial_size
        self.input_placeholder_desc = {'type': 'input_placeholder', 'out_channels': current_channels_cnn, 'out_spatial': input_spatial_size, 'input_indices': []}
        stage_base_channels_cnn = [init_model_ch_cnn, init_model_ch_cnn * 2, init_model_ch_cnn * 4]
        for i_cnn_stage in range(num_stages_cnn):
            stage_module_cnn = DynamicStageModule(i_cnn_stage, current_channels_cnn, current_spatial_size, max_ops_dyn=None)
            target_stage_out_channels_cnn = stage_base_channels_cnn[i_cnn_stage] if i_cnn_stage < len(stage_base_channels_cnn) else stage_base_channels_cnn[-1]
            conv1 = nn.Conv2d(current_channels_cnn, target_stage_out_channels_cnn, 3, 1, 1, bias=False).to(DEVICE)
            bn1 = nn.BatchNorm2d(target_stage_out_channels_cnn).to(DEVICE)
            relu1 = nn.ReLU(inplace=False).to(DEVICE)
            stage_module_cnn.add_op(conv1, {'type': 'conv2d', 'out_channels': target_stage_out_channels_cnn, 'out_spatial': current_spatial_size, 'input_indices': [-1]})
            stage_module_cnn.add_op(bn1, {'type': 'batchnorm2d', 'out_channels': target_stage_out_channels_cnn, 'out_spatial': current_spatial_size, 'input_indices': [0]})
            stage_module_cnn.add_op(relu1, {'type': 'relu', 'out_channels': target_stage_out_channels_cnn, 'out_spatial': current_spatial_size, 'input_indices': [1]})
            current_channels_after_block_cnn, current_spatial_after_block = stage_module_cnn.get_current_out_properties()
            if i_cnn_stage < num_stages_cnn - 1 :
                pool_cnn = nn.MaxPool2d(2, 2).to(DEVICE)
                current_spatial_size //= 2
                stage_module_cnn.add_op(pool_cnn, {'type': 'maxpool2d', 'out_channels': current_channels_after_block_cnn, 'out_spatial': current_spatial_size, 'input_indices': [2]})
            self.stages.append(stage_module_cnn)
            current_channels_cnn, current_spatial_size = stage_module_cnn.get_current_out_properties()
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
        fc_in_features_cnn, _ = self.get_last_stage_out_properties()
        self.classifier = nn.ModuleList([nn.Linear(max(1, fc_in_features_cnn), num_classes_cnn).to(DEVICE)])
        self.classifier_op_descriptions = [{'type': 'linear', 'out_features': num_classes_cnn, 'input_indices': [-1]}]
    def forward(self,x_cnn_fwd):
        for stage_cnn_fwd in self.stages: x_cnn_fwd=stage_cnn_fwd(x_cnn_fwd)
        x_cnn_fwd = self.flatten(self.adaptive_pool(x_cnn_fwd))
        for op_module in self.classifier:
            x_cnn_fwd = op_module(x_cnn_fwd)
        return x_cnn_fwd
    def get_last_stage_out_properties(self):
        if not self.stages:
            return self.input_placeholder_desc['out_channels'], self.input_placeholder_desc['out_spatial']
        return self.stages[-1].get_current_out_properties()
    def get_classifier_in_features(self):
        if not self.classifier: return 0
        return self.classifier[0].in_features
    def get_classifier_out_features(self):
        if not self.classifier: return 0
        return self.classifier[-1].out_features

# =======================================================================
#  GNN MetaAgent
# =======================================================================
if PYG_AVAILABLE:
    class MetaAgentGNN(nn.Module):
        def __init__(self, node_feature_dim_agent=NODE_FEATURE_DIM, initial_gnn_hidden_dim_agent=INITIAL_GNN_HIDDEN_DIM, initial_num_gnn_layers_agent=INITIAL_NUM_GNN_LAYERS,
                     global_history_dim_flat_agent=GLOBAL_SUMMARY_FEATURE_DIM * MAX_GLOBAL_HISTORY_LEN, device_to_use_agent=DEVICE, **kwargs_agent):
            super().__init__()
            self.device = device_to_use_agent
            self.node_feature_dim = node_feature_dim_agent
            self.current_gnn_hidden_dim = initial_gnn_hidden_dim_agent
            self.current_num_gnn_layers = initial_num_gnn_layers_agent
            self.action_space_sizes = {k: v for k, v in kwargs_agent.items()}
            self.head_depth_counters = {}
            self.edit_type_embedding = nn.Embedding(self.action_space_sizes['num_edit_types'], EDIT_TYPE_EMBED_DIM).to(self.device)
            self.policy_head_names = ['head_target_loc_stage', 'head_target_conv_ch_mult', 'head_target_resize_factor']
            self.heads_input_dim_global_current = -1
            self.gnn_layers = nn.ModuleList()
            self._build_gnn_layers()
            self._update_mlp_heads()

        def _get_current_gnn_output_dim(self):
            if self.current_num_gnn_layers == 0:
                return self.node_feature_dim
            return self.current_gnn_hidden_dim

        def _adapt_head_input(self, head, new_input_dim):
            if isinstance(head, (nn.Linear, PyGLinear)):
                is_pyg = isinstance(head, PyGLinear)
                old_in = head.in_channels if is_pyg else head.in_features
                old_out = head.out_channels if is_pyg else head.out_features
                if old_in != new_input_dim:
                    return _resize_linear_layer(head, new_input_dim, old_out, self.device)
            elif isinstance(head, nn.Sequential):
                first_linear = head[0]
                if isinstance(first_linear, (nn.Linear, PyGLinear)):
                    is_pyg = isinstance(first_linear, PyGLinear)
                    old_in = first_linear.in_channels if is_pyg else first_linear.in_features
                    old_out = first_linear.out_channels if is_pyg else first_linear.out_features
                    if old_in != new_input_dim:
                        head[0] = _resize_linear_layer(first_linear, new_input_dim, old_out, self.device)
            return head

        def _update_mlp_heads(self):
            gnn_output_dim = self._get_current_gnn_output_dim()
            global_history_dim_flat = GLOBAL_SUMMARY_FEATURE_DIM * MAX_GLOBAL_HISTORY_LEN
            new_base_input_dim = gnn_output_dim + global_history_dim_flat

            if new_base_input_dim == self.heads_input_dim_global_current:
                return

            head_configs = {
                'head_target_edit_type': (new_base_input_dim, self.action_space_sizes['num_edit_types']),
                'head_meta_self_edit_type': (new_base_input_dim, NUM_META_SELF_EDIT_TYPES),
                'head_value': (new_base_input_dim, 1),
                'head_target_loc_stage': (new_base_input_dim + EDIT_TYPE_EMBED_DIM, self.action_space_sizes['num_stages_target']),
                'head_target_conv_ch_mult': (new_base_input_dim + EDIT_TYPE_EMBED_DIM, len(DISCRETE_CH_MULT_ADD)),
                'head_target_resize_factor': (new_base_input_dim + EDIT_TYPE_EMBED_DIM, len(DISCRETE_RESIZE_FACTORS)),
                'head_resize_op_selector_scorer': (gnn_output_dim, 1),
                'head_skip_source_scorer': (gnn_output_dim, 1),
                'head_skip_destination_scorer': (gnn_output_dim, 1),
            }

            for name, (in_dim, out_dim) in head_configs.items():
                if hasattr(self, name):
                    setattr(self, name, self._adapt_head_input(getattr(self, name), in_dim))
                else:
                    setattr(self, name, nn.Linear(in_dim, out_dim).to(self.device))
                    if name in self.policy_head_names:
                        self.head_depth_counters[name] = 1

            self.heads_input_dim_global_current = new_base_input_dim

        def _build_gnn_layers(self):
            current_in_dim = self.node_feature_dim
            if self.current_num_gnn_layers > 0:
                for _ in range(self.current_num_gnn_layers):
                    out_dim = self.current_gnn_hidden_dim
                    self.gnn_layers.append(pyg_nn.GCNConv(current_in_dim, out_dim, bias=True, normalize=True, add_self_loops=True).to(self.device))
                    current_in_dim = out_dim

        def get_mlp_head_names(self):
             return self.policy_head_names

        def _process_graph_and_state(self, graph_data, global_states_history_flat):
            node_features, edge_index = graph_data.x, graph_data.edge_index
            embeddings = node_features
            if self.current_num_gnn_layers > 0 and graph_data.num_nodes > 0:
                for gnn_layer in self.gnn_layers:
                    embeddings = F.relu(gnn_layer(embeddings, edge_index))
            elif graph_data.num_nodes == 0:
                embeddings = torch.empty(0, self._get_current_gnn_output_dim(), device=self.device)
            batch_vector = graph_data.batch
            if batch_vector is None and embeddings.numel() > 0:
                batch_vector = torch.zeros(embeddings.size(0), dtype=torch.long, device=self.device)
            graph_embedding = pyg_nn.global_mean_pool(embeddings, batch_vector) if graph_data.num_nodes > 0 else torch.zeros(1, self._get_current_gnn_output_dim(), device=self.device)
            if global_states_history_flat.ndim == 1: global_states_history_flat = global_states_history_flat.unsqueeze(0)
            if graph_embedding.ndim == 1: graph_embedding = graph_embedding.unsqueeze(0)
            combined_features = torch.cat((graph_embedding, global_states_history_flat), dim=1)
            return combined_features, embeddings

        def forward(self, graph_data, global_states_history_flat):
            combined_features, node_embeddings = self._process_graph_and_state(graph_data, global_states_history_flat)
            l_te = self.head_target_edit_type(combined_features)
            l_mse = self.head_meta_self_edit_type(combined_features)
            value_pred = self.head_value(combined_features)
            return l_te, l_mse, value_pred, node_embeddings, combined_features

        def get_conditional_logits(self, base_state_embedding, chosen_edit_type_tensor):
            type_emb = self.edit_type_embedding(chosen_edit_type_tensor)
            conditional_state = torch.cat([base_state_embedding, type_emb], dim=1)
            logits = {}
            edit_type = chosen_edit_type_tensor.item()
            if edit_type == EDIT_TYPE_ADD_CONV_BLOCK:
                logits['stage'] = self.head_target_loc_stage(conditional_state)
                logits['ch_mult'] = self.head_target_conv_ch_mult(conditional_state)
            elif edit_type == EDIT_TYPE_RESIZE_LAYER:
                logits['stage'] = self.head_target_loc_stage(conditional_state)
                logits['resize_factor'] = self.head_target_resize_factor(conditional_state)
            elif edit_type == EDIT_TYPE_ADD_SKIP:
                logits['stage'] = self.head_target_loc_stage(conditional_state)
            return logits

        def deepen_gnn(self, device='cpu'):
            print(f"  Deepening MetaAgentGNN: GNN Layers {self.current_num_gnn_layers} -> {self.current_num_gnn_layers + 1}")
            new_layer_in_dim = self._get_current_gnn_output_dim()
            if self.current_num_gnn_layers == 0 :
                new_gcn_layer = pyg_nn.GCNConv(self.node_feature_dim, self.current_gnn_hidden_dim, bias=True, normalize=True, add_self_loops=True).to(device)
            else:
                new_gcn_layer = create_identity_gcn_layer(self.current_gnn_hidden_dim, device=device)
            self.gnn_layers.append(new_gcn_layer)
            self.current_num_gnn_layers += 1
            self._update_mlp_heads()
            return True

        def widen_gnn_hidden_dim(self, factor=META_GNN_WIDEN_FACTOR, device='cpu'):
            old_dim = self.current_gnn_hidden_dim
            new_dim = max(MIN_GNN_HIDDEN_DIM, int(round(old_dim * factor)))
            if new_dim == old_dim or self.current_num_gnn_layers == 0: return False

            print(f"  Widening MetaAgentGNN: GNN Hidden Dim {old_dim} -> {new_dim}")
            new_gnn_list = nn.ModuleList()
            current_in_dim = self.node_feature_dim
            for i in range(self.current_num_gnn_layers):
                original_gcn = self.gnn_layers[i]
                final_gcn, _ = resize_gcn_conv_hidden(original_gcn, new_dim, current_in_dim, device)
                new_gnn_list.append(final_gcn)
                current_in_dim = new_dim
            self.gnn_layers = new_gnn_list
            self.current_gnn_hidden_dim = new_dim
            self._update_mlp_heads()
            return True

        def deepen_one_mlp_head(self, head_attr_name, device='cpu'):
            original_comp = getattr(self, head_attr_name, None)
            if original_comp is None: return False
            current_depth = self.head_depth_counters.get(head_attr_name, 1)
            if current_depth >= 8: return False
            changed = net2deeper_linear_insert_identity(self, head_attr_name, device=device)
            if changed:
                new_comp = getattr(self, head_attr_name)
                self.head_depth_counters[head_attr_name] = len(new_comp) if isinstance(new_comp, nn.Sequential) else 1
            return changed

        def shrink_gnn_hidden_dim(self, factor=META_GNN_SHRINK_FACTOR, device='cpu'):
            return self.widen_gnn_hidden_dim(factor=factor, device=device)

        def prune_gnn_layer(self, device='cpu'):
            if self.current_num_gnn_layers <= MIN_GNN_LAYERS: return False
            print(f"  Pruning MetaAgentGNN: GNN Layers {self.current_num_gnn_layers} -> {self.current_num_gnn_layers - 1}")
            self.gnn_layers.pop(-1)
            self.current_num_gnn_layers -= 1
            self._update_mlp_heads()
            return True

        def prune_one_mlp_head(self, head_attr_name, device='cpu'):
            original_comp = getattr(self, head_attr_name, None)
            if original_comp is None: return False
            current_depth = self.head_depth_counters.get(head_attr_name, 1)
            if current_depth <= MIN_MLP_HEAD_SEQUENTIAL_DEPTH: return False
            changed = net2thinner_linear_remove_layer(self, head_attr_name, device=device)
            if changed:
                new_comp = getattr(self, head_attr_name)
                self.head_depth_counters[head_attr_name] = len(new_comp) if isinstance(new_comp, nn.Sequential) else 1
            return changed

# =======================================================================
#  Prometheus System
# =======================================================================
class DEITI:
    def __init__(self):
        if not PYG_AVAILABLE: raise ImportError("PyTorch Geometric required for Prometheus.")
        self.device=DEVICE
        self.target_cnn = TargetCNN().to(DEVICE)
        self.meta_agent = MetaAgentGNN(
            device_to_use_agent=self.device, num_edit_types=NUM_TARGET_CNN_EDIT_TYPES, num_stages_target=NUM_STAGES_TARGET_CNN,
            num_ch_mults=len(DISCRETE_CH_MULT_ADD), num_resize_factors=len(DISCRETE_RESIZE_FACTORS),
        ).to(DEVICE)

        self.criterion_target=nn.CrossEntropyLoss()
        self.global_states_history_buffer=[]
        self.amp_scaler = torch.amp.GradScaler(enabled=(self.device.type=='cuda'))
        self._init_dataloaders()
        self.opt_target = optim.AdamW(self.target_cnn.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
        self.sched_target = CosineAnnealingLR(self.opt_target, T_max=BASE_POST_EPOCHS)
        self.opt_meta = optim.Adam(self.meta_agent.parameters(), lr=BASE_META_LR)

        self.frozen_bns = []
        self.warmup_state = {'active': False, 'original_lr': LEARNING_RATE, 'param_ratio': 1.0}

        self.best_global_accuracy = -1.0
        self.best_global_model = None

        self.iterations_without_improvement = 0
        self.consecutive_dummy_pass_failures = 0

    def _init_dataloaders(self):
        # --- MODIFIED FOR FASHION-MNIST ---
        fashion_mnist_mean, fashion_mnist_std = (0.5,), (0.5,)
        train_transforms = transforms.Compose([
            transforms.RandomCrop(28, padding=4), # Adjusted for 28x28 images
            transforms.RandomHorizontalFlip(),
            transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
            transforms.RandAugment(num_ops=2, magnitude=9),
            transforms.ToTensor(),
            transforms.Normalize(fashion_mnist_mean, fashion_mnist_std),
            transforms.RandomErasing(p=0.5, scale=(0.02, 0.2)),
        ])

        val_transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(fashion_mnist_mean, fashion_mnist_std)
        ])
        try:
            tr_ds = torchvision.datasets.FashionMNIST('./data', train=True, download=True, transform=train_transforms)
            val_ds = torchvision.datasets.FashionMNIST('./data', train=False, download=True, transform=val_transforms)
            print("Successfully loaded Fashion-MNIST dataset.")
        except Exception as e:
            print(f"Fashion-MNIST download failed: {e}. Using FakeData.")
            fake_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(fashion_mnist_mean, fashion_mnist_std)
            ])
            tr_ds = torchvision.datasets.FakeData(size=BATCH_SIZE*50, image_size=(1,28,28), num_classes=10, transform=fake_transform)
            val_ds = torchvision.datasets.FakeData(size=BATCH_SIZE*20, image_size=(1,28,28), num_classes=10, transform=fake_transform)

        num_workers = 4 if self.device.type == 'cuda' else 0
        use_persistent_workers = num_workers > 0
        print(f"Using {num_workers} workers for data loading (persistent: {use_persistent_workers}).")
        self.train_loader = DataLoader(tr_ds, BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=self.device.type=='cuda', drop_last=True, persistent_workers=use_persistent_workers)
        self.val_loader = DataLoader(val_ds, BATCH_SIZE, shuffle=False, num_workers=num_workers, pin_memory=self.device.type=='cuda', drop_last=True, persistent_workers=use_persistent_workers)

    def _create_target_cnn_graph_data(self):
        nodes, src, tgt, op_map, gid = {}, [], [], {}, 0
        _, initial_spatial = self.target_cnn.input_placeholder_desc['out_channels'], self.target_cnn.input_placeholder_desc['out_spatial']
        current_spatial_map = {-1: initial_spatial}
        for si, sm in enumerate(self.target_cnn.stages):
            stage_initial_spatial = current_spatial_map.get(si-1, initial_spatial)
            op_output_props = {-1: (sm.initial_in_channels, stage_initial_spatial)}
            for oi, od in enumerate(sm.op_descriptions):
                op_map[(si, oi)] = gid
                op_type_str = od.get('type', 'none')
                input_indices = od.get('input_indices', [-1])
                input_indices = input_indices if isinstance(input_indices, list) else [input_indices]
                prev_op_idx = input_indices[0]
                _, expected_in_spatial = op_output_props.get(prev_op_idx, (sm.initial_in_channels, stage_initial_spatial))
                is_conv = 1.0 if op_type_str == 'conv2d' else 0.0
                out_spatial = expected_in_spatial
                stride_val = 1
                if op_type_str == 'maxpool2d':
                    stride_val = sm.ops[oi].stride if hasattr(sm.ops[oi], 'stride') else 1
                    out_spatial //= stride_val
                elif op_type_str == 'conv2d':
                    current_op = sm.ops[oi]
                    stride_val = current_op.stride[0] if isinstance(current_op.stride, tuple) else current_op.stride
                    out_spatial //= stride_val
                op_output_props[oi] = (od.get('out_channels', 0), out_spatial)
                od['out_spatial'] = out_spatial
                nf = [
                    float(OP_TYPE_IDS.get(op_type_str, 0)),
                    float(od.get('out_channels', 0)) / NORMALIZATION_WIDTH_DIVISOR,
                    float(si) / max(1, NUM_STAGES_TARGET_CNN - 1),
                    float(oi) / max(1, len(sm.ops) - 1),
                    is_conv,
                    float(stride_val -1),
                    float(out_spatial) / NORMALIZATION_SPATIAL_DIVISOR,
                ]
                nodes[gid] = nf
                gid += 1
            if sm.op_descriptions:
                 current_spatial_map[si] = op_output_props[len(sm.ops)-1][1]
            else:
                 current_spatial_map[si] = stage_initial_spatial
        s_out_gid_map = {s: op_map.get((s, len(sm_loop.ops) - 1), -1) for s, sm_loop in enumerate(self.target_cnn.stages) if sm_loop.ops}
        last_conv_gid = s_out_gid_map.get(len(self.target_cnn.stages) - 1, -1)
        for oi, od in enumerate(self.target_cnn.classifier_op_descriptions):
            op_map[('classifier', oi)] = gid
            op_type_str = od.get('type', 'none')
            nf = [
                float(OP_TYPE_IDS.get(op_type_str, 0)),
                float(od.get('out_features', 0)) / NORMALIZATION_WIDTH_DIVISOR,
                1.0, float(oi) / max(1, len(self.target_cnn.classifier) - 1),
                0.0, 0.0, 0.0
            ]
            nodes[gid] = nf
            gid += 1
        for si, sm in enumerate(self.target_cnn.stages):
            for oi, od in enumerate(sm.op_descriptions):
                cur_gid = op_map.get((si, oi), -1)
                if cur_gid == -1: continue
                in_ids = od.get('input_indices', [-1] if oi == 0 else [oi - 1])
                in_ids = in_ids if isinstance(in_ids, list) else [in_ids]
                for lsid in in_ids:
                    if lsid == -1: sgid = s_out_gid_map.get(si - 1, -1) if si > 0 else -1
                    else: sgid = op_map.get((si, lsid), -1)
                    if sgid != -1: src.append(sgid); tgt.append(cur_gid)
        prev_gid = last_conv_gid
        for oi, od in enumerate(self.target_cnn.classifier_op_descriptions):
            cur_gid = op_map.get(('classifier', oi), -1)
            if cur_gid != -1 and prev_gid != -1:
                src.append(prev_gid)
                tgt.append(cur_gid)
            prev_gid = cur_gid
        x = torch.tensor([nodes[i] for i in range(gid)] if gid > 0 else [[0.] * NODE_FEATURE_DIM], dtype=torch.float32, device=self.device)
        eidx = torch.tensor([src, tgt], dtype=torch.long, device=self.device) if src else torch.empty((2, 0), dtype=torch.long, device=self.device)
        return Data(x=x, edge_index=eidx, batch=torch.zeros(x.size(0), dtype=torch.long, device=self.device) if x.size(0) > 0 else None), op_map, s_out_gid_map

    def _ensure_target_cnn_consistency(self):
        current_channels, current_spatial = self.target_cnn.input_placeholder_desc['out_channels'], self.target_cnn.input_placeholder_desc['out_spatial']
        for stage_module in self.target_cnn.stages:
            stage_module.initial_in_channels = current_channels
            stage_module.initial_spatial_size = current_spatial
            op_output_props = {-1: (current_channels, current_spatial)}
            for i, op in enumerate(stage_module.ops):
                desc = stage_module.op_descriptions[i]
                input_indices = desc.get('input_indices', [-1]); input_indices = input_indices if isinstance(input_indices, list) else [input_indices]
                prev_op_idx = input_indices[0] if input_indices else -1
                expected_in_channels, expected_in_spatial = op_output_props.get(prev_op_idx, (current_channels, current_spatial))
                new_op = None
                if isinstance(op, nn.Conv2d):
                    if op.in_channels != expected_in_channels: new_op = adapt_conv_input_channels(op, expected_in_channels, self.device)
                    if new_op is None and op.in_channels != expected_in_channels: return False
                    desc['out_channels'] = op.out_channels if new_op is None else new_op.out_channels
                    stride_val = op.stride[0] if isinstance(op.stride, tuple) else op.stride
                    desc['out_spatial'] = expected_in_spatial // stride_val
                elif isinstance(op, (nn.BatchNorm2d, nn.BatchNorm1d)):
                    if op.num_features != expected_in_channels: new_op = adapt_batchnorm_features(op, expected_in_channels, self.device)
                    desc['out_channels'] = expected_in_channels
                    desc['out_spatial'] = expected_in_spatial
                elif isinstance(op, nn.ReLU):
                    desc['out_channels'] = expected_in_channels; desc['out_spatial'] = expected_in_spatial
                elif isinstance(op, nn.MaxPool2d):
                    stride_val = op.stride if isinstance(op.stride, int) else op.stride[0]
                    desc['out_channels'] = expected_in_channels; desc['out_spatial'] = expected_in_spatial // stride_val
                elif isinstance(op, AddWithProjection):
                    primary_ch, primary_sp = op_output_props[input_indices[0]]; skip_ch, _ = op_output_props[input_indices[1]]
                    if primary_ch != skip_ch: op.projection = nn.Sequential(nn.Conv2d(skip_ch, primary_ch, kernel_size=1, bias=False), nn.BatchNorm2d(primary_ch)).to(self.device)
                    else: op.projection = nn.Identity()
                    desc['out_channels'], desc['out_spatial'] = primary_ch, primary_sp
                if new_op is not None: stage_module.ops[i] = new_op
                op_output_props[i] = (desc.get('out_channels'), desc.get('out_spatial'))
            current_channels, current_spatial = stage_module.get_current_out_properties()
        last_conv_channels, _ = self.target_cnn.get_last_stage_out_properties()
        current_features = max(1, last_conv_channels)
        for i, op in enumerate(self.target_cnn.classifier):
            desc = self.target_cnn.classifier_op_descriptions[i]
            new_op = None
            if isinstance(op, nn.Linear):
                if op.in_features != current_features: new_op = adapt_linear_input_features(op, current_features, self.device)
                current_features = op.out_features if new_op is None else new_op.out_features
                desc['out_features'] = current_features
            elif isinstance(op, nn.BatchNorm1d):
                if op.num_features != current_features: new_op = adapt_batchnorm_features(op, current_features, self.device)
            elif isinstance(op, nn.ReLU):
                pass
            if new_op is not None: self.target_cnn.classifier[i] = new_op
        return True

    def _apply_target_cnn_edit(self, actions):
        edit_type = actions['target_edit_type'].item(); changed = False
        newly_added_bns = []

        if edit_type == EDIT_TYPE_ADD_LINEAR_BLOCK:
            prev_out_features = -1
            for op in reversed(self.target_cnn.classifier[:-1]):
                if hasattr(op, 'out_features'):
                    prev_out_features = op.out_features
                    break
            if prev_out_features == -1:
                prev_out_features = self.target_cnn.get_classifier_in_features()

            new_linear = nn.Linear(prev_out_features, prev_out_features).to(self.device)
            nn.init.eye_(new_linear.weight)
            if new_linear.bias is not None: nn.init.zeros_(new_linear.bias)

            bn = nn.BatchNorm1d(prev_out_features, momentum=0.1).to(self.device)
            with torch.no_grad(): bn.weight.data.fill_(1.0); bn.bias.data.zero_()
            bn.eval(); newly_added_bns.append(bn)

            new_relu = nn.ReLU(inplace=False).to(self.device)
            insert_idx = len(self.target_cnn.classifier) - 1
            self.target_cnn.classifier.insert(insert_idx, new_linear); self.target_cnn.classifier.insert(insert_idx + 1, bn); self.target_cnn.classifier.insert(insert_idx + 2, new_relu)
            self.target_cnn.classifier_op_descriptions.insert(insert_idx, {'type': 'linear', 'out_features': new_linear.out_features})
            self.target_cnn.classifier_op_descriptions.insert(insert_idx + 1, {'type': 'batchnorm1d', 'out_features': new_linear.out_features})
            self.target_cnn.classifier_op_descriptions.insert(insert_idx + 2, {'type': 'relu', 'out_features': new_linear.out_features})
            changed = True

        elif edit_type == EDIT_TYPE_ADD_CONV_BLOCK:
            stage_idx = actions['target_loc_stage'].item()
            if not (0 <= stage_idx < len(self.target_cnn.stages)): return False, []
            stage = self.target_cnn.stages[stage_idx]
            in_ch, in_sp = stage.get_current_out_properties(); in_ch = max(1, in_ch)

            k = 3; s = 1
            identity_conv = nn.Conv2d(in_ch, in_ch, k, stride=s, padding=(k-1)//2, bias=False).to(self.device)
            with torch.no_grad():
                identity_conv.weight.data.zero_()
                center = k // 2
                for i in range(in_ch): identity_conv.weight.data[i, i, center, center] = 1.0

            bn = nn.BatchNorm2d(in_ch, momentum=0.1).to(self.device)
            with torch.no_grad():
                bn.weight.data.fill_(1.0)
                bn.bias.data.zero_()
            bn.eval()
            newly_added_bns.append(bn)
            relu = nn.ReLU(inplace=False).to(self.device)

            insert_idx = len(stage.ops)
            in_indices = [-1] if insert_idx == 0 else [insert_idx-1]
            stage.add_op(identity_conv, {'type':'conv2d', 'out_channels': in_ch, 'out_spatial': in_sp, 'input_indices':in_indices}, insert_at=insert_idx)
            stage.add_op(bn, {'type':'batchnorm2d', 'out_channels': in_ch, 'out_spatial': in_sp, 'input_indices':[insert_idx]}, insert_at=insert_idx+1)
            stage.add_op(relu, {'type':'relu', 'out_channels': in_ch, 'out_spatial': in_sp, 'input_indices':[insert_idx+1]}, insert_at=insert_idx+2)

            m = DISCRETE_CH_MULT_ADD[actions['target_conv_ch_mult_idx'].item()]
            target_out_ch = max(1, int(round(in_ch * m)))

            if in_ch != target_out_ch:
                factor = float(target_out_ch) / in_ch
                widened_conv = resize_conv_output(stage.ops[insert_idx], factor, self.device)
                widened_bn = adapt_batchnorm_features(stage.ops[insert_idx+1], widened_conv.out_channels, self.device)

                stage.ops[insert_idx] = widened_conv
                stage.ops[insert_idx+1] = widened_bn

                stage.op_descriptions[insert_idx]['out_channels'] = widened_conv.out_channels
                stage.op_descriptions[insert_idx+1]['out_channels'] = widened_conv.out_channels
                stage.op_descriptions[insert_idx+2]['out_channels'] = widened_conv.out_channels

            changed = True

        elif edit_type == EDIT_TYPE_RESIZE_LAYER:
            stage_idx = actions['target_loc_stage'].item()
            if not (0 <= stage_idx < len(self.target_cnn.stages)): return False, []
            stage = self.target_cnn.stages[stage_idx]
            op_idx = actions.get('target_actual_op_idx_in_stage', -1)
            if op_idx != -1 and 0 <= op_idx < len(stage.ops):
                op_mod = stage.ops[op_idx]
                factor = DISCRETE_RESIZE_FACTORS[actions['target_resize_factor_idx'].item()]
                if isinstance(op_mod, nn.Conv2d):
                    new_op = resize_conv_output(op_mod, factor, self.device)
                    if new_op is not op_mod: stage.ops[op_idx] = new_op; stage.op_descriptions[op_idx]['out_channels'] = new_op.out_channels; changed = True
                elif isinstance(op_mod, nn.Linear):
                    new_op = resize_linear_output(op_mod, factor, self.device)
                    if new_op is not op_mod: stage.ops[op_idx] = new_op; stage.op_descriptions[op_idx]['out_features'] = new_op.out_features; changed = True
        elif edit_type == EDIT_TYPE_ADD_SKIP:
            stage_idx = actions['target_loc_stage'].item()
            if not (0 <= stage_idx < len(self.target_cnn.stages)): return False, []
            stage = self.target_cnn.stages[stage_idx]
            source_op_idx = actions.get('source_op_idx', -1); dest_op_idx = actions.get('dest_op_idx', -1)
            if source_op_idx != -1 and dest_op_idx != -1:
                add_op = AddWithProjection().to(self.device)
                _, dest_sp = stage.get_op_output_properties(dest_op_idx)
                dest_ch, _ = stage.get_op_output_properties(dest_op_idx)
                add_desc = {'type': 'add', 'out_channels': dest_ch, 'out_spatial': dest_sp, 'input_indices': [dest_op_idx, source_op_idx]}
                insert_at_idx = dest_op_idx + 1
                stage.add_op(add_op, add_desc, insert_at=insert_at_idx)
                for i in range(insert_at_idx, len(stage.ops)):
                    if stage.op_descriptions[i].get('input_indices') == [dest_op_idx]:
                        stage.op_descriptions[i]['input_indices'] = [insert_at_idx]; break
                changed = True

        if changed:
            consistency_ok = self._ensure_target_cnn_consistency()
            return consistency_ok, newly_added_bns

        return False, []

    def _apply_meta_self_edit(self, action):
        edit_type = action.item()
        changed = False
        if edit_type == META_EDIT_DEEPEN_GNN:
            changed = self.meta_agent.deepen_gnn(device=self.device)
        elif edit_type == META_EDIT_WIDEN_GNN_HIDDEN:
            changed = self.meta_agent.widen_gnn_hidden_dim(factor=META_GNN_WIDEN_FACTOR, device=self.device)
        elif edit_type == META_EDIT_DEEPEN_MLP_HEAD:
            head_names = self.meta_agent.get_mlp_head_names()
            if head_names:
                changed = self.meta_agent.deepen_one_mlp_head(np.random.choice(head_names), device=self.device)
        elif edit_type == META_EDIT_SHRINK_GNN_HIDDEN:
            changed = self.meta_agent.shrink_gnn_hidden_dim(factor=META_GNN_SHRINK_FACTOR, device=self.device)
        elif edit_type == META_EDIT_PRUNE_GNN:
            changed = self.meta_agent.prune_gnn_layer(device=self.device)
        elif edit_type == META_EDIT_PRUNE_MLP_HEAD:
            head_names = self.meta_agent.get_mlp_head_names()
            prunable_heads = [h for h in head_names if self.meta_agent.head_depth_counters.get(h, 1) > MIN_MLP_HEAD_SEQUENTIAL_DEPTH]
            if prunable_heads:
                changed = self.meta_agent.prune_one_mlp_head(np.random.choice(prunable_heads), device=self.device)

        if changed:
            print(f"MetaAgentGNN arch changed. New Params: {sum(p.numel() for p in self.meta_agent.parameters())}. Re-init optimizer.")
            current_lr = self.opt_meta.param_groups[0]['lr']
            del self.opt_meta
            gc.collect(); torch.cuda.empty_cache()
            self.opt_meta = optim.Adam(self.meta_agent.parameters(), lr=current_lr)
        return changed

    def _sanitize_bn_stats(self):
        for m in self.target_cnn.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.running_mean.nan_to_num_(nan=0.0, posinf=1e4, neginf=-1e4)
                m.running_var.nan_to_num_(nan=1.0, posinf=1e4, neginf=1e-4)
                m.running_var.clamp_(min=1e-5)

    def _train_target_one_epoch(self, loader, optimizer, scheduler, name="Tr", current_epoch=0):
        self.target_cnn.train()

        warmup_total_epochs = 5
        if self.warmup_state['active']:
            if current_epoch < warmup_total_epochs:
                lr_scale = (current_epoch + 1) / warmup_total_epochs
                lr_scale /= math.sqrt(self.warmup_state['param_ratio'])
                for g in optimizer.param_groups:
                    g['lr'] = self.warmup_state['original_lr'] * lr_scale

            if current_epoch >= warmup_total_epochs:
                print(f"  Warmup complete. Restoring LR to {self.warmup_state['original_lr']:.2e}.")
                for g in optimizer.param_groups: g['lr'] = self.warmup_state['original_lr']
                self.warmup_state['active'] = False

        if self.frozen_bns and current_epoch >= warmup_total_epochs:
            print(f"  Unfreezing {len(self.frozen_bns)} new BatchNorm layers.")
            for bn in self.frozen_bns:
                bn.train()
            self.frozen_bns = []

        loss_sum, n_batches, correct, total = 0, 0, 0, 0
        data_iterator = itertools.islice(loader, BATCHES_PER_EPOCH) if BATCHES_PER_EPOCH is not None else loader
        for x,y in data_iterator:
            if x.size(0) <= 1: continue
            x,y=x.to(self.device),y.to(self.device)
            optimizer.zero_grad(set_to_none=True)
            with autocast(self.device.type,enabled=(self.device.type=='cuda')):
                logits=self.target_cnn(x); loss = self.criterion_target(logits, y)

            if torch.isnan(loss) or torch.isinf(loss):
                continue

            self.amp_scaler.scale(loss).backward()
            self.amp_scaler.unscale_(optimizer)
            clip_grad_norm_(self.target_cnn.parameters(), MAX_GRAD_NORM)
            self.amp_scaler.step(optimizer)
            self.amp_scaler.update()

            loss_sum+=loss.item(); _,pred=logits.max(1); total+=y.size(0); correct+=pred.eq(y).sum().item(); n_batches+=1

        if scheduler and not self.warmup_state['active']:
            scheduler.step()

        self._sanitize_bn_stats()

        return loss_sum/max(1,n_batches), correct/max(1,total)

    def _validate_target(self, loader):
        self.target_cnn.eval()
        loss_sum, correct, total, n_batches = 0,0,0,0
        criterion = nn.CrossEntropyLoss()
        data_iterator = itertools.islice(loader, BATCHES_PER_EPOCH) if BATCHES_PER_EPOCH is not None else loader
        with torch.no_grad():
            for x,y in data_iterator:
                if x.size(0) <= 1: continue
                x,y=x.to(self.device),y.to(self.device)
                logits=self.target_cnn(x); loss=criterion(logits,y)
                if torch.isnan(loss) or torch.isinf(loss): continue
                loss_sum+=loss.item(); _,pred=logits.max(1); total+=y.size(0); correct+=pred.eq(y).sum().item(); n_batches+=1
        return loss_sum/max(1,n_batches), correct/max(1,total)

    def _validate_edit_with_dummy_pass(self):
        self.target_cnn.eval()
        try:
            dummy_x, _ = next(iter(self.val_loader))
            dummy_x = dummy_x.to(self.device)

            with torch.no_grad(), autocast(self.device.type, enabled=(self.device.type=='cuda')):
                output = self.target_cnn(dummy_x)

            if torch.isnan(output).any() or torch.isinf(output).any():
                print("  !!! Dummy pass failed: NaN/Inf detected in output. Edit is invalid. !!!")
                return False
            return True
        except Exception as e:
            print(f"  !!! Dummy pass failed with exception: {e}. Edit is invalid. !!!")
            traceback.print_exc()
            return False
        finally:
            self.target_cnn.train()

    def train_loop(self, iterations=100):
        for itr in range(iterations):
            print(f"\n===== Iteration {itr+1}/{iterations} =====")
            print(f"Current MetaAgentGNN: GNN Layers={self.meta_agent.current_num_gnn_layers}, GNN Hidden={self.meta_agent.current_gnn_hidden_dim}, Params: {sum(p.numel() for p in self.meta_agent.parameters()):,}")
            print(f"Current Global Best Accuracy: {self.best_global_accuracy:.4f}")

            if self.opt_target is None:
                self.opt_target = optim.AdamW(self.target_cnn.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)

            for pre_ep in range(PRE_EPOCHS):
                self._train_target_one_epoch(self.train_loader, self.opt_target, None, f"PreE{pre_ep+1}", current_epoch=pre_ep)
            val_loss_b, acc_b = self._validate_target(self.val_loader)
            print(f"PreVal: ValidL={val_loss_b:.4f}, ValidA={acc_b:.4f}")

            pre_edit_model_backup = copy.deepcopy(self.target_cnn)
            old_params = sum(p.numel() for p in self.target_cnn.parameters() if p.requires_grad)

            graph, op_map, s_out_gid_map = self._create_target_cnn_graph_data()
            target_params_M = sum(p.numel() for p in self.target_cnn.parameters() if p.requires_grad)/1e6
            norm_m_gl=self.meta_agent.current_num_gnn_layers / (MIN_GNN_LAYERS + 5)
            norm_m_gh=self.meta_agent.current_gnn_hidden_dim / (MIN_GNN_HIDDEN_DIM + 256)
            g_state=[val_loss_b,acc_b,target_params_M,norm_m_gl,norm_m_gh]; self.global_states_history_buffer.append(g_state)
            if len(self.global_states_history_buffer)>MAX_GLOBAL_HISTORY_LEN: self.global_states_history_buffer.pop(0)
            hist = list(self.global_states_history_buffer);
            while len(hist)<MAX_GLOBAL_HISTORY_LEN: hist.insert(0,[0.]*GLOBAL_SUMMARY_FEATURE_DIM)
            g_hist_flat=torch.tensor(hist,dtype=torch.float32,device=self.device).view(1,-1)
            actions = {}; lp_t, en_t = 0.0, 0.0
            l_te, l_mse, val_p, fne, base_state_embed = self.meta_agent(graph, g_hist_flat)
            dist_te = Categorical(logits=l_te); s_te = dist_te.sample()
            actions['target_edit_type'] = s_te; lp_t += dist_te.log_prob(s_te); en_t += dist_te.entropy().mean()
            edit_type = s_te.item()
            conditional_logits = self.meta_agent.get_conditional_logits(base_state_embed, s_te)
            if edit_type != EDIT_TYPE_ADD_LINEAR_BLOCK:
                l_ts = conditional_logits['stage']; dist_ts = Categorical(logits=l_ts); s_ts = dist_ts.sample()
                actions['target_loc_stage'] = s_ts; lp_t += dist_ts.log_prob(s_ts); en_t += dist_ts.entropy().mean()
            if edit_type == EDIT_TYPE_ADD_CONV_BLOCK:
                l_tc = conditional_logits['ch_mult']; dist_tc = Categorical(logits=l_tc); s_tc = dist_tc.sample()
                actions['target_conv_ch_mult_idx'] = s_tc; lp_t += dist_tc.log_prob(s_tc); en_t += dist_tc.entropy().mean()
            elif edit_type == EDIT_TYPE_RESIZE_LAYER:
                stage=self.target_cnn.stages[s_ts.item()]; candidates=[]
                for oi,op in enumerate(stage.ops):
                    if isinstance(op, (nn.Conv2d, nn.Linear)):
                        valid_idx = _valid_resize_indices(op.out_channels if isinstance(op, nn.Conv2d) else op.out_features)
                        if valid_idx:
                            gid=op_map.get((s_ts.item(),oi),-1)
                            if gid!=-1 and gid<fne.size(0): candidates.append((oi,fne[gid],valid_idx))
                a_op_idx_rsz = -1; l_tr = conditional_logits['resize_factor']; dist_rf = Categorical(logits=l_tr)
                if candidates:
                    scores = self.meta_agent.head_resize_op_selector_scorer(torch.stack([e for _,e,_ in candidates])).squeeze(-1)
                    if scores.numel() > 0:
                        dist_co = Categorical(logits=scores); s_kth = dist_co.sample(); a_op_idx_rsz = candidates[s_kth.item()][0]
                        lp_t += dist_co.log_prob(s_kth); en_t += dist_co.entropy().mean()
                        valid_rf_idx = candidates[s_kth.item()][2]
                        masked_l_tr = _mask_logits(l_tr.squeeze(0), valid_rf_idx)
                        dist_rf = Categorical(logits=masked_l_tr if not torch.all(torch.isinf(masked_l_tr)) else l_tr)
                s_rf_idx = dist_rf.sample(); lp_t += dist_rf.log_prob(s_rf_idx); en_t += dist_rf.entropy().mean()
                actions['target_resize_factor_idx'] = s_rf_idx; actions['target_actual_op_idx_in_stage'] = a_op_idx_rsz
            elif edit_type == EDIT_TYPE_ADD_SKIP:
                stage = self.target_cnn.stages[s_ts.item()]; valid_pairs = []
                for i in range(-1, len(stage.ops)):
                    for j in range(i + 1, len(stage.ops)):
                        _, src_sp = stage.get_op_output_properties(i); _, dest_sp = stage.get_op_output_properties(j)
                        if src_sp == dest_sp: valid_pairs.append((i, j))
                if valid_pairs:
                    source_gids = [s_out_gid_map.get(s_ts.item() - 1, -1) if p[0] == -1 else op_map[(s_ts.item(), p[0])] for p in valid_pairs]
                    dest_gids = [op_map[(s_ts.item(), p[1])] for p in valid_pairs]
                    source_scores = self.meta_agent.head_skip_source_scorer(fne[source_gids]).squeeze()
                    dest_scores = self.meta_agent.head_skip_destination_scorer(fne[dest_gids]).squeeze()
                    if source_scores.dim() == 0: source_scores = source_scores.unsqueeze(0); dest_scores = dest_scores.unsqueeze(0)
                    pair_scores = source_scores + dest_scores; dist_pair = Categorical(logits=pair_scores); chosen_pair_idx = dist_pair.sample()
                    actions['source_op_idx'], actions['dest_op_idx'] = valid_pairs[chosen_pair_idx.item()]
                    lp_t += dist_pair.log_prob(chosen_pair_idx); en_t += dist_pair.entropy()
                else: actions['source_op_idx'] = -1; actions['dest_op_idx'] = -1

            log_parts = [f"TargEdit:Typ={edit_type}"]
            if 'target_loc_stage' in actions: log_parts.append(f"Stg={actions['target_loc_stage'].item()}")
            if edit_type==EDIT_TYPE_ADD_CONV_BLOCK: log_parts.append(f"CHM={DISCRETE_CH_MULT_ADD[actions['target_conv_ch_mult_idx'].item()]}")
            elif edit_type==EDIT_TYPE_RESIZE_LAYER: log_parts.append(f"Op={actions.get('target_actual_op_idx_in_stage',-1)},RszF={DISCRETE_RESIZE_FACTORS[actions.get('target_resize_factor_idx', 0).item()]}")
            elif edit_type==EDIT_TYPE_ADD_SKIP: log_parts.append(f"Src={actions.get('source_op_idx', -1)}->Dest={actions.get('dest_op_idx', -1)}")

            print(" ".join(log_parts))

            param_ratio = 1.0
            t_changed, new_bns = self._apply_target_cnn_edit(actions)

            if t_changed:
                is_edit_valid = self._validate_edit_with_dummy_pass()
                if not is_edit_valid:
                    print("  Edit rolled back due to dummy pass failure.")
                    self.target_cnn = pre_edit_model_backup
                    t_changed = False

                    self.consecutive_dummy_pass_failures += 1
                    if self.consecutive_dummy_pass_failures >= 3:
                        print("\n!!! INSTABILITY DETECTED: 3 consecutive edits failed dummy pass. !!!")
                        print("  Attempting to prune (simplify) Meta-Agent...")

                        pruning_actions = [META_EDIT_SHRINK_GNN_HIDDEN, META_EDIT_PRUNE_GNN, META_EDIT_PRUNE_MLP_HEAD]
                        valid_pruning_actions = [a for a in pruning_actions if a not in _invalid_meta_indices(self.meta_agent)]

                        if valid_pruning_actions:
                            chosen_self_edit = np.random.choice(valid_pruning_actions)
                            self._apply_meta_self_edit(torch.tensor(chosen_self_edit, device=self.device))
                        else:
                            print("  Meta-Agent is at minimum complexity or below prune threshold. No pruning possible.")

                        self.consecutive_dummy_pass_failures = 0
                else:
                    self.consecutive_dummy_pass_failures = 0

            if t_changed:
                del pre_edit_model_backup
                gc.collect()
                new_params = sum(p.numel() for p in self.target_cnn.parameters() if p.requires_grad)
                param_ratio = new_params / max(1, old_params)
                print(f"  TargetCNN arch changed. New Params: {new_params:,} (Ratio: {param_ratio:.2f})")

                print("  Creating fresh optimizer for new architecture.")
                del self.opt_target, self.sched_target; gc.collect(); torch.cuda.empty_cache()
                self.opt_target = optim.AdamW(self.target_cnn.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
                self.amp_scaler = torch.amp.GradScaler(enabled=(self.device.type=='cuda'))

                if param_ratio > 1.2:
                    print(f"  Large parameter jump. Activating LR warmup.")
                    self.warmup_state = {'active': True, 'original_lr': LEARNING_RATE, 'param_ratio': param_ratio}
                else:
                    self.warmup_state['active'] = False
                self.frozen_bns = new_bns
            else:
                print("  Edit was invalid or a no-op. Restoring pre-edit model.")
                self.target_cnn = pre_edit_model_backup

                del self.opt_target, self.sched_target; gc.collect(); torch.cuda.empty_cache()
                self.opt_target = optim.AdamW(self.target_cnn.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
                self.amp_scaler = torch.amp.GradScaler(enabled=(self.device.type=='cuda'))

            post_epochs = int(BASE_POST_EPOCHS * max(1.0, param_ratio))
            post_epochs = min(post_epochs, 100)
            print(f"  Training for {post_epochs} post-edit epochs.")
            self.sched_target = CosineAnnealingLR(self.opt_target, T_max=post_epochs)

            for post_ep in range(post_epochs):
                train_loss, train_acc = self._train_target_one_epoch(self.train_loader,self.opt_target, self.sched_target, f"PostE{post_ep+1}", current_epoch=post_ep)
                if (post_ep + 1) % 5 == 0 or post_ep == post_epochs - 1:
                    print(f"  PostE{post_ep+1}: TrainL={train_loss:.4f}, TrainA={train_acc:.4f}")
            val_loss_a, acc_a = self._validate_target(self.val_loader)
            print(f"PostVal: ValidL={val_loss_a:.4f}, ValidA={acc_a:.4f}")

            new_best_found = False
            if acc_a > self.best_global_accuracy:
                print(f"  *** New Best Global Accuracy! {self.best_global_accuracy:.4f} -> {acc_a:.4f} ***")
                self.best_global_accuracy = acc_a
                if self.best_global_model is not None:
                    del self.best_global_model
                self.best_global_model = copy.deepcopy(self.target_cnn)
                self.iterations_without_improvement = 0
                new_best_found = True
            else:
                self.iterations_without_improvement += 1
                print(f"  No improvement for {self.iterations_without_improvement} iterations.")

            reward = 100 * (acc_a - acc_b)
            penalty = 0.0
            current_params = sum(p.numel() for p in self.target_cnn.parameters())
            if current_params > COMPLEXITY_PENALTY_THRESHOLD:
                excess_params_M = (current_params - COMPLEXITY_PENALTY_THRESHOLD) / 1e6
                penalty = (excess_params_M ** 2) * COMPLEXITY_PENALTY_ALPHA

            final_reward = reward - penalty
            print(f"Reward: {reward:.4f} | Penalty: {penalty:.2f} | Final Reward: {final_reward:.4f}")

            advantage = final_reward - val_p.detach().squeeze().item()
            actor_loss = -(lp_t)*advantage
            critic_loss = F.mse_loss(val_p.squeeze(), torch.tensor(final_reward, device=self.device, dtype=torch.float32))
            meta_loss=actor_loss.mean()+ 0.5 * critic_loss - 0.0005*(en_t)
            if torch.isnan(meta_loss)or torch.isinf(meta_loss):
                print("MetaLoss NaN/Inf! Skipping meta-update."); gc.collect(); continue

            progress = (max(0, acc_a - LR_ACC_BASE_THRESHOLD)) / max(1e-6, LR_ACC_TARGET_THRESHOLD - LR_ACC_BASE_THRESHOLD)
            progress = min(1.0, progress)
            new_meta_lr = BASE_META_LR - progress * (BASE_META_LR - MIN_META_LR)
            for param_group in self.opt_meta.param_groups: param_group['lr'] = new_meta_lr

            self.opt_meta.zero_grad(); meta_loss.backward(); clip_grad_norm_(self.meta_agent.parameters(),MAX_GRAD_NORM); self.opt_meta.step()
            print(f"MetaL:{meta_loss.item():.4f}(A:{actor_loss.mean().item():.4f},C:{critic_loss.item():.4f},E:{en_t.item():.4f}) | MetaLR: {new_meta_lr:.2e}")

            if self.iterations_without_improvement >= 5:
                print("\n!!! STAGNATION DETECTED: 5 iterations without improvement. !!!")

                print(f"  Reverting Target CNN to best known model (Acc: {self.best_global_accuracy:.4f}).")
                if self.best_global_model is not None:
                    self.target_cnn = copy.deepcopy(self.best_global_model)
                    del self.opt_target; gc.collect()
                    self.opt_target = optim.AdamW(self.target_cnn.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)

                print("  Attempting to upgrade (grow) Meta-Agent...")
                growth_actions = [META_EDIT_DEEPEN_GNN, META_EDIT_WIDEN_GNN_HIDDEN, META_EDIT_DEEPEN_MLP_HEAD]
                valid_growth_actions = [a for a in growth_actions if a not in _invalid_meta_indices(self.meta_agent)]

                if valid_growth_actions:
                    chosen_self_edit = np.random.choice(valid_growth_actions)
                    self._apply_meta_self_edit(torch.tensor(chosen_self_edit, device=self.device))
                else:
                    print("  Meta-Agent cannot grow further. No self-edit possible.")

                self.iterations_without_improvement = 0
                print("  Stagnation counter reset. Continuing search with upgraded agent.\n")

            revert_threshold = 0.04
            if self.best_global_model is not None and acc_a < (self.best_global_accuracy - revert_threshold) and not new_best_found:
                print(f"  !!! Accuracy dropped by >{revert_threshold:.0%}. Reverting to the global best model (Acc: {self.best_global_accuracy:.4f}). !!!")

                del self.target_cnn; gc.collect()
                self.target_cnn = copy.deepcopy(self.best_global_model)

                del self.opt_target, self.sched_target; gc.collect()
                self.opt_target = optim.AdamW(self.target_cnn.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
                self.sched_target = CosineAnnealingLR(self.opt_target, T_max=BASE_POST_EPOCHS)
                self.amp_scaler = torch.amp.GradScaler(enabled=(self.device.type=='cuda'))

            gc.collect()

Enabling TensorFloat32 matmul precision for supported GPU.


In [6]:
#
# ==============================================================================
#  ENTRY POINT 1: STAGE 1 - BROAD SEARCH AND SAVING (Saves both models)
# ==============================================================================
#
if __name__ == '__main__':
    if not PYG_AVAILABLE:
        print("Exiting: PyTorch Geometric is required for this script.")
        exit()

    torch.manual_seed(420)
    np.random.seed(420)
    print(f"Using device: {DEVICE}")

    # --- Mount Google Drive for saving checkpoints ---
    try:
        from google.colab import drive
        drive.mount('/content/drive')
        SAVE_DIR = "/content/drive/My Drive/DEITI_Checkpoints"
    except ImportError:
        print("Not running in Google Colab. Models will be saved locally to './checkpoints'.")
        SAVE_DIR = "./checkpoints"

    os.makedirs(SAVE_DIR, exist_ok=True)
    print(f"Model checkpoints will be saved to: {SAVE_DIR}")

    deiti_system = DEITI()

    # --- STAGE 1: Broad Search ---
    print("\n\n" + "="*20 + " STAGE 1: BROAD SEARCH " + "="*20)
    try:
        print(f"Init TargetCNN P: {sum(p.numel() for p in deiti_system.target_cnn.parameters() if p.requires_grad):,}")
        print(f"Init MetaAgentGNN: L={deiti_system.meta_agent.current_num_gnn_layers},H={deiti_system.meta_agent.current_gnn_hidden_dim},P={sum(p.numel() for p in deiti_system.meta_agent.parameters()):,}")

        print("\n--- Initial Validation ---")
        initial_loss, initial_acc = deiti_system._validate_target(deiti_system.val_loader)
        print(f"InitVal: L={initial_loss:.4f},A={initial_acc:.4f}")

        deiti_system.best_global_accuracy = initial_acc
        deiti_system.best_global_model = copy.deepcopy(deiti_system.target_cnn)

        deiti_system.train_loop(iterations=150)
    except Exception as e:
        print(f"\n\nFATAL RUNTIME ERROR in STAGE 1: {e}")
        traceback.print_exc()
        print("Stopping execution.")
        raise e

    # --- Save the Stage 1 checkpoint ---
    print("\n\n" + "="*20 + " SAVING STAGE 1 CHECKPOINT " + "="*20)
    target_cnn_path = os.path.join(SAVE_DIR, "target_cnn_model_fashion.pth")
    meta_agent_path = os.path.join(SAVE_DIR, "meta_agent_model_fashion.pth") # New path for meta-agent
    acc_path = os.path.join(SAVE_DIR, "best_accuracy_fashion.txt")

    try:
        if deiti_system.best_global_model:
            print(f"Saving best TargetCNN from search with accuracy: {deiti_system.best_global_accuracy:.4f}")
            torch.save(deiti_system.best_global_model, target_cnn_path)
            print(f"Successfully saved Stage 1 best TargetCNN object to: {target_cnn_path}")

            # Save the final meta-agent from the end of the search
            print("Saving final Meta-Agent state...")
            torch.save(deiti_system.meta_agent, meta_agent_path)
            print(f"Successfully saved Meta-Agent object to: {meta_agent_path}")

            # Save the best validation accuracy from the search loop
            with open(acc_path, "w") as f:
                f.write(str(deiti_system.best_global_accuracy))
            print(f"Saved best accuracy ({deiti_system.best_global_accuracy:.4f}) to: {acc_path}")
        else:
            print("No best model was tracked during Stage 1. Cannot save checkpoint.")

    except Exception as e:
        print(f"!!! FAILED to save checkpoint: {e} !!!")
        traceback.print_exc()

Using device: cuda
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Model checkpoints will be saved to: /content/drive/My Drive/DEITI_Checkpoints
Successfully loaded Fashion-MNIST dataset.
Using 4 workers for data loading (persistent: True).


Init TargetCNN P: 372,682
Init MetaAgentGNN: L=2,H=32,P=3,059

--- Initial Validation ---
InitVal: L=2.3055,A=0.1064

===== Iteration 1/150 =====
Current MetaAgentGNN: GNN Layers=2, GNN Hidden=32, Params: 3,059
Current Global Best Accuracy: 0.1064
PreVal: ValidL=0.7075, ValidA=0.7444
TargEdit:Typ=0 Stg=0 CHM=2.0
  TargetCNN arch changed. New Params: 520,394 (Ratio: 1.40)
  Creating fresh optimizer for new architecture.
  Large parameter jump. Activating LR warmup.
  Training for 34 post-edit epochs.
  PostE5: TrainL=0.6628, TrainA=0.7585
  Warmup complete. Restoring LR to 1.00e-03.
  Unfreezing 1 new BatchNorm layers.
  PostE10: TrainL=0.5384, TrainA=0.8041
  PostE15:

In [None]:
#
# ==============================================================================
#  ENTRY POINT 2: STAGE 2 - LOADING AND FOCUSED, REPEATABLE SEARCH
# ==============================================================================
#
if __name__ == '__main__':
    if not PYG_AVAILABLE:
        print("Exiting: PyTorch Geometric is required for this script.")
        exit()

    torch.manual_seed(1337) # Use a different seed for Stage 2 if desired
    np.random.seed(1337)
    print(f"Using device: {DEVICE}")

    # --- Mount Google Drive to load checkpoints ---
    try:
        from google.colab import drive
        drive.mount('/content/drive')
        SAVE_DIR = "/content/drive/My Drive/DEITI_Checkpoints"
    except ImportError:
        print("Not running in Google Colab. Loading models from local './checkpoints'.")
        SAVE_DIR = "./checkpoints"

    if not os.path.exists(SAVE_DIR):
        print(f"ERROR: Save directory '{SAVE_DIR}' does not exist. Cannot load model.")
        exit()

    # --- Define checkpoint paths ---
    target_cnn_path = os.path.join(SAVE_DIR, "target_cnn_model_fashion.pth")
    meta_agent_path = os.path.join(SAVE_DIR, "meta_agent_model_fashion.pth") # New path for meta-agent
    acc_path = os.path.join(SAVE_DIR, "best_accuracy_fashion.txt")

    # --- Load the saved models and accuracy from previous run ---
    print("\n\n" + "="*20 + " LOADING PREVIOUS BEST MODELS " + "="*20)
    best_target_cnn_previous = None
    loaded_meta_agent = None
    previous_best_acc = 0.0

    try:
        # Load the TargetCNN
        best_target_cnn_previous = torch.load(target_cnn_path, map_location=DEVICE, weights_only=False)
        best_target_cnn_previous.to(DEVICE)
        print(f"Successfully loaded TargetCNN from: {target_cnn_path}")

        # Load the Meta-Agent
        loaded_meta_agent = torch.load(meta_agent_path, map_location=DEVICE, weights_only=False)
        loaded_meta_agent.to(DEVICE)
        print(f"Successfully loaded Meta-Agent from: {meta_agent_path}")

        # Load the accuracy
        with open(acc_path, "r") as f:
            previous_best_acc = float(f.read())
        print(f"Loaded previous best accuracy: {previous_best_acc:.4f}")

    except FileNotFoundError as e:
        print(f"!!! CHECKPOINT FILE NOT FOUND: {e}. Cannot proceed. !!!")
        exit()
    except Exception as e:
        print(f"!!! FAILED to load checkpoint: {e} !!!")
        traceback.print_exc()
        exit()

    # --- STAGE 2: Focused Search ---
    print("\n\n" + "="*20 + " STAGE 2: FOCUSED SEARCH " + "="*20)

    # Initialize a new DEITI system
    deiti_system_stage2 = DEITI()

    # Replace the default models with the loaded ones
    del deiti_system_stage2.target_cnn
    del deiti_system_stage2.meta_agent
    gc.collect()
    deiti_system_stage2.target_cnn = best_target_cnn_previous
    deiti_system_stage2.meta_agent = loaded_meta_agent

    # Reset optimizer, as it's tied to the new model parameters
    deiti_system_stage2.opt_target = None
    deiti_system_stage2.opt_meta = optim.Adam(deiti_system_stage2.meta_agent.parameters(), lr=BASE_META_LR) # Re-create optimizer for loaded agent

    print("Resetting global best trackers for this search session.")
    deiti_system_stage2.best_global_accuracy = previous_best_acc
    if deiti_system_stage2.best_global_model is not None:
        del deiti_system_stage2.best_global_model
    deiti_system_stage2.best_global_model = copy.deepcopy(deiti_system_stage2.target_cnn)

    print(f"Starting Stage 2 with TargetCNN params: {sum(p.numel() for p in deiti_system_stage2.target_cnn.parameters()):,}, Acc: {previous_best_acc:.4f}")
    print(f"Continuing with Meta-Agent params: {sum(p.numel() for p in deiti_system_stage2.meta_agent.parameters()):,}")

    # Run the training loop for Stage 2
    try:
        deiti_system_stage2.train_loop(iterations=150)
    except Exception as e:
        print(f"\n\nFATAL RUNTIME ERROR in STAGE 2: {e}")
        traceback.print_exc()
        print("Stopping execution.")
        raise e

    # --- Save the final best model and accuracy from Stage 2 ---
    print("\n\n" + "="*20 + " SAVING FINAL CHECKPOINT FOR THIS STAGE " + "="*20)

    try:
        if deiti_system_stage2.best_global_model:
            print(f"Saving final best TargetCNN with accuracy: {deiti_system_stage2.best_global_accuracy:.4f}")
            torch.save(deiti_system_stage2.best_global_model, target_cnn_path)
            print(f"Successfully saved final TargetCNN object to: {target_cnn_path}")

            # Save the final meta-agent state from this run
            print("Saving final Meta-Agent state...")
            torch.save(deiti_system_stage2.meta_agent, meta_agent_path)
            print(f"Successfully saved Meta-Agent object to: {meta_agent_path}")

            # Overwrite the accuracy file with the new best accuracy for the next run
            with open(acc_path, "w") as f:
                f.write(str(deiti_system_stage2.best_global_accuracy))
            print(f"Updated accuracy file for next run: {acc_path}")
        else:
            print("No better model was found in this run. Checkpoint remains unchanged.")

    except Exception as e:
        print(f"!!! FAILED to save final checkpoint: {e} !!!")
        traceback.print_exc()