In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import requests
from datetime import datetime
from tqdm import tqdm
import re
import warnings
warnings.filterwarnings("ignore")

CONFIG = {
    "csv_file": "nas_cnn_kan_cifar100_results.csv",
    "n_samples_to_load": 5000,
    "test_size": 0.2,
    "random_state": 42,
    "n_test_samples": 200,
    "ollama_url": "http://localhost:11434",
    "timeout": 1200,
    "temperature": 0.2,
    "top_p": 0.9,
    "num_predict": 120,
    "max_tries_format": 3,
    "lora_adapter_dir": "nas_mistral7b_lora",
    "fine_tuned_model_name": "nas-mistral-lora",
}

KAN_ACTS = {"RELU": "ReLU", "GELU": "GELU", "SILU": "SiLU"}

def load_data(csv_file, n_samples=None):
    df = pd.read_csv(csv_file)
    if n_samples is not None and n_samples < len(df):
        df = df.sample(n=n_samples, random_state=CONFIG["random_state"]).reset_index(drop=True)
    return df

def normalize_arch(text: str) -> str:
    s = " ".join(str(text).strip().split())
    s = s.replace("→", "->")
    s = s.replace(" - ", " -> ")
    s = s.replace("->", " -> ")
    s = re.sub(r"\s+->\\s+", " -> ", s)
    for prefix in ["ARCHITECTURE:", "Architecture:", "architecture:"]:
        if s.startswith(prefix):
            s = s[len(prefix):].strip()
    s = re.sub(r"Conv2dk(\d+)\s*,\s*p(\d+)", r"Conv2d(k=\1,p=\2)", s, flags=re.I)
    s = re.sub(r"Conv2d\s*\(\s*k\s*=\s*(\d+)\s*,\s*p\s*=\s*(\d+)\s*\)",
               r"Conv2d(k=\1,p=\2)", s, flags=re.I)
    m = re.search(r"KAN\s*order\s*(\d+)\s*,\s*grid\s*(\d+)\s*,\s*act\s*([A-Za-z0-9_]+)", s, flags=re.I)
    if m:
        order, grid, act = m.group(1), m.group(2), m.group(3).upper()
        act = KAN_ACTS.get(act, act.title())
        s = re.sub(
            r"KAN\s*order\s*\d+\s*,\s*grid\s*\d+\s*,\s*act\s*[A-Za-z0-9_]+",
            f"KAN(order={order},grid={grid},act={act})",
            s, flags=re.I
        )
    s = re.sub(
        r"KAN\s*\(\s*order\s*=\s*(\d+)\s*,\s*grid\s*=\s*(\d+)\s*,\s*act\s*=\s*([A-Za-z]+)\s*\)",
        lambda x: f"KAN(order={x.group(1)},grid={x.group(2)},act={KAN_ACTS.get(x.group(3).upper(), x.group(3))})",
        s
    )
    return s.strip()

def is_canonical_arch(s: str) -> bool:
    conv = r"(?:Conv2d\(k=\d+,p=\d+\))(?: -> Conv2d\(k=\d+,p=\d+\))*"
    kan  = r"KAN\(order=\d+,grid=\d+,act=(?:ReLU|GELU|SiLU)\)"
    return re.fullmatch(conv + r" -> " + kan, s) is not None

def serialize_constraints_from_row(row):
    acc_weight = float(row.get('acc_weight', 0.8))
    eff_weight = 1.0 - acc_weight
    return (
        "Goal: choose an architecture by trading off CIFAR-100 test_accuracy vs efficiency.\\n"
        "Trade-off:\\n"
        f"- acc_weight = {acc_weight:.3f} (higher favors accuracy)\\n"
        f"- eff_weight = {eff_weight:.3f} (higher favors efficiency)\\n"
        "Constraints (must satisfy all):\\n"
        f"- num_params <= {int(row['num_params'])}\\n"
        f"- num_flops <= {int(row['num_flops'])}\\n"
        f"- epoch_time_sec <= {float(row['epoch_time_sec']):.3f}\\n"
        "Output ONLY one architecture string in this exact format:\\n"
        "Conv2d(k=<int>,p=<int>) -> ... -> KAN(order=<int>,grid=<int>,act=<ReLU|GELU|SiLU>)"
    )

def serialize_constraints_manual(num_params, num_flops, epoch_time_sec, acc_weight=0.8):
    eff_weight = 1.0 - float(acc_weight)
    return (
        "Goal: choose an architecture by trading off CIFAR-100 test_accuracy vs efficiency.\\n"
        "Trade-off:\\n"
        f"- acc_weight = {float(acc_weight):.3f} (higher favors accuracy)\\n"
        f"- eff_weight = {eff_weight:.3f} (higher favors efficiency)\\n"
        "Constraints (must satisfy all):\\n"
        f"- num_params <= {int(num_params)}\\n"
        f"- num_flops <= {int(num_flops)}\\n"
        f"- epoch_time_sec <= {float(epoch_time_sec):.3f}\\n"
        "Output ONLY one architecture string in this exact format:\\n"
        "Conv2d(k=<int>,p=<int>) -> ... -> KAN(order=<int>,grid=<int>,act=<ReLU|GELU|SiLU>)"
    )

def build_prompt(constraint_text):
    return f"<s>[INST] {constraint_text} [/INST]"

def generate_raw(prompt, model_name, config):
    r = requests.post(
        f"{config['ollama_url']}/api/generate",
        json={
            "model": model_name,
            "prompt": prompt,
            "stream": False,
            "options": {
                "temperature": config["temperature"],
                "top_p": config["top_p"],
                "num_predict": config["num_predict"],
            },
        },
        timeout=config["timeout"],
    )
    return r.json()["response"].strip()

def predict_architecture_canonical_from_text(constraint_text, model_name, config):
    last_raw, last_norm = "", ""
    for _ in range(config["max_tries_format"]):
        prompt = build_prompt(constraint_text)
        raw = generate_raw(prompt, model_name, config)
        norm = normalize_arch(raw)
        last_raw, last_norm = raw, norm
        if is_canonical_arch(norm):
            return norm, raw
    return last_norm, last_norm

def predict_architecture_canonical_from_row(row, model_name, config):
    constraint_text = serialize_constraints_from_row(row)
    return predict_architecture_canonical_from_text(constraint_text, model_name, config)

def token_jaccard(a, b):
    a = str(a).replace("->", " ").replace(",", " ")
    b = str(b).replace("->", " ").replace(",", " ")
    A = set(a.split())
    B = set(b.split())
    if len(A | B) == 0:
        return 0.0
    return len(A & B) / len(A | B)

def evaluate_model(model_name, test_df, config):
    n = min(config["n_test_samples"], len(test_df))
    test_subset = test_df.sample(n=n, random_state=config["random_state"]).reset_index(drop=True)
    preds, raws, exacts, jacs = [], [], [], []
    start = datetime.now()
    pbar = tqdm(total=n, desc=f"Evaluating {model_name}", unit="sample", dynamic_ncols=True)
    for i in range(n):
        row = test_subset.iloc[i]
        pred_arch, raw = predict_architecture_canonical_from_row(row, model_name, config)
        true_arch = str(row["architecture"]).strip()
        preds.append(pred_arch)
        raws.append(raw)
        exacts.append(int(pred_arch == true_arch))
        jacs.append(token_jaccard(pred_arch, true_arch))
        pbar.set_postfix({
            "exact_avg": f"{np.mean(exacts):.3f}",
            "jac_avg": f"{np.mean(jacs):.3f}",
        })
        pbar.update(1)
    pbar.close()
    elapsed = (datetime.now() - start).total_seconds()
    exact_acc = float(np.mean(exacts))
    mean_jac = float(np.mean(jacs))
    print(f"\\nExecution Time: {elapsed:.1f}s")
    print(f"Exact-match accuracy: {exact_acc:.4f}")
    print(f"Mean token-Jaccard:  {mean_jac:.4f}")
    out = test_subset.copy()
    out["predicted_architecture"] = preds
    out["raw_model_output"] = raws
    out["exact_match"] = exacts
    out["token_jaccard"] = jacs
    out_csv = f"nas_arch_predictions_{model_name.replace(':','_')}.csv"
    out.to_csv(out_csv, index=False)
    print(f"Saved: {out_csv}")
    return {"model": model_name, "exact_acc": exact_acc, "mean_jaccard": mean_jac, "elapsed_sec": elapsed}

def demo_manual_constraints(model_name, config):
    demo_constraints = [
        {"num_params": 120_000,   "num_flops": 3_000_000,   "epoch_time_sec": 30.0, "acc_weight": 0.8},
        {"num_params": 700_000,   "num_flops": 25_000_000,  "epoch_time_sec": 40.0, "acc_weight": 0.6},
        {"num_params": 2_000_000, "num_flops": 90_000_000,  "epoch_time_sec": 70.0, "acc_weight": 0.9},
        {"num_params": 5_000_000, "num_flops": 200_000_000, "epoch_time_sec": 200.0, "acc_weight": 0.5},
    ]
    print("\\n" + "=" * 80)
    print("DEMO: Manual constraint queries")
    print("=" * 80)
    pbar = tqdm(demo_constraints, desc="Manual queries", unit="query", dynamic_ncols=True)
    for c in pbar:
        constraint_text = serialize_constraints_manual(c["num_params"], c["num_flops"], c["epoch_time_sec"], c["acc_weight"])
        pred, raw = predict_architecture_canonical_from_text(constraint_text, model_name, config)
        ok = is_canonical_arch(pred)
        pbar.set_postfix({"canonical": ok})
        print("\\nConstraints:", c)
        print("Predicted architecture:", pred)
        if not ok:
            print("Raw model output:", raw)

def setup_ollama_model(config):
    modelfile_content = f"""FROM mistral:7b
ADAPTER ./{config['lora_adapter_dir']}"""
    
    print("\\n" + "=" * 80)
    print("Setting up fine-tuned model in Ollama")
    print("=" * 80)
    print("Modelfile content:")
    print(modelfile_content)
    print()
    
    r = requests.post(
        f"{config['ollama_url']}/api/create",
        json={
            "name": config['fine_tuned_model_name'],
            "modelfile": modelfile_content,
            "stream": False
        },
        timeout=300
    )
    
    if r.status_code == 200:
        print(f"✓ Model '{config['fine_tuned_model_name']}' created successfully")
    else:
        print(f"Model creation response: {r.status_code}")
    
    return config['fine_tuned_model_name']

def main():
    df = load_data(CONFIG["csv_file"], n_samples=CONFIG["n_samples_to_load"])
    bins = pd.qcut(df["test_accuracy"], q=10, duplicates="drop")
    train_df, test_df = train_test_split(
        df,
        test_size=CONFIG["test_size"],
        random_state=CONFIG["random_state"],
        stratify=bins,
    )
    
    model_name = setup_ollama_model(CONFIG)
    
    demo_manual_constraints(model_name, CONFIG)
    
    results = []
    results.append(evaluate_model(model_name, test_df, CONFIG))
    
    print("\\nFinal Results:")
    for r in results:
        print(f"{r['model']}: exact={r['exact_acc']:.4f}, jaccard={r['mean_jaccard']:.4f}, time={r['elapsed_sec']:.1f}s")

if __name__ == "__main__":
    main()


DEMO: Manual constraint queries


Manual queries:  25%|██▌       | 1/4 [01:59<05:59, 119.81s/query, canonical=1]


Constraints: {'num_params': 120000, 'num_flops': 3000000, 'epoch_time_sec': 30.0}
Predicted architecture: Conv2d(k=3,p=1) -> Conv2d(k=3,p=1) -> KAN(order=4,grid=5,act=ReLU)


Manual queries:  50%|█████     | 2/4 [02:52<02:41, 80.55s/query, canonical=1] 


Constraints: {'num_params': 700000, 'num_flops': 25000000, 'epoch_time_sec': 40.0}
Predicted architecture: Conv2d(k=3,p=1) -> Conv2d(k=5,p=1) -> Conv2d(k=7,p=1) -> KAN(order=4,grid=8,act=SiLU)


Manual queries:  75%|███████▌  | 3/4 [03:52<01:11, 71.09s/query, canonical=1]


Constraints: {'num_params': 2000000, 'num_flops': 90000000, 'epoch_time_sec': 70.0}
Predicted architecture: Conv2d(k=3,p=1) -> Conv2d(k=3,p=1) -> Conv2d(k=5,p=1) -> Conv2d(k=7,p=1) -> KAN(order=4,grid=8,act=SiLU)


Manual queries: 100%|██████████| 4/4 [04:44<00:00, 71.22s/query, canonical=1]



Constraints: {'num_params': 5000000, 'num_flops': 200000000, 'epoch_time_sec': 200.0}
Predicted architecture: Conv2d(k=3,p=1) -> Conv2d(k=5,p=1) -> Conv2d(k=7,p=1) -> KAN(order=4,grid=8,act=SiLU)


Evaluating mistral:7b: 100%|██████████| 200/200 [1:36:28<00:00, 28.94s/sample, exact_avg=0.000, jac_avg=0.251]


Execution Time: 5788.4s
Exact-match accuracy: 0.0000
Mean token-Jaccard:  0.2506
Saved: nas_arch_predictions_mistral_7b.csv

Final Results:
mistral:7b: exact=0.0000, jaccard=0.2506, time=5788.4s



