# Beam Search Module

> This module handles all aspects of the VAE, including encoding, decoding, and latent space representation.

In [None]:
#| default_exp search.beam_search

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *
from torchvision.utils import save_image
import torch
import os
from torch import nn
import torch.nn.functional as F
import pandas as pd

In [None]:
#| hide
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from mawm.core import Program, PRIMITIVE_TEMPLATES
from mawm.models.program_embedder import ProgramEmbedder
from mawm.models.program_encoder import ProgramEncoder
from mawm.models.program_synthizer import Proposer

device = 'cpu'
grid_size= 7

num_primitives = len(PRIMITIVE_TEMPLATES)
p_embed = ProgramEmbedder(
    num_primitives= num_primitives,
    param_cardinalities= [7, 7],
    max_params_per_primitive= 2,
    d_name= 32,
    d_param= 32,
).to(device)

p_encoder = ProgramEncoder(num_primitives, [grid_size, grid_size],2, seq_len=5)
proposer = Proposer(obs_dim= 32,
                    num_prims= num_primitives,
                    max_params= 2,
                    seq_len= 5,
                    prog_emb_dim_x= 32,
                    prog_emb_dim_y= 32,
                    prog_emb_dim_prims= 32)

In [None]:
SEQ_LEN = 5
PAD_PRIM = len(PRIMITIVE_TEMPLATES)  # index for padding primitive
PAD_PARAM = -1  # value for padding parameters
MAX_PARAMS = 2  # maximum number of parameters per primitive

def program_to_indices(program):
    tokens = program.tokens[:SEQ_LEN]

    prims = [t[0] for t in tokens]
    params = [list(t[1])[:MAX_PARAMS] for t in tokens]

    # pad params
    for p in params:
        while len(p) < MAX_PARAMS:
            p.append(PAD_PARAM)

    # pad program length
    pad_len = SEQ_LEN - len(tokens)
    prims += [PAD_PRIM] * pad_len
    params += [[PAD_PARAM]*MAX_PARAMS for _ in range(pad_len)]

    return (
        torch.tensor(prims).unsqueeze(0),
        torch.tensor(params).unsqueeze(0)
    )

In [None]:
def batchify_programs(programs, seq_len=SEQ_LEN, max_params=MAX_PARAMS):
    prim_list = []
    param_list = []

    for p in programs:
        prim_ids, param_ids = program_to_indices(p)
        prim_list.append(prim_ids[0])       # (SEQ_LEN)
        param_list.append(param_ids[0])     # (SEQ_LEN, max_params)

    prim_batch = torch.stack(prim_list, dim=0)      # (B, SEQ_LEN)
    param_batch = torch.stack(param_list, dim=0)    # (B, SEQ_LEN, max_params)

    return prim_batch, param_batch


In [None]:
sos_idx = proposer.sos_idx
EOS_IDX = len(PRIMITIVE_TEMPLATES)
zero_params = torch.full((MAX_PARAMS,), -1, device=device)

init_prog = Program(tokens=[(sos_idx, zero_params.tolist())])

In [None]:
a, b = batchify_programs([init_prog, init_prog])
a.shape, b.shape

(torch.Size([2, 5]), torch.Size([2, 5, 2]))

In [None]:
b.dtype

torch.int64

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from mawm.core import Program, PRIMITIVE_TEMPLATES
# from mawm.models.program_embedder import batchify_programs
from mawm.models.program_encoder import ProgramEncoder
from mawm.models.program_synthizer import Proposer

@torch.no_grad()
def neural_guided_beam_search(
    z,
    program_embedder: nn.Module,
    proposer: nn.Module,
    program_encoder: nn.Module,
    score_fn,
    MAX_PARAMS=3,
    beam_width=5,
    topk=6,
    max_prog_len=5,
    grid_size=7,
    lambdas= (0.25, 0.5, 0.25),
    device="cuda"
):
    """
    neural-guided beam search.
    """

    # --- normalize z shape ---
    z = z.to(device)
    if z.dim() == 1:
        z = z.unsqueeze(0)
    lambda_1, lambda_2, lambda_3 = lambdas

    sos_idx = proposer.sos_idx
    EOS_IDX = len(PRIMITIVE_TEMPLATES)
    zero_params = torch.full((MAX_PARAMS,), -1, device=device)

    # initial program: [SOS]
    init_prog = Program(tokens=[(sos_idx, zero_params.tolist())])
    beam = [(init_prog, 0.0)]  # (program, score)

    best_program = init_prog
    best_score = -float("inf")

    # ---- start beam search ----
    for depth in range(1, max_prog_len + 1):
        alive = []
        finished = []

        for prog, score in beam:
            last_prim = prog.tokens[-1][0]
            if last_prim == EOS_IDX:
                finished.append((prog, score))
            else:
                alive.append((prog, score))

        if len(alive) == 0:
            break
        # ---- 1) build batch of prefixes ----
        prefix_programs = [p for (p, _) in alive]
        prev_idx_batch, prev_params_batch = batchify_programs(prefix_programs, seq_len=max_prog_len, max_params=MAX_PARAMS)
        prev_idx_batch = prev_idx_batch.to(torch.long)
        prev_params_batch = prev_params_batch.to(torch.long)
        B = prev_idx_batch.shape[0]

        p_vec = program_embedder(prev_idx_batch.to(device),
                                 prev_params_batch.to(device))

        # replicate z
        z_batch = z.repeat(B, 1)

        # # proposer forward
        prim_logits_batch, param_pred_batch = proposer.forward_step(z_batch, p_vec)

        # ---- 2) expand each beam entry ----
        expansions = []
        for b_i, (prog_parent, parent_score) in enumerate(alive):
            prim_logprobs = F.log_softmax(prim_logits_batch[b_i], dim=-1)
            top_vals, top_idx = torch.topk(prim_logprobs, k=topk)

            for k_i in range(topk):
                prim_idx = int(top_idx[k_i].item())
                prim_logp = float(top_vals[k_i].item())
                if prim_idx == EOS_IDX:
                    # add program as completed (no more expansions)
                    new_prog = prog_parent.extend(prim_idx, [])
                    expansions.append((new_prog, parent_score + prim_logp))
                    # print(f"Beam search found EOS at depth {depth} with program: {new_prog}")
                    continue

                arity = PRIMITIVE_TEMPLATES[prim_idx][1]              

                # ----- parameter instantiation -----
                instantiations = []
                if arity > 0:
                    pred_params = param_pred_batch[b_i][:arity].cpu().numpy()
                    for _ in range(3):
                        if np.random.rand() < 0.7:
                            inst = [
                                float(round(float(p) * (grid_size - 1)))
                                for p in pred_params
                            ]
                        else:
                            inst = [
                                float(np.random.randint(0, grid_size))
                                for _ in range(arity)
                            ]
                        instantiations.append(inst)
                else:
                    instantiations.append([])

                # ----- create new beam children -----
                for inst_params in instantiations:
                    new_prog = prog_parent.extend(prim_idx, inst_params)
                    expansions.append((new_prog, parent_score + prim_logp))

        if not expansions:
            break

        # ---- 3) Score all expanded programs ----
        cand_programs = [p for (p, _) in expansions]
        prim_ids_padded, params_padded = batchify_programs(cand_programs, seq_len=max_prog_len, max_params=MAX_PARAMS)
        # import ipdb; ipdb.set_trace()
        params_padded = params_padded.to(torch.long)
        prim_ids_padded = prim_ids_padded.to(torch.long)

        with torch.no_grad():
            # program_encoder = get_pencoder(prim_ids_padded.shape[-1])
            prog_emb_batch = program_encoder(prim_ids_padded, params_padded)

        scores = score_fn(z, prog_emb_batch)

        # ---- 4) attach scores and prune ----
        candidates = []
        for idx_c, (prog, base_score) in enumerate(expansions):
            total = lambda_1 * base_score + lambda_2 * float(scores[idx_c].item()) - lambda_3 * len(prog.tokens) 
            candidates.append((prog, total))

        candidates.sort(key=lambda x: x[1], reverse=True)
        # best_program, best_score = max(candidates, key=lambda x: x[1])

        topk_from_alive = candidates#[:beam_width]
        beam = topk_from_alive + finished

        filtered_beam = []
        for p in beam:
            if len(p[0].tokens) == 1 and p[0].tokens[0][0] == -1:
                continue
            filtered_beam.append(p)
        if len(filtered_beam) == 0:
            break

        best_program, best_score = max(filtered_beam, key=lambda x: x[1])
        beam = sorted(beam, key=lambda x: x[1], reverse=True)[:beam_width]

        # update alive + finished splits
        alive  = [(p,s) for (p,s) in beam if not p.finished]
        finished = [(p,s) for (p,s) in beam if p.finished]

    return best_program, best_score


In [None]:
# hide
from mawm.models.program_encoder import ProgramPredictor, ProgramEncoder

import torch
def loss_fn(z_hat, z, loss_exp= 1):
    return torch.mean(torch.abs(z_hat - z) ** loss_exp) / loss_exp

def score_fn(z, m, predictor= None):
    predictor = ProgramPredictor()
    # import ipdb; ipdb.set_trace()
    scores = []
    with torch.no_grad():
        z_hat = predictor(m)
    
    return [-loss_fn(z_hat[i].unsqueeze(0), z) for i in range(z_hat.size(0))]


In [None]:
predictor = ProgramPredictor()
# set predictor parameters to stop gradients
for param in predictor.parameters():
    param.requires_grad = False

In [None]:
#| hide
from mawm.models.program_embedder import ProgramEmbedder, PRIMITIVE_TEMPLATES
from mawm.models.program_encoder import ProgramEncoder
from mawm.core import *
from mawm.models.program_synthizer import Proposer
import torch
z = torch.randn(1, 32)
device = 'cpu'
grid_size= 7

num_primitives = len(PRIMITIVE_TEMPLATES)
p_embed = ProgramEmbedder(
    num_primitives= num_primitives,
    param_cardinalities= [7, 7],
    max_params_per_primitive= 2,
    d_name= 32,
    d_param= 32,
)

p_encoder = ProgramEncoder(num_primitives, [grid_size, grid_size],2, seq_len=5)
proposer = Proposer(obs_dim= 32,
                    num_prims= num_primitives,
                    max_params= 2,
                    seq_len= 5,
                    prog_emb_dim_x= 32,
                    prog_emb_dim_y= 32,
                    prog_emb_dim_prims= 32)



In [None]:
#| export
@torch.no_grad()
def neural_guided_beam_search_batched(
    z_batch,                      # shape [B, D]
    program_embedder,
    proposer,
    program_encoder,
    score_fn,
    MAX_PARAMS=3,
    beam_width=5,
    topk=6,
    max_prog_len=5,
    grid_size=7,
    lambdas=(0.25, 0.5, 0.25),
    device="cuda",
):
    """
    Batched version: run *one beam search per batch element*.
    Returns:
        list of length B, each entry is (best_program, best_score)
    """

    z_batch = z_batch.to(device)
    if z_batch.dim() == 1:
        z_batch = z_batch.unsqueeze(0)

    B = z_batch.shape[0]
    results = []

    # --- run your SAME beam search per batch element ---
    for b in range(B):
        z_single = z_batch[b]  # shape [D]
        best_prog, best_score = neural_guided_beam_search(
            z_single,
            program_embedder,
            proposer,
            program_encoder,
            score_fn,
            MAX_PARAMS=MAX_PARAMS,
            beam_width=beam_width,
            topk=topk,
            max_prog_len=max_prog_len,
            grid_size=grid_size,
            lambdas=lambdas,
            device=device,
        )
        results.append((best_prog, best_score))

    return results


In [None]:
# lst_res = []
# z = torch.randn(32, 32).to(device)
# for i in range(10000):
#     res = neural_guided_beam_search_batched(z, p_embed, proposer, p_encoder, score_fn, beam_width=3, topk=4, max_prog_len= 5, lambdas= (0.7, 0.2, 0.01), device= device)
#     lst_res.append(res)

In [None]:
# lst_res[1]

[(OtherAgentDirection(-1, -1, -1) | GoalAt(6.0, 0.0) | GoalAt(3.0, 3.0) | GoalAt(3.0, 3.0) | GoalAt(3.0, 3.0) | GoalAt(3.0, 3.0),
  -5.453608751927137),
 (OtherAgentDirection(-1, -1, -1) | GoalAt(3.0, 4.0) | GoalAt(3.0, 3.0) | GoalAt(3.0, 3.0) | GoalAt(3.0, 3.0) | GoalAt(3.0, 3.0),
  -5.565990809686897),
 (OtherAgentDirection(-1, -1, -1) | GoalAt(6.0, 1.0) | GoalAt(3.0, 3.0) | GoalAt(3.0, 3.0) | GoalAt(4.0, 0.0) | GoalAt(3.0, 3.0),
  -5.459200592587708),
 (OtherAgentDirection(-1, -1, -1) | GoalAt(6.0, 1.0) | GoalAt(3.0, 0.0) | GoalAt(1.0, 5.0) | GoalAt(3.0, 3.0) | GoalAt(3.0, 3.0),
  -5.462237890699624),
 (OtherAgentDirection(-1, -1, -1) | GoalAt(5.0, 0.0) | GoalAt(3.0, 3.0) | GoalAt(3.0, 3.0) | GoalAt(3.0, 3.0) | GoalAt(3.0, 3.0),
  -5.422450281898497),
 (OtherAgentDirection(-1, -1, -1) | GoalAt(3.0, 2.0) | GoalAt(3.0, 3.0) | GoalAt(1.0, 3.0) | GoalAt(5.0, 1.0) | GoalAt(3.0, 3.0),
  -5.464751803342818),
 (OtherAgentDirection(-1, -1, -1) | GoalAt(0.0, 2.0) | GoalAt(3.0, 2.0) | GoalAt(0

In [None]:
n_samples = 32 * 32 * 10000
n_samples

10240000

In [None]:
total_time = 208 / 60
total_time


3.466666666666667

In [None]:
time_per_sample = total_time / n_samples
time_per_sample

3.3854166666666667e-07

In [None]:
# [lst_res[0][i][0].__len__() for i in range(32)]

[11,
 3,
 10,
 11,
 11,
 2,
 11,
 11,
 11,
 11,
 2,
 6,
 11,
 11,
 3,
 4,
 11,
 4,
 5,
 11,
 11,
 3,
 2,
 8,
 11,
 5,
 3,
 11,
 11,
 11,
 5,
 2]

In [None]:
# a, b  = batchify_programs([lst_res[0][i][0] for i in range(32)])

In [None]:
# a.shape, b.shape

(torch.Size([32, 11]), torch.Size([32, 11, 2]))

In [None]:
# import numpy as np
# idx = np.random.choice(len(lst_res))
# print(idx)
# lst_res[idx][0]

32


OtherAgentDirection(-1, -1, -1) | CellEmpty(1.0, 0.0) | OtherAgentDirection(3.0,) | GoalAt(3.0, 3.0) | ItemAt(0.0, 1.0) | GoalAt(3.0, 3.0) | CellObstacle(3.0, 3.0) | CellGoal(3.0, 3.0) | GoalAt(3.0, 3.0) | CellAgent(2.0, 4.0) | OtherAgentNear()

In [None]:
# sum([len(lst_res[idx][0]) for idx in range(len(lst_res))]) / len(lst_res)

6.716

i think the formulation of the specification is weak and does not sound natural. The problem of program synthesis is the task of constructing a program that provably satisfies a given high-level formal specification. So, the program should be a way of adhering to the specification. For example, when applied to the input, it should satisfy the output. 

In our case, you defined the specification with the output as s(x, p), which is ambiguous and does not adhere to the definition of program synthesis. Now, what are other ways we can cast the problem as a program synthesis problem? 

My data set is structured as (observation, action, reward, next_observation). Can we devise a program that satisfy some specifications from such scenario? ideally the program will be shared with another agent

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()