In [None]:
import os, math, random, time, json, gc, copy
from dataclasses import dataclass, asdict
from typing import List, Dict, Any, Tuple
from collections import defaultdict  

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit
from sklearn.metrics import f1_score
from tqdm import tqdm


SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', DEVICE)


BACKBONE_NAME   = 'cardiffnlp/twitter-roberta-base-sentiment-latest'
MAX_LEN         = 128
BATCH           = 32
NUM_WORKERS     = 0
FREEZE_BACKBONE = True     
GRAD_ACC        = 1
MIXED_PREC      = True


LF_SUBSET_PER_TASK = 14000
LF_EPOCHS          = 5
LF_VAL_FRAC        = 0.1


FINAL_EPOCHS = 5
FINAL_LR     = 2e-5
WEIGHT_DECAY = 0.01

# Evolutionary search hyperparams
POP_SIZE   = 10
TOURN_SIZE = 5
MUT_PRO = 1
MUT_AFTER_TRANSFER = 1  
CHILDREN_PER_TRANSFER = 12
TOPK_REUSE = 4
UNIT       = 1.1
ALPHA      = 0.1
GENERATIONS = 10

# Tasks
TASKS: List[Dict[str, Any]] = [
    {
        'name': 'twitter2',
        'csv_path': 'training.1600000.processed.noemoticon.csv',  
        'format': 'sent140_raw',
        'text_col': 'message',
        'label_col': 'label',     
        'num_classes': 2,
        'train_frac': 0.20,       
        'lf_subset': 12000,
    },
    {
        'name': 'yt_taskA',
        'csv_path': 'taskA_youtube_raw.csv',  
        'text_col': 'text',
        'label_col': 'label',        
        'num_classes': 3,
        'train_frac': 1,
        'lf_subset': 12000,
    },
    {
        'name': 'yt_taskB',
        'csv_path': 'taskB_youtube_raw.csv',  
        'text_col': 'text',
        'label_col': 'label',
        'num_classes': 3,
        'train_frac': 1,
        'lf_subset': 12000,
    },

    {
        'name': 'twitter_research',
        'csv_path': 'twitter_sentiment_data_ResearchPaper.csv',
        'text_col': 'message',      
        'label_col': 'label',       
        'num_classes': 4,           
        'train_frac': 1,
        'lf_subset': 12000,
    },
]



TOKENIZER = AutoTokenizer.from_pretrained(BACKBONE_NAME, normalization=True, use_fast=False)

class Backbone(nn.Module):
    def __init__(self, name: str):
        super().__init__()
        self.model = AutoModel.from_pretrained(name)
        self.hidden = self.model.config.hidden_size
    def forward(self, input_ids, attention_mask):
        out = self.model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        return out.last_hidden_state  # [B, T, H]


class TextDataset(Dataset):
    def __init__(self, texts: List[str], labels: List[int], tokenizer: AutoTokenizer, max_len: int):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len
    def __len__(self):
        return len(self.texts)
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        enc = self.tokenizer(text, truncation=True, padding='max_length', max_length=self.max_len)
        item = {k: torch.tensor(v, dtype=torch.long) for k, v in enc.items()}
        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item

# Head search space
KERNEL_CHOICES   = [1, 3, 5, 7, 9, 11, 13, 15]
FILTER_CHOICES   = [128, 192, 256, 320, 384]
LAYERS_CHOICES   = [1, 2, 3, 4]
BRANCH_CHOICES   = [1, 2, 3, 4, 5]
DILATION_CHOICES = [1, 2, 4, 8, 16]

DROPOUT_CHOICES  = [0.1, 0.2, 0.3, 0.5]


POOLING_CHOICES  = ['max', 'avg', 'attn', 'gem', 'kmax']
KMAX_CHOICES     = [1, 2, 3, 4]
GEM_P_CHOICES    = [2.0, 3.0, 4.0]


ACT_CHOICES      = ['gelu', 'relu', 'silu', 'mish']
NORM_CHOICES     = ['bn', 'none', 'gn8', 'ln']


SEP_CHOICES      = [False, True]
GROUPS_CHOICES   = [1, 2, 4, 8]


SE_RATIO_CHOICES = [0, 4, 8, 16]   


RESIDUAL_CHOICES = [False, True]


HEADSIZE_CAP = 6000   
MIN_F_PER    = 64     

def _feasible(filters: int, branches: int) -> bool:
    return (filters * branches) <= HEADSIZE_CAP and (filters // max(1, branches)) >= MIN_F_PER

SKIP_STATS = defaultdict(int)  

def snapshot_skip_stats():
    return dict(SKIP_STATS)

@dataclass(frozen=True)
class HeadArch:
    layers: int
    filters: int                     
    branches: int                    
    kernels: tuple                   
    dropout: float
    pooling: str
    
    act: str = 'gelu'
    norm: str = 'bn'
    dilations: tuple = ()
    
    kmax_k: int = 1
    gem_p: float = 3.0
    
    sep: bool = False
    groups: int = 1
    
    se_ratio: int = 0
    
    residual: bool = False

    def as_tuple(self):
        
        dils = tuple(self.dilations) if self.dilations else tuple([1]*self.branches)
        return (
            self.layers, self.filters, self.branches, tuple(self.kernels),
            round(self.dropout, 2), self.pooling, self.act, self.norm, dils,
            int(self.kmax_k), float(self.gem_p), bool(self.sep), int(self.groups),
            int(self.se_ratio), bool(self.residual)
        )

    @staticmethod
    def random():
        for attempts in range(1, 101):
            b  = random.choice(BRANCH_CHOICES)
            ks = tuple(sorted(random.sample(KERNEL_CHOICES, k=b)))
            ds = tuple(random.choice(DILATION_CHOICES) for _ in range(b))
            f  = random.choice(FILTER_CHOICES)
            arch = HeadArch(
                layers=random.choice(LAYERS_CHOICES),
                filters=f,
                branches=b,
                kernels=ks,
                dropout=random.choice(DROPOUT_CHOICES),
                pooling=random.choice(POOLING_CHOICES),
                act='gelu',
                norm='bn',
                dilations=ds,
                kmax_k=random.choice(KMAX_CHOICES),
                gem_p=random.choice(GEM_P_CHOICES),
                sep=random.choice(SEP_CHOICES),
                groups=random.choice(GROUPS_CHOICES),
                se_ratio=random.choice(SE_RATIO_CHOICES),
                residual=random.choice(RESIDUAL_CHOICES),
            )
            if _feasible(arch.filters, arch.branches):
                SKIP_STATS['random_skips'] += (attempts - 1)
                return arch
        SKIP_STATS['random_skips'] += 100
        
        return HeadArch(layers=1, filters=128, branches=1, kernels=(3,), dropout=0.1,
                        pooling='max', act='gelu', norm='bn', dilations=(1,))

    def mutate(self):
        fields = [
            'layers', 'filters', 'branches', 'kernels', 'dropout', 'pooling',
            'act', 'norm', 'dilations', 'kmax_k', 'gem_p', 'sep', 'groups',
            'se_ratio', 'residual'
        ]
        for attempts in range(1, 51):
            f = random.choice(fields)
            d = asdict(self)
            if f == 'layers':
                d['layers'] = random.choice([c for c in LAYERS_CHOICES if c != self.layers])
            elif f == 'filters':
                d['filters'] = random.choice([c for c in FILTER_CHOICES if c != self.filters])
            elif f == 'branches':
                new_b = random.choice([c for c in BRANCH_CHOICES if c != self.branches])
                d['branches'] = new_b
                d['kernels'] = tuple(sorted(random.sample(KERNEL_CHOICES, k=new_b)))
                d['dilations'] = tuple(random.choice(DILATION_CHOICES) for _ in range(new_b))
            elif f == 'kernels':
                d['kernels'] = tuple(sorted(random.sample(KERNEL_CHOICES, k=self.branches)))
                if len(self.dilations) != self.branches:
                    d['dilations'] = tuple(random.choice(DILATION_CHOICES) for _ in range(self.branches))
            elif f == 'dropout':
                d['dropout'] = random.choice([c for c in DROPOUT_CHOICES if c != self.dropout])
            elif f == 'pooling':
                d['pooling'] = random.choice([c for c in POOLING_CHOICES if c != self.pooling])
            elif f == 'act':
                d['act'] = random.choice([a for a in ACT_CHOICES if a != self.act])
            elif f == 'norm':
                d['norm'] = random.choice([n for n in NORM_CHOICES if n != self.norm])
            elif f == 'dilations':
                d['dilations'] = tuple(random.choice(DILATION_CHOICES) for _ in range(self.branches))
            elif f == 'kmax_k':
                d['kmax_k'] = random.choice([v for v in KMAX_CHOICES if v != self.kmax_k])
            elif f == 'gem_p':
                d['gem_p'] = random.choice([v for v in GEM_P_CHOICES if v != self.gem_p])
            elif f == 'sep':
                d['sep'] = not self.sep
            elif f == 'groups':
                d['groups'] = random.choice([g for g in GROUPS_CHOICES if g != self.groups])
            elif f == 'se_ratio':
                d['se_ratio'] = random.choice([r for r in SE_RATIO_CHOICES if r != self.se_ratio])
            else:  # 'residual'
                d['residual'] = not self.residual

            cand = HeadArch(**d)
            if _feasible(cand.filters, cand.branches):
                SKIP_STATS['mutate_skips'] += (attempts - 1)
                return cand
        SKIP_STATS['mutate_skips'] += 50
        return self  


def _gn_groups(C: int) -> int:
    for g in (8, 4, 2, 1):
        if C % g == 0:
            return g
    return 1

def get_act(name: str):
    if name == 'relu': return nn.ReLU()
    if name == 'silu': return nn.SiLU()
    if name == 'mish': return nn.Mish()
    return nn.GELU()  

def get_norm(name: str, C: int):
    if name == 'bn':   return nn.BatchNorm1d(C)
    if name == 'gn8':  return nn.GroupNorm(_gn_groups(C), C)
    if name == 'ln':   return nn.GroupNorm(1, C)  
    return nn.Identity()


class GeM(nn.Module):
    def __init__(self, p=3.0, eps=1e-6):
        super().__init__()
        self.p = nn.Parameter(torch.tensor(float(p)))
        self.eps = eps
    def forward(self, x):                 
        x = x.clamp(min=self.eps).pow(self.p)
        x = x.mean(dim=2)
        return x.pow(1.0/self.p)


class SE1d(nn.Module):
    def __init__(self, C: int, r: int):
        super().__init__()
        m = max(1, C // r)
        self.fc1 = nn.Linear(C, m)
        self.fc2 = nn.Linear(m, C)
    def forward(self, x):                 
        s = x.mean(dim=2)                 
        s = F.silu(self.fc1(s))
        s = torch.sigmoid(self.fc2(s)).unsqueeze(-1)
        return x * s


class ResidualBranch(nn.Module):
    def __init__(self, core: nn.Module, skip: nn.Module):
        super().__init__()
        self.core = core
        self.skip = skip  
    def forward(self, x):
        y = self.core(x)
        s = x if isinstance(self.skip, nn.Identity) else self.skip(x)
        return y + s


class CNNHead(nn.Module):
    def __init__(self, hidden: int, num_classes: int, arch: HeadArch):
        super().__init__()
        self.arch = arch
        b = arch.branches
        f_total = arch.filters
        f_per = max(1, f_total // b)

        
        self.branches = nn.ModuleList()
        for bi, k in enumerate(arch.kernels):
            layers = []
            in_ch = hidden
            
            d = arch.dilations[bi] if (hasattr(arch, 'dilations') and len(arch.dilations) > bi) else 1
            pad = ((k - 1) * d) // 2

            
            use_res = getattr(arch, 'residual', False)
            skip = None  

            for li in range(arch.layers):
                
                if getattr(arch, 'sep', False):
                    
                    layers.append(nn.Conv1d(in_ch, in_ch, kernel_size=k, padding=pad, dilation=d, groups=in_ch))
                    layers.append(get_act(arch.act))
                    layers.append(get_norm(arch.norm, in_ch))
                    layers.append(nn.Conv1d(in_ch, f_per, kernel_size=1))
                else:
                    
                    g_ok = 1
                    for g in sorted(GROUPS_CHOICES, reverse=True):
                        if in_ch % g == 0 and f_per % g == 0:
                            g_ok = g
                            break
                    layers.append(nn.Conv1d(in_ch, f_per, kernel_size=k, padding=pad, dilation=d, groups=g_ok))
                layers.append(get_act(arch.act))
                layers.append(get_norm(arch.norm, f_per))

                
                if getattr(arch, 'se_ratio', 0) > 0:
                    layers.append(SE1d(f_per, arch.se_ratio))

                in_ch = f_per
                if use_res and skip is None:
                    skip = nn.Identity() if hidden == f_per else nn.Conv1d(hidden, f_per, 1)

            branch_core = nn.Sequential(*layers)
            if use_res:
                self.branches.append(ResidualBranch(branch_core, skip if skip is not None else nn.Identity()))
            else:
                self.branches.append(branch_core)

        concat_ch = f_per * b
        self.proj = None
        if concat_ch != f_total:
            self.proj = nn.Conv1d(concat_ch, f_total, kernel_size=1)

        self.dropout = nn.Dropout(arch.dropout)
        self.pooling = arch.pooling
        if self.pooling == 'attn':
            self.attn = nn.Linear(f_total, 1)
        elif self.pooling == 'gem':
            self.gem = GeM(p=arch.gem_p)

        self.out = nn.Linear(f_total, num_classes)

    def forward(self, last_hidden: torch.Tensor, attn_mask: torch.Tensor):
        
        x = last_hidden.transpose(1, 2).contiguous()  

        
        with torch.amp.autocast('cuda', enabled=False):
            x = x.float()

            feats_list = [branch(x) for branch in self.branches]   
            feats = torch.cat(feats_list, dim=1)                   
            if self.proj is not None:
                feats = self.proj(feats)                           

            if self.pooling == 'max':
                x_out = torch.amax(feats, dim=2)                   
            elif self.pooling == 'avg':
                x_out = torch.mean(feats, dim=2)                   
            elif self.pooling == 'gem':
                x_out = self.gem(feats)                            
            elif self.pooling == 'kmax':
                mask = (attn_mask == 0)[:, None, :]                
                feats_m = feats.masked_fill(mask, torch.finfo(feats.dtype).min)
                k = max(1, int(getattr(self.arch, 'kmax_k', 1)))
                vals, _ = torch.topk(feats_m, k, dim=2)            
                x_out = vals.mean(dim=2)                           
            else:  # 'attn'
                feats_T = feats.transpose(1, 2)                    
                logits = self.attn(feats_T)                        
                mask = (attn_mask == 0).unsqueeze(-1)              
                logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
                w = torch.softmax(logits, dim=1)                   
                x_out = torch.sum(feats_T * w, dim=1)              

            x_out = self.dropout(x_out)
            return self.out(x_out)

class SentimentModel(nn.Module):
    def __init__(self, backbone: Backbone, num_classes: int, head_arch: HeadArch):
        super().__init__()
        self.backbone = backbone
        self.head = CNNHead(self.backbone.hidden, num_classes, head_arch)
        self.backbone_frozen = False
    def freeze_backbone(self, freeze: bool = True):
        for p in self.backbone.parameters():
            p.requires_grad = not freeze
        self.backbone_frozen = freeze
    def forward(self, input_ids, attention_mask):
        if self.backbone_frozen:
            with torch.no_grad():
                last_hidden = self.backbone(input_ids, attention_mask)
            last_hidden = last_hidden.detach()  
        else:
            last_hidden = self.backbone(input_ids, attention_mask)
        return self.head(last_hidden, attention_mask)


CROSSOVER_PRO   = 0.30
MUT_AFTER_CROSS = 0.40


In [None]:
#warm start
import json as _json


WARMSTART_WINNERS_PATH = globals().get('WARMSTART_WINNERS_PATH', None)
WARMSTART_VARIANTS = globals().get('WARMSTART_VARIANTS', 8)
WARMSTART_KEEP_WINNER = globals().get('WARMSTART_KEEP_WINNER', True)

def _arch_to_dict(a: HeadArch):
    return {'layers': a.layers, 'filters': a.filters, 'branches': a.branches,
            'kernels': list(a.kernels), 'dropout': a.dropout, 'pooling': a.pooling}

def _arch_from_dict(d):
    return HeadArch(d['layers'], d['filters'], d['branches'], tuple(d['kernels']), d['dropout'], d['pooling'])

def save_winners(best_per_task, path='winners_last.json'):
    data = {name: _arch_to_dict(arch) for name, arch in best_per_task.items()}
    with open(path, 'w', encoding='utf-8') as f:
        _json.dump(data, f, ensure_ascii=False, indent=2)
    print(f"Saved winners to {path}")

def load_winners(path='winners_last.json'):
    with open(path, 'r', encoding='utf-8') as f:
        data = _json.load(f)
    return {name: _arch_from_dict(d) for name, d in data.items()}


In [None]:
import pandas as pd
from typing import List, Dict, Any, Tuple
from sklearn.model_selection import train_test_split

def load_task_dataframe(task: Dict[str, Any]) -> pd.DataFrame:
    fmt = task.get('format', None)

    
    if fmt == 'sent140_raw':
        
        df = pd.read_csv(
            task['csv_path'],
            encoding="latin-1",
            header=None,
            names=["polarity", "id", "date", "query", "user", "message"]
        )
        df = df[df["polarity"].isin([0, 4])].copy()
        df["label"] = (df["polarity"] == 4).astype(int)  
        df = df[df["message"].notnull()].reset_index(drop=True)
        return df[["message", "label"]]

    
    df = pd.read_csv(task['csv_path'], encoding='utf-8', engine='python')

    
    if task.get('name') == 'twitter_research':
        
        if 'sentiment' in df.columns and task.get('label_col', 'label') not in df.columns:
            label_map = {1: 0, -1: 1, 0: 2, 2: 3}
            df[task['label_col']] = df['sentiment'].map(label_map).astype(int)

    if task.get('name') == 'youtube3':
        
        lbl = task.get('label_col', 'label')
        if df[lbl].dtype == object:
            df[lbl] = (
                df[lbl].astype(str).str.lower()
                  .map({'negative': 0, 'neutral': 1, 'positive': 2})
                  .astype(int)
            )

    
    assert task['text_col'] in df.columns and task['label_col'] in df.columns, \
        f"Missing columns in {task['csv_path']}"

    return df


def make_lf_splits(df: pd.DataFrame, text_col: str, label_col: str,
                   subset_n: int, val_frac: float) -> Tuple[List[str], List[int], List[str], List[int]]:
    if subset_n is not None and subset_n < len(df):
        df = df.sample(subset_n, random_state=SEED)
    train_df, val_df = train_test_split(df, test_size=val_frac,
                                        stratify=df[label_col], random_state=SEED)
    Xtr = train_df[text_col].astype(str).tolist()
    ytr = train_df[label_col].astype(int).tolist()
    Xva = val_df[text_col].astype(str).tolist()
    yva = val_df[label_col].astype(int).tolist()
    return Xtr, ytr, Xva, yva

def build_loader(texts, labels, batch=None, shuffle=False):
    
    if batch is None:
        batch = BATCH
    ds = TextDataset(texts, labels, TOKENIZER, MAX_LEN)
    return DataLoader(ds, batch_size=batch, shuffle=shuffle,
                      num_workers=NUM_WORKERS, pin_memory=True)


In [None]:
from dataclasses import dataclass
from typing import List
import numpy as np, time, math, hashlib
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.amp import autocast, GradScaler  
from sklearn.metrics import f1_score




def nas_reset_vram_peak():
    try:
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()
    except Exception:
        pass

def _vram_peak_gb():
    try:
        if torch.cuda.is_available():
            return float(torch.cuda.max_memory_reserved() / (1024**3))
    except Exception:
        pass
    return float("nan")

@dataclass
class ChildEval:
    gen: int
    task: str
    child_id: str
    op: str
    arch_hash: str
    head_desc: str
    params_m: float
    batch_used: int
    vram_peak_gb: float
    train_time_s: float
    val_acc: float
    val_loss: float
    lr_head: float = None
    lr_backbone: float = None
    feasible: bool = True
    infeasible_reason: str = ""
    arch_vec: object = None  

class GenLogger:
    def __init__(self, topk=3, objective_key="val_acc", maximize=True, stop_hint_k=4):
        from collections import defaultdict
        self.children = []
        self.gen_to_idx = defaultdict(list)
        self.best_so_far = None                
        self.best_overall_by_hash = None       
        self.objective_key = objective_key
        self.maximize = maximize
        self.clock = time.time()
        self.no_improve_streak = 0
        self.archive_hashes = set()
        self.hash_to_vec = {}
        self.topk = int(topk)
        self.stop_hint_k = int(stop_hint_k)

    def log_child(self, **kw):
        
        ce = ChildEval(
            gen=kw.get("gen"),
            task=str(kw.get("task")),
            child_id=str(kw.get("child_id")),
            op=str(kw.get("op")),
            arch_hash=str(kw.get("arch_hash")),
            head_desc=str(kw.get("head_desc") or ""),
            params_m=float(kw.get("params_m", 0.0)),
            batch_used=int(kw.get("batch_used", -1)),
            vram_peak_gb=_vram_peak_gb(),
            train_time_s=float(kw.get("train_time_s", float("nan"))),
            val_acc=float(kw.get("val_acc", float("nan"))),
            val_loss=float(kw.get("val_loss", float("nan"))),
            lr_head=(None if kw.get("lr_head") is None else float(kw.get("lr_head"))),
            lr_backbone=(None if kw.get("lr_backbone") is None else float(kw.get("lr_backbone"))),
            feasible=bool(kw.get("feasible", True)),
            infeasible_reason=str(kw.get("infeasible_reason", "")),
            arch_vec=(None if kw.get("arch_vec") is None else np.asarray(kw.get("arch_vec"), dtype=float)),
        )
        idx = len(self.children)
        self.children.append(ce)
        self.gen_to_idx[ce.gen].append(idx)
        if ce.arch_vec is not None and ce.arch_hash not in self.hash_to_vec:
            self.hash_to_vec[ce.arch_hash] = ce.arch_vec

    def end_gen_summary(self, gen:int, infeasible_counts=None, batch_usage=None, ooms_dict=None,
                        survivors_ops_counter=None, new_best_ops_counter=None):
        from collections import defaultdict
        infeasible_counts = infeasible_counts or {}
        batch_usage = batch_usage or {}
        ooms_dict = ooms_dict or {}

        idxs = self.gen_to_idx.get(gen, [])
        feas = [self.children[i] for i in idxs if self.children[i].feasible]
        infeas = [self.children[i] for i in idxs if not self.children[i].feasible]

        
        sc = np.array([c.val_acc for c in feas], dtype=float)
        acc_mean = float(np.nanmean(sc)) if sc.size else float("nan")
        acc_std  = float(np.nanstd(sc))  if sc.size else float("nan")
        p25, p50, p75 = (np.percentile(sc, [25,50,75]).tolist() if sc.size else [float("nan")]*3)

        feas_sorted = sorted(feas, key=lambda c: getattr(c, self.objective_key), reverse=self.maximize)
        top = feas_sorted[:min(self.topk, len(feas_sorted))]

        improved = False
        if top:
            s0 = getattr(top[0], self.objective_key)
            if self.best_so_far is None:
                self.best_so_far = (s0, top[0], gen)
                improved = True
            else:
                prev = self.best_so_far[0]
                if (s0 > prev if self.maximize else s0 < prev):
                    self.best_so_far = (s0, top[0], gen)
                    improved = True
        self.no_improve_streak = 0 if improved else (self.no_improve_streak + 1)

        
        from collections import defaultdict as _dd
        by_hash = _dd(list)
        for c in feas:
            by_hash[c.arch_hash].append(c.val_acc)
        best_hash, best_hash_mean = None, None
        if by_hash:
            for h, lst in by_hash.items():
                m = float(np.mean(lst)) if lst else float("nan")
                if best_hash_mean is None or (m > best_hash_mean if self.maximize else m < best_hash_mean):
                    best_hash, best_hash_mean = h, m
            if best_hash is not None:
                if self.best_overall_by_hash is None:
                    self.best_overall_by_hash = (best_hash_mean, best_hash, gen)
                else:
                    prev_m, _, _ = self.best_overall_by_hash
                    if (best_hash_mean > prev_m if self.maximize else best_hash_mean < prev_m):
                        self.best_overall_by_hash = (best_hash_mean, best_hash, gen)

        
        new_hashes = {self.children[i].arch_hash for i in idxs}
        seen_before = sum(1 for h in new_hashes if h in self.archive_hashes)
        dup_rate = (seen_before / max(1, len(new_hashes))) * 100.0
        self.archive_hashes |= new_hashes

        novelty_med = float("nan")
        if self.hash_to_vec and len(self.hash_to_vec) > 1:
            archive = {h:v for h,v in self.hash_to_vec.items()}
            dist_list = []
            for h in new_hashes:
                v = archive.get(h, None)
                if v is None: continue
                nn = []
                for hh, vv in archive.items():
                    if hh == h or vv.shape != v.shape: continue
                    nn.append(np.sum(np.abs(v - vv)))
                if nn: dist_list.append(min(nn))
            if dist_list: novelty_med = float(np.median(dist_list))

        
        vram_vals = [c.vram_peak_gb for c in feas if not math.isnan(c.vram_peak_gb)]
        t_vals    = [c.train_time_s  for c in feas if not math.isnan(c.train_time_s)]
        vmin = min(vram_vals) if vram_vals else float("nan")
        vmed = float(np.median(vram_vals)) if vram_vals else float("nan")
        vmax = max(vram_vals) if vram_vals else float("nan")
        tmed = float(np.median(t_vals)) if t_vals else float("nan")
        tp90 = float(np.percentile(t_vals, 90)) if t_vals else float("nan")

        
        from collections import defaultdict as _d2
        op_counts = _d2(int)
        for c in feas: op_counts[c.op] += 1
        reason_counts = _d2(int)
        for c in infeas: reason_counts[c.infeasible_reason] += 1
        top_reasons = sorted(reason_counts.items(), key=lambda kv: kv[1], reverse=True)[:3]

        elapsed = time.time() - self.clock
        self.clock = time.time()

        
        print(f"[Gen {gen}] this_gen={len(idxs)}  feasible={len(feas)}  infeasible={len(infeas)}  "
              f"dup_rate={dup_rate:.1f}%  novelty_med(L1)={novelty_med if not math.isnan(novelty_med) else 'N/A'}  "
              f"no_improve_streak={self.no_improve_streak}  elapsed={elapsed:.1f}s")

        accept_rate = (len(feas) / max(1, len(idxs))) * 100.0
        def _fmt(x): return f"{x:.4f}" if not math.isnan(x) else "nan"
        print(f"  Scores: accept={accept_rate:.1f}%  mean={_fmt(acc_mean)} ± {_fmt(acc_std)} | p25/50/75 = {_fmt(p25)}/{_fmt(p50)}/{_fmt(p75)}")

        if top:
            print(f"  Top-{self.topk} (by val_acc)")
            for r, ce in enumerate(top, 1):
                print(f"    #{r} {ce.val_acc:.4f}  task={ce.task}  op={ce.op}  head={ce.head_desc}  "
                      f"params={ce.params_m:.2f}M  batch={ce.batch_used}  vram={ce.vram_peak_gb:.2f}GB  "
                      f"time={ce.train_time_s:.1f}s  id={ce.child_id[:8]}  hash={ce.arch_hash[:8]}")

        if self.best_so_far:
            best_s, best_ce, best_gen = self.best_so_far
            print(f"  Best-so-far child: {best_s:.4f} (task={best_ce.task}, head={best_ce.head_desc}, op={best_ce.op}, gen={best_gen}, id={best_ce.child_id[:8]})"
                  + ("  improved" if improved else ""))

        if best_hash is not None:
            tag = ""
            if self.best_overall_by_hash and self.best_overall_by_hash[1] == best_hash and self.best_overall_by_hash[2] == gen:
                tag = "  improved overall-by-arch"
            print(f"  Best arch (this gen, mean across tasks): {best_hash[:8]}  mean_acc={best_hash_mean:.4f}{tag}")

        if survivors_ops_counter:
            print("  Survivors by operator:", ", ".join(f"{k}={v}" for k,v in survivors_ops_counter.items()))
        if new_best_ops_counter:
            print("  New-best origins:", ", ".join(f"{k}={v}" for k,v in new_best_ops_counter.items()))

        if op_counts:
            print("  Ops (feasible evals):", ", ".join(f"{k}={v}" for k,v in op_counts.items()))
        if infeasible_counts:
            inf_str = ", ".join(f"{k}={v}" for k,v in infeasible_counts.items() if v)
            if inf_str: print("  Infeasible resamples:", inf_str)
        if top_reasons:
            print("  Top infeasible reasons:", ", ".join(f"{k}={v}" for k,v in top_reasons))

        if batch_usage: print("  BATCH_USAGE:", dict(sorted(batch_usage.items())))
        if ooms_dict:   print("  OOMs:", ooms_dict)

        print(f"  VRAM reserved (GB) min/med/max = "
              f"{('N/A' if math.isnan(vmin) else f'{vmin:.2f}')}/"
              f"{('N/A' if math.isnan(vmed) else f'{vmed:.2f}')}/"
              f"{('N/A' if math.isnan(vmax) else f'{vmax:.2f}')}  |  "
              f"time/child med/p90 = "
              f"{('N/A' if math.isnan(tmed) else f'{tmed:.1f}')}/"
              f"{('N/A' if math.isnan(tp90) else f'{tp90:.1f}')}s")

        print(f"  Archive: {len(self.archive_hashes)} unique architectures  |  "
              f"Champion age: {0 if self.best_so_far is None else (gen - self.best_so_far[2])} gens  "
              f"| Stop-hint when no_improve_streak ≥ {self.stop_hint_k}")


def head_summary(a) -> str:
    try:
        ks = getattr(a, "kernels", (getattr(a, "kernel", 3),))
        ks = tuple(int(k) for k in ks)
        return (f"mkCNN[k={','.join(map(str, ks))}]+layers={getattr(a,'layers', '?')}"
                f"+filters={getattr(a,'filters','?')}+branches={getattr(a,'branches','?')}"
                f"+{getattr(a,'pooling','?')}+drop={getattr(a,'dropout','?')}")
    except Exception:
        return str(a)

def arch_hash(a) -> str:
    try:
        if hasattr(a, "as_tuple"):
            t = a.as_tuple()
        else:
            t = (getattr(a,'layers',None), getattr(a,'filters',None), getattr(a,'branches',None),
                 tuple(getattr(a,'kernels',(getattr(a,'kernel',3),))), getattr(a,'dropout',None), getattr(a,'pooling',None))
        return hashlib.md5(str(t).encode("utf-8")).hexdigest()
    except Exception:
        return hashlib.md5(str(a).encode("utf-8")).hexdigest()



@dataclass
class EvalResult:
    acc: float
    f1: float
    loss: float

@torch.no_grad()
def quick_val(model: nn.Module, loader: DataLoader) -> EvalResult:
    model.eval()
    total = 0
    correct = 0
    all_logits = []
    all_labels = []
    total_loss = 0.0
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():                        
        for batch in loader:
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)
            logits = model(input_ids, attention_mask)
            loss = criterion(logits, labels)
            total_loss += float(loss.detach().cpu())
            preds = logits.argmax(dim=-1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            all_logits.append(logits.detach().cpu())
            all_labels.append(labels.detach().cpu())
    if total == 0:
        return EvalResult(0.0, 0.0, 0.0)
    logits = torch.cat(all_logits, dim=0).numpy()
    labels = torch.cat(all_labels, dim=0).numpy()
    acc = correct / total
    f1 = f1_score(labels, logits.argmax(axis=1), average='macro')
    return EvalResult(acc=acc, f1=f1, loss=total_loss / max(1, len(all_labels)))

def evaluate_arch_on_task(arch: HeadArch, ts, epochs: int = LF_EPOCHS, freeze: bool = FREEZE_BACKBONE) -> float:
    """Low-fidelity train-once evaluator for an architecture on a given task state.
    Returns validation accuracy.
    """
    
    train_loader = build_loader(ts.Xtr, ts.ytr, shuffle=True)
    val_loader   = build_loader(ts.Xva, ts.yva, shuffle=False)

    
    bb = Backbone(BACKBONE_NAME)
    model = SentimentModel(bb, ts.num_classes, arch)
    model.freeze_backbone(freeze)
    model.to(DEVICE)

    
    nas_reset_vram_peak()

    
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-4, weight_decay=0.01)
    scaler = GradScaler(enabled=MIXED_PREC)
    criterion = nn.CrossEntropyLoss()

    for _ in range(max(1, epochs)):
        model.train()
        for batch in train_loader:
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)
            optimizer.zero_grad(set_to_none=True)
            with autocast('cuda', enabled=MIXED_PREC):
                logits = model(input_ids, attention_mask)
                loss = criterion(logits, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

    
    res = quick_val(model, val_loader)

    
    return float(res.acc)


def encode_arch(a: HeadArch) -> List[int]:
    kernel_idx = max([KERNEL_CHOICES.index(k) for k in a.kernels]) if len(a.kernels) else 0
    return [
        LAYERS_CHOICES.index(a.layers),
        FILTER_CHOICES.index(a.filters),
        BRANCH_CHOICES.index(a.branches),
        kernel_idx,
        DROPOUT_CHOICES.index(a.dropout),
        POOLING_CHOICES.index(a.pooling),
    ]


def encode_arch_onehot(a: HeadArch) -> list:
    v = []
    def oh(choices, sel):
        z = [0]*len(choices)
        z[choices.index(sel)] = 1
        return z
    v += oh(LAYERS_CHOICES, a.layers)
    v += oh(FILTER_CHOICES, a.filters)
    v += oh(BRANCH_CHOICES, a.branches)
    v += [1 if k in KERNEL_CHOICES and k in a.kernels else 0 for k in KERNEL_CHOICES]
    v += oh(DROPOUT_CHOICES, a.dropout)
    v += oh(POOLING_CHOICES, a.pooling)
    return v

def pop_diversity(pop: List[HeadArch]) -> float:
    if len(pop) < 2:
        return 0.0
    enc = [encode_arch_onehot(a) for a in pop]
    n = len(enc)
    man = 0.0
    for i in range(n):
        for j in range(i+1, n):
            man += sum(abs(enc[i][k] - enc[j][k]) for k in range(len(enc[i])))
    denom = (n*(n-1)/2) * (sum(enc[0]))
    if denom == 0:
        return 0.0
    return 0.25 * (man / denom)

def _smoke_test():
    
    cfg = TASKS[0]
    df = load_task_dataframe(cfg).sample(min(8, 32), random_state=SEED)
    X = df[cfg['text_col']].tolist()
    y = df[cfg['label_col']].astype(int).tolist()
    loader = build_loader(X, y, batch=4, shuffle=False)

    arch = HeadArch(layers=1, filters=128, branches=2, kernels=(3,5), dropout=0.2, pooling='attn')
    model = SentimentModel(Backbone(BACKBONE_NAME), cfg['num_classes'], arch).to(DEVICE)
    model.freeze_backbone(True)  

    batch = next(iter(loader))
    with torch.cuda.amp.autocast(enabled=MIXED_PREC):
        logits = model(batch['input_ids'].to(DEVICE), batch['attention_mask'].to(DEVICE))
    assert logits.shape[0] == len(batch['labels'])
    print("Smoke test OK — logits:", tuple(logits.shape))


In [None]:
from dataclasses import dataclass
from typing import List, Dict, Any, Tuple
import random, math, time, gc
import numpy as np
import torch
import math as _math


def safe_eval(arch, ts, epochs, freeze):
    # Ladder of batch sizes to try during search (override elsewhere if you want)
    candidates = globals().get("BATCH_SEARCH_CANDIDATES", [96, 80, 64, 48, 32, 24, 16, 12, 8])

    globals().setdefault("BATCH_USAGE", {})      
    globals().setdefault("BATCH_OOM_COUNT", {})  

    
    orig = globals().get("BATCH", 16)
    last_err = None

    for b in candidates:
        try:
            globals()["BATCH"] = b
            out = evaluate_arch_on_task(arch, ts, epochs=epochs, freeze=freeze)
            
            globals()["BATCH_USAGE"][b] = globals()["BATCH_USAGE"].get(b, 0) + 1
            globals()["LAST_SUCCESS_BATCH"] = b  
            globals()["BATCH"] = orig
            return out
        except RuntimeError as e:
            s = str(e)
            if ("CUDA out of memory" in s) or ("cudaError" in s) or ("cublas" in s) or ("cuDNN" in s):
                globals()["BATCH_OOM_COUNT"][b] = globals()["BATCH_OOM_COUNT"].get(b, 0) + 1
                last_err = e
                try:
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                except Exception:
                    pass
                
                continue
            else:
                
                globals()["BATCH"] = orig
                raise
    
    globals()["BATCH"] = orig
    if last_err is not None:
        raise last_err
    
    return evaluate_arch_on_task(arch, ts, epochs=epochs, freeze=freeze)


def _onehot(a: 'HeadArch'):
    try:
        k = tuple(getattr(a, "kernels", [getattr(a, "kernel", 3)]))
        br = getattr(a, "branches", 1)
        ops = [("layers", getattr(a, "layers", 2)),
               ("filters", getattr(a, "filters", 64)),
               ("branches", br),
               ("kernels", k),
               ("dropout", getattr(a, "dropout", 0.2)),
               ("pooling", getattr(a, "pooling", "attn"))]
        s = "|".join(f"{k}={v}" for k, v in ops)
        rng = np.random.RandomState(abs(hash(s)) % (2**32))
        vec = rng.rand(64)
        idx = np.argpartition(vec, -8)[-8:]
        out = np.zeros_like(vec)
        out[idx] = 1
        return out
    except Exception:
        return np.zeros(64)

def pop_diversity(pop: List['HeadArch']) -> float:
    if len(pop) < 2:
        return 0.0
    E = [_onehot(a) for a in pop]
    n = len(E)
    acc = 0.0
    denom = (n * (n - 1)) / 2
    for i in range(n):
        for j in range(i + 1, n):
            acc += np.abs(E[i] - E[j]).sum()
    norm = max(np.abs(E[0]).sum(), 1.0)
    return 0.25 * (acc / max(denom, 1)) / norm

def probi(unit, div_i, d_i):
    z = unit * div_i + d_i
    return 1.0 / (1.0 + _math.exp(-z))


@dataclass
class TaskState:
    name: str
    num_classes: int
    Xtr: List[str]
    ytr: List[int]
    Xva: List[str]
    yva: List[int]
    population: List['HeadArch']
    archive_best: List[Tuple['HeadArch', float]]


PRERANK_TOPK = None

def zero_cost_score(arch: 'HeadArch') -> float:
    br = getattr(arch, "branches", 1)
    kernels = getattr(arch, "kernels", [getattr(arch, "kernel", 3)])
    layers = getattr(arch, "layers", 2)
    width  = getattr(arch, "filters", 64)
    variety = len(set(kernels))
    param_proxy = layers * width * br
    score = 1.0 + 0.3 * (1 if br == 2 else (2 if br == 3 else 0)) + 0.2 * variety - 0.0005 * param_proxy
    return float(score)

def _ensure_task_runtime_fields(ts):
    if not hasattr(ts, 'nsf'): ts.nsf = 0
    if not hasattr(ts, 'ntf'): ts.ntf = 0
    if not hasattr(ts, 'd'):   ts.d   = 0.0
    if not hasattr(ts, 'selected_set'): ts.selected_set = set()

def tournament_select(pop, scores, k):
    if not pop:
        raise RuntimeError("Empty population")
    
    cand = random.sample(pop, min(k, len(pop)))
    return max(cand, key=lambda a: scores.get(a, -1e9))

def search_multitask(tasks_cfg: List[Dict[str, Any]]):
    
    global GENLOG
    if 'GENLOG' not in globals():
        GENLOG = GenLogger(topk=4, objective_key="val_acc", maximize=True, stop_hint_k=4)

   
    _snap = globals().get("snapshot_skip_stats", None)
    def _snapshot():
        return _snap() if callable(_snap) else {}

    
    warm_winners: Dict[str, 'HeadArch'] = {}
    try:
        import os, json
        _load_fn = globals().get("load_winners", None)
        ws_path = globals().get("WARMSTART_WINNERS_PATH", None)
        if ws_path and os.path.exists(ws_path):
            if _load_fn is not None:
                warm_winners = _load_fn(ws_path)
            else:
                
                with open(ws_path, "r", encoding="utf-8") as f:
                    raw = json.load(f)
                warm_winners = {
                    k: HeadArch(v["layers"], v["filters"], v["branches"],
                                tuple(v["kernels"]), v["dropout"], v["pooling"])
                    for k, v in raw.items()
                }
            print(f"[WarmStart] Loaded winners: {list(warm_winners.keys())}")
    except Exception as e:
        print(f"[WarmStart] Load skipped: {e}")

    
    task_states: List[TaskState] = []
    prep_t0 = time.time()
    print(f"[Init] config: tasks={len(tasks_cfg)}, POP_SIZE={globals().get('POP_SIZE','?')}, "
          f"LF_EPOCHS={globals().get('LF_EPOCHS','?')}, subset={globals().get('LF_SUBSET_PER_TASK','?')}, "
          f"freeze={globals().get('FREEZE_BACKBONE','?')}, batch={globals().get('BATCH','?')}", flush=True)

    s0 = _snapshot()  

    for cfg in tasks_cfg:
        tprep = time.time()
        print(f"[Init] {cfg['name']}: loading & building LF splits...", flush=True)
        df = load_task_dataframe(cfg)
        Xtr, ytr, Xva, yva = make_lf_splits(
            df, cfg['text_col'], cfg['label_col'], cfg.get('lf_subset', LF_SUBSET_PER_TASK), LF_VAL_FRAC
        )
        print(f"[Init] {cfg['name']}: splits ready | Xtr={len(Xtr)} Xva={len(Xva)} (took {time.time()-tprep:.1f}s)", flush=True)

        
        pop: List['HeadArch'] = []
        used = set()
        winner = warm_winners.get(cfg['name'])
        if winner is not None:
            
            pop.append(winner)
            used.add(winner.as_tuple())
        
        while len(pop) < POP_SIZE:
            a = HeadArch.sample() if hasattr(HeadArch, "sample") else (HeadArch.random() if hasattr(HeadArch, "random") else random_arch())
            tup = a.as_tuple()
            if tup in used:
                continue
            pop.append(a)
            used.add(tup)

        ts = TaskState(
            name=cfg['name'], num_classes=cfg['num_classes'],
            Xtr=Xtr, ytr=ytr, Xva=Xva, yva=yva,
            population=pop,
            archive_best=[],
        )
        _ensure_task_runtime_fields(ts)
        task_states.append(ts)
        print(f"[Init] {cfg['name']}: pop seeded (warm={'yes' if winner is not None else 'no'}, size={len(pop)})", flush=True)

    print(f"[Init] prep-phase total {time.time()-prep_t0:.1f}s")

    s1 = _snapshot()  
    if s1:
        dr = s1.get('random_skips', 0) - s0.get('random_skips', 0)
        dm = s1.get('mutate_skips', 0) - s0.get('mutate_skips', 0)
        print(f"[Init] infeasible re-samples during seeding: random={dr}, mutate={dm} "
              f"(cumulative: random={s1.get('random_skips',0)}, mutate={s1.get('mutate_skips',0)})")

    
    TOURN_SIZE = globals().get("TOURN_SIZE", 3)
    MUT_PRO = globals().get("MUT_PRO", 0.8)
    MUT_AFTER_TRANSFER = globals().get("MUT_AFTER_TRANSFER", 0.9)

    UNIT = globals().get("UNIT", 1.1)
    ALPHA = globals().get("ALPHA", 0.1)

    start_time = time.time()
    for gen in range(1, GENERATIONS + 1):

        
        crossover_pro = 0.30   
        lf_epochs = 1
        children_per_transfer = 12
        topk_reuse = 4

        gen_t0 = time.time()

        
        globals()["BATCH_USAGE"] = {}
        globals()["BATCH_OOM_COUNT"] = {}

        
        counts: Dict[str, int] = {ts.name: 0 for ts in task_states}      
        sec:    Dict[str, float] = {ts.name: 0.0 for ts in task_states}  
        all_children: List['HeadArch'] = []                               

        
        prev_best: Dict[str, float] = {}
        for ts in task_states:
            _ensure_task_runtime_fields(ts)
            ts.nsf = 0
            ts.ntf = 0
            prev_best[ts.name] = max([acc for (a, acc) in ts.archive_best], default=-1e9)

        
        prev_stats = _snapshot()

        
        for idx, ts in enumerate(task_states):
            block_t0 = time.time()

            
            scores: Dict['HeadArch', float] = {}
            cand_pool = ts.population
            if PRERANK_TOPK is not None and PRERANK_TOPK < len(ts.population):
                # take top-K by zero-cost proxy
                ranked = sorted(ts.population, key=zero_cost_score, reverse=True)
                cand_pool = ranked[:PRERANK_TOPK]
            for a in cand_pool:
                if a not in scores:
                    scores[a] = safe_eval(a, ts, epochs=1, freeze=FREEZE_BACKBONE)

            
            sum_ntf_other_prev = sum(getattr(other, "ntf", 0) for j, other in enumerate(task_states) if j != idx)
            ts.d = ALPHA * ts.d + (getattr(ts, "nsf", 0) - sum_ntf_other_prev)
            ts.nsf = 0
            ts.ntf = 0

            # Diversity & operator
            div_i = pop_diversity(ts.population)
            p = probi(UNIT, div_i, ts.d)
            do_self = (random.random() < p)

            # Candidate parents
            candidates: List[Tuple[str, Any]] = []
            if do_self:
                parent = tournament_select(ts.population, scores, TOURN_SIZE)
                candidates.append(("self", parent))
            else:
                # Algorithm 2 (Low-Fidelity Knowledge Extraction)
                donor_pool = []
                for j, other in enumerate(task_states):
                    if j == idx:
                        continue
                    for (arch_j, acc_j) in other.archive_best:
                        if arch_j.as_tuple() in ts.selected_set:
                            continue
                        donor_pool.append((j, arch_j, acc_j))
                if not donor_pool:
                    parent = tournament_select(ts.population, scores, TOURN_SIZE)
                    candidates.append(("self", parent))
                else:
                   
                    best_score = -1e9
                    best = None
                    for (j_src, arch_cand, _acc_src) in donor_pool:
                        try:
                            score_cand = safe_eval(arch_cand, ts, epochs=lf_epochs, freeze=FREEZE_BACKBONE)
                        finally:
                            if torch.cuda.is_available():
                                torch.cuda.empty_cache()
                        if score_cand > best_score:
                            best_score = score_cand
                            best = (j_src, arch_cand)
                    if best is None:
                        parent = tournament_select(ts.population, scores, TOURN_SIZE)
                        candidates.append(("self", parent))
                    else:
                        donor_task_idx, donor_arch = best
                        
                        ts.selected_set.add(donor_arch.as_tuple())
                        candidates = [("transfer", (donor_task_idx, donor_arch))]

            import copy as _copy
            extra_children: List[Tuple[float, 'HeadArch']] = []  

            if candidates[0][0] == "self":
                parent = candidates[0][1]
                if random.random() < MUT_PRO:
                    child = parent.mutate()
                    _op = "mutate"
                else:
                    child = _copy.deepcopy(parent)
                    _op = "self"
               
                st = time.time()
                acc_child = safe_eval(child, ts, epochs=lf_epochs, freeze=FREEZE_BACKBONE)
                elapsed = time.time() - st
                if torch.cuda.is_available(): torch.cuda.empty_cache()

                
                try:
                    GENLOG.log_child(
                        gen=gen, task=ts.name,
                        child_id=arch_hash(child),
                        op=_op,
                        arch_hash=arch_hash(child),
                        head_desc=head_summary(child),
                        params_m=0.0,  # keep LF light
                        batch_used=int(globals().get("LAST_SUCCESS_BATCH", -1)),
                        train_time_s=elapsed,
                        val_acc=float(acc_child),
                        val_loss=float("nan"),
                        lr_head=None, lr_backbone=None,
                        feasible=True,
                        arch_vec=np.asarray(encode_arch_onehot(child), dtype=float)
                    )
                except Exception:
                    pass

                counts[ts.name] += 1
                all_children.append(child)
                child_acc = acc_child
                kind = "self"

            else:
                donor_task_idx, donor_arch = candidates[0][1]
                # Algorithm 3
                tmp_children: List[Tuple[float, 'HeadArch']] = []
                for _ in range(children_per_transfer):
                    c = _copy.deepcopy(donor_arch)
                    if random.random() < MUT_AFTER_TRANSFER:
                        c = c.mutate()
                    st = time.time()
                    acc_c = safe_eval(c, ts, epochs=lf_epochs, freeze=FREEZE_BACKBONE)
                    elapsed = time.time() - st
                    tmp_children.append((acc_c, c))
                    
                    try:
                        GENLOG.log_child(
                            gen=gen, task=ts.name,
                            child_id=arch_hash(c),
                            op="xfer",
                            arch_hash=arch_hash(c),
                            head_desc=head_summary(c),
                            params_m=0.0,
                            batch_used=int(globals().get("LAST_SUCCESS_BATCH", -1)),
                            train_time_s=elapsed,
                            val_acc=float(acc_c),
                            val_loss=float("nan"),
                            lr_head=None, lr_backbone=None,
                            feasible=True,
                            arch_vec=np.asarray(encode_arch_onehot(c), dtype=float)
                        )
                    except Exception:
                        pass

                    if torch.cuda.is_available(): torch.cuda.empty_cache()
                tmp_children.sort(key=lambda x: x[0], reverse=True)
                extra_children = tmp_children[:topk_reuse]
                
                child_acc_prefetch, child = extra_children[0]
                kind = "transfer"
                
                ts.selected_set.add(donor_arch.as_tuple())
                
                counts[ts.name] += len(tmp_children)
                all_children.extend([c for _, c in tmp_children])
                
                child_acc = safe_eval(child, ts, epochs=lf_epochs, freeze=FREEZE_BACKBONE)

            
            if extra_children:
                best_i = max((scores[a] for a in ts.population), default=-1e9)
                for acc_c, c_arch in extra_children:
                    if acc_c > best_i:
                        ts.archive_best.append((c_arch, acc_c))
                        ts.ntf += 1

            # Algorithm 1
            best_i = max((scores[a] for a in ts.population), default=-1e9)
            if child_acc > best_i:
                ts.archive_best.append((child, child_acc))
                if kind == "self":
                    ts.nsf += 1
                else:
                    ts.ntf += 1

            
            if len(ts.population) > 0:
                worst = min(ts.population, key=lambda a: scores.get(a, -1e9))
                if child_acc > scores.get(worst, -1e9):
                    # replace
                    ts.population.remove(worst)
                    ts.population.append(child)

            sec[ts.name] += time.time() - block_t0

        

        
        gen_elapsed = time.time() - gen_t0
        kids_total = sum(counts.values())

        
        cur_stats = _snapshot()
        dr = dm = 0
        if cur_stats:
            dr = cur_stats.get('random_skips', 0) - (prev_stats.get('random_skips', 0) if prev_stats else 0)
            dm = cur_stats.get('mutate_skips', 0) - (prev_stats.get('mutate_skips', 0) if prev_stats else 0)
            print(f"[Gen {gen}] infeasible re-samples: random={dr}, mutate={dm} "
                  f"(cumulative: random={cur_stats.get('random_skips',0)}, mutate={cur_stats.get('mutate_skips',0)})")

        print(f"[Gen {gen}] elapsed {gen_elapsed:.1f}s  | children evals={kids_total}  | "
              f"BATCH_USAGE={globals().get('BATCH_USAGE',{})}  | OOMs={globals().get('BATCH_OOM_COUNT',{})}")

        
        try:
            GENLOG.end_gen_summary(
                gen=gen,
                infeasible_counts={'random': dr, 'mutate': dm},
                batch_usage=globals().get('BATCH_USAGE', {}),
                ooms_dict=globals().get('BATCH_OOM_COUNT', {})
                
            )
        except Exception:
            pass

    
    
    results = {}
    for ts in task_states:
        if ts.archive_best:
            best_arch, best_acc = max(ts.archive_best, key=lambda t: t[1])
        else:
            
            scored = [(safe_eval(a, ts, epochs=1, freeze=FREEZE_BACKBONE), a) for a in ts.population]
            best_acc, best_arch = max(scored, key=lambda t: t[0])
        print(f"[Result] Task {ts.name}: best arch = {best_arch} | acc={best_acc:.4f}")
        results[ts.name] = best_arch
    return results


In [None]:
def full_train_task(cfg: Dict[str, Any], head_arch: HeadArch, epochs=FINAL_EPOCHS, lr=FINAL_LR):
    df = load_task_dataframe(cfg)

    
    train_df, val_df = train_test_split(
        df, test_size=0.1, stratify=df[cfg['label_col']], random_state=SEED
    )

    
    frac = cfg.get('train_frac', None)
    if frac is not None and 0 < frac < 1.0:
        sss = StratifiedShuffleSplit(n_splits=1, train_size=frac, random_state=SEED)
        idx_keep, _ = next(sss.split(train_df, train_df[cfg['label_col']]))
        original_n = len(train_df)
        train_df = train_df.iloc[idx_keep].reset_index(drop=True)
        print(f"[{cfg['name']}] Subsampled TRAIN to {len(train_df)} of {original_n} ({frac*100:.1f}%).")

    Xtr = train_df[cfg['text_col']].astype(str).tolist()
    ytr = train_df[cfg['label_col']].astype(int).tolist()
    Xva = val_df[cfg['text_col']].astype(str).tolist()
    yva = val_df[cfg['label_col']].astype(int).tolist()

    train_loader = build_loader(Xtr, ytr, shuffle=True)
    val_loader   = build_loader(Xva, yva)

    model = SentimentModel(Backbone(BACKBONE_NAME), cfg['num_classes'], head_arch)
    
    model.freeze_backbone(False if not FREEZE_BACKBONE else True)
    model.to(DEVICE)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=1)
    scaler = torch.cuda.amp.GradScaler(enabled=MIXED_PREC)
    criterion = nn.CrossEntropyLoss()

    best_acc, best_path = 0.0, f"best_{cfg['name']}.pt"
    for ep in range(1, epochs + 1):
        model.train()
        pbar = tqdm(train_loader, desc=f"Train {cfg['name']} ep{ep}")
        for batch in pbar:
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)
            optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=MIXED_PREC):
                logits = model(input_ids, attention_mask)
                loss = criterion(logits, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            pbar.set_postfix({'loss': float(loss)})

        
        @torch.no_grad()
        def _val_step():
            return quick_val(model, val_loader)
        res = _val_step()
        scheduler.step(res.acc)
        print(f"[Val] ep{ep} acc={res.acc:.4f} f1={res.f1:.4f}")
        if res.acc > best_acc:
            best_acc = res.acc
            torch.save({'arch': asdict(head_arch),
                        'state_dict': model.state_dict(),
                        'backbone': BACKBONE_NAME},
                       best_path)
            print(f"Saved: {best_path} (acc={best_acc:.4f})")
    return best_acc




In [None]:
#Training loop
BATCH_SEARCH_CANDIDATES = [32, 24, 16, 12, 8]


#Warm start
WARMSTART_WINNERS_PATH = "winners_last.json"
WARMSTART_VARIANTS = 8
WARMSTART_KEEP_WINNER = True


FREEZE_BACKBONE = True

print("Starting search...")
best_arches = search_multitask(TASKS)

print("\nBest heads found per task:")
for name, arch in best_arches.items():
    print(f"{name}: {arch}")


FREEZE_BACKBONE = False


PER_TASK_EPOCHS = {
    'twitter2': 4,         
    'yt_taskA': 10,        
    'yt_taskB': 25,        
    'twitter_research': 25 
}

print("\nStarting short full training per task...")
for cfg in TASKS:
    ep = PER_TASK_EPOCHS.get(cfg['name'], 3)
    acc = full_train_task(cfg, best_arches[cfg['name']], epochs=ep, lr=2e-5)
    print(f"{cfg['name']} final acc ({ep} ep): {acc:.4f}")
