<a href="https://colab.research.google.com/github/PlushyWushy/Prometheus/blob/main/Prometheus_Variation_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch_geometric



In [None]:
import time
import copy
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.distributions import Categorical
from torch.amp import autocast
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import CosineAnnealingLR
import math
import inspect # For debugging
import logging
import gc # Garbage Collector interface
import itertools

# Silence verbose logs from the compiler
logging.getLogger("torch._dynamo").setLevel(logging.FATAL)
logging.getLogger("torch._inductor").setLevel(logging.FATAL)

try:
    import torch_geometric.nn as pyg_nn
    from torch_geometric.data import Data
    from torch_geometric.utils import to_undirected
    from torch_geometric.nn.dense.linear import Linear as PyGLinear
    PYG_AVAILABLE = True
except ImportError:
    PYG_AVAILABLE = False
    print("PyTorch Geometric not found. GNN MetaAgent will not be available if used.")

# =======================================================================
#  Constants & Configuration
# =======================================================================
PRE_EPOCHS=1; POST_EPOCHS=10; BATCHES_PER_EPOCH=None; BATCH_SIZE=512
LEARNING_RATE=1e-3; META_LEARNING_RATE=5e-4; MAX_GRAD_NORM=1.0
DEVICE=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if DEVICE.type == 'cuda':
    torch.backends.cudnn.benchmark = True
    if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
        print("Enabling TensorFloat32 matmul precision for supported GPU.")
        torch.set_float32_matmul_precision('high')

# --- Action Space with Skip Connections ---
EDIT_TYPE_ADD_CONV2D=0; EDIT_TYPE_ADD_RELU=1; EDIT_TYPE_RESIZE_CNN_CONV=2; EDIT_TYPE_ADD_SKIP=3
NUM_TARGET_CNN_EDIT_TYPES=4
DISCRETE_KERNEL_SIZES=[1,3]; DISCRETE_CH_MULT_ADD=[0.5,1.0,2.0];
DISCRETE_RESIZE_FACTORS = [0.25, 0.5, 0.75, 1.25, 1.5, 2.0, 3.0, 4.0]
NUM_STAGES_TARGET_CNN=3

# --- MetaAgent & RL Configuration ---
META_EDIT_NONE = 0; META_EDIT_DEEPEN_GNN = 1; META_EDIT_WIDEN_GNN_HIDDEN = 2
META_EDIT_DEEPEN_MLP_HEAD = 3
NUM_META_SELF_EDIT_TYPES = 4
META_SELF_EDIT_INTERVAL = 5
META_GNN_WIDEN_FACTOR = 2
MAX_GNN_LAYERS = None
MAX_GNN_HIDDEN_DIM = None
MAX_MLP_HEAD_SEQUENTIAL_DEPTH = None
INITIAL_GNN_HIDDEN_DIM = 32; INITIAL_NUM_GNN_LAYERS = 2
EDIT_TYPE_EMBED_DIM = 16

# --- Graph & State Representation ---
OP_TYPE_IDS = {'conv2d':1,'relu':2,'maxpool2d':3,'add': 8, 'input_placeholder':5,'none':0}
NODE_FEATURE_DIM = 7
NORMALIZATION_WIDTH_DIVISOR = 512.0; NORMALIZATION_IDX_DIVISOR = 50.0; NORMALIZATION_SPATIAL_DIVISOR = 32.0
GLOBAL_SUMMARY_FEATURE_DIM = 3 + 2
MAX_GLOBAL_HISTORY_LEN = 5

# =======================================================================
#  Custom Modules for Dynamic Graph
# =======================================================================
class AddWithProjection(nn.Module):
    def __init__(self, projection_module=None):
        super().__init__()
        self.projection = projection_module if projection_module is not None else nn.Identity()

    def forward(self, x_primary, x_skip):
        x_skip_projected = self.projection(x_skip)

        if x_primary.shape[2:] != x_skip_projected.shape[2:]:
            x_skip_projected = F.interpolate(x_skip_projected, size=x_primary.shape[2:], mode='bilinear', align_corners=False)

        return x_primary + x_skip_projected

# =======================================================================
#  Net2Net Utilities & Masking Helpers
# =======================================================================
def _valid_resize_indices(old_oc: int):
    valid = []
    for idx, f_resize in enumerate(DISCRETE_RESIZE_FACTORS):
        new_oc_resize = max(1, int(round(old_oc * f_resize)))
        if new_oc_resize != old_oc: valid.append(idx)
    if not valid and old_oc > 0 :
        try: valid.append(DISCRETE_RESIZE_FACTORS.index(1.0))
        except ValueError: pass
    if not valid: valid.append(0)
    return valid

def _mask_logits(logits: torch.Tensor, valid_indices: list):
    mask_val = torch.full_like(logits, float('-inf'))
    if valid_indices:
        valid_indices_tensor = torch.tensor(valid_indices, device=logits.device, dtype=torch.long)
        if valid_indices_tensor.numel() > 0:
            valid_indices_tensor = valid_indices_tensor[valid_indices_tensor < logits.shape[-1]]
            if valid_indices_tensor.numel() > 0:
                 mask_val[..., valid_indices_tensor] = 0.0
    return logits + mask_val

def _invalid_meta_indices(agent_meta_invalid):
    invalid_meta = []
    if MAX_GNN_LAYERS is not None and agent_meta_invalid.current_num_gnn_layers >= MAX_GNN_LAYERS:
        invalid_meta.append(META_EDIT_DEEPEN_GNN)
    if MAX_GNN_HIDDEN_DIM is not None and agent_meta_invalid.current_gnn_hidden_dim >= MAX_GNN_HIDDEN_DIM:
        invalid_meta.append(META_EDIT_WIDEN_GNN_HIDDEN)
    if MAX_MLP_HEAD_SEQUENTIAL_DEPTH is not None:
        eligible_heads_meta_invalid = agent_meta_invalid.get_mlp_head_names()
        all_mlp_heads_at_max_depth_meta_invalid = True
        if not eligible_heads_meta_invalid: all_mlp_heads_at_max_depth_meta_invalid = True
        else:
            for head_name_meta_invalid in eligible_heads_meta_invalid:
                current_depth_meta_invalid = agent_meta_invalid.head_depth_counters.get(head_name_meta_invalid, 1)
                if current_depth_meta_invalid < MAX_MLP_HEAD_SEQUENTIAL_DEPTH:
                    all_mlp_heads_at_max_depth_meta_invalid = False; break
        if all_mlp_heads_at_max_depth_meta_invalid:
            invalid_meta.append(META_EDIT_DEEPEN_MLP_HEAD)
    return invalid_meta

def net2wider_conv_output(conv: nn.Conv2d, factor: float, device='cpu') -> nn.Conv2d:
    old_oc = conv.out_channels; new_oc = max(1, int(round(old_oc * factor)))
    if new_oc == old_oc : return conv
    new_conv_module = nn.Conv2d(conv.in_channels, new_oc, conv.kernel_size, stride=conv.stride, padding=conv.padding, groups=conv.groups, bias=(conv.bias is not None)).to(device)
    with torch.no_grad():
        new_conv_module.weight.data.fill_(0); min_oc = min(old_oc, new_oc)
        if old_oc > 0:
            new_conv_module.weight.data[:min_oc] = conv.weight.data[:min_oc].clone()
            if new_oc > old_oc:
                for i in range(old_oc, new_oc): new_conv_module.weight.data[i] = conv.weight.data[i % old_oc].clone()
        if conv.bias is not None and new_conv_module.bias is not None:
            new_conv_module.bias.data.fill_(0)
            if old_oc > 0:
                new_conv_module.bias.data[:min_oc] = conv.bias.data[:min_oc].clone()
                if new_oc > old_oc:
                    for i in range(old_oc, new_oc): new_conv_module.bias.data[i] = conv.bias.data[i % old_oc].clone()
            elif new_oc > 0: nn.init.zeros_(new_conv_module.bias.data)
        elif new_conv_module.bias is not None: nn.init.zeros_(new_conv_module.bias.data)
    return new_conv_module

def net2thinner_conv_output(conv: nn.Conv2d, factor: float, device='cpu') -> nn.Conv2d:
    old_oc = conv.out_channels; new_oc = max(1, int(round(old_oc * factor)))
    if new_oc == old_oc: return conv
    new_conv_module = nn.Conv2d(conv.in_channels, new_oc, conv.kernel_size, stride=conv.stride, padding=conv.padding, groups=conv.groups, bias=(conv.bias is not None)).to(device)
    with torch.no_grad():
        if old_oc > 0 :
            new_conv_module.weight.data = conv.weight.data[:new_oc].clone()
            if conv.bias is not None and new_conv_module.bias is not None:
                 new_conv_module.bias.data = conv.bias.data[:new_oc].clone()
            elif new_conv_module.bias is not None: nn.init.zeros_(new_conv_module.bias.data)
    return new_conv_module

def resize_conv_output(conv: nn.Conv2d, factor: float, device='cpu') -> nn.Conv2d:
    if abs(factor - 1.0) < 1e-6 : return conv
    old_oc = conv.out_channels; new_oc = max(1, int(round(old_oc * factor)))
    if new_oc == old_oc and old_oc > 0: return conv
    if new_oc == old_oc and old_oc == 0 and factor != 1.0 : pass
    elif new_oc == old_oc: return conv
    if new_oc > old_oc: return net2wider_conv_output(conv, float(new_oc)/old_oc if old_oc > 0 else factor, device)
    else: return net2thinner_conv_output(conv, float(new_oc)/old_oc if old_oc > 0 else factor, device)

def adapt_conv_input_channels(conv: nn.Conv2d, new_in_channels: int, device='cpu') -> nn.Conv2d:
    if conv.in_channels == new_in_channels: return conv
    new_in_channels = max(1, new_in_channels); old_ic = conv.in_channels; new_groups = conv.groups
    if (conv.groups == old_ic and old_ic != new_in_channels and old_ic !=0) or \
       (conv.groups != 1 and new_in_channels % conv.groups != 0):
        new_groups = new_in_channels if (conv.groups == old_ic and old_ic !=0) else 1
    new_conv_module = nn.Conv2d(new_in_channels, conv.out_channels, conv.kernel_size, stride=conv.stride, padding=conv.padding, groups=new_groups, bias=(conv.bias is not None)).to(device)
    with torch.no_grad():
        if old_ic == 0 or (new_groups != conv.groups and (conv.groups !=1 or new_groups ==1)): nn.init.kaiming_normal_(new_conv_module.weight, mode='fan_in', nonlinearity='relu')
        elif new_groups == 1 and conv.groups == 1 :
            w_new_conv_adapt = torch.zeros_like(new_conv_module.weight.data)
            for o in range(conv.out_channels):
                for i_new in range(new_in_channels):
                    if i_new < old_ic: w_new_conv_adapt[o, i_new].copy_(conv.weight.data[o, i_new])
                    else:
                        if old_ic > 0: w_new_conv_adapt[o, i_new].copy_(conv.weight.data[o, i_new % old_ic]); w_new_conv_adapt[o, i_new] /= max(1.0, (new_in_channels / old_ic))
                        else: nn.init.kaiming_normal_(w_new_conv_adapt[o, i_new].unsqueeze(0).unsqueeze(0), mode='fan_in', nonlinearity='relu')
            new_conv_module.weight.data.copy_(w_new_conv_adapt)
        else: nn.init.kaiming_normal_(new_conv_module.weight, mode='fan_in', nonlinearity='relu')
        if conv.bias is not None and new_conv_module.bias is not None: new_conv_module.bias.data.copy_(conv.bias.data)
        elif new_conv_module.bias is not None: nn.init.zeros_(new_conv_module.bias.data)
    return new_conv_module

def net2deeper_linear_insert_identity(head_module_owner: nn.Module, head_name: str, device='cpu'):
    if not hasattr(head_module_owner, head_name): print(f"Err: Attr {head_name} not found for deepening"); return False
    original_component = getattr(head_module_owner, head_name)
    if isinstance(original_component, nn.Linear):
        identity_dim = original_component.out_features
        if identity_dim <= 0: print(f"Cannot insert identity for dim {identity_dim} in Linear head {head_name}"); return False
        identity_layer = nn.Linear(identity_dim, identity_dim, bias=True).to(device)
        with torch.no_grad(): identity_layer.weight.data.copy_(torch.eye(identity_dim,device=device)); identity_layer.bias.data.fill_(0)
        setattr(head_module_owner,head_name,nn.Sequential(original_component,identity_layer).to(device)); print(f"  Deepened MLP head '{head_name}' (Linear -> Sequential)")
        return True
    elif isinstance(original_component, nn.Sequential):
        if not original_component or not isinstance(original_component[-1],nn.Linear): print(f"Cannot deepen Seq head '{head_name}', last not Linear."); return False
        identity_dim = original_component[-1].in_features
        if identity_dim <= 0: print(f"Cannot insert identity for dim {identity_dim} in Seq head {head_name}"); return False
        identity_layer = nn.Linear(identity_dim, identity_dim, bias=True).to(device)
        with torch.no_grad(): identity_layer.weight.data.copy_(torch.eye(identity_dim,device=device)); identity_layer.bias.data.fill_(0)
        new_seq_layers = nn.ModuleList([l for l in original_component[:-1]] + [identity_layer, original_component[-1]])
        setattr(head_module_owner,head_name,nn.Sequential(*new_seq_layers).to(device)); print(f"  Deepened Seq head '{head_name}'.")
        return True
    print(f"Err: Head '{head_name}' type {type(original_component)} not Linear/Seq for deepening.");
    return False

def _get_gcn_conv_linear_submodule(gcn_layer):
    if hasattr(gcn_layer, 'lin') and (isinstance(gcn_layer.lin, nn.Linear) or (PYG_AVAILABLE and isinstance(gcn_layer.lin, PyGLinear))):
        return gcn_layer.lin
    return None

if PYG_AVAILABLE:
    def widen_gcn_conv_hidden(gcn_layer: pyg_nn.GCNConv, new_hidden_dim: int, prev_layer_out_dim: int, device='cpu'):
        old_hidden_dim = gcn_layer.out_channels
        if new_hidden_dim == old_hidden_dim: return gcn_layer, False
        new_gcn = pyg_nn.GCNConv(prev_layer_out_dim,new_hidden_dim,bias=(gcn_layer.bias is not None), improved=gcn_layer.improved,add_self_loops=gcn_layer.add_self_loops, normalize=gcn_layer.normalize).to(device)
        with torch.no_grad():
            gcn_lin_original = _get_gcn_conv_linear_submodule(gcn_layer)
            gcn_lin_new = _get_gcn_conv_linear_submodule(new_gcn)
            if gcn_lin_original is None or gcn_lin_new is None: return new_gcn, True
            orig_lin_w = gcn_lin_original.weight.data; new_lin_w_target = gcn_lin_new.weight.data
            min_out = min(old_hidden_dim,new_hidden_dim); new_lin_w_fill = torch.zeros_like(new_lin_w_target)
            new_lin_w_fill[:min_out,:] = orig_lin_w[:min_out,:].clone()
            if new_hidden_dim>old_hidden_dim and old_hidden_dim>0:
                for r_idx in range(old_hidden_dim,new_hidden_dim): new_lin_w_fill[r_idx,:]=orig_lin_w[r_idx%old_hidden_dim,:].clone()
            new_lin_w_target.copy_(new_lin_w_fill)
            if gcn_layer.bias is not None and new_gcn.bias is not None:
                orig_gcn_b = gcn_layer.bias.data; new_gcn_b_target = new_gcn.bias.data
                new_gcn_b_fill = torch.zeros_like(new_gcn_b_target)
                new_gcn_b_fill[:min_out] = orig_gcn_b[:min_out].clone()
                if new_hidden_dim>old_hidden_dim and old_hidden_dim>0:
                    for r_idx_bias in range(old_hidden_dim,new_hidden_dim): new_gcn_b_fill[r_idx_bias]=orig_gcn_b[r_idx_bias%old_hidden_dim].clone()
                new_gcn_b_target.copy_(new_gcn_b_fill)
            elif new_gcn.bias is not None : nn.init.zeros_(new_gcn.bias.data)
        return new_gcn,True

    def adapt_gcn_conv_input_dim(gcn_layer: pyg_nn.GCNConv, new_input_dim: int, device='cpu'):
        old_input_dim = gcn_layer.in_channels
        if new_input_dim == old_input_dim: return gcn_layer, False
        new_gcn = pyg_nn.GCNConv(new_input_dim,gcn_layer.out_channels,bias=(gcn_layer.bias is not None), improved=gcn_layer.improved,add_self_loops=gcn_layer.add_self_loops,normalize=gcn_layer.normalize).to(device)
        with torch.no_grad():
            gcn_lin_original_adapt = _get_gcn_conv_linear_submodule(gcn_layer)
            gcn_lin_new_adapt = _get_gcn_conv_linear_submodule(new_gcn)
            if gcn_lin_original_adapt is None or gcn_lin_new_adapt is None: return new_gcn,True
            orig_lin_w_adapt = gcn_lin_original_adapt.weight.data; new_lin_w_target_adapt = gcn_lin_new_adapt.weight.data
            min_in_adapt = min(old_input_dim,new_input_dim); new_lin_w_fill_adapt = torch.zeros_like(new_lin_w_target_adapt)
            if old_input_dim>0:
                new_lin_w_fill_adapt[:,:min_in_adapt] = orig_lin_w_adapt[:,:min_in_adapt].clone()
                if new_input_dim > old_input_dim:
                    for c_adapt in range(old_input_dim,new_input_dim):
                        new_lin_w_fill_adapt[:,c_adapt]=orig_lin_w_adapt[:,c_adapt%old_input_dim].clone()
                        new_lin_w_fill_adapt[:,c_adapt]/=max(1.0,(new_input_dim/old_input_dim))
                new_lin_w_target_adapt.copy_(new_lin_w_fill_adapt)
            if gcn_layer.bias is not None and new_gcn.bias is not None: new_gcn.bias.data.copy_(gcn_layer.bias.data)
            elif new_gcn.bias is not None: nn.init.zeros_(new_gcn.bias.data)
        return new_gcn, True

    def create_identity_gcn_layer(dim: int, device='cpu', **gcn_kwargs):
        identity_gcn = pyg_nn.GCNConv(dim,dim,bias=gcn_kwargs.get('bias',True), normalize=gcn_kwargs.get('normalize',True), add_self_loops=gcn_kwargs.get('add_self_loops',True), improved=gcn_kwargs.get('improved',False)).to(device)
        with torch.no_grad():
            gcn_lin_identity = _get_gcn_conv_linear_submodule(identity_gcn)
            if gcn_lin_identity is not None:
                if gcn_lin_identity.weight.shape[0] == gcn_lin_identity.weight.shape[1]: gcn_lin_identity.weight.data.copy_(torch.eye(dim,device=device))
                else: nn.init.kaiming_uniform_(gcn_lin_identity.weight, a=math.sqrt(5))
                if hasattr(gcn_lin_identity, 'bias') and gcn_lin_identity.bias is not None: gcn_lin_identity.bias.data.fill_(0.0)
            if identity_gcn.bias is not None: identity_gcn.bias.data.fill_(0.0)
        return identity_gcn

# =======================================================================
#  Dynamic Models (TargetCNN)
# =======================================================================
class DynamicStageModule(nn.Module):
    def __init__(self, stage_idx_dyn, initial_in_channels_dyn, initial_spatial_size_dyn, max_ops_dyn=None):
        super().__init__()
        self.stage_idx = stage_idx_dyn
        self.initial_in_channels = initial_in_channels_dyn
        self.initial_spatial_size = initial_spatial_size_dyn
        self.max_ops = max_ops_dyn
        self.ops = nn.ModuleList()
        self.op_descriptions = []

    def add_op(self, op_module_dyn, op_description_dyn, insert_at=None):
        if self.max_ops is not None and len(self.ops) >= self.max_ops:
            return False
        if insert_at is None:
            self.ops.append(op_module_dyn)
            self.op_descriptions.append(op_description_dyn)
        else:
            self.ops.insert(insert_at, op_module_dyn)
            self.op_descriptions.insert(insert_at, op_description_dyn)
        return True

    def get_op_output_properties(self, op_idx):
        if op_idx == -1:
            return self.initial_in_channels, self.initial_spatial_size
        if 0 <= op_idx < len(self.op_descriptions):
            desc = self.op_descriptions[op_idx]
            return desc.get('out_channels', 0), desc.get('out_spatial', 0)
        raise IndexError(f"Operator index {op_idx} out of range for stage {self.stage_idx} with {len(self.ops)} ops.")

    def get_current_out_properties(self):
        if not self.op_descriptions:
            return self.initial_in_channels, self.initial_spatial_size
        return self.get_op_output_properties(len(self.ops) - 1)

    def forward(self, x_dyn):
        outputs_history_dyn = {-1: x_dyn}
        for i_dyn, op_desc_item_dyn in enumerate(self.op_descriptions):
            op_module_fwd_dyn = self.ops[i_dyn]
            input_indices_dyn = op_desc_item_dyn.get('input_indices', [-1])
            if not isinstance(input_indices_dyn, list): input_indices_dyn = [input_indices_dyn]

            current_op_inputs_dyn = []
            for source_op_local_idx_dyn in input_indices_dyn:
                if not (-1 <= source_op_local_idx_dyn < i_dyn):
                     source_op_local_idx_dyn = (i_dyn - 1) if i_dyn > 0 else -1

                if source_op_local_idx_dyn in outputs_history_dyn:
                    current_op_inputs_dyn.append(outputs_history_dyn[source_op_local_idx_dyn])
                else:
                    default_input_key_dyn = (i_dyn-1) if i_dyn > 0 else -1
                    current_op_inputs_dyn.append(outputs_history_dyn.get(default_input_key_dyn, x_dyn))
            try:
                if not current_op_inputs_dyn:
                    op_output_dyn = op_module_fwd_dyn(x_dyn)
                elif len(current_op_inputs_dyn) == 1:
                    op_output_dyn = op_module_fwd_dyn(current_op_inputs_dyn[0])
                else:
                    op_output_dyn = op_module_fwd_dyn(*current_op_inputs_dyn)
            except Exception as e_dyn_fwd:
                print(f"CRITICAL Error in DynamicStageModule op {i_dyn}, type {op_desc_item_dyn.get('type','Unknown')}, stage {self.stage_idx}: {e_dyn_fwd}"); raise e_dyn_fwd
            outputs_history_dyn[i_dyn] = op_output_dyn
        return outputs_history_dyn[len(self.ops)-1] if self.ops else x_dyn


class TargetCNN(nn.Module):
    def __init__(self, num_classes_cnn=10, num_stages_cnn=NUM_STAGES_TARGET_CNN, init_model_ch_cnn=16, input_spatial_size=32):
        super().__init__()
        self.num_stages = num_stages_cnn
        self.stages = nn.ModuleList()
        current_channels_cnn = 3
        current_spatial_size = input_spatial_size
        self.input_placeholder_desc = {'type': 'input_placeholder', 'out_channels': current_channels_cnn, 'out_spatial': input_spatial_size, 'input_indices': []}

        stage_base_channels_cnn = [init_model_ch_cnn, init_model_ch_cnn * 2, init_model_ch_cnn * 4]
        for i_cnn_stage in range(num_stages_cnn):
            stage_module_cnn = DynamicStageModule(i_cnn_stage, current_channels_cnn, current_spatial_size, max_ops_dyn=None)
            target_stage_out_channels_cnn = stage_base_channels_cnn[i_cnn_stage] if i_cnn_stage < len(stage_base_channels_cnn) else stage_base_channels_cnn[-1]

            conv1 = nn.Conv2d(current_channels_cnn, target_stage_out_channels_cnn, 3, 1, 1, bias=False).to(DEVICE)
            stage_module_cnn.add_op(conv1, {'type': 'conv2d', 'out_channels': target_stage_out_channels_cnn, 'out_spatial': current_spatial_size, 'input_indices': [-1]})
            relu1 = nn.ReLU(inplace=False).to(DEVICE)
            stage_module_cnn.add_op(relu1, {'type': 'relu', 'out_channels': target_stage_out_channels_cnn, 'out_spatial': current_spatial_size, 'input_indices': [0]})

            current_channels_after_block_cnn, current_spatial_after_block = stage_module_cnn.get_current_out_properties()
            if i_cnn_stage < num_stages_cnn - 1 :
                pool_cnn = nn.MaxPool2d(2, 2).to(DEVICE)
                current_spatial_size //= 2
                stage_module_cnn.add_op(pool_cnn, {'type': 'maxpool2d', 'out_channels': current_channels_after_block_cnn, 'out_spatial': current_spatial_size, 'input_indices': [1]})

            self.stages.append(stage_module_cnn)
            current_channels_cnn, current_spatial_size = stage_module_cnn.get_current_out_properties()

        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
        fc_in_features_cnn, _ = self.get_last_stage_out_properties()
        self.fc = nn.Linear(max(1, fc_in_features_cnn), num_classes_cnn).to(DEVICE)

    def forward(self,x_cnn_fwd):
        for stage_cnn_fwd in self.stages: x_cnn_fwd=stage_cnn_fwd(x_cnn_fwd)
        return self.fc(torch.flatten(self.adaptive_pool(x_cnn_fwd),1))

    def get_last_stage_out_properties(self):
        if not self.stages:
            return self.input_placeholder_desc['out_channels'], self.input_placeholder_desc['out_spatial']
        return self.stages[-1].get_current_out_properties()


# =======================================================================
#  GNN MetaAgent
# =======================================================================
if PYG_AVAILABLE:
    class MetaAgentGNN(nn.Module):
        def __init__(self, node_feature_dim_agent=NODE_FEATURE_DIM, initial_gnn_hidden_dim_agent=INITIAL_GNN_HIDDEN_DIM, initial_num_gnn_layers_agent=INITIAL_NUM_GNN_LAYERS,
                     global_history_dim_flat_agent=GLOBAL_SUMMARY_FEATURE_DIM * MAX_GLOBAL_HISTORY_LEN, device_to_use_agent=DEVICE, **kwargs_agent):
            super().__init__()
            self.device = device_to_use_agent; self.node_feature_dim = node_feature_dim_agent; self.current_gnn_hidden_dim = initial_gnn_hidden_dim_agent
            self.current_num_gnn_layers = initial_num_gnn_layers_agent; self.action_space_sizes = {k: v for k, v in kwargs_agent.items()}; self.head_depth_counters = {}

            self.edit_type_embedding = nn.Embedding(self.action_space_sizes['num_edit_types'], EDIT_TYPE_EMBED_DIM).to(self.device)

            self.policy_head_names = ['head_target_loc_stage', 'head_target_conv_kernel', 'head_target_conv_ch_mult', 'head_target_resize_factor']
            self._build_gnn_layers(); self._build_mlp_heads(global_history_dim_flat_agent)

        def _build_gnn_layers(self):
            self.gnn_layers = nn.ModuleList(); current_in_dim = self.node_feature_dim
            if self.current_num_gnn_layers > 0:
                for _ in range(self.current_num_gnn_layers):
                    out_dim = self.current_gnn_hidden_dim
                    self.gnn_layers.append(pyg_nn.GCNConv(current_in_dim, out_dim, bias=True, normalize=True, add_self_loops=True).to(self.device)); current_in_dim = out_dim

        def _get_current_gnn_output_dim(self):
            if self.current_num_gnn_layers == 0: return self.node_feature_dim
            return self.current_gnn_hidden_dim

        def _build_mlp_heads(self, global_history_dim_flat):
            gnn_output_dim = self._get_current_gnn_output_dim();
            base_input_dim = gnn_output_dim + global_history_dim_flat
            self.heads_input_dim_global_current = base_input_dim

            self.head_target_edit_type = nn.Linear(base_input_dim, self.action_space_sizes['num_edit_types'])
            self.head_meta_self_edit_type = nn.Linear(base_input_dim, NUM_META_SELF_EDIT_TYPES)
            self.head_value = nn.Linear(base_input_dim, 1)

            conditional_input_dim = base_input_dim + EDIT_TYPE_EMBED_DIM
            self.head_target_loc_stage = nn.Linear(conditional_input_dim, self.action_space_sizes['num_stages_target'])
            self.head_target_conv_kernel = nn.Linear(conditional_input_dim, len(DISCRETE_KERNEL_SIZES))
            self.head_target_conv_ch_mult = nn.Linear(conditional_input_dim, len(DISCRETE_CH_MULT_ADD))
            self.head_target_resize_factor = nn.Linear(conditional_input_dim, len(DISCRETE_RESIZE_FACTORS))

            self.head_resize_op_selector_scorer = nn.Linear(gnn_output_dim, 1)
            self.head_skip_source_scorer = nn.Linear(gnn_output_dim, 1)
            self.head_skip_destination_scorer = nn.Linear(gnn_output_dim, 1)

            all_head_names = self.policy_head_names + ['head_target_edit_type', 'head_meta_self_edit_type', 'head_value', 'head_resize_op_selector_scorer', 'head_skip_source_scorer', 'head_skip_destination_scorer']
            for name, module in self.named_children():
                if name in all_head_names or name == 'edit_type_embedding':
                    module.to(self.device)
            self.head_depth_counters = {name: (len(getattr(self,name)) if isinstance(getattr(self,name),nn.Sequential) else 1) for name in self.get_mlp_head_names()}

        def get_mlp_head_names(self):
             return self.policy_head_names

        def _process_graph_and_state(self, graph_data, global_states_history_flat):
            node_features, edge_index = graph_data.x, graph_data.edge_index
            embeddings = node_features
            if self.current_num_gnn_layers > 0 and graph_data.num_nodes > 0:
                for gnn_layer in self.gnn_layers:
                    embeddings = F.relu(gnn_layer(embeddings, edge_index))
            elif graph_data.num_nodes == 0:
                embeddings = torch.empty(0, self._get_current_gnn_output_dim(), device=self.device)

            batch_vector = graph_data.batch
            if batch_vector is None and embeddings.numel() > 0:
                batch_vector = torch.zeros(embeddings.size(0), dtype=torch.long, device=self.device)

            graph_embedding = pyg_nn.global_mean_pool(embeddings, batch_vector) if graph_data.num_nodes > 0 else torch.zeros(1, self._get_current_gnn_output_dim(), device=self.device)

            if global_states_history_flat.ndim == 1: global_states_history_flat = global_states_history_flat.unsqueeze(0)
            if graph_embedding.ndim == 1: graph_embedding = graph_embedding.unsqueeze(0)

            combined_features = torch.cat((graph_embedding, global_states_history_flat), dim=1)
            return combined_features, embeddings

        def forward(self, graph_data, global_states_history_flat):
            combined_features, node_embeddings = self._process_graph_and_state(graph_data, global_states_history_flat)

            l_te = self.head_target_edit_type(combined_features)
            l_mse = self.head_meta_self_edit_type(combined_features)
            value_pred = self.head_value(combined_features)

            return l_te, l_mse, value_pred, node_embeddings, combined_features

        def get_conditional_logits(self, base_state_embedding, chosen_edit_type_tensor):
            type_emb = self.edit_type_embedding(chosen_edit_type_tensor)
            conditional_state = torch.cat([base_state_embedding, type_emb], dim=1)
            logits = {}
            edit_type = chosen_edit_type_tensor.item()

            logits['stage'] = self.head_target_loc_stage(conditional_state)

            if edit_type == EDIT_TYPE_ADD_CONV2D:
                logits['kernel'] = self.head_target_conv_kernel(conditional_state)
                logits['ch_mult'] = self.head_target_conv_ch_mult(conditional_state)
            elif edit_type == EDIT_TYPE_RESIZE_CNN_CONV:
                logits['resize_factor'] = self.head_target_resize_factor(conditional_state)

            return logits

        def deepen_gnn(self, device='cpu'):
            if MAX_GNN_LAYERS is not None and self.current_num_gnn_layers >= MAX_GNN_LAYERS: return False
            print(f"  Deepening MetaAgentGNN: GNN Layers {self.current_num_gnn_layers} -> {self.current_num_gnn_layers + 1}")
            new_layer_in_dim = self._get_current_gnn_output_dim()
            if self.current_num_gnn_layers == 0 : new_gcn_layer = pyg_nn.GCNConv(self.node_feature_dim, self.current_gnn_hidden_dim, bias=True, normalize=True, add_self_loops=True).to(device)
            else: new_gcn_layer = create_identity_gcn_layer(self.current_gnn_hidden_dim, device=device) if new_layer_in_dim == self.current_gnn_hidden_dim else pyg_nn.GCNConv(new_layer_in_dim, self.current_gnn_hidden_dim, bias=True, normalize=True, add_self_loops=True).to(device)
            self.gnn_layers.append(new_gcn_layer); self.current_num_gnn_layers += 1
            if self._get_current_gnn_output_dim() + (GLOBAL_SUMMARY_FEATURE_DIM * MAX_GLOBAL_HISTORY_LEN) != self.heads_input_dim_global_current:
                 self._build_mlp_heads(GLOBAL_SUMMARY_FEATURE_DIM * MAX_GLOBAL_HISTORY_LEN)
            return True

        def widen_gnn_hidden_dim(self, factor=META_GNN_WIDEN_FACTOR, device='cpu'):
            old_dim = self.current_gnn_hidden_dim; new_dim = max(INITIAL_GNN_HIDDEN_DIM // 2 if INITIAL_GNN_HIDDEN_DIM > 1 else 1 , int(round(old_dim * factor)))
            if MAX_GNN_HIDDEN_DIM is not None: new_dim = min(new_dim, MAX_GNN_HIDDEN_DIM)
            if new_dim == old_dim or self.current_num_gnn_layers == 0: return False
            print(f"  Widening MetaAgentGNN: GNN Hidden Dim {old_dim} -> {new_dim}")
            new_gnn_list = nn.ModuleList(); current_in_dim = self.node_feature_dim; any_changed = False
            for i in range(self.current_num_gnn_layers):
                original_gcn = self.gnn_layers[i]; temp_gcn = original_gcn
                if original_gcn.in_channels != current_in_dim: temp_gcn, chg1 = adapt_gcn_conv_input_dim(original_gcn, current_in_dim, device); any_changed |= chg1
                final_gcn, chg2 = widen_gcn_conv_hidden(temp_gcn, new_dim, current_in_dim, device); any_changed |= chg2
                new_gnn_list.append(final_gcn); current_in_dim = new_dim
            if not any_changed and new_dim != old_dim : return False
            self.gnn_layers = new_gnn_list; self.current_gnn_hidden_dim = new_dim; self._build_mlp_heads(GLOBAL_SUMMARY_FEATURE_DIM * MAX_GLOBAL_HISTORY_LEN)
            return True

        def deepen_one_mlp_head(self, head_attr_name, device='cpu'):
            original_comp = getattr(self, head_attr_name, None)
            if original_comp is None: return False
            current_depth = self.head_depth_counters.get(head_attr_name, len(original_comp) if isinstance(original_comp, nn.Sequential) else 1)
            if MAX_MLP_HEAD_SEQUENTIAL_DEPTH is not None and current_depth >= MAX_MLP_HEAD_SEQUENTIAL_DEPTH : return False
            changed = net2deeper_linear_insert_identity(self, head_attr_name, device=device)
            if changed: self.head_depth_counters[head_attr_name] = len(getattr(self, head_attr_name)) if isinstance(getattr(self, head_attr_name), nn.Sequential) else 1
            return changed

# =======================================================================
#  DEITI System
# =======================================================================
class DEITI:
    def __init__(self):
        if not PYG_AVAILABLE: raise ImportError("PyTorch Geometric required for DEITI.")
        self.device=DEVICE
        self.target_cnn = TargetCNN().to(DEVICE)
        self.meta_agent = MetaAgentGNN(
            device_to_use_agent=self.device, num_edit_types=NUM_TARGET_CNN_EDIT_TYPES, num_stages_target=NUM_STAGES_TARGET_CNN,
            num_kernel_sizes=len(DISCRETE_KERNEL_SIZES), num_ch_mults=len(DISCRETE_CH_MULT_ADD), num_resize_factors=len(DISCRETE_RESIZE_FACTORS)
        ).to(DEVICE)

        self.criterion_target=nn.CrossEntropyLoss(label_smoothing=0.1)
        self.global_states_history_buffer=[]
        self.amp_scaler = torch.amp.GradScaler(enabled=(self.device.type=='cuda'))
        self._init_dataloaders(); self._ensure_target_cnn_consistency()
        self.opt_target = optim.AdamW(self.target_cnn.parameters(), lr=LEARNING_RATE, weight_decay=5e-4)
        self.sched_target = CosineAnnealingLR(self.opt_target, T_max=PRE_EPOCHS+POST_EPOCHS)
        self.opt_meta = optim.Adam(self.meta_agent.parameters(), lr=META_LEARNING_RATE)

    def mixup_data(self, x, y, alpha=1.0):
        if alpha > 0: lam = np.random.beta(alpha, alpha)
        else: lam = 1
        batch_size = x.size()[0]
        index = torch.randperm(batch_size, device=self.device)
        mixed_x = lam * x + (1 - lam) * x[index, :]
        y_a, y_b = y, y[index]
        return mixed_x, y_a, y_b, lam

    def mixup_criterion(self, pred, y_a, y_b, lam):
        return lam * self.criterion_target(pred, y_a) + (1 - lam) * self.criterion_target(pred, y_b)

    def _init_dataloaders(self):
        cifar10_mean, cifar10_std = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)
        train_transforms = transforms.Compose([
            transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
            transforms.RandomHorizontalFlip(),
            transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
            transforms.ToTensor(),
            transforms.Normalize(cifar10_mean, cifar10_std),
            transforms.RandomErasing(p=0.25, scale=(0.02, 0.33)),
        ])
        val_transforms=transforms.Compose([transforms.ToTensor(), transforms.Normalize(cifar10_mean, cifar10_std)])
        try:
            tr_ds = torchvision.datasets.CIFAR10('./data', train=True, download=True, transform=train_transforms)
            val_ds = torchvision.datasets.CIFAR10('./data', train=False, download=True, transform=val_transforms)
        except Exception as e:
            print(f"CIFAR-10 download failed: {e}. Using FakeData.")
            tr_ds, val_ds = torchvision.datasets.FakeData((BATCH_SIZE*50),(3,32,32),10,train_transforms), torchvision.datasets.FakeData((BATCH_SIZE*20),(3,32,32),10,val_transforms)
        num_workers = 4 if self.device.type == 'cuda' else 0
        print(f"Using {num_workers} workers for data loading.")
        self.train_loader = DataLoader(tr_ds, BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=self.device.type=='cuda')
        self.val_loader = DataLoader(val_ds, BATCH_SIZE, shuffle=False, num_workers=num_workers, pin_memory=self.device.type=='cuda')

    def _create_target_cnn_graph_data(self):
        nodes, src, tgt, op_map, gid = {}, [], [], {}, 0
        _, initial_spatial = self.target_cnn.input_placeholder_desc['out_channels'], self.target_cnn.input_placeholder_desc['out_spatial']

        current_spatial_map = {-1: initial_spatial}
        for si, sm in enumerate(self.target_cnn.stages):
            stage_initial_spatial = current_spatial_map.get(si-1, initial_spatial)
            op_output_props = {-1: (sm.initial_in_channels, stage_initial_spatial)}

            for oi, od in enumerate(sm.op_descriptions):
                op_map[(si, oi)] = gid
                op_type_str = od.get('type', 'none')

                input_indices = od.get('input_indices', [-1])
                input_indices = input_indices if isinstance(input_indices, list) else [input_indices]
                prev_op_idx = input_indices[0]
                _, expected_in_spatial = op_output_props.get(prev_op_idx, (sm.initial_in_channels, stage_initial_spatial))

                is_conv = 1.0 if op_type_str == 'conv2d' else 0.0
                out_spatial = expected_in_spatial
                stride_val = 1
                if op_type_str == 'maxpool2d':
                    stride_val = sm.ops[oi].stride
                    out_spatial //= 2 # Assuming stride 2 for maxpool
                elif op_type_str == 'conv2d':
                    stride_val = sm.ops[oi].stride[0]
                    out_spatial //= stride_val

                op_output_props[oi] = (od.get('out_channels'), out_spatial)
                od['out_spatial'] = out_spatial

                nf = [
                    float(OP_TYPE_IDS.get(op_type_str, 0)),
                    float(od.get('out_channels', 0)) / NORMALIZATION_WIDTH_DIVISOR,
                    float(si) / max(1, NUM_STAGES_TARGET_CNN - 1),
                    float(oi) / max(1, len(sm.ops) - 1),
                    is_conv,
                    float(stride_val -1),
                    float(out_spatial) / NORMALIZATION_SPATIAL_DIVISOR,
                ]
                nodes[gid] = nf
                gid += 1
            if sm.op_descriptions:
                 current_spatial_map[si] = op_output_props[len(sm.ops)-1][1]
            else:
                 current_spatial_map[si] = stage_initial_spatial


        s_out_gid_map = {s: op_map.get((s, len(sm_loop.ops) - 1), -1) for s, sm_loop in enumerate(self.target_cnn.stages) if sm_loop.ops}

        for si, sm in enumerate(self.target_cnn.stages):
            for oi, od in enumerate(sm.op_descriptions):
                cur_gid = op_map.get((si, oi), -1)
                if cur_gid == -1: continue

                in_ids = od.get('input_indices', [-1] if oi == 0 else [oi - 1])
                in_ids = in_ids if isinstance(in_ids, list) else [in_ids]

                for lsid in in_ids:
                    if lsid == -1:
                        sgid = s_out_gid_map.get(si - 1, -1) if si > 0 else -1
                    else:
                        sgid = op_map.get((si, lsid), -1)

                    if sgid != -1:
                        src.append(sgid)
                        tgt.append(cur_gid)

        x = torch.tensor([nodes[i] for i in range(gid)] if gid > 0 else [[0.] * NODE_FEATURE_DIM], dtype=torch.float32, device=self.device)
        eidx = torch.tensor([src, tgt], dtype=torch.long, device=self.device) if src else torch.empty((2, 0), dtype=torch.long, device=self.device)
        return Data(x=x, edge_index=eidx, batch=torch.zeros(x.size(0), dtype=torch.long, device=self.device) if x.size(0) > 0 else None), op_map

    def _ensure_target_cnn_consistency(self):
        current_channels, current_spatial = self.target_cnn.input_placeholder_desc['out_channels'], self.target_cnn.input_placeholder_desc['out_spatial']

        for stage_module in self.target_cnn.stages:
            stage_module.initial_in_channels = current_channels
            stage_module.initial_spatial_size = current_spatial

            op_output_props = {-1: (current_channels, current_spatial)}

            for i, op in enumerate(stage_module.ops):
                desc = stage_module.op_descriptions[i]
                input_indices = desc.get('input_indices', [-1])
                input_indices = input_indices if isinstance(input_indices, list) else [input_indices]

                prev_op_idx = input_indices[0] if input_indices else -1
                expected_in_channels, expected_in_spatial = op_output_props.get(prev_op_idx, (current_channels, current_spatial))

                if isinstance(op, nn.Conv2d):
                    if op.in_channels != expected_in_channels:
                        new_op = adapt_conv_input_channels(op, expected_in_channels, self.device)
                        stage_module.ops[i] = new_op
                        op = new_op
                    desc['out_channels'] = op.out_channels
                    desc['out_spatial'] = expected_in_spatial // op.stride[0]

                elif isinstance(op, nn.ReLU):
                    desc['out_channels'] = expected_in_channels
                    desc['out_spatial'] = expected_in_spatial

                elif isinstance(op, nn.MaxPool2d):
                    desc['out_channels'] = expected_in_channels
                    desc['out_spatial'] = expected_in_spatial // op.stride

                elif isinstance(op, AddWithProjection):
                    primary_ch, primary_sp = op_output_props[input_indices[0]]
                    skip_ch, _ = op_output_props[input_indices[1]]

                    if primary_ch != skip_ch:
                        op.projection = nn.Sequential(
                            nn.Conv2d(skip_ch, primary_ch, kernel_size=1, bias=False),
                            nn.BatchNorm2d(primary_ch)
                        ).to(self.device)
                    else:
                        op.projection = nn.Identity()

                    desc['out_channels'], desc['out_spatial'] = primary_ch, primary_sp

                op_output_props[i] = (desc.get('out_channels'), desc.get('out_spatial'))

            current_channels, current_spatial = stage_module.get_current_out_properties()

        final_channels, _ = self.target_cnn.get_last_stage_out_properties()
        fc_in_features = max(1, final_channels)
        if self.target_cnn.fc.in_features != fc_in_features:
            old_fc = self.target_cnn.fc
            self.target_cnn.fc = nn.Linear(fc_in_features, old_fc.out_features).to(self.device)

    def _apply_target_cnn_edit(self, actions):
        edit_type, stage_idx = actions['target_edit_type'].item(), actions['target_loc_stage'].item()
        if not (0 <= stage_idx < len(self.target_cnn.stages)): return False
        stage = self.target_cnn.stages[stage_idx]; changed = False

        if edit_type == EDIT_TYPE_RESIZE_CNN_CONV:
            op_idx = actions.get('target_actual_op_idx_in_stage', -1)
            if op_idx != -1 and 0 <= op_idx < len(stage.ops) and isinstance(stage.ops[op_idx], nn.Conv2d):
                op_mod, op_desc = stage.ops[op_idx], stage.op_descriptions[op_idx]; factor = DISCRETE_RESIZE_FACTORS[actions['target_resize_factor_idx'].item()]
                if abs(factor - 1.0) < 1e-6: return False
                new_op = resize_conv_output(op_mod, factor, self.device)
                if new_op is not op_mod: stage.ops[op_idx]=new_op; op_desc['out_channels']=new_op.out_channels; changed=True

        elif edit_type == EDIT_TYPE_ADD_CONV2D or edit_type == EDIT_TYPE_ADD_RELU:
            in_ch, in_sp = stage.get_current_out_properties()
            in_ch = max(1, in_ch)
            in_indices = [-1] if not stage.ops else [len(stage.ops)-1]
            op, desc = None, {}
            if edit_type == EDIT_TYPE_ADD_CONV2D:
                k, m = DISCRETE_KERNEL_SIZES[actions['target_conv_kernel_idx'].item()], DISCRETE_CH_MULT_ADD[actions['target_conv_ch_mult_idx'].item()]
                out_ch=max(1,int(round(in_ch*m)))
                op=nn.Conv2d(in_ch,out_ch,k,stride=1,padding=(k-1)//2,bias=False).to(self.device); nn.init.kaiming_normal_(op.weight,mode='fan_out',nonlinearity='relu')
                desc={'type':'conv2d','params':{'k':k,'mult':m},'out_channels':out_ch, 'out_spatial': in_sp, 'input_indices':in_indices}
            elif edit_type == EDIT_TYPE_ADD_RELU:
                op=nn.ReLU(inplace=False).to(self.device); desc={'type':'relu','out_channels':in_ch, 'out_spatial': in_sp, 'input_indices':in_indices}
            if op and stage.add_op(op,desc): changed=True

        elif edit_type == EDIT_TYPE_ADD_SKIP:
            source_op_idx = actions.get('source_op_idx', -1)
            dest_op_idx = actions.get('dest_op_idx', -1)

            if source_op_idx != -1 and dest_op_idx != -1:
                source_ch, _ = stage.get_op_output_properties(source_op_idx)
                dest_ch, dest_sp = stage.get_op_output_properties(dest_op_idx)

                projection = nn.Identity()
                if source_ch != dest_ch:
                    projection = nn.Sequential(
                        nn.Conv2d(source_ch, dest_ch, kernel_size=1, bias=False),
                        nn.BatchNorm2d(dest_ch)
                    ).to(self.device)

                add_op = AddWithProjection(projection)
                add_desc = {'type': 'add', 'out_channels': dest_ch, 'out_spatial': dest_sp, 'input_indices': [dest_op_idx, source_op_idx]}

                insert_at_idx = dest_op_idx + 1
                stage.add_op(add_op, add_desc, insert_at=insert_at_idx)

                # --- Robust Graph Rewiring ---
                # 1. Shift all indices that are affected by the insertion
                for op_idx_remap in range(len(stage.ops)):
                    if op_idx_remap == insert_at_idx: continue
                    remapped_indices = []
                    for input_idx in stage.op_descriptions[op_idx_remap]['input_indices']:
                        if input_idx >= insert_at_idx:
                            remapped_indices.append(input_idx + 1)
                        else:
                            remapped_indices.append(input_idx)
                    stage.op_descriptions[op_idx_remap]['input_indices'] = remapped_indices

                # 2. Find the op(s) that originally took the destination as input and rewire them to the new Add op
                for op_idx_remap in range(insert_at_idx + 1, len(stage.ops)):
                    remapped_indices = []
                    for input_idx in stage.op_descriptions[op_idx_remap]['input_indices']:
                        if input_idx == dest_op_idx + 1: # If it was pointing to the old destination (which is now shifted)
                            remapped_indices.append(insert_at_idx) # Point it to the new Add op
                        else:
                            remapped_indices.append(input_idx)
                    stage.op_descriptions[op_idx_remap]['input_indices'] = remapped_indices

                changed = True

        if changed: self._ensure_target_cnn_consistency()
        return changed

    def _apply_meta_self_edit(self, action):
        edit_type=action.item(); changed=False
        if edit_type==META_EDIT_DEEPEN_GNN: changed=self.meta_agent.deepen_gnn(device=self.device)
        elif edit_type==META_EDIT_WIDEN_GNN_HIDDEN: changed=self.meta_agent.widen_gnn_hidden_dim(factor=META_GNN_WIDEN_FACTOR,device=self.device)
        elif edit_type==META_EDIT_DEEPEN_MLP_HEAD:
            head_names=self.meta_agent.get_mlp_head_names()
            if head_names: changed=self.meta_agent.deepen_one_mlp_head(np.random.choice(head_names),device=self.device)
        if changed:
            print(f"MetaAgentGNN arch changed. New Params: {sum(p.numel() for p in self.meta_agent.parameters())}. Re-init optimizer.")
            del self.opt_meta; gc.collect(); torch.cuda.empty_cache()
            self.opt_meta=optim.Adam(self.meta_agent.parameters(),lr=META_LEARNING_RATE)
        return changed

    def _train_target_one_epoch(self, loader, optimizer, scheduler, name="Tr"):
        self.target_cnn.train()
        loss_sum, n_batches, correct, total = 0, 0, 0, 0

        data_iterator = itertools.islice(loader, BATCHES_PER_EPOCH) if BATCHES_PER_EPOCH is not None else loader

        for x,y in data_iterator:
            x,y=x.to(self.device),y.to(self.device)
            mixed_x, y_a, y_b, lam = self.mixup_data(x, y, alpha=0.4)
            y_a_one_hot, y_b_one_hot = F.one_hot(y_a, 10).float(), F.one_hot(y_b, 10).float()

            optimizer.zero_grad(set_to_none=True)
            with autocast(self.device.type,enabled=(self.device.type=='cuda')):
                logits=self.target_cnn(mixed_x)
                loss = self.mixup_criterion(logits, y_a_one_hot, y_b_one_hot, lam)

            if torch.isnan(loss) or torch.isinf(loss):
                continue

            self.amp_scaler.scale(loss).backward()
            self.amp_scaler.unscale_(optimizer)
            clip_grad_norm_(self.target_cnn.parameters(), MAX_GRAD_NORM)
            self.amp_scaler.step(optimizer)
            self.amp_scaler.update()

            loss_sum+=loss.item(); _,pred=logits.max(1); total+=y_a.size(0); correct+=pred.eq(y_a).sum().item(); n_batches+=1

        scheduler.step()
        return loss_sum/max(1,n_batches), correct/max(1,total)

    def _validate_target(self, loader):
        self.target_cnn.eval()
        loss_sum, correct, total, n_batches = 0,0,0,0
        criterion = nn.CrossEntropyLoss()

        data_iterator = itertools.islice(loader, BATCHES_PER_EPOCH) if BATCHES_PER_EPOCH is not None else loader

        with torch.no_grad():
            for x,y in data_iterator:
                x,y=x.to(self.device),y.to(self.device)
                logits=self.target_cnn(x); loss=criterion(logits,y)
                if torch.isnan(loss) or torch.isinf(loss): continue
                loss_sum+=loss.item(); _,pred=logits.max(1); total+=y.size(0); correct+=pred.eq(y).sum().item(); n_batches+=1
        return loss_sum/max(1,n_batches), correct/max(1,total)

    def train_loop(self, iterations=100):
        for itr in range(iterations):
            print(f"\n===== Iteration {itr+1}/{iterations} =====")
            print(f"Current MetaAgentGNN: GNN Layers={self.meta_agent.current_num_gnn_layers}, GNN Hidden={self.meta_agent.current_gnn_hidden_dim}, Params: {sum(p.numel() for p in self.meta_agent.parameters()):,}")
            self.sched_target = CosineAnnealingLR(self.opt_target, T_max=PRE_EPOCHS+POST_EPOCHS)

            for pre_ep in range(PRE_EPOCHS):
                self._train_target_one_epoch(self.train_loader, self.opt_target, self.sched_target, f"PreE{pre_ep+1}")
            val_loss_b, acc_b = self._validate_target(self.val_loader)
            print(f"PreVal: ValidL={val_loss_b:.4f}, ValidA={acc_b:.4f}")

            model_backup = copy.deepcopy(self.target_cnn)
            optimizer_state_backup = copy.deepcopy(self.opt_target.state_dict())

            graph, op_map = self._create_target_cnn_graph_data()
            target_params_M = sum(p.numel() for p in self.target_cnn.parameters() if p.requires_grad)/1e6
            norm_m_gl=self.meta_agent.current_num_gnn_layers / (MAX_GNN_LAYERS or self.meta_agent.current_num_gnn_layers + 1e-6)
            norm_m_gh=self.meta_agent.current_gnn_hidden_dim / (MAX_GNN_HIDDEN_DIM or self.meta_agent.current_gnn_hidden_dim + 1e-6)
            g_state=[val_loss_b,acc_b,target_params_M,norm_m_gl,norm_m_gh]; self.global_states_history_buffer.append(g_state)
            if len(self.global_states_history_buffer)>MAX_GLOBAL_HISTORY_LEN: self.global_states_history_buffer.pop(0)
            hist = list(self.global_states_history_buffer);
            while len(hist)<MAX_GLOBAL_HISTORY_LEN: hist.insert(0,[0.]*GLOBAL_SUMMARY_FEATURE_DIM)
            g_hist_flat=torch.tensor(hist,dtype=torch.float32,device=self.device).view(1,-1)

            actions = {}
            lp_t, en_t = 0.0, 0.0

            l_te, l_mse, val_p, fne, base_state_embed = self.meta_agent(graph, g_hist_flat)
            dist_te = Categorical(logits=l_te)
            s_te = dist_te.sample()
            actions['target_edit_type'] = s_te
            lp_t += dist_te.log_prob(s_te); en_t += dist_te.entropy().mean()
            edit_type = s_te.item()

            conditional_logits = self.meta_agent.get_conditional_logits(base_state_embed, s_te)
            l_ts = conditional_logits['stage']
            dist_ts = Categorical(logits=l_ts)
            s_ts = dist_ts.sample()
            actions['target_loc_stage'] = s_ts
            lp_t += dist_ts.log_prob(s_ts); en_t += dist_ts.entropy().mean()

            if edit_type == EDIT_TYPE_ADD_CONV2D:
                l_tk = conditional_logits['kernel']; dist_tk = Categorical(logits=l_tk); s_tk = dist_tk.sample()
                actions['target_conv_kernel_idx'] = s_tk
                lp_t += dist_tk.log_prob(s_tk); en_t += dist_tk.entropy().mean()

                l_tc = conditional_logits['ch_mult']; dist_tc = Categorical(logits=l_tc); s_tc = dist_tc.sample()
                actions['target_conv_ch_mult_idx'] = s_tc
                lp_t += dist_tc.log_prob(s_tc); en_t += dist_tc.entropy().mean()

            elif edit_type == EDIT_TYPE_RESIZE_CNN_CONV:
                stage=self.target_cnn.stages[s_ts.item()]; candidates=[]
                for oi,op in enumerate(stage.ops):
                    if isinstance(op, nn.Conv2d):
                        valid_idx = _valid_resize_indices(op.out_channels)
                        if valid_idx:
                            gid=op_map.get((s_ts.item(),oi),-1)
                            if gid!=-1 and gid<fne.size(0): candidates.append((oi,fne[gid],valid_idx))

                a_op_idx_rsz = -1
                l_tr = conditional_logits['resize_factor']
                dist_rf = Categorical(logits=l_tr)
                if candidates:
                    scores = self.meta_agent.head_resize_op_selector_scorer(torch.stack([e for _,e,_ in candidates])).squeeze(-1)
                    if scores.numel() > 0:
                        dist_co = Categorical(logits=scores); s_kth = dist_co.sample(); a_op_idx_rsz = candidates[s_kth.item()][0]
                        lp_t += dist_co.log_prob(s_kth); en_t += dist_co.entropy().mean()
                        valid_rf_idx = candidates[s_kth.item()][2]
                        masked_l_tr = _mask_logits(l_tr.squeeze(0), valid_rf_idx)
                        dist_rf = Categorical(logits=masked_l_tr if not torch.all(torch.isinf(masked_l_tr)) else l_tr)

                s_rf_idx = dist_rf.sample()
                lp_t += dist_rf.log_prob(s_rf_idx); en_t += dist_rf.entropy().mean()
                actions['target_resize_factor_idx'] = s_rf_idx
                actions['target_actual_op_idx_in_stage'] = a_op_idx_rsz

            elif edit_type == EDIT_TYPE_ADD_SKIP:
                stage = self.target_cnn.stages[s_ts.item()]
                valid_pairs = []
                for i in range(len(stage.ops)):
                    for j in range(i + 1, len(stage.ops)):
                        _, src_sp = stage.get_op_output_properties(i)
                        _, dest_sp = stage.get_op_output_properties(j)
                        if src_sp == dest_sp:
                            valid_pairs.append((i, j))

                if valid_pairs:
                    source_indices = torch.tensor([op_map[(s_ts.item(), p[0])] for p in valid_pairs], device=self.device)
                    dest_indices = torch.tensor([op_map[(s_ts.item(), p[1])] for p in valid_pairs], device=self.device)

                    source_scores = self.meta_agent.head_skip_source_scorer(fne[source_indices]).squeeze()
                    dest_scores = self.meta_agent.head_skip_destination_scorer(fne[dest_indices]).squeeze()

                    if source_scores.dim() == 0:
                        source_scores = source_scores.unsqueeze(0)
                        dest_scores = dest_scores.unsqueeze(0)

                    pair_scores = source_scores + dest_scores
                    dist_pair = Categorical(logits=pair_scores)
                    chosen_pair_idx = dist_pair.sample()

                    chosen_source_op, chosen_dest_op = valid_pairs[chosen_pair_idx.item()]
                    actions['source_op_idx'] = chosen_source_op
                    actions['dest_op_idx'] = chosen_dest_op
                    lp_t += dist_pair.log_prob(chosen_pair_idx); en_t += dist_pair.entropy()
                else:
                    actions['source_op_idx'] = -1; actions['dest_op_idx'] = -1

            valid_meta_actions = [i for i in range(NUM_META_SELF_EDIT_TYPES) if i not in _invalid_meta_indices(self.meta_agent)]
            if not valid_meta_actions: s_mse, lp_mse, en_mse = torch.tensor(META_EDIT_NONE, device=self.device), torch.tensor(0.0), torch.tensor(0.0)
            else:
                masked_l_mse = _mask_logits(l_mse.squeeze(0), valid_meta_actions)
                dist_mse = Categorical(logits=masked_l_mse if not torch.all(torch.isinf(masked_l_mse)) else l_mse)
                s_mse=dist_mse.sample(); lp_mse = dist_mse.log_prob(s_mse); en_mse = dist_mse.entropy().mean()

            log_parts = [f"TargEdit:Typ={edit_type},Stg={actions['target_loc_stage'].item()}"]
            if edit_type==EDIT_TYPE_ADD_CONV2D: log_parts.append(f"K={DISCRETE_KERNEL_SIZES[actions['target_conv_kernel_idx'].item()]},CHM={DISCRETE_CH_MULT_ADD[actions['target_conv_ch_mult_idx'].item()]}")
            elif edit_type==EDIT_TYPE_RESIZE_CNN_CONV: log_parts.append(f"Op={actions.get('target_actual_op_idx_in_stage',-1)},RszF={DISCRETE_RESIZE_FACTORS[actions.get('target_resize_factor_idx', 0).item()]}")
            elif edit_type==EDIT_TYPE_ADD_SKIP: log_parts.append(f"Src={actions.get('source_op_idx', -1)}->Dest={actions.get('dest_op_idx', -1)}")
            log_parts.append(f"| MetaSelfEdit:Typ={s_mse.item()}")
            print(" ".join(log_parts))

            t_changed = self._apply_target_cnn_edit(actions)
            if t_changed:
                print(f"  TargetCNN arch changed. New Params: {sum(p.numel() for p in self.target_cnn.parameters() if p.requires_grad):,}")
                del self.opt_target, self.sched_target; gc.collect(); torch.cuda.empty_cache()
                self.opt_target=optim.AdamW(self.target_cnn.parameters(),lr=LEARNING_RATE,weight_decay=1e-4)
                self.sched_target = CosineAnnealingLR(self.opt_target, T_max=PRE_EPOCHS+POST_EPOCHS)

            m_changed=False
            if (itr+1)%META_SELF_EDIT_INTERVAL==0 and s_mse.item()!=META_EDIT_NONE:
                m_changed=self._apply_meta_self_edit(s_mse)

            for post_ep in range(POST_EPOCHS):
                train_loss, train_acc = self._train_target_one_epoch(self.train_loader,self.opt_target, self.sched_target, f"PostE{post_ep+1}")
                print(f"  PostE{post_ep+1}: TrainL={train_loss:.4f}, TrainA={train_acc:.4f}")
            val_loss_a, acc_a = self._validate_target(self.val_loader)
            print(f"PostVal: ValidL={val_loss_a:.4f}, ValidA={acc_a:.4f}")

            if acc_a == 0.0 and val_loss_b > 0:
                print("\nFATAL ERROR: Training has collapsed, validation accuracy is zero.")
                raise RuntimeError("Training collapsed, accuracy is zero.")

            was_reverted = False
            if acc_a < (acc_b - 0.08):
                print(f"  !!! ACCURACY DROP > 8% ({acc_b:.4f} -> {acc_a:.4f}). REVERTING MODEL. !!!")
                self.target_cnn = model_backup
                self.opt_target = optim.AdamW(self.target_cnn.parameters(), lr=LEARNING_RATE, weight_decay=5e-4)
                self.opt_target.load_state_dict(optimizer_state_backup)
                self.sched_target = CosineAnnealingLR(self.opt_target, T_max=PRE_EPOCHS+POST_EPOCHS)
                acc_a = acc_b; val_loss_a = val_loss_b
                was_reverted = True

            reward=(acc_a-acc_b)*100
            if was_reverted: reward = -10.0
            else:
                if not t_changed: reward-=2
                if m_changed: reward+=1
            print(f"Reward:{reward:.2f}")

            advantage = reward - val_p.detach().squeeze().item()
            actor_loss = -(lp_t+lp_mse)*advantage
            critic_loss = F.mse_loss(val_p.squeeze(), torch.tensor(reward, device=self.device, dtype=torch.float32))

            meta_loss=actor_loss.mean()+ 0.5 * critic_loss - 0.05*(en_t+en_mse)
            if torch.isnan(meta_loss)or torch.isinf(meta_loss):
                print("MetaLoss NaN/Inf! Skipping meta-update.")
                del model_backup, optimizer_state_backup; gc.collect()
                continue

            self.opt_meta.zero_grad(); meta_loss.backward(); clip_grad_norm_(self.meta_agent.parameters(),MAX_GRAD_NORM); self.opt_meta.step()
            print(f"MetaL:{meta_loss.item():.4f}(A:{actor_loss.mean().item():.4f},C:{critic_loss.item():.4f},E:{(en_t+en_mse).item():.4f})")

            del model_backup, optimizer_state_backup; gc.collect()


if __name__ == '__main__':
    if not PYG_AVAILABLE:
        print("Exiting: PyTorch Geometric required."); exit()
    torch.manual_seed(42); np.random.seed(42)
    print(f"Using device: {DEVICE}")
    try:
        deiti_system = DEITI()
        print(f"Init TargetCNN P: {sum(p.numel() for p in deiti_system.target_cnn.parameters() if p.requires_grad):,}")
        print(f"Init MetaAgentGNN: L={deiti_system.meta_agent.current_num_gnn_layers},H={deiti_system.meta_agent.current_gnn_hidden_dim},P={sum(p.numel() for p in deiti_system.meta_agent.parameters()):,}")
        print("\n--- Initial Validation ---")
        initial_loss,initial_acc=deiti_system._validate_target(deiti_system.val_loader);print(f"InitVal: L={initial_loss:.4f},A={initial_acc:.4f}")
        deiti_system.train_loop(iterations=200)
        print("\n===== Search Finished =====")
        print(f"Final TargetCNN P: {sum(p.numel() for p in deiti_system.target_cnn.parameters() if p.requires_grad):,}")
        print(f"Final MetaAgentGNN: L={deiti_system.meta_agent.current_num_gnn_layers},H={deiti_system.meta_agent.current_gnn_hidden_dim},P={sum(p.numel() for p in deiti_system.meta_agent.parameters()):,}")
    except Exception as e:
        print(f"\n\nFATAL RUNTIME ERROR: {e}")
        print("Stopping execution to release Colab resources.")
        raise e

Enabling TensorFloat32 matmul precision for supported GPU.
Using device: cuda
Using 4 workers for data loading.
Init TargetCNN P: 24,122
Init MetaAgentGNN: L=2,H=32,P=3,181

--- Initial Validation ---
InitVal: L=2.3045,A=0.0906

===== Iteration 1/200 =====
Current MetaAgentGNN: GNN Layers=2, GNN Hidden=32, Params: 3,181
PreVal: ValidL=1.9505, ValidA=0.3002
TargEdit:Typ=2,Stg=1 Op=0,RszF=4.0 | MetaSelfEdit:Typ=0
  TargetCNN arch changed. New Params: 93,242
  PostE1: TrainL=2.1243, TrainA=0.1776
  PostE2: TrainL=2.0940, TrainA=0.1876
  PostE3: TrainL=2.0640, TrainA=0.1922
  PostE4: TrainL=2.0444, TrainA=0.2163
  PostE5: TrainL=2.0182, TrainA=0.2299
  PostE6: TrainL=2.0133, TrainA=0.2290
  PostE7: TrainL=1.9878, TrainA=0.2257
  PostE8: TrainL=1.9873, TrainA=0.2337
  PostE9: TrainL=1.9726, TrainA=0.2375
  PostE10: TrainL=1.9592, TrainA=0.2637
PostVal: ValidL=1.5349, ValidA=0.4842
Reward:18.40
MetaL:250.3989(A:87.5769,C:326.2265,E:5.8258)

===== Iteration 2/200 =====
Current MetaAgentGNN: G