# Beam Search Module

> This module handles all aspects of the VAE, including encoding, decoding, and latent space representation.

In [None]:
#| default_exp search.enumerative_search

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *
from torchvision.utils import save_image
import torch
import os
from torch import nn
import torch.nn.functional as F
import pandas as pd

In [None]:
#| hide
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from mawm.core import Program, PRIMITIVE_TEMPLATES
from mawm.models.program_embedder import ProgramEmbedder
from mawm.models.program_encoder import ProgramEncoder
from mawm.models.program_synthizer import Proposer

device = 'cpu'
grid_size= 7

num_primitives = len(PRIMITIVE_TEMPLATES)
p_embed = ProgramEmbedder(
    num_primitives= num_primitives,
    param_cardinalities= [7, 7],
    max_params_per_primitive= 2,
    d_name= 32,
    d_param= 32,
).to(device)

p_encoder = ProgramEncoder(num_primitives, [grid_size, grid_size],2, seq_len=5)
proposer = Proposer(obs_dim= 32,
                    num_prims= num_primitives,
                    max_params= 2,
                    seq_len= 5,
                    prog_emb_dim_x= 32,
                    prog_emb_dim_y= 32,
                    prog_emb_dim_prims= 32)

{'CellEmpty': 0, 'CellObstacle': 1, 'CellItem': 2, 'CellGoal': 3, 'CellAgent': 4, 'GoalAt': 5, 'ItemAt': 6, 'Near': 7, 'SeeGoal': 8, 'CanMove': 9, 'OtherAgentAt': 10, 'OtherAgentNear': 11, 'OtherAgentDirection': 12}


In [None]:
SEQ_LEN = 5
PAD_PRIM = len(PRIMITIVE_TEMPLATES)  # index for padding primitive
PAD_PARAM = -1  # value for padding parameters
MAX_PARAMS = 2  # maximum number of parameters per primitive

def program_to_indices(program):
    tokens = program.tokens[:SEQ_LEN]

    prims = [t[0] for t in tokens]
    params = [list(t[1])[:MAX_PARAMS] for t in tokens]

    # pad params
    for p in params:
        while len(p) < MAX_PARAMS:
            p.append(PAD_PARAM)

    # pad program length
    pad_len = SEQ_LEN - len(tokens)
    prims += [PAD_PRIM] * pad_len
    params += [[PAD_PARAM]*MAX_PARAMS for _ in range(pad_len)]

    return (
        torch.tensor(prims).unsqueeze(0),
        torch.tensor(params).unsqueeze(0)
    )

In [None]:
def batchify_programs(programs, seq_len=SEQ_LEN, max_params=MAX_PARAMS):
    prim_list = []
    param_list = []

    for p in programs:
        prim_ids, param_ids = program_to_indices(p)
        prim_list.append(prim_ids[0])       # (SEQ_LEN)
        param_list.append(param_ids[0])     # (SEQ_LEN, max_params)

    prim_batch = torch.stack(prim_list, dim=0)      # (B, SEQ_LEN)
    param_batch = torch.stack(param_list, dim=0)    # (B, SEQ_LEN, max_params)

    return prim_batch, param_batch


In [None]:
# sos_idx = proposer.sos_idx
# EOS_IDX = len(PRIMITIVE_TEMPLATES)
# zero_params = torch.full((MAX_PARAMS,), -1, device=device)

# init_prog = Program(tokens=[(sos_idx, zero_params.tolist())])

In [None]:
# a, b = batchify_programs([init_prog, init_prog])
# a.shape, b.shape

(torch.Size([2, 5]), torch.Size([2, 5, 2]))

In [None]:
# b.dtype

torch.int64

In [None]:
import heapq
from collections import defaultdict
import torch
import numpy as np

def program_to_key(prog):
    # prog.tokens: list[(prim_idx, [params])]
    flat = []
    for p, ps in prog.tokens:
        flat.append(int(p))
        for v in ps:
            flat.append(int(v))
    return tuple(flat)

@torch.no_grad()
def enumerative_search(
    z, program_embedder, proposer, program_encoder, score_fn,
    max_prog_len=5, grid_size=7, device="cuda",
    init_top_params=3, top_prims=6, top_params=2,
    frontier_size=5000, eval_batch_size=512,
    lambda_prop=0.5, lambda_score=1.0, lambda_len=0.25, max_params=2,
    seed=None
):
    if seed is not None:
        np.random.seed(seed)
        torch.manual_seed(seed)

    z = z.to(device)
    if z.dim() == 1:
        z = z.unsqueeze(0)

    num_prims = len(PRIMITIVE_TEMPLATES)

    # caches
    score_cache = {}    # key -> float
    emb_cache = {}      # key -> tensor (optional)

    # priority queue: (-priority, counter, Program)
    counter = 0
    heap = []
    best_prog = None
    best_score = -1e9

    # helper: evaluate a batch of programs (list of Program)
    def eval_batch(programs):
        # filter those not in cache
        to_eval = []
        keys = []
        for p in programs:
            k = program_to_key(p)
            if k not in score_cache:
                to_eval.append(p)
                keys.append(k)
        if len(to_eval) == 0:
            return
        prim_ids, params_padded = batchify_programs(to_eval, seq_len=max_prog_len, max_params=max_params)
        prim_ids = prim_ids.to(torch.long).to(device)
        params_padded = params_padded.to(torch.long).to(device)
        emb = program_encoder(prim_ids, params_padded)
        scores_t = score_fn(z, emb).detach().cpu().numpy().squeeze()
        for k, s in zip(keys, scores_t):
            score_cache[k] = float(s)

    # 1) initialize single-token programs
    init_programs = []
    # for arity 0: just add; for arity >0: ask proposer for init params
    for prim_idx in range(num_prims):
        arity = PRIMITIVE_TEMPLATES[prim_idx][1]
        if arity == 0:
            init_programs.append(Program(tokens=[(prim_idx, [])]))
        else:
            # build a fake empty prefix program for proposer call; you can use (sos) or empty
            prefix = Program(tokens=[(-1, [-1]*max_params)])  # or an empty representation your batchify handles
            pids, pparams = batchify_programs([prefix], seq_len=max_prog_len, max_params=max_params)
            pvec = program_embedder(pids.to(device), pparams.to(device))
            zrep = z.repeat(1,1)
            prim_logits, param_logits = proposer.forward_step(zrep, pvec)
            # quantize param_logits to grid values, get top-k combos
            param_preds = param_logits[0].cpu().numpy()  # shape (max_params,)
            # simple deterministic: center = round
            center_vals = [int(round(float(x) * (grid_size-1))) for x in param_preds[:arity]]
            combos = [center_vals]
            # add small variations
            for _ in range(init_top_params-1):
                inst = [max(0, min(grid_size-1, v + np.random.randint(-1,2))) for v in center_vals]
                combos.append(inst)
            for inst in combos:
                init_programs.append(Program(tokens=[(prim_idx, inst)]))

    # evaluate init set in batch
    eval_batch(init_programs)

    # push to heap with priority
    for p in init_programs:
        k = program_to_key(p)
        s = score_cache[k]
        # optionally add proposer primitive logprob (cheap for single token, can skip)
        priority = lambda_score * s - lambda_len * len(p.tokens)
        heapq.heappush(heap, (-priority, counter, p)); counter += 1

    # main loop
    expansions_done = 0
    batch_buffer = []
    while heap and expansions_done < 20000:  # or other budget
        _, _, parent = heapq.heappop(heap)
        # optionally stop early if priority small
        parent_key = program_to_key(parent)
        parent_score = score_cache.get(parent_key, -1e9)
        if parent_score > best_score:
            best_score = parent_score
            best_prog = parent

        if len(parent.tokens) >= max_prog_len:
            continue

        # use proposer to get top primitives & params for this parent
        # build prefix for proposer: batchify single parent
        prims_to_try = list(range(num_prims))
        param_insts_per_prim = {p: [[ ]] for p in prims_to_try}
        if proposer is not None:
            pids, pparams = batchify_programs([parent], seq_len=max_prog_len, max_params=max_params)
            pvec = program_embedder(pids.to(device), pparams.to(device))
            zrep = z.repeat(1,1)
            prim_logits, param_logits = proposer.forward_step(zrep, pvec)
            prim_logprobs = torch.log_softmax(prim_logits[0], dim=-1).cpu().numpy()
            top_prim_idx = np.argsort(prim_logprobs)[-top_prims:][::-1]
            prims_to_try = [int(x) for x in top_prim_idx if int(x) < num_prims]
            # params: quantize param_logits
            param_preds = param_logits[0].cpu().numpy()
            for p in prims_to_try:
                ar = PRIMITIVE_TEMPLATES[p][1]
                insts = []
                if ar == 0:
                    insts = [[]]
                else:
                    center = [int(round(float(x) * (grid_size-1))) for x in param_preds[:ar]]
                    # generate small set of param instantiations around center
                    insts = [center]
                    # add +-1 neighbors
                    for _ in range(top_params-1):
                        insts.append([max(0, min(grid_size-1, c + np.random.randint(-1,2))) for c in center])
                param_insts_per_prim[p] = insts

        # produce children
        for prim in prims_to_try:
            for inst in param_insts_per_prim[prim]:
                child = parent.extend(int(prim), inst)
                # print(child.tokens)
                k = program_to_key(child)
                if k in score_cache:
                    sc = score_cache[k]
                    priority = lambda_score * sc - lambda_len * len(child.tokens)
                    heapq.heappush(heap, (-priority, counter, child)); counter += 1
                else:
                    batch_buffer.append(child)

        # batch eval when enough
        if len(batch_buffer) >= eval_batch_size:
            eval_batch(batch_buffer)
            for cb in batch_buffer:
                k = program_to_key(cb)
                sc = score_cache[k]
                priority = lambda_score * sc - lambda_len * len(cb.tokens)
                heapq.heappush(heap, (-priority, counter, cb)); counter += 1
            batch_buffer = []

        expansions_done += 1
        # maintain frontier size
        if len(heap) > frontier_size:
            # drop worst to keep size small (inefficient to remove many from heap; you can rebuild)
            heap = heapq.nsmallest(frontier_size, heap)
            heapq.heapify(heap)

    # final flush
    if batch_buffer:
        eval_batch(batch_buffer)
        for cb in batch_buffer:
            k = program_to_key(cb)
            sc = score_cache[k]
            priority = lambda_score * sc - lambda_len * len(cb.tokens)
            heapq.heappush(heap, (-priority, counter, cb)); counter += 1
        batch_buffer = []
    print(score_cache)
    return best_prog, best_score


In [None]:
# hide
from mawm.models.program_encoder import ProgramPredictor, ProgramEncoder

import torch
def loss_fn(z_hat, z, loss_exp= 1):
    return torch.mean(torch.abs(z_hat - z) ** loss_exp) / loss_exp

def score_fn(z, m, predictor= None):
    predictor = ProgramPredictor()
    # import ipdb; ipdb.set_trace()
    scores = []
    with torch.no_grad():
        z_hat = predictor(m)
    
    return torch.tensor([-loss_fn(z_hat[i].unsqueeze(0), z) for i in range(z_hat.size(0))])


In [None]:
predictor = ProgramPredictor()
# set predictor parameters to stop gradients
for param in predictor.parameters():
    param.requires_grad = False

In [None]:
# #| hide
# from mawm.models.program_embedder import ProgramEmbedder, PRIMITIVE_TEMPLATES
# from mawm.models.program_encoder import ProgramEncoder
# from mawm.core import *
# from mawm.models.program_synthizer import Proposer
# import torch
# z = torch.randn(1, 32)
# device = 'cpu'
# grid_size= 7

# num_primitives = len(PRIMITIVE_TEMPLATES)
# p_embed = ProgramEmbedder(
#     num_primitives= num_primitives,
#     param_cardinalities= [7, 7],
#     max_params_per_primitive= 2,
#     d_name= 32,
#     d_param= 32,
# )

# p_encoder = ProgramEncoder(num_primitives, [grid_size, grid_size],2, seq_len=5)
# proposer = Proposer(obs_dim= 32,
#                     num_prims= num_primitives,
#                     max_params= 2,
#                     seq_len= 5,
#                     prog_emb_dim_x= 32,
#                     prog_emb_dim_y= 32,
#                     prog_emb_dim_prims= 32)



In [None]:
# import torch
# z = torch.randn(32, 32).to(device)
# res = enumerative_search(
#     z[0], p_embed, proposer, p_encoder, score_fn,
#     max_prog_len=5, grid_size=7, device="cpu",
#     init_top_params=3, top_prims=6, top_params=2,
#     frontier_size=5000, eval_batch_size=512,
#     lambda_prop=0.7, lambda_score=0.2, lambda_len=0.01, max_params=2,
#     seed=None
# )

{(0, 3, 3): -0.5741857290267944, (0, 2, 3): -0.5741221904754639, (1, 3, 3): -0.5742213129997253, (1, 4, 2): -0.5742852687835693, (1, 3, 2): -0.5742353200912476, (2, 3, 3): -0.5742419958114624, (2, 2, 2): -0.5742534399032593, (2, 4, 3): -0.5743171572685242, (3, 3, 3): -0.5742272734642029, (3, 4, 3): -0.5743175745010376, (3, 4, 4): -0.5742443203926086, (4, 3, 3): -0.5742278099060059, (4, 2, 3): -0.5741777420043945, (5, 3, 3): -0.57429438829422, (5, 2, 3): -0.5742396116256714, (6, 3, 3): -0.574146568775177, (6, 4, 4): -0.574253261089325, (6, 2, 4): -0.5741258859634399, (7,): -0.5740537643432617, (8,): -0.574178159236908, (9, 3): -0.5741127133369446, (9, 4): -0.5741903781890869, (9, 2): -0.574112057685852, (10, 3, 3): -0.5742272734642029, (10, 4, 3): -0.5742970108985901, (10, 3, 2): -0.5742251873016357, (11,): -0.5742096900939941, (12, 3): -0.5741569995880127, (12, 2): -0.5740978717803955, (7, 0, 3, 3): -0.5876527428627014, (7, 0, 4, 4): -0.5876098871231079, (7, 9, 3): -0.5876118540763855,

In [None]:
# res

(Near(), -0.5740537643432617)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()