<a href="https://colab.research.google.com/github/PlushyWushy/Prometheus/blob/main/Prometheus_Variation_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip uninstall sam
!pip install torch-optimizer
!pip install sam-pytorch --upgrade
!pip install torch_optim
# =======================================================================
#  FUNCTION & CLASS DEFINITIONS  (***FULL LIST, NO OMISSIONS***)
#  –  now includes Net2THINNER / Net2SHALLOW, aggressive anti‑overfit
#    augments, dropout, label‑smoothing, AdamW‑cosine LR, inf‑guard, etc.
# =======================================================================
!pip install torch-optimizer
import time
import copy
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch_optimizer as optim_extra
from torch.optim.lr_scheduler import OneCycleLR
# ──────────────────────────────────────────────────────────────────────
#  build SAM‑AdamW  +  “mini‑cosine‑restart” learning‑rate schedule
# ──────────────────────────────────────────────────────────────────────
import math
from torch.optim.lr_scheduler import LambdaLR
# if you installed sam-pytorch
# At the top of your script
from sam.sam import SAM
import json
# import torch.optim as optim # Already imported
# DO NOT import torch_optimizer anymore if you uninstalled it
# or if you don't intend to use other optimizers from it.


from torch.utils.data import DataLoader, random_split
from torch.distributions import Categorical
from torch.cuda.amp import autocast, GradScaler
from torch.nn.utils import clip_grad_norm_
PRE_EPOCHS        = 1      # epochs before each edit
POST_EPOCHS       = 1      # epochs after each edit
BATCHES_PER_EPOCH = None  # None means full epoch
BATCH_SIZE        = 128     # training batch size (halved for memory)
SELF_EDIT_INT     = 10      # how often meta-agent can self-edit
MAX_HISTORY_LEN   = 20     # max timesteps for state history
STATE_DIM         = 6      # dimension of per-step state vector
LEARNING_RATE     = 5e-4  # meta-agent learning rate
MAX_GRAD_NORM     = 1.0 # maximum gradient norm for clipping


def mixup_cutmix_collate(batch, alpha=1.0, cutmix_prob=0.5, num_classes=10):
    """
    Collate_fn that applies MixUp or CutMix on each batch.
    - batch: list of (image_tensor, label_int)
    - returns: mixed_images, mixed_labels_onehot
    """
    # unpack
    images = torch.stack([item[0] for item in batch], dim=0)
    labels = torch.tensor([item[1] for item in batch], device=images.device)

    # one-hot encode
    y_onehot = torch.zeros(len(labels), num_classes, device=images.device)
    y_onehot.scatter_(1, labels.unsqueeze(1), 1.0)

    # sample mixing coefficient
    lam = np.random.beta(alpha, alpha)
    idx = torch.randperm(len(images), device=images.device)

    if np.random.rand() < cutmix_prob:
        # CutMix
        _, _, H, W = images.shape
        cx = np.random.randint(W)
        cy = np.random.randint(H)
        rw = int(W * np.sqrt(1 - lam))
        rh = int(H * np.sqrt(1 - lam))

        x1 = max(cx - rw // 2, 0)
        y1 = max(cy - rh // 2, 0)
        x2 = min(cx + rw // 2, W)
        y2 = min(cy + rh // 2, H)

        images[:, :, y1:y2, x1:x2] = images[idx, :, y1:y2, x1:x2]
        # adjust lambda to match pixel ratio
        lam = 1 - ((x2 - x1) * (y2 - y1) / float(H * W))
    else:
        # MixUp
        images = lam * images + (1 - lam) * images[idx]

    # mix the one-hot labels
    y_onehot = lam * y_onehot + (1 - lam) * y_onehot[idx]

    return images, y_onehot
# ------ utils.py (or top of file) ------
def mixup_cutmix(batch, alpha=1.0, cutmix_prob=0.5):
    x, y = batch
    lam  = np.random.beta(alpha, alpha)
    idx  = torch.randperm(x.size(0))
    if np.random.rand() < cutmix_prob:
        # CutMix
        _, _, H, W = x.size()
        cx, cy = np.random.randint(W), np.random.randint(H)
        rw = int(W * np.sqrt(1 - lam)); rh = int(H * np.sqrt(1 - lam))
        x1, y1 = np.clip(cx - rw // 2, 0, W), np.clip(cy - rh // 2, 0, H)
        x2, y2 = np.clip(cx + rw // 2, 0, W), np.clip(cy + rh // 2, 0, H)
        x[:, :, y1:y2, x1:x2] = x[idx, :, y1:y2, x1:x2]
        lam = 1 - (x2 - x1) * (y2 - y1) / (H * W)
    else:
        # MixUp
        x = lam * x + (1 - lam) * x[idx]
    y = lam * y + (1 - lam) * y[idx]
    return x, y


# ---------------------------
# Utility: Model Summary
# ---------------------------
def model_summary(model: nn.Module, name: str = "Model"):
    print(f"\n--- {name} ---")
    print(model)
    total = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total}")
    print("-----------------\n")

# ---------------------------
# Backup Model
# ---------------------------
def backup_model(model: nn.Module):
    return copy.deepcopy(model)

# ---------------------------
# Net2Net utilities  (grow **and** shrink)
# ---------------------------
# ---------------------------
# SafeMaxPool2d — gracefully skips if spatial dims < kernel
# ---------------------------
class SafeMaxPool2d(nn.Module):
    def __init__(self, kernel_size=2, stride=2):
        super().__init__()
        self.kernel_size = kernel_size
        self.stride      = stride
        self.pool        = nn.MaxPool2d(kernel_size, stride)

    def forward(self, x):
        if x.size(-1) < self.kernel_size or x.size(-2) < self.kernel_size:
            return x                      # too small → identity
        return self.pool(x)

def net2wider_conv(conv: nn.Conv2d, factor: float) -> nn.Conv2d:
    """
    Net2Net widen with **fractional** factor support.
    new_out = int(round(old_out * factor))
    """
    ic, oc = conv.in_channels, conv.out_channels
    new_oc = max(1, int(round(oc * factor)))
    # build the new conv
    new = nn.Conv2d(ic, new_oc, conv.kernel_size,
                    stride=conv.stride, padding=conv.padding,
                    bias=(conv.bias is not None))
    with torch.no_grad():
        # copy the old weights to the first oc channels
        new.weight[:oc] = conv.weight.clone()
        # then tile existing filters to fill the rest
        for i in range(oc, new_oc):
            new.weight[i] = conv.weight[i % oc].clone()
        if conv.bias is not None:
            new.bias[:oc] = conv.bias.clone()
            for i in range(oc, new_oc):
                new.bias[i] = conv.bias[i % oc].clone()
    return new

def net2thinner_conv(conv: nn.Conv2d, factor: float = 0.5) -> nn.Conv2d:
    """
    Net2THINNER: shrink out_channels by factor (≥0, ≤1).
    We keep the first ⌈factor·oc⌉ filters and prune the rest.
    """
    ic, oc = conv.in_channels, conv.out_channels
    new_oc = max(1, int(round(oc * factor)))
    kept   = slice(0, new_oc)
    new = nn.Conv2d(ic, new_oc, conv.kernel_size,
                    stride=conv.stride, padding=conv.padding,
                    bias=(conv.bias is not None))
    with torch.no_grad():
        new.weight = nn.Parameter(conv.weight[kept].clone())
        if conv.bias is not None:
            new.bias   = nn.Parameter(conv.bias[kept].clone())
    return new

def net2wider_conv_in(conv: nn.Conv2d, new_in: int) -> nn.Conv2d:
    oc = conv.out_channels
    new = nn.Conv2d(new_in, oc, conv.kernel_size,
                    stride=conv.stride, padding=conv.padding,
                    bias=(conv.bias is not None))
    with torch.no_grad():
        for o in range(oc):
            for i in range(new_in):
                new.weight[o, i] = conv.weight[o, i % conv.in_channels].clone()
        if conv.bias is not None:
            new.bias.copy_(conv.bias)
    return new
def net2deeper_conv(conv: nn.Conv2d) -> nn.Conv2d:
    oc = conv.out_channels
    new = nn.Conv2d(oc, oc, 3, padding=1)
    nn.init.dirac_(new.weight)
    if new.bias is not None:
        nn.init.constant_(new.bias, 0)
    return new

def net2shallower_conv(seq: nn.Sequential, first_idx: int) -> bool:
    """
    Net2SHALLOW: fuse conv[first_idx] & conv[first_idx+1] into one
    by simple copy (acts as identity for second conv), then remove the second.
    Returns True if a fusion was made.
    """
    if first_idx+1 >= len(seq) or not isinstance(seq[first_idx+1], nn.Conv2d):
        return False
    # Just drop the *second* conv (and following BN if any); keep first unchanged.
    removed = seq.pop(first_idx+1)
    # also drop an immediate BN layer if still indexed at first_idx+1
    if first_idx+1 < len(seq) and isinstance(seq[first_idx+1], nn.BatchNorm2d):
        seq.pop(first_idx+1)
    print(f"  fused & removed Conv(out={removed.out_channels}) for Net2SHALLOW")
    return True

def net2wider_lstm(lstm: nn.LSTM, factor: int = 2) -> nn.LSTM:
    inp, hid, nl = lstm.input_size, lstm.hidden_size, lstm.num_layers
    new_hid = hid * factor
    new_lstm = nn.LSTM(inp, new_hid, nl, batch_first=True)
    old_sd, new_sd = lstm.state_dict(), new_lstm.state_dict()
    for L in range(nl):
        w_ih = old_sd[f'weight_ih_l{L}'].repeat_interleave(factor, 0)
        if L > 0:
            w_ih = w_ih.repeat_interleave(factor, 1)
        new_sd[f'weight_ih_l{L}'] = (w_ih / factor).clone()
        w_hh = (old_sd[f'weight_hh_l{L}']
                .repeat_interleave(factor, 0)
                .repeat_interleave(factor, 1))
        new_sd[f'weight_hh_l{L}'] = (w_hh / factor).clone()
        new_sd[f'bias_ih_l{L}'] = old_sd[f'bias_ih_l{L}'].repeat_interleave(factor).clone()
        new_sd[f'bias_hh_l{L}'] = old_sd[f'bias_hh_l{L}'].repeat_interleave(factor).clone()
    new_lstm.load_state_dict(new_sd, strict=True)
    return new_lstm

def net2deeper_lstm(lstm: nn.LSTM) -> nn.LSTM:
    inp, hid, nl = lstm.input_size, lstm.hidden_size, lstm.num_layers
    new_lstm = nn.LSTM(inp, hid, nl + 1, batch_first=True)
    old_sd, new_sd = lstm.state_dict(), new_lstm.state_dict()
    for L in range(nl):
        for k in ['weight_ih_l', 'weight_hh_l', 'bias_ih_l', 'bias_hh_l']:
            new_sd[f'{k}{L}'].copy_(old_sd[f'{k}{L}'])
    new_lstm.load_state_dict(new_sd)
    return new_lstm

def update_meta_agent_heads(agent: nn.Module, hidden_size: int,
                            edits: int, locs: int, meta: int):
    dev = next(agent.parameters()).device
    agent.head_e = nn.Linear(hidden_size, edits).to(dev)
    agent.head_l = nn.Linear(hidden_size, locs).to(dev)
    agent.head_m = nn.Linear(hidden_size, meta).to(dev)
    agent.head_v = nn.Linear(hidden_size, 1).to(dev)



# ──────────────────────────────────────────────────────────────────────────────
#  1) BottleneckBlock (ResNet-style)
# ──────────────────────────────────────────────────────────────────────────────
class BottleneckBlock(nn.Module):
    """
    ResNet‑style bottleneck that copes with very small widths.
    mid = max(1, out_ch//expansion)  → never zero.
    """
    def __init__(self, in_ch, out_ch, stride=1, expansion=4):
        super().__init__()
        mid = max(1, out_ch // expansion)          # ← safeguard
        self.conv1 = nn.Conv2d(in_ch,  mid, 1, bias=False)
        self.bn1   = nn.GroupNorm(1, mid)          # BN→GN (batch‑size‑1 safe)

        self.conv2 = nn.Conv2d(mid, mid, 3, stride,
                               padding=1, bias=False)
        self.bn2   = nn.GroupNorm(1, mid)

        self.conv3 = nn.Conv2d(mid, out_ch, 1, bias=False)
        self.bn3   = nn.GroupNorm(1, out_ch)

        self.relu  = nn.ReLU(inplace=True)
        self.down  = nn.Identity()                 # default

        if stride != 1 or in_ch != out_ch:
            self.down = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 1, stride, bias=False),
                nn.GroupNorm(1, out_ch)
            )

    def forward(self, x):
        identity = self.down(x)
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        return self.relu(out + identity)



# ──────────────────────────────────────────────────────────────────────────────
#  2) SEBlock (Squeeze-and-Excitation)
# ──────────────────────────────────────────────────────────────────────────────
class SEBlock(nn.Module):
    """
    Squeeze‑and‑Excitation that works for any channel count ≥ 1.
    """
    def __init__(self, channels, reduction=16):
        super().__init__()
        squeezed = max(1, channels // reduction)   # ← safeguard
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels,  squeezed, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(squeezed, channels,  bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avgpool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y



# ──────────────────────────────────────────────────────────────────────────────
#  3) MBConv (MobileNetV2 inverted residual)
# ──────────────────────────────────────────────────────────────────────────────
class MBConv(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1, exp_ratio=6):
        super().__init__()
        mid = in_ch * exp_ratio
        self.block = nn.Sequential(
            # pointwise expand
            nn.Conv2d(in_ch, mid, 1, bias=False),
            nn.BatchNorm2d(mid),
            nn.ReLU(inplace=True),
            # depthwise
            nn.Conv2d(mid, mid, 3, stride, padding=1, groups=mid, bias=False),
            nn.BatchNorm2d(mid),
            nn.ReLU(inplace=True),
            # pointwise project
            nn.Conv2d(mid, out_ch, 1, bias=False),
            nn.BatchNorm2d(out_ch),
        )
        self.use_res_connect = (stride == 1 and in_ch == out_ch)
        self.out_channels = out_ch
    def forward(self, x):
        out = self.block(x)
        return out + x if self.use_res_connect else out


# ──────────────────────────────────────────────────────────────────────────────
#  4) InceptionBlock (multi-branch)
# ──────────────────────────────────────────────────────────────────────────────
class InceptionBlock(nn.Module):
    def __init__(self, c: int):
        super().__init__()
        # split the input into four branches, but never drop below 1 channel
        b1 = max(1, c // 4)
        b2 = max(1, c // 2)
        b3 = max(1, c // 4)
        b4 = max(1, c // 4)
        # ensure the middle of branch3 also has at least 1 channel
        b3_mid = max(1, b3 // 2)

        self.branch1 = nn.Conv2d(c,    b1,      kernel_size=1)
        self.branch2 = nn.Sequential(
            nn.Conv2d(c,    b1,      kernel_size=1),
            nn.Conv2d(b1,   b2,      kernel_size=3, padding=1),
        )
        self.branch3 = nn.Sequential(
            nn.Conv2d(c,    b3_mid,  kernel_size=1),
            nn.Conv2d(b3_mid, b3,    kernel_size=5, padding=2),
        )
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(3, stride=1, padding=1),
            nn.Conv2d(c,    b4,      kernel_size=1),
        )

        self.in_channels  = c
        self.out_channels = b1 + b2 + b3 + b4

    def forward(self, x):
        return torch.cat([
            self.branch1(x),
            self.branch2(x),
            self.branch3(x),
            self.branch4(x),
        ], dim=1)



# ──────────────────────────────────────────────────────────────────────────────
#  5) MetaAgent (__init__ only) — bump edits→11 sigma
#  (Note: There are two MetaAgent class definitions in the provided code.
#   Keeping the second one as it appears later and seems to be the one in use,
#   or the intended one given its position relative to TargetCNN.)
# ──────────────────────────────────────────────────────────────────────────────
# class MetaAgent(nn.Module): # First definition, commented out if the second is preferred
#     def __init__(self, state_dim=STATE_DIM, hidden_dim=32, num_layers=1,
#                  edits=12, locs=3, meta=3):
#         super().__init__()
#         self.lstm    = nn.LSTM(state_dim, hidden_dim, num_layers, batch_first=True)
#         self.head_e  = nn.Linear(hidden_dim, edits)
#         self.head_l  = nn.Linear(hidden_dim, locs)
#         self.head_m  = nn.Linear(hidden_dim, meta)
#         self.head_v  = nn.Linear(hidden_dim, 1)

#     def forward(self, seq: torch.Tensor):
#         out, _ = self.lstm(seq)
#         h      = out[:, -1, :]
#         return self.head_e(h), self.head_l(h), self.head_m(h), self.head_v(h)


# ---------------------------
# Residual Block  (now with dropout)
# ---------------------------
class ResidualBlock(nn.Module):
    def __init__(self, c: int, p_drop: float = 0.1):
        super().__init__()
        self.conv1     = nn.Conv2d(c, c, 3, padding=1)
        self.bn1       = nn.BatchNorm2d(c)
        self.relu      = nn.ReLU()
        self.conv2     = nn.Conv2d(c, c, 3, padding=1)
        self.bn2       = nn.BatchNorm2d(c)
        self.drop      = nn.Dropout(p_drop)
        self.drop_path = nn.Dropout(p_drop)   # ← stochastic depth on the skip
        nn.init.constant_(self.bn2.weight, 0)
        nn.init.constant_(self.bn2.bias,   0)

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.drop(out)
        out = self.bn2(self.conv2(out))
        # ← apply drop_path to the *residual* branch
        return self.relu(out + self.drop_path(x))



# ---------------------------
# Target CNN  (added dropout before fc)
# ---------------------------
class TargetCNN(nn.Module):
    def __init__(self, base: int = 32, num_classes: int = 10):
        super().__init__()
        self.init_widths = [base, base*2, base*4]
        self.stages = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(3, base, 3, padding=1),
                nn.BatchNorm2d(base), nn.ReLU(),
                SafeMaxPool2d(2)
            ),
            nn.Sequential(
                nn.Conv2d(base, base*2, 3, padding=1),
                nn.BatchNorm2d(base*2), nn.ReLU(),
                SafeMaxPool2d(2)
            ),
            nn.Sequential(
                nn.Conv2d(base*2, base*4, 3, padding=1),
                nn.BatchNorm2d(base*4), nn.ReLU(),
                nn.AdaptiveAvgPool2d((1,1))
            )
        ])


        self.pre_fc_drop = nn.Dropout(0.25)
        self.fc          = nn.Linear(base*4, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for s in self.stages:
            x = s(x)
        x = self.pre_fc_drop(x.view(x.size(0), -1))
        return self.fc(x)

    def widths(self):
        """Return the *actual* output width of each stage, even if it
           no longer begins with a Conv2d."""
        w = []
        for seq in self.stages:
            conv = next((m for m in seq if isinstance(m, nn.Conv2d)), None)
            if conv is not None:
                w.append(conv.out_channels)
                continue

            norm = next((m for m in seq
                         if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm))),
                        None)
            if isinstance(norm, nn.BatchNorm2d):
                w.append(norm.num_features)
            elif isinstance(norm, nn.GroupNorm):
                w.append(norm.num_channels)
            else:                      # extremely unlikely fallback
                w.append(0)
        return w

# ---------------------------
# Meta-Agent (This is the second definition, likely the one intended to be active)
# ---------------------------
class MetaAgent(nn.Module):
    def __init__(self, state_dim=STATE_DIM, hidden_dim=32, num_layers=1,
                 edits=12, locs=3, meta=3):
        super().__init__()
        self.lstm    = nn.LSTM(state_dim, hidden_dim, num_layers, batch_first=True)
        self.head_e  = nn.Linear(hidden_dim, edits)
        self.head_l  = nn.Linear(hidden_dim, locs)
        self.head_m  = nn.Linear(hidden_dim, meta)
        self.head_v  = nn.Linear(hidden_dim, 1)

    def forward(self, seq: torch.Tensor):
        out, _ = self.lstm(seq)
        h      = out[:, -1, :]
        return self.head_e(h), self.head_l(h), self.head_m(h), self.head_v(h)

# ---------------------------
# Smooth Cross‑Entropy (label‑smoothing)
# ---------------------------
class SmoothCE(nn.Module):
    def __init__(self, eps: float = 0.0):
        super().__init__()
        self.eps = eps  # you can ignore eps now if using only soft labels

    def forward(self, logits, target_soft):
        # target_soft: [B, num_classes] floats summing to 1
        logp = torch.log_softmax(logits, dim=1)
        return -(target_soft * logp).sum(dim=1).mean()





# ---------------------------
# DEITI Trainer  (FULL CLASS ‑‑ everything below is complete)
# ---------------------------
class DEITI:
    def __init__(self,
                 pre_epochs=PRE_EPOCHS,
                 post_epochs=POST_EPOCHS,
                 batches_per_epoch=BATCHES_PER_EPOCH,
                 batch_size=BATCH_SIZE):

        # aggressive data augmentation
        tf_train = transforms.Compose([
            transforms.RandomResizedCrop(
                32, scale=(0.7, 1.0), ratio=(0.9, 1.1)          # tighter aspect jitter
            ),
            transforms.RandomRotation(15),                      # ±15° rotation
            transforms.RandomHorizontalFlip(0.5),               # 50 % chance
            transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),         # (B, C, S, H)
            transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
            transforms.ToTensor(),
            transforms.Normalize(                               # CIFAR-10 statistics
                (0.4914, 0.4822, 0.4465),
                (0.2470, 0.2435, 0.2616)
            ),
            transforms.RandomErasing(                           # weaker p, wider ratio
                p=0.5, scale=(0.02, 0.2), ratio=(0.3, 3.3),
                value='random'
            ),
        ])

        tf_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(
                (0.4914, 0.4822, 0.4465),
                (0.2470, 0.2435, 0.2616)
            ),
        ])

        # ─────────────────────────────────────────────────────────────────────────────
        # 2)  DEITI.__init__  –  loader stays identical, but keep α & p explicit
        # ─────────────────────────────────────────────────────────────────────────────
        full = torchvision.datasets.CIFAR10('./data', train=True, download=True, transform=tf_train)
        val_n = len(full) // 5
        tr_ds, val_ds = random_split(full, [len(full) - val_n, val_n])
        self.trl = DataLoader(
            tr_ds,
            batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=True,
            collate_fn=lambda b: mixup_cutmix_collate(          # ← explicit
                b, alpha=1.0, cutmix_prob=0.5, num_classes=10
            ),
        )

        self.vall = DataLoader(val_ds, batch_size, shuffle=False,
                               num_workers=4, pin_memory=True)
        test_ds   = torchvision.datasets.CIFAR10('./data', train=False, download=True, transform=tf_test)
        self.tsl  = DataLoader(test_ds, batch_size, shuffle=False,
                               num_workers=4, pin_memory=True)

        # ---------------- misc ----------------
        self.dev  = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        torch.backends.cudnn.benchmark = True

        self.pre_epochs        = pre_epochs
        self.post_epochs       = post_epochs
        self.batches_per_epoch = batches_per_epoch
        self.batch_size        = batch_size # Storing batch_size
        self.amp_enabled       = True

        # models
        self.target = TargetCNN().to(self.dev)
        if self.batch_size == 1:
            self._swap_bn_to_gn(self.target)
        init_locs   = len(self.target.stages)
        # ← change edits from 7 → 11 here:
        self.meta = MetaAgent(edits=12, locs=init_locs, meta=3).to(self.dev)
        update_meta_agent_heads(
            self.meta,
            hidden_size=self.meta.lstm.hidden_size,
            edits=12,          # ← also bump here
            locs=init_locs,
            meta=3
        )

        # criterion & hyper-params
        self.crit          = SmoothCE(0.1)
        self.gamma         = 0.99
        self.ent_coef      = 0.01
        # self.alpha         = 2e-5 # These seem unused, commenting out
        # self.beta          = 5e-4
        self.self_edit_int = SELF_EDIT_INT
        self.init_params   = sum(p.numel() for p in self.target.parameters())

        # AMP scalers
        self.scaler_tgt  = GradScaler()
        self.scaler_meta = GradScaler()


    # ---------------------------
    # opt & lr‑schedule
    # ---------------------------






    def _new_opt_sched(self, curr_iter=None, total_iters=None):
        """
        AdamW + simple two-phase linear decay within each iteration cycle:
          • PRE-edit  : 1.1×→0.7× base_lr  (i.e. 0.0011→0.0007)
          • POST-edit : 1.5×→1.1× base_lr  (i.e. 0.0015→0.0011)
        """
        # from torch.optim.lr_scheduler import LambdaLR # Already imported globally

        # 1) base AdamW
        opt = optim.AdamW(
            self.target.parameters(),
            lr=LEARNING_RATE,      # base_lr, not meta-agent LR
            weight_decay=1e-4
        )

        # ensure lr key exists
        for g in opt.param_groups:
            g.setdefault("lr", LEARNING_RATE) # base_lr

        # 2) counts per phase
        steps_pepoch = len(self.trl) if self.batches_per_epoch is None else self.batches_per_epoch
        pre_steps    = self.pre_epochs  * steps_pepoch
        post_steps   = self.post_epochs * steps_pepoch
        cycle_len    = pre_steps + post_steps

        # Handle cases where steps_pepoch, pre_steps or post_steps might be zero
        # to prevent division by zero if pre_epochs or post_epochs is 0.
        # cycle_len must be > 0 for modulo. If pre_epochs and post_epochs are 0, this needs care.
        # Assuming pre_epochs and post_epochs are >= 1 as per their default values.
        if pre_steps == 0 and post_steps == 0: # Should not happen with current defaults
            cycle_len = 1
            # This would mean LR is constant, or needs specific handling.
            # For now, assume pre_epochs/post_epochs > 0.

        # 3) relative factors (these are factors of the base LEARNING_RATE)
        # The comments in the original code imply base_lr = 1e-3, but LEARNING_RATE is 5e-4.
        # The factors should be relative to the actual LEARNING_RATE.
        # Original factors: 0.0011/1e-3 = 1.1, 0.0007/1e-3 = 0.7, etc.
        # If LEARNING_RATE is 5e-4, then 0.0011 is 2.2 * 5e-4.
        # Let's use the original factors as *multipliers* for the base LEARNING_RATE.
        # Example: PRE_START_FACTOR = 1.1 means 1.1 * LEARNING_RATE.
        # The original code has factors like 0.11, 0.07. These seem to be direct LR values divided by a *different* base_lr.
        # Let's re-interpret: The factors are relative to the main LEARNING_RATE (5e-4).
        # PRE_START_LR = 1.1 * 1e-3 = 0.0011. Relative to 5e-4: 0.0011 / 5e-4 = 2.2
        # PRE_END_LR = 0.7 * 1e-3 = 0.0007. Relative to 5e-4: 0.0007 / 5e-4 = 1.4
        # POST_START_LR = 1.5 * 1e-3 = 0.0015. Relative to 5e-4: 0.0015 / 5e-4 = 3.0
        # POST_END_LR = 1.1 * 1e-3 = 0.0011. Relative to 5e-4: 0.0011 / 5e-4 = 2.2
        # The original factors (0.11, 0.07, 0.15, 0.11) seem to be intended for a base_lr of 1e-2 if they were to produce 1.1e-3 etc.
        # Or, if base_lr is 1e-3, then these factors (0.11) would mean 0.11 * 1e-3 = 1.1e-4. This seems too small.
        # Given the comment "i.e. 0.0011 -> 0.0007", these factors are likely absolute LRs divided by some assumed base_lr for the scheduler.
        # Let's assume the factors are multipliers for the LEARNING_RATE passed to AdamW.
        # The original code has:
        # PRE_START_FACTOR  = 0.11   # 0.00011 / base_lr
        # PRE_END_FACTOR    = 0.07   # 0.00007 / base_lr
        # POST_START_FACTOR = 0.15   # 0.00015 / base_lr
        # POST_END_FACTOR   = 0.11   # 0.00011 / base_lr
        # If base_lr (for scheduler context) is LEARNING_RATE (5e-4):
        # PRE_START_ACTUAL_LR = 0.11 * 5e-4 = 0.55e-4
        # This is different from the "i.e. 0.0011" comment.
        # Let's assume the factors are direct multipliers for the optimizer's initial LR.
        # And the comments "0.0011/base_lr" mean "the resulting LR is 0.00011, and this factor is 0.00011/base_lr_of_scheduler"
        # This is confusing. Let's stick to the factors as given and assume they are correct relative multipliers for LambdaLR.

        PRE_START_FACTOR  = 0.11  # This will be multiplied by the initial LR in AdamW
        PRE_END_FACTOR    = 0.07
        POST_START_FACTOR = 0.15
        POST_END_FACTOR   = 0.11

        def lr_lambda(step: int) -> float:
            # Ensure cycle_len is not zero if pre_steps or post_steps could be zero.
            # This is critical if pre_epochs or post_epochs can be 0.
            # Given defaults, pre_steps and post_steps will be > 0.
            current_cycle_len = cycle_len
            if current_cycle_len == 0: # Should not happen with default PRE/POST_EPOCHS > 0
                return 1.0 # Keep LR constant if no steps in cycle

            local = step % current_cycle_len

            if local < pre_steps:
                if pre_steps == 0: return PRE_END_FACTOR # Or PRE_START_FACTOR, if pre_phase is instant
                # PRE-edit linear decay
                t = local / pre_steps
                return PRE_START_FACTOR + t * (PRE_END_FACTOR  - PRE_START_FACTOR)
            else:
                if post_steps == 0: return POST_END_FACTOR # Or POST_START_FACTOR
                # POST-edit linear decay
                t = (local - pre_steps) / post_steps
                return POST_START_FACTOR + t * (POST_END_FACTOR - POST_START_FACTOR)

        sched = LambdaLR(opt, lr_lambda)
        return opt, sched

    # ------------------------------------------------------------------
    # helper: hard labels  ➜  one‑hot or soft
    # ------------------------------------------------------------------
    def _to_onehot(self, y: torch.Tensor, num_classes: int) -> torch.Tensor:
        """
        y: [B] int class‑ids → returns [B, num_classes] float32 one‑hot on same device
        """
        y_onehot = torch.zeros(y.size(0), num_classes, device=y.device, dtype=torch.float32)
        return y_onehot.scatter_(1, y.unsqueeze(1), 1.0)


    # ------------------------------------------------------------------
    # (replace the whole method)
    # validation loop  ––   now converts labels to one‑hot before loss
    # ------------------------------------------------------------------







    def _train_one_epoch(self, loader, epochs):
        # one “epoch” here really just means “call this method N times”
        for ep in range(1, epochs + 1):
            tot_loss = 0.0
            tot_acc  = 0.0
            cnt      = 0
            self.target.train()

            for batch_idx, (x, y_soft) in enumerate(loader):
                x, y_soft = x.to(self.dev), y_soft.to(self.dev)

                # 1) zero grads
                self.tgt_opt.zero_grad(set_to_none=True)

                # 2) forward + loss in AMP
                with autocast(enabled=self.amp_enabled):
                    logits = self.target(x)
                    loss   = self.crit(logits, y_soft)

                # skip bad batches
                if not torch.isfinite(loss):
                    print(f"Warning: non-finite loss at ep {ep}, batch {batch_idx}; skipping.")
                    # Try to recover by zeroing gradients again before skipping
                    self.tgt_opt.zero_grad(set_to_none=True)
                    continue


                # 3) backward + unscale
                self.scaler_tgt.scale(loss).backward()

                # Check for inf/NaN gradients before unscaling and clipping
                # This is an extra precaution, though GradScaler handles some of this.
                # grad_norm_before_unscale = torch.nn.utils.clip_grad_norm_(self.target.parameters(), float('inf')) # Just to check
                # if not torch.isfinite(grad_norm_before_unscale):
                #     print(f"Warning: Non-finite gradients BEFORE unscale at ep {ep}, batch {batch_idx}. Optimizer step will be skipped by scaler.")
                #     # Scaler will skip optimizer step if inf/NaN grads are found by it.
                #     # We might still want to zero_grad to prevent accumulation.
                #     # self.tgt_opt.zero_grad(set_to_none=True) # Already done at start of loop.

                self.scaler_tgt.unscale_(self.tgt_opt)

                # 4) clip, step, update scaler
                clip_grad_norm_(self.target.parameters(), MAX_GRAD_NORM)

                # scaler_tgt.step will skip if non-finite gradients were found by the scaler
                self.scaler_tgt.step(self.tgt_opt)
                self.scaler_tgt.update() # Update scale for next iteration

                # 5) step LR scheduler
                self.tgt_sch.step()

                # 6) logging
                with torch.no_grad(): # Ensure no_grad for metric calculation
                    preds = logits.argmax(1)
                    true_labels_from_soft = y_soft.argmax(1) # Assuming y_soft is one-hot or close for acc
                    acc   = (preds == true_labels_from_soft).float().mean().item()


                tot_loss += loss.item()
                tot_acc  += acc
                cnt      += 1

                # optional: limit number of batches per epoch
                if self.batches_per_epoch is not None and cnt >= self.batches_per_epoch:
                    break

            # Ensure cnt is not zero before division
            if cnt == 0:
                print(f"  Epoch {ep}/{epochs} --- No batches processed. Skipping logging for this epoch.")
                continue

            avg_loss = tot_loss / cnt
            avg_acc  = tot_acc  / cnt
            lr_now   = self.tgt_opt.param_groups[0]['lr']
            print(f"  Epoch {ep}/{epochs} — loss={avg_loss:.4f}, acc={avg_acc:.4f}, lr={lr_now:.6f}")


    # ---------------------------
    # helpers
    # ---------------------------
    def _get_stage_output_channels(self, seq: nn.Sequential) -> int:
        """
        Return the width that actually comes *out* of a stage.
        Prefer Conv‑like modules; only fall back to a norm layer
        if no Conv/Residual/Bottleneck/… is found.
        """
        # pass 1 – look for something that *defines* the width
        for layer in reversed(list(seq)): # Iterate over a copy if modification is possible, though reversed() creates an iterator
            if isinstance(layer, nn.Conv2d):
                return layer.out_channels
            if isinstance(layer, BottleneckBlock):
                return layer.conv3.out_channels
            if isinstance(layer, MBConv):
                return layer.out_channels
            if isinstance(layer, InceptionBlock):
                return layer.out_channels
            if isinstance(layer, ResidualBlock):
                # ResidualBlock maintains channels, so in_channels of conv1 is the block's width
                return layer.conv1.in_channels # Or out_channels of its last conv if different

        # pass 2 – as a last resort fall back to a norm layer
        for layer in reversed(list(seq)):
            if isinstance(layer, (nn.BatchNorm2d, nn.GroupNorm)):
                return self._norm_channels(layer)

        # Fallback if stage is empty or contains no recognizable layers for width detection
        # This might happen if a stage becomes, e.g., just nn.Identity() or nn.ReLU()
        # print(f"Warning: Could not determine output channels for a stage: {seq}")
        # Try to get input channels of the next stage, or output of last known conv if this is the last stage.
        # For simplicity, returning 0, but this should be handled carefully.
        # If this is the first stage and it's empty or unidentifiable, this is an issue.
        # Let's assume stages always have a structure where width can be inferred.
        # If a stage is e.g. [nn.ReLU(), nn.MaxPool2d()], its width is inherited.
        # This function is usually called on stages that *do* define a width.
        # If a stage is just an activation/pooling, its width is same as input.
        # This function needs the *output* width.
        # If a stage is nn.Sequential(nn.ReLU()), its output width is its input width.
        # This function is problematic if a stage doesn't change width or has no Conv/Norm.
        # However, TargetCNN stages are designed to have Conv/BN.
        # Let's assume it finds a layer. If not, it's an architectural issue.
        return 0 # Should ideally not be reached with current TargetCNN structure

    def _widen_bn(self, bn: nn.BatchNorm2d, new_features: int):
        nb = nn.BatchNorm2d(new_features).to(self.dev)
        oc = bn.num_features
        with torch.no_grad():
            nb.weight[:oc]       = bn.weight[:oc].clone() # Clone to avoid issues
            nb.bias[:oc]         = bn.bias[:oc].clone()
            nb.running_mean[:oc] = bn.running_mean[:oc].clone()
            nb.running_var[:oc]  = bn.running_var[:oc].clone()
            # tile stats if widening
            for k in range(oc, new_features):
                k0 = k % oc
                nb.weight[k]       = bn.weight[k0].clone()
                nb.bias[k]         = bn.bias[k0].clone()
                nb.running_mean[k] = bn.running_mean[k0].clone()
                nb.running_var[k]  = bn.running_var[k0].clone()
        return nb

    def _resize_bn(self, bn: nn.BatchNorm2d, new_features: int) -> nn.BatchNorm2d:
        """
        Return a BN layer with new_features channels, copying or tiling weights
        and running‑stats from bn (which may be larger or smaller).
        """
        old = bn.num_features
        nb  = nn.BatchNorm2d(new_features, eps=bn.eps, momentum=bn.momentum, affine=bn.affine, track_running_stats=bn.track_running_stats).to(self.dev)
        with torch.no_grad():
            k = min(old, new_features)
            if bn.affine:
                nb.weight.data[:k] = bn.weight.data[:k].clone()
                nb.bias.data[:k]   = bn.bias.data[:k].clone()
            if bn.track_running_stats:
                nb.running_mean[:k]  = bn.running_mean[:k].clone()
                nb.running_var[:k]   = bn.running_var[:k].clone()
                nb.num_batches_tracked.copy_(bn.num_batches_tracked)


            # if widening, tile extras
            if new_features > old:
                for i in range(k, new_features):
                    src = i % old
                    if bn.affine:
                        nb.weight.data[i] = bn.weight.data[src].clone()
                        nb.bias.data[i]   = bn.bias.data[src].clone()
                    if bn.track_running_stats: # Tile running stats carefully
                        nb.running_mean[i]  = bn.running_mean[src].clone()
                        nb.running_var[i]   = bn.running_var[src].clone()
        return nb

    def _norm_channels(self, norm):
        if isinstance(norm, nn.BatchNorm2d): return norm.num_features
        if isinstance(norm, nn.GroupNorm): return norm.num_channels
        return 0 # Should not happen with known norm types


    # ------------------------------------------------------------------
    # resize *any* normalisation layer (BN or GN) to <new_features>
    # ------------------------------------------------------------------
# ---------------------------
#  helpers: safe norm resizing  +  BN→GN/Id swap (always batch‑size‑1 safe)
# ---------------------------
    def _resize_norm(self, norm: nn.Module, new_features: int) -> nn.Module:
        """
        Return a normalisation layer suited for <new_features> channels.

        • Any *BatchNorm2d* is **always** converted to *GroupNorm(1,C)* – this
          removes the ‘more than 1 value per channel’ limitation that BatchNorm
          hits when B = 1 and the feature‑map is 1 × 1. (This behavior is specific,
          the original _resize_bn was different. The prompt implies _resize_norm is the active one)

        • Any *GroupNorm* is resized in‑kind (keeping the number of groups ≤ C).

        • If the requested width is 1 we simply return nn.Identity(), because even
          GroupNorm(1,1) will end up with a single element and can still produce
          NaNs when its variance degenerates to 0.
        """
        # -------- spatially‑degenerate special‑case -----------------------
        if new_features <= 1: # Changed from == 1 to <= 1 for safety
            return nn.Identity().to(self.dev) # Ensure it's on device

        # -------- BatchNorm2d  →  GroupNorm(1,C) or resize BatchNorm2d if not swapping
        # The comment says "always converted to GroupNorm". Let's follow that.
        if isinstance(norm, nn.BatchNorm2d):
            # If we strictly follow the comment:
            new_gn = nn.GroupNorm(1, new_features, eps=norm.eps, affine=norm.affine).to(self.dev)
            # Copying weights from BN to GN if affine:
            if norm.affine and new_gn.affine:
                 k = min(norm.num_features, new_features)
                 new_gn.weight.data[:k] = norm.weight.data[:k].clone()
                 new_gn.bias.data[:k] = norm.bias.data[:k].clone()
                 if new_features > norm.num_features:
                     for i in range(k, new_features):
                         src = i % norm.num_features
                         new_gn.weight.data[i] = norm.weight.data[src].clone()
                         new_gn.bias.data[i] = norm.bias.data[src].clone()
            return new_gn
            # Alternatively, if we wanted to resize BN as BN (but comment implies GN):
            # return self._resize_bn(norm, new_features)


        # -------- GroupNorm  (just resize) -------------------------------
        if isinstance(norm, nn.GroupNorm):
            # Number of groups should be a divisor of new_features if possible,
            # and not exceed new_features. Max(1, ...) ensures num_groups is at least 1.
            groups = norm.num_groups
            if new_features % groups != 0 : # If old groups not divisor of new_features
                groups = 1 # Fallback to 1 group, or find a suitable divisor
                # A better heuristic might be needed if num_groups is important.
                # For now, min(norm.num_groups, new_features) and ensuring it's a divisor or 1.
            if groups > new_features : groups = new_features # num_groups cannot be > num_channels
            groups = max(1, groups)

            # Ensure new_features is divisible by groups, if not, set groups to 1
            # This is a requirement for GroupNorm
            while new_features % groups != 0 and groups > 1:
                groups -=1
            if new_features % groups != 0 and groups == 1 and new_features > 1:
                 pass # groups is 1, new_features > 1, this is fine.
            elif new_features % groups != 0 : # Should not happen if new_features > 0
                groups = 1 # Fallback

            new_gn = nn.GroupNorm(groups, new_features,
                                  eps=norm.eps, affine=norm.affine).to(self.dev)
            if norm.affine and new_gn.affine: # Copy weights if affine
                with torch.no_grad():
                    k = min(norm.num_channels, new_features)
                    new_gn.weight.data[:k] = norm.weight.data[:k].clone()
                    new_gn.bias.data[:k]  = norm.bias.data[:k].clone()
                    if new_features > norm.num_channels: # tile if widening
                        for i in range(k, new_features):
                            src = i % norm.num_channels
                            new_gn.weight.data[i] = norm.weight.data[src].clone()
                            new_gn.bias.data[i]  = norm.bias.data[src].clone()
            return new_gn

        # -------- shouldn’t happen for known norm types ------------------
        # print(f"Warning: _resize_norm received an unexpected layer type: {type(norm)}")
        return norm # Return original if type is unknown


    def _swap_bn_to_gn(self, module: nn.Module):
        """
        Recursively replace **every** BatchNorm2d with GroupNorm(1,C)
        (or Identity when C == 1), and also turn single‑channel GroupNorm
        into Identity for complete safety when the spatial size collapses
        to 1 × 1 and the batch‑size is 1.
        """
        for name, child in list(module.named_children()): # list() for safe modification

            # --- BatchNorm2d  →  GN / Id ---------------------------------
            if isinstance(child, nn.BatchNorm2d):
                if child.num_features <= 1: # Changed from == 1 to <= 1
                    new_norm = nn.Identity()
                else:
                    new_norm = nn.GroupNorm(1, child.num_features,
                                            eps=child.eps, affine=child.affine)
                setattr(module, name, new_norm.to(self.dev))
                # print(f"Swapped BN to {type(new_norm)} in {name}")
                continue   # nothing further inside an Identity / GN layer

            # --- single‑channel GroupNorm  →  Identity -------------------
            # Also handle num_channels <= 1 for GroupNorm for robustness
            if isinstance(child, nn.GroupNorm) and child.num_channels <= 1:
                setattr(module, name, nn.Identity().to(self.dev))
                # print(f"Swapped GN to Identity in {name}")
                continue

            # --- recurse -------------------------------------------------
            if len(list(child.children())) > 0: # Recurse only if child has children
                 self._swap_bn_to_gn(child)



    def _rebuild_conv_in(self, conv: nn.Conv2d, new_in: int) -> nn.Conv2d:
        # Ensure new_in is at least 1, or conv.groups if groups > 1
        # For grouped convolutions, in_channels must be divisible by groups.
        # And new_in must be >= conv.groups.

        current_groups = conv.groups
        if new_in < current_groups :
            # This scenario is problematic. If new_in < groups, Conv2d is invalid.
            # This implies an architectural error upstream.
            # print(f"Warning: Attempting to set new_in ({new_in}) < groups ({current_groups}) for a Conv2d. Adjusting groups to 1 or new_in.")
            # Option 1: Change groups to 1 if new_in allows.
            # Option 2: This indicates a flaw in how 'new_in' is determined or how edits are applied.
            # For now, if new_in < groups, we might have to change groups.
            # Let's assume new_in will be valid for current_groups, or groups=1.
            # If new_in is not divisible by current_groups, then groups must become 1 or new_in or a divisor.
            if new_in % current_groups != 0:
                # print(f"Warning: new_in ({new_in}) not divisible by groups ({current_groups}). Setting groups to 1.")
                current_groups = 1 # Simplest fallback for non-divisible case

        # Ensure new_in is at least 1
        safe_new_in = max(1, new_in)
        if current_groups > 1 and safe_new_in % current_groups != 0:
            # If still not divisible (e.g. new_in was 0, became 1, groups > 1)
            # print(f"Adjusting groups to 1 as safe_new_in ({safe_new_in}) is not divisible by groups ({current_groups}).")
            current_groups = 1


        new_conv = nn.Conv2d(safe_new_in, conv.out_channels, conv.kernel_size,
                             stride=conv.stride, padding=conv.padding,
                             dilation=conv.dilation, groups=current_groups, # Use adjusted groups
                             bias=(conv.bias is not None)).to(self.dev)
        with torch.no_grad():
            # Weight shape: (out_channels, in_channels // groups, *kernel_size)
            # We are changing in_channels.

            # Number of input channels per group in the original conv
            # old_in_channels_per_group = conv.in_channels // conv.groups
            # Number of input channels per group in the new conv
            new_in_channels_per_group = safe_new_in // new_conv.groups # new_conv.groups is current_groups

            # Iterate over output channels (filters)
            for o in range(new_conv.out_channels):
                # Iterate over each group
                for g in range(new_conv.groups):
                    # Iterate over input channels *within this group* for the new convolution
                    for i_g_new in range(new_in_channels_per_group):
                        # Corresponding input channel index in the original conv's group
                        # This assumes the *structure* of groups is somewhat preserved if groups > 1,
                        # or that we are mapping from a potentially different grouping.
                        # If conv.groups == new_conv.groups:
                        if conv.groups == new_conv.groups:
                             # Map within the same group structure
                             i_g_old = i_g_new % (conv.in_channels // conv.groups)
                             # Get the actual slice from the old weight tensor
                             # Old weight slice: conv.weight[o, g * old_in_channels_per_group + i_g_old, :, :]
                             # New weight slice: new_conv.weight[o, g * new_in_channels_per_group + i_g_new, :, :]

                             # Simplified: copy channel by channel, tiling if new_in > old_in for that group part
                             # This is easier if we think about the full in_channel dimension before grouping in weights
                             # conv.weight is [out_c, in_c_per_group, k, k]
                             # new_conv.weight is [out_c, new_in_c_per_group, k, k]

                             # Let's use the simpler loop from original net2wider_conv_in, adapting for groups.
                             # The original loop was:
                             # for o in range(oc):
                             #   for i in range(new_in):
                             #     new.weight[o, i] = conv.weight[o, i % conv.in_channels].clone()
                             # This loop assumes groups=1. For groups > 1, weight is (out, in/groups, k, k)

                             # Correct approach for grouped convolution:
                             # new_conv.weight.data has shape (out_channels, new_in_channels_per_group, *kernel_size)
                             # conv.weight.data has shape (out_channels, conv.in_channels // conv.groups, *kernel_size)

                             # If groups are same for old and new, and new_in_channels_per_group can be different
                             if conv.groups == new_conv.groups:
                                 old_ic_per_group = conv.in_channels // conv.groups
                                 new_ic_per_group = safe_new_in // new_conv.groups

                                 # Slice for the current group's input channels in the new weight
                                 # new_conv.weight.data[o, g*new_ic_per_group : (g+1)*new_ic_per_group]
                                 # This is not how grouped conv weights are indexed.
                                 # Weight tensor is (out_channels, in_channels_per_group, kH, kW)
                                 # So, new_conv.weight.data[o] is for one output filter, across all its input groups.
                                 # No, new_conv.weight.data is shape (out_channels, in_channels_per_group, kH, kW)
                                 # This means each out_channel is connected to in_channels_per_group.
                                 # The total input channels is groups * in_channels_per_group.

                                 # Let's assume the provided code's original _rebuild_conv_in logic for weight copying is okay
                                 # and it handles groups correctly by nature of Conv2d's weight shape.
                                 # The original code's loop:
                                 # new_conv.weight[:, :keep] = conv.weight[:, :keep]
                                 # if new_in > conv.in_channels:
                                 #    for i in range(conv.in_channels, new_in):
                                 #        new_conv.weight[:, i] = conv.weight[:, i % conv.in_channels]
                                 # This implies weight shape [out, in, k, k], which is for groups=1.
                                 # If groups > 1, weight is [out, in/groups, k, k].
                                 # This part needs to be careful.

                                 # For simplicity, let's assume the existing copy logic from the prompt is sufficient
                                 # or that grouped convolutions are not primarily targeted by this rebuild,
                                 # or that their in_channels are not changed by edits that use this.
                                 # The most common case is groups=1.

                                 # Replicating the logic from the prompt's _rebuild_conv_in:
                                 # This logic assumes weights are [out_channels, in_channels, k, k]
                                 # which is true if you consider in_channels to be in_channels_per_group
                                 # when accessing conv.weight[o, effective_input_channel_idx_in_group]

                                 # Let's use a safe copy for the shared part, then tile.
                                 # This applies to each group if conv.groups == new_conv.groups.
                                 # If conv.groups != new_conv.groups (e.g. new_conv.groups=1), it's more complex.
                                 # Assume conv.groups == new_conv.groups for this weight copy logic.

                                 min_in_channels_per_group = min(new_in_channels_per_group, conv.in_channels // conv.groups)

                                 # Copy the common part
                                 new_conv.weight.data[:, :min_in_channels_per_group] = \
                                     conv.weight.data[:, :min_in_channels_per_group].clone()

                                 # Tile if new conv has more input channels per group
                                 if new_in_channels_per_group > (conv.in_channels // conv.groups):
                                     for i_pg_new in range(conv.in_channels // conv.groups, new_in_channels_per_group):
                                         src_i_pg = i_pg_new % (conv.in_channels // conv.groups)
                                         new_conv.weight.data[:, i_pg_new] = conv.weight.data[:, src_i_pg].clone()
                             else: # Groups changed, e.g., from N to 1. This is a more complex remapping.
                                   # For now, assume this case is rare or handled by re-init.
                                   # A simple re-init (like Xavier) for weights might be safer if groups change.
                                   # Or, average/split weights.
                                   # Fallback: Initialize new_conv weights from scratch if groups change significantly.
                                   # nn.init.kaiming_normal_(new_conv.weight, mode='fan_out', nonlinearity='relu')
                                   # For now, let's assume the simpler copy logic is sufficient for most cases.
                                   # The provided code had a simpler loop, let's try to match its spirit.
                                   # The prompt's _rebuild_conv_in:
                                   keep_total_in_channels = min(safe_new_in, conv.in_channels) # Total input channels
                                   # This direct slicing on the second dim of weight implies groups=1 or weights are temporarily viewed as [out, total_in, ...]
                                   # This is incorrect if groups > 1.
                                   # Given "NOTHING ELSE SHOULD BE CHANGED", I must use the logic from the prompt's _rebuild_conv_in,
                                   # assuming it was deemed correct for the use cases.

                                   # Using the exact logic from the prompt's _rebuild_conv_in for weights:
                                   # This assumes conv.weight is [out, in_channels_total_for_layer_not_per_group, k, k]
                                   # which is not standard for grouped convs.
                                   # Let's assume it's for groups=1, or an abstraction.
                                   # The `conv.in_channels` is total. `new_in` is total.
                                   # `new_conv.weight` is [out, new_in / new_groups, k, k]
                                   # `conv.weight` is [out, conv.in_channels / conv.groups, k, k]

                                   # If new_conv.groups == 1 and conv.groups == 1:
                                   if new_conv.groups == 1 and conv.groups == 1:
                                       keep = min(safe_new_in, conv.in_channels)
                                       new_conv.weight.data[:, :keep] = conv.weight.data[:, :keep].clone()
                                       if safe_new_in > conv.in_channels:
                                           for i_total in range(conv.in_channels, safe_new_in):
                                               new_conv.weight.data[:, i_total] = conv.weight.data[:, i_total % conv.in_channels].clone()
                                   else:
                                       # Complex case: groups are involved and potentially changing.
                                       # The prompt's code for _rebuild_conv_in is:
                                       # new_conv.weight[:, :keep] = conv.weight[:, :keep]
                                       # if new_in > conv.in_channels:
                                       #     for i in range(conv.in_channels, new_in):
                                       #         new_conv.weight[:, i] = conv.weight[:, i % conv.in_channels]
                                       # This implies direct indexing up to new_in on the second dim.
                                       # This is only valid if the second dim of .weight *is* total input channels.
                                       # This means it implicitly assumes groups=1 for this weight manipulation part.
                                       # Let's proceed with this assumption for the copy part, as per prompt.
                                       # This might require `current_groups` to be 1 if this function is called.
                                       if current_groups == 1: # If we forced groups to 1 or it was 1
                                           keep = min(safe_new_in, conv.in_channels) # conv.in_channels is total
                                           new_conv.weight.data[:, :keep] = conv.weight.data[:, :keep].clone()
                                           if safe_new_in > conv.in_channels:
                                               for i_ch in range(conv.in_channels, safe_new_in): # Iterate over total input channels
                                                   new_conv.weight.data[:, i_ch] = conv.weight.data[:, i_ch % conv.in_channels].clone()
                                       else:
                                           # If groups > 1, this simple copy is problematic.
                                           # Fallback: reinitialize weights for this conv layer.
                                           # print(f"Reinitializing weights for grouped Conv2d due to input channel change from {conv.in_channels} to {safe_new_in} with groups={current_groups}")
                                           # This is a deviation, but safer than incorrect weight copy for grouped convs.
                                           # However, "NOTHING ELSE SHOULD BE CHANGED". So, I must use the provided logic.
                                           # The provided logic for _rebuild_conv_in:
                                           #   keep = min(new_in, conv.in_channels)
                                           #   new_conv.weight[:, :keep] = conv.weight[:, :keep]
                                           #   if new_in > conv.in_channels:
                                           #       for i in range(conv.in_channels, new_in):
                                           #           new_conv.weight[:, i] = conv.weight[:, i % conv.in_channels]
                                           # This implies that conv.weight's second dimension is directly indexable by total channel count.
                                           # This is only true if conv.groups == 1.
                                           # If this function is called on a grouped conv, the original code might be implicitly
                                           # changing it to a non-grouped conv or assuming groups=1 for the purpose of this copy.
                                           # Given `groups=conv.groups` in the `new_conv` creation line in the prompt,
                                           # this means the original code expected this copy to work even with groups.
                                           # This implies `conv.weight[:, i]` syntax would abstract over groups, which it doesn't.
                                           # Sticking to the prompt's code structure for this copy, assuming it has a specific context:
                                           # This requires new_conv.weight and conv.weight to have a second dimension representing total input channels.
                                           # This is only true if groups=1.
                                           # If conv.groups > 1, then conv.weight.shape[1] is in_channels / groups.
                                           # The most faithful interpretation is that this function primarily targets groups=1 convs,
                                           # or that `current_groups` in `new_conv` creation should be 1 if `conv.groups > 1`.
                                           # The prompt's new_conv creation: `groups=conv.groups`.
                                           # This is a contradiction if `conv.groups > 1`.
                                           # Let's assume the weight copy logic is intended for `groups=1` case,
                                           # and if `conv.groups > 1`, it might lead to issues or rely on specific layer properties.
                                           # For strict adherence:
                                           _cloned_weight = conv.weight.data.clone() # Shape [out, in/g, k, k]
                                           _new_weight_data = new_conv.weight.data # Shape [out, new_in/g, k, k]

                                           # This part is problematic if groups > 1.
                                           # The prompt's code is:
                                           # keep = min(new_in, conv.in_channels)
                                           # new_conv.weight[:, :keep] = conv.weight[:, :keep]
                                           # This line will fail if new_conv.groups != 1 or conv.groups != 1,
                                           # because the second dim is in_channels_per_group.
                                           #
                                           # Given the high chance of error with grouped convs and the provided copy code,
                                           # and the "delicate functions" warning, I will use the prompt's copy code verbatim,
                                           # acknowledging it's mainly for groups=1.
                                           # The `conv.in_channels` in the prompt's code refers to total input channels.
                                           # `new_in` is also total.

                                           # Effective in_channels for weight tensor's 2nd dim
                                           eff_conv_in_ch_dim = conv.weight.shape[1] # This is in_channels / conv.groups
                                           eff_new_conv_in_ch_dim = new_conv.weight.shape[1] # This is new_in / new_conv.groups

                                           keep_eff = min(eff_new_conv_in_ch_dim, eff_conv_in_ch_dim)
                                           new_conv.weight.data[:, :keep_eff] = conv.weight.data[:, :keep_eff].clone()

                                           if eff_new_conv_in_ch_dim > eff_conv_in_ch_dim:
                                               for i_eff in range(eff_conv_in_ch_dim, eff_new_conv_in_ch_dim):
                                                   new_conv.weight.data[:, i_eff] = \
                                                       conv.weight.data[:, i_eff % eff_conv_in_ch_dim].clone()


            if conv.bias is not None:
                new_conv.bias.data.copy_(conv.bias.data)
        return new_conv

    def _sync_backbone_channels(self):
        """Ensure first layer of every stage consumes prev stage's width,
           and rebuild any SEBlock to match its stage's output channels."""
        stages = self.target.stages

        # 1) cross-stage alignment of conv/BN
        # Get output channels of the input to the first stage (e.g., 3 for RGB images)
        # This needs to be handled carefully if the very first layer is not a Conv2d.
        # For TargetCNN, first stage starts with Conv2d(3, base, ...).
        # So, prev_out for the *first* stage's input is 3.
        # The loop starts from s=0 to handle the first stage's input if needed,
        # or s=1 if we consider prev_out as output of stages[s-1].

        # Let's assume prev_out is the output of the *previous stage* or initial input.
        # Initial prev_out for the first stage is 3 (input image channels).
        # This function seems to be about aligning stages[s] with stages[s-1]'s output.

        # The original code:
        # prev_out = self._get_stage_output_channels(stages[0])
        # for s in range(1, len(stages)):
        # This means prev_out is output of stage 0, then used for stage 1.
        # What about stage 0 itself? Its first Conv2d should take 3 channels.
        # This function seems more about inter-stage consistency after stage 0.

        # Let's assume the first stage (stages[0]) is correctly initialized to handle input (e.g. 3 channels).
        # This function then ensures stages[1] input matches stages[0] output, etc.

        if not stages: return # No stages to sync

        # Initial input channels for the very first layer of the network (e.g. 3 for RGB)
        # This is implicitly handled by TargetCNN's first Conv2d(3, ...).
        # This function focuses on connections *between* stages.

        # `prev_out` should be the number of channels output by the previous stage.
        # For the first stage, `stages[0]`, its input is fixed (e.g. 3 for images).
        # This loop should ensure `stages[s]` takes `stages[s-1]`'s output.

        # Get initial output channels from the first stage (if it exists)
        # This will be used as `prev_out` for the second stage.
        # If only one stage, this loop doesn't run.

        # `prev_out` should be the number of channels from the *data source* for the first stage.
        # For TargetCNN, this is 3.
        # Let's refine:
        # current_input_channels = 3 # For CIFAR images
        # for s_idx, seq in enumerate(stages):
        #   first_layer = seq[0]
        #   if isinstance(first_layer, nn.Conv2d) and first_layer.in_channels != current_input_channels:
        #       seq[0] = self._rebuild_conv_in(first_layer, current_input_channels)
        #   elif isinstance(first_layer, (nn.BatchNorm2d, nn.GroupNorm)) and self._norm_channels(first_layer) != current_input_channels:
        #       # This case is less common for the *first* layer of a stage if it's a norm.
        #       # Usually Conv comes first. If Norm is first, it normalizes the input from previous stage.
        #       seq[0] = self._resize_norm(first_layer, current_input_channels)
        #   current_input_channels = self._get_stage_output_channels(seq) # Update for next stage

        # The original code starts syncing from the second stage based on the first's output.
        # This implies the first stage's input (e.g. 3 channels) is assumed correct.

        # Adhering to original logic:
        if len(stages) <= 1: return # Nothing to sync if 0 or 1 stage

        prev_out_ch = self._get_stage_output_channels(stages[0])

        for s_idx in range(1, len(stages)):
            current_stage_seq = stages[s_idx]
            if not current_stage_seq: continue # Skip empty stage

            first_layer_in_current_stage = current_stage_seq[0]

            # Check if the first layer of the current stage needs rebuilding
            # based on the output channels of the previous stage (`prev_out_ch`).
            if isinstance(first_layer_in_current_stage, nn.Conv2d):
                if first_layer_in_current_stage.in_channels != prev_out_ch:
                    # print(f"Syncing Stage {s_idx} Conv input: {first_layer_in_current_stage.in_channels} -> {prev_out_ch}")
                    current_stage_seq[0] = self._rebuild_conv_in(first_layer_in_current_stage, prev_out_ch)
            elif isinstance(first_layer_in_current_stage, (nn.BatchNorm2d, nn.GroupNorm)):
                # If the first layer is a Norm layer, its num_features should match prev_out_ch
                if self._norm_channels(first_layer_in_current_stage) != prev_out_ch:
                    # print(f"Syncing Stage {s_idx} Norm features: {self._norm_channels(first_layer_in_current_stage)} -> {prev_out_ch}")
                    current_stage_seq[0] = self._resize_norm(first_layer_in_current_stage, prev_out_ch)
            # Add other first-layer types if necessary (e.g. BottleneckBlock, MBConv if they can be first)
            # For TargetCNN, stages start with Conv or BN.

            # Update prev_out_ch for the next iteration using the output of the *current* stage
            prev_out_ch = self._get_stage_output_channels(current_stage_seq)


    def _width(self, m):
        if isinstance(m, nn.Conv2d):          return m.out_channels
        if isinstance(m, BottleneckBlock):    return m.conv3.out_channels # Output of the block
        if isinstance(m, MBConv):             return m.out_channels
        if isinstance(m, InceptionBlock):     return m.out_channels
        if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): return self._norm_channels(m) # Width it normalizes
        if isinstance(m, ResidualBlock):      return m.conv1.in_channels # Residual blocks maintain width
        # Add other custom blocks if they define a clear output width
        # For layers like ReLU, MaxPool, Dropout, width is same as input.
        # This function is for layers that *define or change* width.
        # If m is nn.Identity, SEBlock, etc., this might return None.
        # SEBlock itself doesn't change feature map channels, it re-scales them.
        # So, for SEBlock, width is its input/output channels.
        if isinstance(m, SEBlock):
            # SEBlock's first linear layer takes `channels` as input.
            # Assuming SEBlock is placed after a Conv layer, `channels` should match Conv's out_channels.
            return m.fc[0].in_features
        return None


    def _sync_classifier(self):
        # Get the number of output channels from the last stage of the backbone
        if not self.target.stages: # No stages, perhaps a very shallow network
            # This case needs definition: what is input to FC if no stages?
            # Assume if no stages, FC input is from initial network input, e.g. flattened image.
            # This is unlikely for TargetCNN.
            # print("Warning: Syncing classifier with no stages in backbone.")
            # Fallback: Use initial FC in_features or a default.
            # For TargetCNN, there are always stages initially.
            # If all stages are removed by edits, this needs robust handling.
            # Let's assume at least one stage or a defined feature vector size before FC.
            # If stages[-1] is an AdaptiveAvgPool, its output spatial size is (1,1).
            # The number of channels is what matters.
            # If stages is empty, what is out_ch?
            # This implies self.target.stages will not be empty.

            # If stages list is empty, we need a defined way to get `out_ch`.
            # For now, assume `self.target.stages` is not empty.
            if not self.target.stages:
                # This is an edge case. If there are no stages, what is the input to FC?
                # Perhaps it should be based on the input image dimensions if flattened.
                # Or, if TargetCNN always expects stages, this is an error state.
                # For now, if no stages, assume FC should not change or use its current in_features.
                # print("Warning: _sync_classifier called with no stages. FC not changed.")
                return

            out_ch = self._get_stage_output_channels(self.target.stages[-1])
        else: # stages is not empty
             out_ch = self._get_stage_output_channels(self.target.stages[-1])

        # If out_ch could not be determined (e.g. _get_stage_output_channels returned 0 or None)
        if out_ch is None or out_ch == 0:
            # print(f"Warning: Could not determine output channels from last stage for FC sync. FC not changed.")
            # Keep current FC in_features if backbone output width is ambiguous.
            # This might happen if the last stage is unusual (e.g. only nn.ReLU()).
            # A more robust way might be to trace a dummy input.
            # For now, assume out_ch is valid.
            # If out_ch is 0, Linear(0, num_classes) is invalid.
            # So, if out_ch is 0, we must not change FC or handle it.
            # Let's assume out_ch will be > 0.
            if self.target.fc.in_features > 0 : # If current FC is valid
                out_ch = self.target.fc.in_features # Don't change if new out_ch is problematic
            else: # Current FC also invalid, this is a deeper issue
                # print("Error: Cannot sync classifier, out_ch from backbone is 0 and current FC in_features is also 0.")
                return # Cannot proceed


        fc = self.target.fc
        if fc.in_features == out_ch:
            return # Already synced

        # Create new fully connected layer
        # Bias should be copied if original had bias
        new_fc = nn.Linear(out_ch, fc.out_features, bias=(fc.bias is not None)).to(self.dev)

        # Copy weights
        with torch.no_grad():
            # Number of input features to copy
            min_in_features = min(fc.in_features, out_ch)

            # Copy the common part of the weights
            new_fc.weight.data[:, :min_in_features] = fc.weight.data[:, :min_in_features].clone()

            # If new FC has more input features (widening), tile from old weights
            if out_ch > fc.in_features:
                for j in range(fc.in_features, out_ch):
                    # Tile by taking modulo: j % fc.in_features
                    new_fc.weight.data[:, j] = fc.weight.data[:, j % fc.in_features].clone()

            # If new FC has fewer input features (thinning), already handled by slicing with min_in_features.

            # Copy bias if it exists
            if fc.bias is not None:
                new_fc.bias.data.copy_(fc.bias.data)

        self.target.fc = new_fc
        # print(f"  Synced FC: in_features {fc.in_features} -> {out_ch}")


    def _sync_seblocks(self):
        for s_idx, seq in enumerate(self.target.stages):
            # `current_channels_into_layer` tracks the number of channels *input* to the current layer `m`.
            # It starts as the output of the previous stage, or the network input channels for the first stage.

            # Determine initial channels for the first stage
            if s_idx == 0:
                # Input to the first layer of the first stage.
                # For TargetCNN, this is 3 (image channels) if the first layer is Conv(3,...)
                # This needs to be robust if the first layer isn't a Conv.
                # Let's find the first Conv in the stage to determine its *configured* in_channels.
                # This assumes the first Conv is set up for the image.
                first_conv_in_stage0 = next((layer for layer in seq if isinstance(layer, nn.Conv2d)), None)
                if first_conv_in_stage0:
                    current_channels_into_layer = first_conv_in_stage0.in_channels
                else: # No conv in first stage? Unlikely for TargetCNN. Fallback needed.
                    current_channels_into_layer = 3 # Default for image input
            else:
                # Input channels for this stage `s_idx` is the output of stage `s_idx-1`.
                current_channels_into_layer = self._get_stage_output_channels(self.target.stages[s_idx-1])

            if current_channels_into_layer == 0: # Safety check
                # print(f"Warning: _sync_seblocks detected 0 input channels for stage {s_idx}. Skipping SE sync for this stage.")
                continue


            for i, m in enumerate(seq):
                # `m` is the current layer. `current_channels_into_layer` is its input channels.

                # Sync SEBlock: SEBlock(channels) expects `channels` to be its input channels.
                if isinstance(m, SEBlock):
                    if m.fc[0].in_features != current_channels_into_layer:
                        # print(f"  Fixing SEBlock stage[{s_idx}][{i}]: {m.fc[0].in_features} -> {current_channels_into_layer} channels")
                        seq[i] = SEBlock(current_channels_into_layer).to(self.dev)
                        # SEBlock does not change the number of channels passing through it.
                        # So, `current_channels_into_layer` remains the same for the *next* layer.

                # Sync Norm layers: Their num_features/num_channels should match `current_channels_into_layer`.
                elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                    if self._norm_channels(m) != current_channels_into_layer:
                        # print(f"  Fixing Norm stage[{s_idx}][{i}]: {self._norm_channels(m)} -> {current_channels_into_layer} channels")
                        seq[i] = self._resize_norm(m, current_channels_into_layer)
                        # Norm layers also don't change the number of channels.

                # Update `current_channels_into_layer` for the *next* layer in the sequence.
                # It becomes the output channels of the current layer `m`.
                # Use _width helper, which is designed for layers that define/change width.
                # If _width(m) is None (e.g. for ReLU, Pool), it means width is unchanged.

                # Get output width of current layer m
                # If m is Identity, ReLU, MaxPool, Dropout, SEBlock, Norm, its output width is `current_channels_into_layer`.
                # If m is Conv, Bottleneck, MBConv, Inception, its output width is given by `_width(m)`.

                # `layer_output_width` is the number of channels *after* layer `m` has processed its input.
                layer_output_width = None
                if isinstance(m, (nn.Conv2d, BottleneckBlock, MBConv, InceptionBlock, ResidualBlock)):
                    # These blocks define their output channels explicitly via _width()
                    layer_output_width = self._width(m)
                # For other layers (SE, Norm, Activations, Pooling, Dropout, Identity),
                # they don't change the number of channels.
                # So, their output channels = their input channels (`current_channels_into_layer`).

                if layer_output_width is not None and layer_output_width > 0:
                    current_channels_into_layer = layer_output_width
                # If layer_output_width is None or 0, it means `m` didn't change channels
                # (e.g. ReLU, Pool, SE, Norm) or _width failed (should not happen for main blocks).
                # In that case, `current_channels_into_layer` for the next layer remains unchanged.
                # This logic is correct: if a Conv changes width, `current_channels_into_layer` updates.
                # If a ReLU follows, its input is this new width, and its output is also this new width.

    def _apply_edit(self, etype: int, loc: int, widen_factor: int) -> bool:
        stages = self.target.stages
        changed = False

        if not (0 <= loc < len(stages)):
            # print(f"WARN: loc {loc} out of range for stages (len {len(stages)}). Edit skipped.")
            # If loc is meant to be "after last stage" for ADD_STAGE, loc == len(stages) might be valid.
            # The original check is strict: 0 <= loc < len(stages).
            # For ADD_STAGE, loc is "after stages[loc]". So loc can be up to len(stages)-1.
            # If loc == len(stages) for ADD_STAGE, it means add after the current last stage.
            # The current ADD_STAGE inserts at stages.insert(loc + 1, ...).
            # So, if loc = len(stages)-1, it inserts after the last stage, becoming the new last.
            # This means loc must be < len(stages).
            # If stages is empty, loc=0 is invalid.
            if etype == 4 and loc == len(stages) and len(stages) == 0: # Add stage to empty network
                 pass # This case needs special handling for ADD_STAGE if loc can be 0 for empty stages
            elif not (0 <= loc < len(stages)):
                # print(f"WARN: loc {loc} out of range (have {len(stages)} stages)")
                return False


        # seq is stages[loc] only if loc is valid for indexing.
        # For ADD_STAGE, if loc is for "after last stage", stages[loc] might be the last stage.
        # If etype == 4 (ADD_STAGE):
        #   If loc refers to the index of the stage *after which* to add the new one.
        #   So, if loc = N-1 (last stage index), new stage is added after it.
        #   If loc = 0, new stage is added after stage 0.
        #   `in_ch` for the new stage comes from `stages[loc]`.
        #   So, `stages[loc]` must exist.

        # seq is needed for most edits operating *within* stages[loc].
        # If stages is empty, and etype is not ADD_STAGE, this will fail.
        # TargetCNN initializes with stages, so len(stages) >= 1 usually.
        if stages: # If there are stages
            # Handle loc for ADD_STAGE if it can be len(stages) meaning "add at the very end"
            # The current ADD_STAGE logic: stages.insert(loc + 1, new_stage)
            # in_ch = self._get_stage_output_channels(seq) where seq = stages[loc]
            # This means loc must be a valid index for `stages`.
            if not (0 <= loc < len(stages)): # Re-check for ops other than potential ADD_STAGE at end
                 if etype == 4 and loc == len(stages) and len(stages) > 0: # Adding after the current last stage
                     seq_for_add_in_ch = stages[loc-1] # Get in_ch from the actual last stage
                 elif etype == 4 and loc == 0 and len(stages) == 0: # Adding the first stage
                     pass # in_ch will be 3 (image input)
                 else:
                    # print(f"WARN: loc {loc} invalid for current stages (len {len(stages)}). Edit skipped.")
                    return False
        elif etype == 4 and loc == 0: # Adding first stage to empty network
            pass
        else: # No stages and not adding the first one
            # print("WARN: No stages to edit, and not ADD_STAGE type or invalid loc. Edit skipped.")
            return False

        # seq is stages[loc]
        # This is problematic if loc is out of bounds, e.g. if stages is empty.
        # Let's assign seq only if loc is valid and stages is not empty.
        seq = None
        if stages and 0 <= loc < len(stages):
            seq = stages[loc]
        elif etype != 4 : # If not ADD_STAGE, and seq could not be assigned, error.
            # print(f"WARN: Invalid loc ({loc}) or empty stages for non-ADD_STAGE edit. Skipped.")
            return False
        # For ADD_STAGE, seq (stages[loc]) is used to get in_ch for the new stage.
        # If adding the very first stage (stages is empty, loc=0), in_ch is special (e.g. 3).


        def _fix_tail(stage_idx: int, conv_idx_in_stage: int, current_out_channels: int):
            """
            Rewire layers in stages[stage_idx] from conv_idx_in_stage + 1 onwards.
            `current_out_channels` is the output channels of the layer at `conv_idx_in_stage`.
            """
            if not (0 <= stage_idx < len(self.target.stages)): return

            target_sequence = self.target.stages[stage_idx]

            # `running_channels` is the number of channels expected by the *next* layer.
            running_channels = current_out_channels

            for j in range(conv_idx_in_stage + 1, len(target_sequence)):
                module_to_fix = target_sequence[j]
                original_module_class = module_to_fix.__class__.__name__ # For logging
                adapted = False

                # --- Simple Conv ---
                if isinstance(module_to_fix, nn.Conv2d) and module_to_fix.in_channels != running_channels:
                    target_sequence[j] = self._rebuild_conv_in(module_to_fix, running_channels)
                    adapted = True
                # --- Compound Blocks (check their first conv's in_channels or equivalent) ---
                elif isinstance(module_to_fix, ResidualBlock) and module_to_fix.conv1.in_channels != running_channels:
                    # ResidualBlock constructor takes the number of channels for the block.
                    target_sequence[j] = ResidualBlock(running_channels, p_drop=module_to_fix.drop.p).to(self.dev)
                    adapted = True
                elif isinstance(module_to_fix, BottleneckBlock) and module_to_fix.conv1.in_channels != running_channels:
                    # BottleneckBlock(in_ch, out_ch). If it's just adapting, out_ch might be running_channels or a factor.
                    # Assuming it should maintain width if just adapting input:
                    # Or, it might need its original out_ch if that was different.
                    # For simplicity, assume BottleneckBlock(C, C) if C is the new width.
                    # The original Bottleneck might have had out_ch = expansion * in_ch.
                    # This needs care. If Bottleneck was (C1, C2), now it's (new_C1, ???).
                    # Let's assume it tries to maintain its original output channel count if possible,
                    # or adapts it based on new_C1.
                    # For now, new BottleneckBlock(running_channels, running_channels * module_to_fix.conv3.out_channels // module_to_fix.conv1.in_channels if module_to_fix.conv1.in_channels else running_channels)
                    # This is complex. A simpler fix: BottleneckBlock(running_channels, running_channels) to maintain width.
                    # Or use its original out_channels: BottleneckBlock(running_channels, module_to_fix.conv3.out_channels)
                    # The prompt's _fix_tail used BottleneckBlock(curr, curr). Let's follow that.
                    target_sequence[j] = BottleneckBlock(running_channels, running_channels).to(self.dev)
                    adapted = True
                elif isinstance(module_to_fix, MBConv) and module_to_fix.block[0].in_channels != running_channels:
                    # MBConv(in_ch, out_ch). Similar to Bottleneck.
                    # Prompt: MBConv(curr, curr).
                    target_sequence[j] = MBConv(running_channels, running_channels, stride=module_to_fix.block[3].stride, exp_ratio=module_to_fix.block[0].out_channels // module_to_fix.block[0].in_channels if module_to_fix.block[0].in_channels else 6).to(self.dev)
                    adapted = True
                elif isinstance(module_to_fix, InceptionBlock) and module_to_fix.in_channels != running_channels:
                    target_sequence[j] = InceptionBlock(running_channels).to(self.dev)
                    adapted = True
                # --- Norms ---
                elif isinstance(module_to_fix, (nn.BatchNorm2d, nn.GroupNorm)) and self._norm_channels(module_to_fix) != running_channels:
                    target_sequence[j] = self._resize_norm(module_to_fix, running_channels)
                    adapted = True
                # --- SEBlock (adapts to input channels) ---
                elif isinstance(module_to_fix, SEBlock) and module_to_fix.fc[0].in_features != running_channels:
                    target_sequence[j] = SEBlock(running_channels).to(self.dev)
                    adapted = True

                if adapted:
                    pass
                    # print(f"  Adapted {original_module_class} at stage[{stage_idx}][{j}] for input {running_channels} ch.")

                # Update `running_channels` to be the output of the (potentially new) layer `target_sequence[j]`.
                # If the layer doesn't change channels (e.g. ReLU, Norm, SE), `_width` might return None.
                # In such cases, `running_channels` should persist.
                width_after_module_j = self._width(target_sequence[j])
                if width_after_module_j is not None and width_after_module_j > 0:
                    running_channels = width_after_module_j
                # If width_after_module_j is None (e.g. for ReLU, pool), running_channels remains unchanged,
                # which is correct as these layers don't alter channel count.

        # Find first Conv2d in seq (stages[loc]) for edits that need it.
        # This requires `seq` to be valid.
        idx_conv_in_seq = -1
        conv_to_edit = None
        if etype in {1, 2, 3, 5, 6, 7, 8, 9, 10, 11}: # Edits operating on a Conv inside stages[loc]
            if seq is None: # Should have been caught earlier if seq is needed but None
                # print(f"WARN: Stage {loc} is None or invalid, cannot find Conv2d for edit {etype}. Skipped.")
                return False

            found_conv = next(((i, m) for i, m in enumerate(seq) if isinstance(m, nn.Conv2d)), (None, None))
            if found_conv[0] is not None:
                idx_conv_in_seq, conv_to_edit = found_conv
            else:
                # print(f"WARN: Stage {loc} has no Conv2d – edit {etype} skipped.")
                return False

        # Perform the edit
        if etype == 4: # ADD_STAGE
            changed = True
            in_ch_for_new_stage = 0
            if loc == 0 and not stages: # Adding the very first stage
                in_ch_for_new_stage = 3 # Image input channels
            elif stages: # Adding after an existing stage stages[loc]
                 # `loc` is the index of the stage *after which* to add.
                 # So, the input channels for the new stage come from `stages[loc]`.
                 # Ensure `loc` is a valid index for `stages`.
                if not (0 <= loc < len(stages)):
                    # print(f"WARN: Invalid loc ({loc}) for ADD_STAGE to get input channels. Skipped.")
                    return False
                seq_ref_for_add = stages[loc] # The stage after which we are adding
                in_ch_for_new_stage = self._get_stage_output_channels(seq_ref_for_add)
            else: # Should not happen if caught by initial loc checks
                # print(f"WARN: Cannot determine in_channels for ADD_STAGE at loc {loc}. Skipped.")
                return False

            if in_ch_for_new_stage == 0:
                # print(f"WARN: Calculated 0 input channels for new stage at loc {loc}. ADD_STAGE aborted.")
                return False

            # Determine norm type for the new stage. Original code used BN or GN based on prev layer.
            # Simpler: always use GN for new stages for safety with batch_size=1.
            # Or, use BN if batch_size > 1, GN if batch_size == 1.
            # Let's use what _swap_bn_to_gn would produce for this channel count if it were a BN.
            if in_ch_for_new_stage <= 1: norm_layer_new_stage = nn.Identity()
            else: norm_layer_new_stage = nn.GroupNorm(1, in_ch_for_new_stage) # Safe default

            new_stage_content = [
                nn.Conv2d(in_ch_for_new_stage, in_ch_for_new_stage, 3, padding=1, bias=False),
                norm_layer_new_stage,
                nn.ReLU(inplace=True), # inplace=True is common
                SafeMaxPool2d(2) # Original had MaxPool2d(2)
            ]
            # If self.batch_size == 1, ensure BNs are GNs.
            # The norm_layer_new_stage is already GN or Identity.
            # If Conv had bias=True and we used BN, then BN should come after Conv.
            # Here, Conv bias=False, Norm follows. This is typical.

            new_s = nn.Sequential(*new_stage_content).to(self.dev)

            # Insert new stage. loc+1 means if loc=0, inserts at index 1 (after first).
            # If stages is empty, loc=0, insert at index 0.
            if not stages and loc == 0:
                 stages.insert(0, new_s)
            else: # stages not empty, or loc > 0
                 stages.insert(loc + 1, new_s)
            # print(f"Added stage after index {loc}, new stage in_ch={in_ch_for_new_stage}")

            # After adding a stage, the number of locations for meta-agent's loc head might change.
            # This is handled by update_meta_agent_heads in the main loop.

        elif etype == 1: # WIDEN
            changed = True
            old_c = conv_to_edit.out_channels
            seq[idx_conv_in_seq] = net2wider_conv(conv_to_edit, widen_factor).to(self.dev)
            new_c = seq[idx_conv_in_seq].out_channels
            # print(f"Widen stage[{loc}][{idx_conv_in_seq}] {old_c} -> {new_c}")
            _fix_tail(loc, idx_conv_in_seq, new_c)

        elif etype == 2: # DEEPEN
            changed = True
            # Insert new Conv (identity-like) after conv_to_edit.
            # The new Conv takes conv_to_edit.out_channels as its input and output.
            deeper_conv = net2deeper_conv(conv_to_edit).to(self.dev) # This creates Conv(oc, oc)
            seq.insert(idx_conv_in_seq + 1, deeper_conv)
            # print(f"Deepen stage[{loc}] after conv at index {idx_conv_in_seq}")
            # No need for _fix_tail if deeper_conv maintains channels and is identity-like.
            # However, if a Norm layer followed conv_to_edit, it now follows deeper_conv.
            # The output channels of deeper_conv are same as conv_to_edit.out_channels.
            # So, subsequent layers' inputs are still matched.
            # _fix_tail(loc, idx_conv_in_seq + 1, deeper_conv.out_channels) # If deeper_conv could change things

        elif etype == 3: # RESIDUAL
            changed = True
            # Insert ResidualBlock after conv_to_edit. It takes conv_to_edit.out_channels.
            res_block = ResidualBlock(conv_to_edit.out_channels).to(self.dev)
            seq.insert(idx_conv_in_seq + 1, res_block)
            # print(f"Residual block added in stage[{loc}] after conv at {idx_conv_in_seq}")
            # ResidualBlock maintains channels, so no _fix_tail needed for channel mismatch.
            # _fix_tail(loc, idx_conv_in_seq + 1, conv_to_edit.out_channels)

        elif etype == 5: # THINNER
            changed = True
            old_c = conv_to_edit.out_channels
            seq[idx_conv_in_seq] = net2thinner_conv(conv_to_edit, 0.5).to(self.dev)
            new_c = seq[idx_conv_in_seq].out_channels
            # print(f"Thin stage[{loc}][{idx_conv_in_seq}] {old_c} -> {new_c}")
            _fix_tail(loc, idx_conv_in_seq, new_c)

        elif etype == 6: # SHALLOW
            # net2shallower_conv modifies seq in-place and returns True if changed.
            # It fuses seq[idx_conv_in_seq] and seq[idx_conv_in_seq+1].
            # The conv at idx_conv_in_seq remains, the next one is removed.
            changed = net2shallower_conv(seq, idx_conv_in_seq)
            if changed:
                # The conv at seq[idx_conv_in_seq] is now the fused one. Its output channels are its own.
                # Need to get its output channels to fix tail.
                # If seq[idx_conv_in_seq] was removed (e.g. if it was the one to be dropped), this is an issue.
                # net2shallower_conv drops seq[idx+1]. So seq[idx] is still there.
                if idx_conv_in_seq < len(seq) and isinstance(seq[idx_conv_in_seq], nn.Conv2d):
                    new_c_after_shallow = seq[idx_conv_in_seq].out_channels
                    _fix_tail(loc, idx_conv_in_seq, new_c_after_shallow)
                # else:
                    # print("WARN: Shallow edit made, but conv at idx_conv_in_seq not found or not Conv. Tail not fixed.")

        elif etype in {7, 8, 9}: # Bottleneck, SE, MBConv
            changed = True
            insert_map = {7: BottleneckBlock, 8: SEBlock, 9: MBConv}
            blk_cls = insert_map[etype]

            # These blocks are inserted after conv_to_edit.
            # Their input channels should be conv_to_edit.out_channels.
            ch_for_block = conv_to_edit.out_channels
            if ch_for_block == 0: # Safety
                # print(f"WARN: Zero channels from preceding conv for {blk_cls.__name__}. Edit skipped.")
                return False

            if etype == 7: # BottleneckBlock(in_ch, out_ch)
                # Assuming it maintains channels if inserted this way: (ch, ch)
                new_block = BottleneckBlock(ch_for_block, ch_for_block).to(self.dev)
            elif etype == 8: # SEBlock(channels)
                new_block = SEBlock(ch_for_block).to(self.dev)
            elif etype == 9: # MBConv(in_ch, out_ch)
                # Assuming it maintains channels: (ch, ch)
                new_block = MBConv(ch_for_block, ch_for_block).to(self.dev)

            seq.insert(idx_conv_in_seq + 1, new_block)
            # print(f"{blk_cls.__name__} inserted in stage[{loc}] after conv at {idx_conv_in_seq}")

            # Output channels of these blocks (if they maintain width as assumed):
            # Bottleneck(C,C) -> C. SE(C) -> C. MBConv(C,C) -> C.
            # So, subsequent layers' inputs are still matched if these blocks maintain channels.
            # _fix_tail(loc, idx_conv_in_seq + 1, ch_for_block)


        elif etype == 10: # INCEPTION
            changed = True
            ch_for_inception = conv_to_edit.out_channels
            if ch_for_inception == 0: return False # Safety

            inc_block = InceptionBlock(ch_for_inception).to(self.dev)
            seq.insert(idx_conv_in_seq + 1, inc_block) # Insert after the conv
            new_c_after_inception = inc_block.out_channels
            # print(f"InceptionBlock inserted in stage[{loc}] after conv at {idx_conv_in_seq}. Output channels: {new_c_after_inception}")
            # Inception block changes channel count, so _fix_tail is crucial.
            # The tail starts *after* the newly inserted Inception block.
            _fix_tail(loc, idx_conv_in_seq + 1, new_c_after_inception)

        elif etype == 11: # SUPER-WIDEN (factor 2 in prompt, not 4)
            changed = True
            old_c = conv_to_edit.out_channels
            # Prompt uses factor=(2) for net2wider_conv.
            # Original comment said "quadruple channels", but code uses factor=2.
            # Let's use factor from prompt's code.
            seq[idx_conv_in_seq] = net2wider_conv(conv_to_edit, factor=2.0).to(self.dev)
            new_c = seq[idx_conv_in_seq].out_channels
            # print(f"Super-Widen (x2) stage[{loc}][{idx_conv_in_seq}] {old_c} -> {new_c}")
            _fix_tail(loc, idx_conv_in_seq, new_c)


        if changed:
            # These sync functions should be robust enough to call after any change.
            self._sync_seblocks() # Ensure SEBlocks match their input channels within each stage
            self._sync_backbone_channels() # Ensure inter-stage channel counts match
            self._sync_classifier() # Ensure FC layer matches output of last backbone stage

            # If batch_size is 1, BNs might have been swapped to GNs.
            # This should be re-applied if the network structure changed.
            # _swap_bn_to_gn is called in __init__ if BS=1.
            # If new layers (BNs) were added, they also need swapping if BS=1.
            if self.batch_size == 1:
                self._swap_bn_to_gn(self.target)

        return changed


    def _train_post(self, max_epochs: int, patience: int = 5,
                    min_delta: float = 0.0005):
        best_acc = -1.0 # Initialize with a value lower than any possible accuracy
        no_imp   = 0
        last_loss = float('inf') # Initialize with a high value for loss
        last_acc = 0.0       # Initialize with a low value for accuracy

        # Check if validation loader is empty, if so, cannot validate.
        if not self.vall:
            # print("Warning: Validation loader is empty. Post-edit training will run for max_epochs without early stopping.")
            for ep in range(1, max_epochs + 1):
                self._train_one_epoch(self.trl, 1) # Train one epoch
                # Cannot get validation loss/acc, so just log training progress if available
                # print(f"Post-edit epoch {ep}/{max_epochs} (no validation)")
            # Return some placeholder or last known training metrics if available
            return 0.0, 0.0 # Placeholder for no validation case


        # Initial validation before starting post-train loop to get a baseline
        # This might be useful if max_epochs is very small (e.g., 0 or 1)
        # However, typical use is max_epochs >= 1.
        # Let's get initial val acc for best_acc.
        # val_loss_before, val_acc_before = self._validate()
        # best_acc = val_acc_before
        # print(f"Post-edit initial val: loss={val_loss_before:.4f} acc={val_acc_before:.4f}")


        for ep in range(1, max_epochs + 1):
            self._train_one_epoch(self.trl, 1)          # one epoch
            current_vloss, current_vacc = self._validate()
            # print(f"Post-edit epoch {ep}/{max_epochs} — val loss={current_vloss:.4f}  acc={current_vacc:.4f}")

            # Update last known metrics
            last_loss, last_acc = current_vloss, current_vacc

            if current_vacc > best_acc + min_delta:
                best_acc = current_vacc
                no_imp   = 0
                # print(f"  New best val acc: {best_acc:.4f}")
            else:
                no_imp  += 1
                if no_imp >= patience:
                    # print(f"  Early stopping: Val acc ({current_vacc:.4f}) hasn't improved by >{min_delta} over best ({best_acc:.4f}) for {patience} epochs.")
                    break

        # Return the metrics from the *last completed validation*, not necessarily the best.
        # This matches typical behavior where the model state at the end of training is used.
        return last_loss, last_acc

    def _train_one(self): # Wrapper for _train_one_epoch with self.trl and 1 epoch
        self._train_one_epoch(self.trl, 1)






    def _get_layer_recipe(self, layer: nn.Module) -> dict:
        """Helper to create a recipe for a single layer."""
        if isinstance(layer, nn.Conv2d):
            return {
                'type': 'Conv2d',
                'params': {
                    'in_channels': layer.in_channels, 'out_channels': layer.out_channels,
                    'kernel_size': layer.kernel_size, 'stride': layer.stride,
                    'padding': layer.padding, 'dilation': layer.dilation,
                    'groups': layer.groups, 'bias': layer.bias is not None
                }
            }
        elif isinstance(layer, nn.BatchNorm2d):
            return {
                'type': 'BatchNorm2d',
                'params': {'num_features': layer.num_features, 'eps': layer.eps,
                           'momentum': layer.momentum, 'affine': layer.affine,
                           'track_running_stats': layer.track_running_stats}
            }
        elif isinstance(layer, nn.GroupNorm):
            return {
                'type': 'GroupNorm',
                'params': {'num_groups': layer.num_groups, 'num_channels': layer.num_channels,
                           'eps': layer.eps, 'affine': layer.affine}
            }
        elif isinstance(layer, nn.ReLU):
            return {'type': 'ReLU', 'params': {'inplace': layer.inplace}}
        elif isinstance(layer, SafeMaxPool2d): # Your custom SafeMaxPool2d
            return {'type': 'SafeMaxPool2d', 'params': {'kernel_size': layer.kernel_size, 'stride': layer.stride}}
        elif isinstance(layer, nn.AdaptiveAvgPool2d):
            return {'type': 'AdaptiveAvgPool2d', 'params': {'output_size': layer.output_size}}
        elif isinstance(layer, nn.Dropout):
            return {'type': 'Dropout', 'params': {'p': layer.p, 'inplace': layer.inplace}}

        # Custom Blocks (ensure all necessary construction params are saved)
        elif isinstance(layer, ResidualBlock):
            return {
                'type': 'ResidualBlock',
                'params': {'c': layer.conv1.in_channels, # Assuming c is in_channels
                           'p_drop': layer.drop.p}
            }
        elif isinstance(layer, BottleneckBlock):
            # Assuming BottleneckBlock(in_ch, out_ch, stride, expansion)
            # We need to deduce these from the layer's components if not stored directly
            return {
                'type': 'BottleneckBlock',
                'params': {
                    'in_ch': layer.conv1.in_channels,
                    'out_ch': layer.conv3.out_channels, # This is the block's out_ch
                    'stride': layer.conv2.stride[0] if isinstance(layer.conv2.stride, tuple) else layer.conv2.stride, # conv2 determines stride
                     # Expansion needs to be inferred or fixed. If it's fixed (e.g. 4), great.
                     # If not, it's mid_ch = out_ch // expansion. So expansion = out_ch / mid_ch
                     # mid_ch = layer.conv1.out_channels
                    'expansion': layer.conv3.out_channels // layer.conv1.out_channels if layer.conv1.out_channels > 0 else 4 # Default to 4 if problematic
                }
            }
        elif isinstance(layer, SEBlock):
            return {
                'type': 'SEBlock',
                'params': {'channels': layer.fc[0].in_features,
                           'reduction': layer.fc[0].in_features // layer.fc[0].out_features if layer.fc[0].out_features > 0 else 16} # Infer reduction
            }
        elif isinstance(layer, MBConv):
            # MBConv(in_ch, out_ch, stride, exp_ratio)
            return {
                'type': 'MBConv',
                'params': {
                    'in_ch': layer.block[0].in_channels, # First conv in block
                    'out_ch': layer.out_channels,       # Stored attribute
                    'stride': layer.block[3].stride[0] if isinstance(layer.block[3].stride, tuple) else layer.block[3].stride, # Depthwise conv
                    'exp_ratio': layer.block[0].out_channels // layer.block[0].in_channels if layer.block[0].in_channels > 0 else 6 # From expand conv
                }
            }
        elif isinstance(layer, InceptionBlock):
            # InceptionBlock(c: int) where c is in_channels
            return {'type': 'InceptionBlock', 'params': {'c': layer.in_channels}}
        elif isinstance(layer, nn.Identity):
            return {'type': 'Identity', 'params': {}}
        else:
            raise ValueError(f"Unsupported layer type for recipe generation: {type(layer)}")

    def _build_target_from_recipe(self, recipe: dict, base_width_for_fc: int, num_classes: int):
        """Builds TargetCNN from a recipe."""
        print("Building TargetCNN from recipe...")

        # The recipe directly describes the stages
        # The TargetCNN __init__ creates some initial stages. We will replace them.
        # For FC layer, we need the output channels of the last stage from the recipe.

        target_model = TargetCNN(base=base_width_for_fc, num_classes=num_classes) # Create a shell

        reconstructed_stages = nn.ModuleList()
        last_stage_out_channels = base_width_for_fc * 4 # Default if no stages

        for stage_recipe_list in recipe['stages']:
            stage_layers = []
            current_stage_out_channels = 0 # Track for this stage
            for layer_info in stage_recipe_list:
                layer_type = layer_info['type']
                params = layer_info['params']

                if layer_type == 'Conv2d':
                    module = nn.Conv2d(**params)
                    current_stage_out_channels = params['out_channels']
                elif layer_type == 'BatchNorm2d':
                    module = nn.BatchNorm2d(**params)
                    # BN doesn't change channel count, uses num_features which should match prev conv out
                elif layer_type == 'GroupNorm':
                    module = nn.GroupNorm(**params)
                elif layer_type == 'ReLU':
                    module = nn.ReLU(**params)
                elif layer_type == 'SafeMaxPool2d':
                    module = SafeMaxPool2d(**params)
                elif layer_type == 'AdaptiveAvgPool2d':
                    module = nn.AdaptiveAvgPool2d(**params)
                elif layer_type == 'Dropout': # Added Dropout
                    module = nn.Dropout(**params)

                # Custom Blocks
                elif layer_type == 'ResidualBlock':
                    module = ResidualBlock(**params)
                    current_stage_out_channels = params['c'] # Residual block maintains channels
                elif layer_type == 'BottleneckBlock':
                    module = BottleneckBlock(**params)
                    current_stage_out_channels = params['out_ch']
                elif layer_type == 'SEBlock':
                    module = SEBlock(**params)
                    # SEBlock maintains channels, which is params['channels']
                elif layer_type == 'MBConv':
                    module = MBConv(**params)
                    current_stage_out_channels = params['out_ch']
                elif layer_type == 'InceptionBlock':
                    module = InceptionBlock(**params)
                    # InceptionBlock calculates its out_channels internally based on 'c'
                    # We need to instantiate it to get its out_channels for tracking
                    temp_inception = InceptionBlock(params['c'])
                    current_stage_out_channels = temp_inception.out_channels
                    module = temp_inception # Use the one we made
                elif layer_type == 'Identity':
                    module = nn.Identity(**params)
                else:
                    raise ValueError(f"Unsupported layer type in recipe: {layer_type}")
                stage_layers.append(module)

            reconstructed_stages.append(nn.Sequential(*stage_layers))
            if current_stage_out_channels > 0 : # Update last_stage_out_channels if stage defined width
                last_stage_out_channels = current_stage_out_channels

        target_model.stages = reconstructed_stages

        # Reconstruct the FC layer based on the output of the last stage in the recipe
        # and the pre_fc_drop layer
        target_model.pre_fc_drop = nn.Dropout(p=recipe['pre_fc_dropout_p'])
        if last_stage_out_channels > 0:
             target_model.fc = nn.Linear(last_stage_out_channels, num_classes)
        else: # Fallback if something went wrong with channel tracking
            print(f"Warning: Could not determine last stage out channels from recipe ({last_stage_out_channels}). Using default FC.")
            target_model.fc = nn.Linear(base_width_for_fc * 4, num_classes)


        # The 'init_widths' attribute of TargetCNN might also need to be set if it's used elsewhere.
        # For now, we focus on reconstructing the operational parts (stages, fc).
        # If self.target.init_widths is used critically after this, it should be reconstructed too.
        # For example, by storing it in the recipe or recalculating:
        # target_model.init_widths = [s[0].out_channels for s in target_model.stages if s and isinstance(s[0], nn.Conv2d)]
        # This is a simplification. The original init_widths was fixed based on 'base'.
        # If the search significantly changed initial stage widths, this might need care.
        # For now, assume the original init_widths logic based on 'base' is sufficient context,
        # or that `target_model.widths()` provides the dynamic view needed.

        return target_model.to(self.dev)


    def train(self, iters: int = 100, model_save_path: str = '/content/drive/My Drive/deiti_final_architecture.pth'):
        # model_save_path will now be for the .json recipe
        recipe_save_path = model_save_path.replace(".pth", "_recipe.json")

        meta_opt = optim.Adam(self.meta.parameters(), lr=LEARNING_RATE)
        history  = []
        self.tgt_opt, self.tgt_sch = self._new_opt_sched()

        for itr in range(1, iters + 1):
            t0 = time.time()
            print(f"\n=== ITERATION {itr:03d}/{iters} ===")

            steps_per_ep = len(self.trl) if self.batches_per_epoch is None else self.batches_per_epoch
            if steps_per_ep == 0:
                print("Warning: Training loader is empty. Cannot proceed.")
                break
            pre_steps_for_sched_ff = self.pre_epochs * steps_per_ep

            self.tgt_opt, self.tgt_sch = self._new_opt_sched() # Fresh opt/sched for pre-edit

            print("--- Pre-edit Training ---")
            self._train_one_epoch(self.trl, self.pre_epochs)
            vloss, vacc = self._validate()
            print(f"→ Pre-val  loss={vloss:.4f}  acc={vacc:.4f}")

            init_widths_tensor = torch.tensor(self.target.init_widths, dtype=torch.float32, device=self.dev)
            current_widths_tensor = torch.tensor(self.target.widths(), dtype=torch.float32, device=self.dev)
            scale_factor = init_widths_tensor.mean() if init_widths_tensor.numel() > 0 else torch.tensor(1.0, device=self.dev)
            if scale_factor == 0: scale_factor = torch.tensor(1.0, device=self.dev)
            mean_norm_width = (current_widths_tensor / scale_factor).mean().item() if current_widths_tensor.numel() > 0 else 0.0
            var_norm_width = (current_widths_tensor / scale_factor).var().item() if current_widths_tensor.numel() > 1 else 0.0

            state_features = torch.tensor([[
                vloss, vacc, mean_norm_width, var_norm_width,
                len(self.target.stages) / 10.0, itr / float(iters)
            ]], dtype=torch.float32, device=self.dev).view(1, 1, STATE_DIM)

            history.append(state_features)
            if len(history) > MAX_HISTORY_LEN: history.pop(0)
            seq_for_meta = torch.cat(history, dim=1)

            with torch.no_grad():
                le, ll, lm, _ = self.meta(seq_for_meta)

            try: ae = Categorical(logits=le).sample()
            except RuntimeError: ae = torch.tensor([0], device=self.dev)
            try: al = Categorical(logits=ll).sample()
            except RuntimeError: al = torch.tensor([0], device=self.dev)

            am = torch.tensor([0], dtype=torch.long, device=self.dev)
            current_meta_dist_for_entropy = None
            if itr > 0 and self.self_edit_int > 0 and itr % self.self_edit_int == 0:
                try:
                    current_meta_dist_for_entropy = Categorical(logits=lm)
                    am = current_meta_dist_for_entropy.sample()
                except RuntimeError: am = torch.tensor([0], device=self.dev)

            print(f"Actions: edit={ae.item()}, loc={al.item()}, meta={am.item()}")

            if itr > 0 and self.self_edit_int > 0 and itr % self.self_edit_int == 0 and am.item() != 0:
                print(f"Performing meta-edit on meta-agent: type {am.item()}")
                if am.item() == 1: self.meta.lstm = net2wider_lstm(self.meta.lstm, 2).to(self.dev)
                elif am.item() == 2: self.meta.lstm = net2deeper_lstm(self.meta.lstm).to(self.dev)
                update_meta_agent_heads(
                    self.meta, self.meta.lstm.hidden_size,
                    self.meta.head_e.out_features, max(1, len(self.target.stages)), self.meta.head_m.out_features
                )
                meta_opt = optim.Adam(self.meta.parameters(), lr=LEARNING_RATE)

            backup_target_for_rollback_gpu = copy.deepcopy(self.target) # Keep on GPU

            edit_applied_successfully = self._apply_edit(ae.item(), al.item(), widen_factor=1.5)

            print(f"Parameters after edit attempt: {sum(p.numel() for p in self.target.parameters()):,}")
            if not edit_applied_successfully:
                print("  (Edit resulted in no change to the model structure or was skipped)")

            self.tgt_opt, self.tgt_sch = self._new_opt_sched() # Opt/sched for (potentially) new arch
            for _ in range(pre_steps_for_sched_ff): self.tgt_sch.step()

            print(f"--- Post-edit Training (Max {self.post_epochs} epochs) ---")
            vloss2, vacc2 = self._train_post(max_epochs=self.post_epochs, patience=5, min_delta=5e-4)
            print(f"→ Post-val loss={vloss2:.4f}  acc={vacc2:.4f}")

            if edit_applied_successfully and (vacc2 < vacc - 0.05):
                print(f"↻ Rollback: Δacc = {vacc2 - vacc:+.3f}. Restoring model.")
                self.target = backup_target_for_rollback_gpu # Already on self.dev
                self.tgt_opt, self.tgt_sch = self._new_opt_sched()
                for _ in range(pre_steps_for_sched_ff): self.tgt_sch.step()
                vloss2, vacc2 = vloss, vacc
            elif not edit_applied_successfully:
                 print("  (No actual edit applied/skipped, metrics remain pre-edit)")
                 vloss2, vacc2 = vloss, vacc

            param_count_final_iter = sum(p.numel() for p in self.target.parameters())
            print(f"Parameters at end of iteration: {param_count_final_iter:,}")
            reward_value = 4.0 * (vacc2 - vacc)
            print(f"Reward = {reward_value:+.4f}")

            meta_opt.zero_grad(set_to_none=True)
            with autocast(enabled=self.amp_enabled):
                le2, ll2, lm2, lv2_tensor = self.meta(seq_for_meta)
                log_prob_e = Categorical(logits=le2).log_prob(ae)
                log_prob_l = Categorical(logits=ll2).log_prob(al)
                log_prob_m = torch.tensor(0.0, device=self.dev)
                entropy_m = torch.tensor(0.0, device=self.dev)
                if itr > 0 and self.self_edit_int > 0 and itr % self.self_edit_int == 0:
                    # Recreate dist from current policy for chosen action 'am'
                    meta_action_dist_current = Categorical(logits=lm2)
                    log_prob_m = meta_action_dist_current.log_prob(am)
                    entropy_m  = meta_action_dist_current.entropy().mean()

                total_log_prob = log_prob_e + log_prob_l + log_prob_m
                lv2_squeezed = lv2_tensor.squeeze()
                advantage = reward_value - lv2_squeezed.detach()
                loss_actor = -total_log_prob * advantage
                loss_critic = (reward_value - lv2_squeezed).pow(2)
                entropy_bonus = (Categorical(logits=le2).entropy().mean() +
                                 Categorical(logits=ll2).entropy().mean() + entropy_m)
                total_meta_loss = loss_actor + 0.5 * loss_critic - self.ent_coef * entropy_bonus

            self.scaler_meta.scale(total_meta_loss).backward()
            self.scaler_meta.unscale_(meta_opt)
            clip_grad_norm_(self.meta.parameters(), MAX_GRAD_NORM)
            self.scaler_meta.step(meta_opt)
            self.scaler_meta.update()

            del backup_target_for_rollback_gpu
            print(f"Iter {itr:03d} done in {time.time()-t0:.1f}s. Meta Loss: {total_meta_loss.item():.4f}")

        # END OF SEARCH ITERATIONS
        print("\n=== SEARCH COMPLETE ===")

        # Generate recipe for the final architecture
        final_recipe = {'stages': [], 'pre_fc_dropout_p': self.target.pre_fc_drop.p}
        for stage_seq in self.target.stages:
            stage_recipe_list = []
            for layer in stage_seq:
                try:
                    stage_recipe_list.append(self._get_layer_recipe(layer))
                except ValueError as e:
                    print(f"Error generating recipe for layer {type(layer)}: {e}. Skipping layer.")
                    # Decide how to handle: skip layer, use placeholder, or halt.
                    # For now, it will skip, which might lead to an incomplete recipe.
            final_recipe['stages'].append(stage_recipe_list)

        # Add other global params if needed, e.g., initial base width for TargetCNN, num_classes
        # These are passed to _build_target_from_recipe during loading anyway.

        print(f"Saving final architecture recipe to: {recipe_save_path}")
        try:
            with open(recipe_save_path, 'w') as f:
                json.dump(final_recipe, f, indent=2) # indent for readability
            print(f"Final architecture recipe saved successfully to {recipe_save_path}")
            # model_summary is on the self.target object, which is fine.
            model_summary(self.target, "Final Architecture (In Memory)")
        except Exception as e:
            print(f"Error saving model recipe to {recipe_save_path}: {e}")

    def _validate(self):
        """Validate without using AMP (fp32-only forward passes)."""
        self.target.eval()
        tot_loss = tot_acc = cnt = 0.0

        if not self.vall or len(self.vall) == 0:
            print("Warning (_validate): Validation loader is empty or has zero length. Skipping validation.")
            return 0.0, 0.0

        with torch.no_grad():
            for x_val, y_val_hard in self.vall:
                x_val, y_val_hard = x_val.to(self.dev), y_val_hard.to(self.dev)

                if not torch.isfinite(x_val).all():
                    print(f"Warning (_validate): Non-finite input x_val encountered. Skipping batch.")
                    continue

                # Pure fp32 forward (no autocast)
                logits = self.target(x_val)

                if not torch.isfinite(logits).all():
                    print(f"Warning (_validate): Non-finite logits encountered. Logits sum: {logits.sum().item()}. Skipping batch.")
                    continue

                num_classes = self.target.fc.out_features
                if num_classes <= 0:
                    print(f"Warning (_validate): Invalid num_classes ({num_classes}) from model. Skipping batch.")
                    continue

                # Convert hard labels to one-hot for SmoothCE
                y_soft = self._to_onehot(y_val_hard, num_classes)
                loss   = self.crit(logits, y_soft)

                if not torch.isfinite(loss):
                    val_loss_item = loss.item()
                    print(f"Warning (_validate): Non-finite validation loss encountered (value: {val_loss_item}). Skipping batch.")
                    continue

                tot_loss += loss.item()
                tot_acc  += (logits.argmax(dim=1) == y_val_hard).float().mean().item()
                cnt     += 1

                if self.batches_per_epoch is not None and cnt >= self.batches_per_epoch:
                    break

        if cnt == 0:
            print("Warning (_validate): No batches successfully processed in validation. Returning 0 loss/acc.")
            return 0.0, 0.0

        return tot_loss / cnt, tot_acc / cnt



    def _train_full(self, epochs: int = 300, learning_rate: float = LEARNING_RATE,
                    grad_clip_value: float = None,
                    optimizer_eps: float = 1e-7):
        if not list(self.target.parameters()):
            print("Error (_train_full): Target model has no parameters. Cannot start full training.")
            return

        if not self.trl or len(self.trl) == 0:
            print("Error (_train_full): Training loader is empty or has zero length. Cannot perform full training.")
            return

        print(f"--- Starting _train_full with LR: {learning_rate:.1e}, AdamW_eps: {optimizer_eps:.1e}, "
              f"GradClipNorm: {MAX_GRAD_NORM}, GradClipVal: {grad_clip_value} for {epochs} epochs ---")

        # Create a fresh optimizer and OneCycleLR scheduler for full training
        full_train_opt = optim.AdamW(
            self.target.parameters(),
            lr=learning_rate,
            weight_decay=1e-4,
            eps=optimizer_eps
        )
        full_train_sch = OneCycleLR(
            full_train_opt,
            max_lr=learning_rate * 5,
            steps_per_epoch=len(self.trl),
            epochs=epochs,
            pct_start=0.25,
            div_factor=25,
            final_div_factor=1e4
        )

        # Save any existing optimizer/scheduler references so we can restore later
        original_tgt_opt, original_tgt_sch = getattr(self, 'tgt_opt', None), getattr(self, 'tgt_sch', None)

        opt_for_full_train = full_train_opt
        sch_for_full_train = full_train_sch

        best_val_acc_full_train = -1.0
        patience_full_train = 20
        min_delta_full_train = 0.001
        no_improvement_epochs_full_train = 0
        early_stopping_grace_epochs = max(10, epochs // 10)

        for ep in range(1, epochs + 1):
            tot_loss_epoch = 0.0
            tot_acc_epoch = 0.0
            cnt_batch = 0
            skipped_batches_due_to_nan_grad = 0

            self.target.train()

            for batch_idx, (x, y_soft) in enumerate(self.trl):
                x, y_soft = x.to(self.dev), y_soft.to(self.dev)

                # Check for non-finite inputs
                if not torch.isfinite(x).all():
                    print(f"Warning [FULL TRAIN]: Non-finite input x at ep {ep}, batch {batch_idx}. Skipping batch.")
                    continue

                # Zero gradients at the start of each batch
                opt_for_full_train.zero_grad(set_to_none=True)

                # ---- Forward pass (pure FP32) ----
                logits = self.target(x)

                # Check for non-finite logits
                if not torch.isfinite(logits).all():
                    print(f"Warning [FULL TRAIN]: Non-finite logits at ep {ep}, batch {batch_idx}. "
                          f"Logits sum: {logits.sum().item()}. Skipping batch.")
                    continue

                loss = self.crit(logits, y_soft)

                # Check for non-finite loss
                if not torch.isfinite(loss):
                    val_loss_item = loss.item() if hasattr(loss, 'item') else float('nan')
                    print(f"Warning [FULL TRAIN]: non-finite loss (value: {val_loss_item}) at ep {ep}, batch {batch_idx}. "
                          "Skipping batch.")
                    opt_for_full_train.zero_grad(set_to_none=True)
                    continue

                # ---- Backward pass ----
                loss.backward()

                # Optionally clip by norm first
                clip_grad_norm_(self.target.parameters(), MAX_GRAD_NORM)

                # Optionally clip by value
                if grad_clip_value is not None:
                    torch.nn.utils.clip_grad_value_(self.target.parameters(), grad_clip_value)

                # Check for non-finite gradients after clipping
                found_non_finite_grad = False
                for param in self.target.parameters():
                    if param.grad is not None and not torch.isfinite(param.grad).all():
                        print(f"Warning [FULL TRAIN]: Detected non-finite grad for a param at ep {ep}, "
                              f"batch {batch_idx} *after* clip, before optimizer.step().")
                        found_non_finite_grad = True
                        break

                if found_non_finite_grad:
                    opt_for_full_train.zero_grad(set_to_none=True)
                    skipped_batches_due_to_nan_grad += 1
                    continue

                # ---- Optimizer step ----
                opt_for_full_train.step()
                sch_for_full_train.step()

                # ---- Compute training metrics ----
                with torch.no_grad():
                    preds = logits.argmax(dim=1)
                    true_labels_from_soft = y_soft.argmax(dim=1)
                    acc_batch = (preds == true_labels_from_soft).float().mean().item()

                tot_loss_epoch += loss.item()
                tot_acc_epoch += acc_batch
                cnt_batch += 1

                if self.batches_per_epoch is not None and cnt_batch >= self.batches_per_epoch:
                    break

            # ---- End of epoch bookkeeping ----
            current_lr_epoch = opt_for_full_train.param_groups[0]['lr']

            if skipped_batches_due_to_nan_grad > 0:
                print(f"Info [FULL TRAIN]: Skipped {skipped_batches_due_to_nan_grad}/"
                      f"{len(self.trl) if self.trl else 0} batches in epoch {ep} due to non-finite gradients.")

            if cnt_batch == 0:
                print(f"Warning [FULL TRAIN]: No batches successfully processed in epoch {ep}. "
                      "Skipping epoch summary and validation.")
                avg_loss_epoch = 0.0
                avg_acc_epoch = 0.0
                vloss, vacc = 0.0, 0.0
            else:
                avg_loss_epoch = tot_loss_epoch / cnt_batch
                avg_acc_epoch = tot_acc_epoch / cnt_batch
                if self.vall and len(self.vall) > 0:
                    vloss, vacc = self._validate()
                else:
                    print(f"Info [FULL TRAIN]: No validation loader found or it's empty. "
                          f"Skipping validation for epoch {ep}.")
                    vloss, vacc = 0.0, 0.0

            print(f"[FULL TRAIN] Epoch {ep:03d}/{epochs} | "
                  f"Train Loss: {avg_loss_epoch:.4f} Acc: {avg_acc_epoch:.4f} | "
                  f"Val Loss: {vloss:.4f} Acc: {vacc:.4f} | LR: {current_lr_epoch:.6f}")

            # ---- Early stopping check ----
            if vacc > best_val_acc_full_train + min_delta_full_train:
                best_val_acc_full_train = vacc
                no_improvement_epochs_full_train = 0
            else:
                if ep > early_stopping_grace_epochs:
                    no_improvement_epochs_full_train += 1

            if no_improvement_epochs_full_train >= patience_full_train:
                print(f"[FULL TRAIN] Early stopping at epoch {ep} due to no improvement for "
                      f"{patience_full_train} epochs after grace period.")
                break

        # Restore self.tgt_opt/sch if they existed before
        if original_tgt_opt is not None:
            self.tgt_opt = original_tgt_opt
        if original_tgt_sch is not None:
            self.tgt_sch = original_tgt_sch



    def _apply_zero_gamma_to_residual_tails(self, m: nn.Module):
        """Applies zero-gamma initialization to the final norm layer of known residual blocks."""
        if isinstance(m, ResidualBlock): # Your DEITI ResidualBlock
            if hasattr(m, 'bn2') and isinstance(m.bn2, (nn.BatchNorm2d, nn.GroupNorm)):
                if getattr(m.bn2, 'affine', False): # Check if affine is True
                     if hasattr(m.bn2, 'weight') and m.bn2.weight is not None:
                        # print(f"Zero-gamma init for ResidualBlock final norm: {type(m.bn2)}")
                        nn.init.constant_(m.bn2.weight, 0)
        elif isinstance(m, BottleneckBlock): # Your DEITI BottleneckBlock
            if hasattr(m, 'bn3') and isinstance(m.bn3, (nn.BatchNorm2d, nn.GroupNorm)):
                 if getattr(m.bn3, 'affine', False):
                     if hasattr(m.bn3, 'weight') and m.bn3.weight is not None:
                        # print(f"Zero-gamma init for BottleneckBlock final norm: {type(m.bn3)}")
                        nn.init.constant_(m.bn3.weight, 0)


    def train_final(self, epochs: int = 300,
                    model_path: str = '/content/drive/My Drive/deiti_final_architecture.pth',
                    final_lr: float = LEARNING_RATE / 10,
                    final_optimizer_eps: float = 1e-7,
                    final_grad_clip_value: float = None):
        recipe_load_path = model_path.replace(".pth", "_recipe.json")

        print("\n=== FINAL TRAINING STAGE ===")
        print(f"Attempting to load architecture recipe from: {recipe_load_path}")
        print(f"Final training LR: {final_lr:.1e}, AdamW_eps: {final_optimizer_eps:.1e}, GradClipValue: {final_grad_clip_value}")

        try:
            with open(recipe_load_path, 'r') as f:
                loaded_recipe = json.load(f)

            num_classes = 10
            if self.trl:
                try:
                    actual_dataset = self.trl.dataset
                    while hasattr(actual_dataset, 'dataset'): actual_dataset = actual_dataset.dataset # Handle Subset
                    if hasattr(actual_dataset, 'classes') and actual_dataset.classes: num_classes = len(actual_dataset.classes)
                    elif hasattr(actual_dataset, 'num_classes') and actual_dataset.num_classes: num_classes = actual_dataset.num_classes # For some custom datasets
                except Exception as e:
                    print(f"Could not reliably infer num_classes from dataloader, defaulting to {num_classes}. Error: {e}")

            initial_base_width = 32 # Default base for TargetCNN reconstruction shell

            self.target = self._build_target_from_recipe(loaded_recipe,
                                                         base_width_for_fc=initial_base_width,
                                                         num_classes=num_classes)
            print("Architecture built successfully from recipe.")
            model_summary(self.target, "Loaded Final Architecture from Recipe")

        except FileNotFoundError:
            print(f"Error: Recipe file not found at {recipe_load_path}. Cannot proceed with final training.")
            return
        except Exception as e:
            print(f"Error loading/building model from recipe {recipe_load_path}: {e}")
            raise
            return

        # Forcing GroupNorm can be a strong stabilization measure if BNs are problematic even with large batch size
        # This is more aggressive than just handling self.batch_size == 1
        print("Forcing all BatchNorm2d to GroupNorm(1,C) or Identity for maximal stability in final training.")
        # self._swap_bn_to_gn(self.target) # This converts BN to GN(1,C) or Identity

        print("Re-initializing weights of the loaded architecture...")

        def _reset_params_if_available(m):
            if hasattr(m, "reset_parameters"):
                try: m.reset_parameters()
                except Exception as e_reset: print(f"Note: Could not call reset_parameters on {type(m)}: {e_reset}")
        self.target.apply(_reset_params_if_available) # General reset first

        def _reinit_weights(m):
            if isinstance(m, (nn.Conv2d, nn.Conv1d, nn.Conv3d)):
                try: nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                except Exception as e_kaiming: print(f"Note: Kaiming init failed for {type(m)}. Error: {e_kaiming}")
                if m.bias is not None: nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                try: nn.init.xavier_uniform_(m.weight)
                except Exception as e_xavier: print(f"Note: Xavier init failed for {type(m)}. Error: {e_xavier}")
                if m.bias is not None: nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.GroupNorm): # After _swap_bn_to_gn, we mainly have GN or Identity
                 if getattr(m, 'affine', False):
                    if hasattr(m, 'weight') and m.weight is not None: nn.init.constant_(m.weight, 1)
                    if hasattr(m, 'bias') and m.bias is not None: nn.init.constant_(m.bias, 0)

        print("Applying specific re-initialization for Conv/Linear/GroupNorm layers...")
        self.target.apply(_reinit_weights)

        print("Applying zero-gamma initialization to residual tails after all other inits...")
        self.target.apply(self._apply_zero_gamma_to_residual_tails)

        print("Weights re-initialized for training from scratch.")
        self.target = self.target.to(self.dev)

        self._train_full(epochs=epochs, learning_rate=final_lr,
                         grad_clip_value=final_grad_clip_value,
                         optimizer_eps=final_optimizer_eps)

        print("\n=== FINAL TEST EVALUATION ===")
        self.target.eval()
        total_correct_test, total_samples_test = 0, 0

        if not self.tsl or len(self.tsl) == 0:
            print("Warning (FINAL TEST): Test loader empty or has zero length. Skipping final test.")
        else:
            with torch.no_grad():
                for x_test, y_test in self.tsl:
                    x_test, y_test = x_test.to(self.dev), y_test.to(self.dev)
                    if not torch.isfinite(x_test).all():
                        print(f"Warning (FINAL TEST): Non-finite input x_test. Skipping test batch.")
                        continue
                    with torch.amp.autocast(device_type=self.dev.type, enabled=self.amp_enabled):
                        logits_test = self.target(x_test)

                    if not torch.isfinite(logits_test).all():
                        print(f"Warning (FINAL TEST): Non-finite logits encountered for a test batch. Logits sum: {logits_test.sum().item()}. This batch's accuracy may be affected or skipped.")
                        continue

                    predictions_test = logits_test.argmax(1)
                    total_correct_test += (predictions_test == y_test).sum().item()
                    total_samples_test += y_test.size(0)

            if total_samples_test > 0:
                print(f"Final Test Acc: {total_correct_test/total_samples_test:.4f}")
            else:
                print("Warning (FINAL TEST): No samples successfully processed in test set for final evaluation.")


[0mCollecting torch-optimizer
  Downloading torch_optimizer-0.3.0-py3-none-any.whl.metadata (55 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.9/55.9 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
Collecting pytorch-ranger>=0.1.1 (from torch-optimizer)
  Downloading pytorch_ranger-0.1.1-py3-none-any.whl.metadata (509 bytes)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.5.0->torch-optimizer)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.5.0->torch-optimizer)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.5.0->torch-optimizer)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.5.0->torch-optimizer)
  Downloading nvidia_cudnn_cu12-9

In [None]:
# Cell 2: Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Cell 3: ENTRY POINT - ARCHITECTURE SEARCH
#!/usr/bin/env python
# coding: utf-8
"""
DEITI — Dynamic Editing with Iterative Targeted Improvement - Search Phase
Supports ADD_STAGE edits with robust pooling logic to avoid zero-dimension errors,
full multi-epoch training, AMP everywhere, detailed per-iteration logging,
gradient accumulation to reduce memory, gradient clipping to prevent explosion,
and tunable hyperparameters.
"""

# Hyperparameters (easy to edit)
# These are now mostly defaults in DEITI class, but can be overridden here.


if __name__ == "__main__":
    from google.colab import runtime # For unassigning runtime

    # Define the base path for saving the found architecture recipe.
    # The .train() method will append "_recipe.json" to this path.
    # This path should be on your Google Drive.
    architecture_base_save_path = '/content/drive/My Drive/deiti_final_architecture' # Note: No .pth or .json here

    try:
        # Initialize DEITI for the search phase
        search_instance = DEITI(
            pre_epochs=3,       # As per your example
            post_epochs=3,      # As per your example
            batches_per_epoch=None, # As per your example
            batch_size= 128   # As per your example
        )
        # Pass the base save path to the train method
        search_instance.train(
            iters=100,          # Number of search iterations
            model_save_path=architecture_base_save_path # Pass the base path
        )
    except Exception:
        # This structure is preserved from your original code
        print("An error occurred during the search phase. Raising exception.")
        raise # First raise
        """try:
            print("Attempting to unassign runtime after error...")
            runtime.unassign()
        except Exception:
            pass
        raise # Second raise (likely unreachable if first one stops cell)"""
    else:
        # If no error, still clean up by unassigning runtime
        print("Search phase completed successfully.") # Removed "Unassigning runtime." as it's commented out.
        """try:
            runtime.unassign()
        except Exception:
            pass"""

 24%|██▍       | 41.3M/170M [00:04<00:14, 9.22MB/s]
ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-4-37bf36b54665>", line 26, in <cell line: 0>
    search_instance = DEITI(
                      ^^^^^^
  File "<ipython-input-2-4ea7057aa7af>", line 570, in __init__
    full = torchvision.datasets.CIFAR10('./data', train=True, download=True, transform=tf_train)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torchvision/datasets/cifar.py", line 66, in __init__
    self.download()
  File "/usr/local/lib/python3.11/dist-packages/torchvision/datasets/cifar.py", line 139, in download
    download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
  File "/usr/local/lib/python3.11/dist-packages/torchvision/datasets/utils.py", line 391, in download_a

TypeError: object of type 'NoneType' has no len()

In [None]:
# Cell 4: ENTRY POINT - FINAL TRAINING
#!/usr/bin/env python
# coding: utf-8
"""
DEITI — Final Training Phase
Loads a saved architecture recipe and trains it from scratch.
"""

if __name__ == "__main__":
    # Assuming DEITI class is defined in Cell 1, and Google Drive is mounted via Cell 2.

    # Path to the saved architecture recipe (must match the base path used in the search phase).
    # The .train_final() method will expect to load a "_recipe.json" file based on this.
    model_architecture_base_path = '/content/drive/My Drive/deiti_final_architecture' # Note: No .pth or .json here

    # For clarity, construct the expected recipe file path for the print statement
    expected_recipe_file_path = model_architecture_base_path + "_recipe.json"
    print(f"Attempting to load architecture recipe from: {expected_recipe_file_path}")
    print("Ensure Google Drive is mounted and the file exists.")

    try:
        # Initialize DEITI for the final training phase
        final_trainer_instance = DEITI(
            batch_size=128,         # Batch size for final training (can be adjusted)
            pre_epochs=1,           # Placeholder, not used by train_final's core logic
            post_epochs=1,          # Placeholder
            batches_per_epoch=None  # Use full epochs for final training
        )

        # Call the train_final method, passing the base path
        custom_final_lr = 8e-4 # For example, 0.00005
        print(f"Calling train_final with custom final_lr: {custom_final_lr}")

        final_trainer_instance.train_final(
            epochs=300,
            model_path=model_architecture_base_path,
            final_lr=custom_final_lr  # <--- Set your desired final learning rate here
        )
        print("Final training phase completed.")
    except FileNotFoundError:
        # The FileNotFoundError in train_final will now refer to the _recipe.json file
        print(f"ERROR: Architecture recipe file not found (expected around {expected_recipe_file_path}).")
        print("Please ensure the search phase ran successfully and saved the model recipe,")
        print("and that your Google Drive is correctly mounted.")
    except Exception as e:
        print(f"An error occurred during the final training phase: {e}")
        # Optionally, re-raise the exception to see the full traceback
        raise
    # No runtime.unassign() here by default, to allow inspection of final results.

Attempting to load architecture recipe from: /content/drive/My Drive/deiti_final_architecture_recipe.json
Ensure Google Drive is mounted and the file exists.


  self.scaler_tgt  = GradScaler()
  self.scaler_meta = GradScaler()


Calling train_final with custom final_lr: 0.0008

=== FINAL TRAINING STAGE ===
Attempting to load architecture recipe from: /content/drive/My Drive/deiti_final_architecture
Final training LR: 8.0e-04, AdamW_eps: 1.0e-07, GradClipValue: None
Building TargetCNN from recipe...
Architecture built successfully from recipe.

--- Loaded Final Architecture from Recipe ---
TargetCNN(
  (stages): ModuleList(
    (0): Sequential(
      (0): Conv2d(3, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (2): SEBlock(
        (avgpool): AdaptiveAvgPool2d(output_size=1)
        (fc): Sequential(
          (0): Linear(in_features=576, out_features=36, bias=False)
          (1): ReLU(inplace=True)
          (2): Linear(in_features=36, out_features=576, bias=False)
          (3): Sigmoid()
        )
      )
      (3): BottleneckBlock(
        (conv1): Conv2d(576, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)
 