# Beam Search Module

> This module handles all aspects of the VAE, including encoding, decoding, and latent space representation.

In [None]:
#| default_exp search.beam_search

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *
from torchvision.utils import save_image
import torch
import os
from torch import nn
import torch.nn.functional as F
import pandas as pd

In [None]:
# #| export
# import torch
# import torch.nn.functional as F
# import numpy as np
# import torch
# import torch.nn.functional as F

# # --- Utility Class for Beam Search Node ---
# class BeamNode:
#     def __init__(self, sequence, score, state):
#         self.sequence = sequence  # List of token IDs (the sequence built so far)
#         self.score = score        # Total log-probability of the sequence
#         self.state = state        # Decoder hidden state/memory after the last token

#     def __lt__(self, other):
#         # Comparison operator for min-heap/sorting
#         return self.score < other.score

# def beam_search_decode(
#     model,
#     encoder_output,
#     initial_decoder_state,
#     SOS_IDX,
#     EOS_IDX,
#     K,
#     max_len
# ):
#     # 1. Initialization
#     device = encoder_output.device
    
#     # The beam starts with a single candidate: the <SOS> token.
#     # We use a list to hold the BeamNode objects.
#     initial_token = torch.tensor([SOS_IDX], device=device).unsqueeze(0) # (1, 1) tensor
#     initial_score = 0.0 # Log probability of <SOS> is 0
    
#     # The model should handle mapping the encoder_output to the initial state
#     # for the first decoding step.
    
#     # NOTE: In a real implementation, you need to handle batching. This simple example 
#     # assumes a single input sentence for clarity. For batching, K beams must be created
#     # for *each* item in the batch.

#     beam = [
#         BeamNode(
#             sequence=[SOS_IDX],
#             score=initial_score,
#             state=initial_decoder_state # Initial state, often the last encoder hidden state
#         )
#     ]
    
#     # Store finished sequences to keep track of the final results
#     finished_hypotheses = []

#     # 2. Main Decoding Loop
#     for t in range(max_len):
#         candidates = []
        
#         # Prepare data for batched forward pass on the decoder
#         # Only take the *last* token from each current beam for the next step's input
#         current_tokens = torch.tensor([n.sequence[-1] for n in beam], device=device).unsqueeze(1) # (K, 1)
        
#         # Stack the states of the current beams for a batched forward pass
#         # This part depends heavily on your model's architecture (RNN state tuple or Transformer memory)
#         # Assuming RNN: hidden state is (h, c)
#         if isinstance(beam[0].state, tuple):
#              current_states = tuple(torch.stack([n.state[i] for n in beam], dim=1) for i in range(len(beam[0].state))) # (num_layers, K, hidden_size)
#         else:
#              current_states = torch.stack([n.state for n in beam], dim=1) # (num_layers, K, hidden_size)

#         # DECODER STEP: Compute next token log-probabilities
#         # output_logits: (K, 1, vocab_size) -> (K, vocab_size)
#         # next_states: State for next step, typically (num_layers, K, hidden_size)
#         output_logits, next_states = model.decode(current_tokens, encoder_output, current_states)
#         log_probs = F.log_softmax(output_logits.squeeze(1), dim=-1) # (K, vocab_size)

#         # Iterate over each of the K beams and their potential next steps
#         for i, node in enumerate(beam):
#             # Select the top K next tokens for this *specific* beam
#             # We select K, as K*K will be pruned to K overall later
#             topk_log_probs, topk_indices = log_probs[i].topk(K) 
            
#             # The next state is specific to the current beam (index i)
#             if isinstance(next_states, tuple):
#                  new_state = tuple(s[:, i, :].unsqueeze(1) for s in next_states) # Select state i: (num_layers, 1, hidden_size)
#             else:
#                  new_state = next_states[:, i, :].unsqueeze(1) # Select state i: (num_layers, 1, hidden_size)
            
#             # Create K new candidate nodes
#             for log_prob, idx in zip(topk_log_probs.tolist(), topk_indices.tolist()):
#                 new_sequence = node.sequence + [idx]
#                 new_score = node.score + log_prob # Log probabilities are additive
                
#                 new_node = BeamNode(new_sequence, new_score, new_state)
#                 candidates.append(new_node)
        
#         # 3. Pruning and Filtering
#         # Sort all K*K candidates by score and select the overall top K
#         candidates.sort(key=lambda x: x.score, reverse=True)
#         beam = candidates[:K]

#         # Check for finished hypotheses and remove them from the active beam
#         new_beam = []
#         for node in beam:
#             if node.sequence[-1] == EOS_IDX or t == max_len - 1:
#                 # Sequence is finished or reached max length, add to final list
#                 # Use length normalization here for final scoring (optional but recommended)
#                 # length_penalty = ((5 + len(node.sequence)) / 6)**0.7
#                 # node.score /= length_penalty 
#                 finished_hypotheses.append(node)
#             else:
#                 # Keep active in the beam for the next step
#                 new_beam.append(node)
        
#         beam = new_beam
        
#         if not beam:
#             # All hypotheses finished
#             break
            
#     # 4. Final Result
#     # Combine the remaining active beams with the finished ones
#     finished_hypotheses.extend(beam)
    
#     # Sort final candidates and return the top one (or top N)
#     finished_hypotheses.sort(key=lambda x: x.score, reverse=True)
    
#     # Return the sequence (tokens) of the best hypothesis
#     return finished_hypotheses[0].sequence

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from MAWM.core import Program, PRIMITIVE_TEMPLATES
from MAWM.models.program_embedder import batchify_programs
from MAWM.models.program_encoder import ProgramEncoder
from MAWM.models.program_synthizer import Proposer

@torch.no_grad()
def neural_guided_beam_search(
    z,
    program_embedder: nn.Module,
    proposer: nn.Module,
    program_encoder: nn.Module,
    score_fn,
    MAX_PARAMS=3,
    beam_width=5,
    topk=6,
    max_prog_len=5,
    grid_size=7,
    lambdas= (0.25, 0.5, 0.25),
    device="cuda"
):
    """
    neural-guided beam search.
    """

    # --- normalize z shape ---
    z = z.to(device)
    if z.dim() == 1:
        z = z.unsqueeze(0)
    lambda_1, lambda_2, lambda_3 = lambdas

    sos_idx = proposer.sos_idx
    EOS_IDX = len(PRIMITIVE_TEMPLATES)
    zero_params = torch.full((MAX_PARAMS,), -1, device=device)

    # initial program: [SOS]
    init_prog = Program(tokens=[(sos_idx, zero_params.tolist())])
    beam = [(init_prog, 0.0)]  # (program, score)

    best_program = init_prog
    best_score = -float("inf")

    # ---- start beam search ----
    for depth in range(1, max_prog_len + 1):
        alive = []
        finished = []

        for prog, score in beam:
            last_prim = prog.tokens[-1][0]
            if last_prim == EOS_IDX:
                finished.append((prog, score))
            else:
                alive.append((prog, score))

        if len(alive) == 0:
            break
        # ---- 1) build batch of prefixes ----
        prefix_programs = [p for (p, _) in alive]
        prev_idx_batch, prev_params_batch = batchify_programs(prefix_programs, padding_vals=[EOS_IDX, -1])

        B = prev_idx_batch.shape[0]

        p_vec = program_embedder(prev_idx_batch.to(device),
                                 prev_params_batch.to(device))

        # replicate z
        z_batch = z.repeat(B, 1)

        # proposer forward
        def get_proposer(seq_len):
            return Proposer(
                    obs_dim= z.shape[-1],
                    num_prims= len(PRIMITIVE_TEMPLATES),
                    max_params= 2,
                    seq_len= seq_len,
                    prog_emb_dim_x= 32,
                    prog_emb_dim_y= 32,
                    prog_emb_dim_prims= 32,
                )

        proposer = get_proposer(prev_idx_batch.shape[-1])
        prim_logits_batch, param_pred_batch = proposer.forward_step(z_batch, p_vec)

        # ---- 2) expand each beam entry ----
        expansions = []
        for b_i, (prog_parent, parent_score) in enumerate(alive):
            prim_logprobs = F.log_softmax(prim_logits_batch[b_i], dim=-1)
            top_vals, top_idx = torch.topk(prim_logprobs, k=topk)

            for k_i in range(topk):
                prim_idx = int(top_idx[k_i].item())
                prim_logp = float(top_vals[k_i].item())
                if prim_idx == EOS_IDX:
                    # add program as completed (no more expansions)
                    new_prog = prog_parent.extend(prim_idx, [])
                    expansions.append((new_prog, parent_score + prim_logp))
                    # print(f"Beam search found EOS at depth {depth} with program: {new_prog}")
                    continue

                arity = PRIMITIVE_TEMPLATES[prim_idx][1]              

                # ----- parameter instantiation -----
                instantiations = []
                if arity > 0:
                    pred_params = param_pred_batch[b_i][:arity].cpu().numpy()
                    for _ in range(3):
                        if np.random.rand() < 0.7:
                            inst = [
                                float(round(float(p) * (grid_size - 1)))
                                for p in pred_params
                            ]
                        else:
                            inst = [
                                float(np.random.randint(0, grid_size))
                                for _ in range(arity)
                            ]
                        instantiations.append(inst)
                else:
                    instantiations.append([])

                # ----- create new beam children -----
                for inst_params in instantiations:
                    new_prog = prog_parent.extend(prim_idx, inst_params)
                    expansions.append((new_prog, parent_score + prim_logp))

        if not expansions:
            break

        # ---- 3) Score all expanded programs ----
        cand_programs = [p for (p, _) in expansions]
        prim_ids_padded, params_padded = batchify_programs(cand_programs, padding_vals=[EOS_IDX, -1])

        def get_pencoder(seq_len):
            return ProgramEncoder(num_primitives= len(PRIMITIVE_TEMPLATES),
                    param_cardinalities= [grid_size, grid_size],
                    seq_len= seq_len,
                    max_params_per_primitive= 2)
        
        with torch.no_grad():
            program_encoder = get_pencoder(prim_ids_padded.shape[-1])
            prog_emb_batch = program_encoder(prim_ids_padded, params_padded)

        scores = score_fn(z, prog_emb_batch)

        # ---- 4) attach scores and prune ----
        candidates = []
        for idx_c, (prog, base_score) in enumerate(expansions):
            total = lambda_1 * base_score + lambda_2 * float(scores[idx_c].item()) - lambda_3 * len(prog.tokens) 
            candidates.append((prog, total))

        candidates.sort(key=lambda x: x[1], reverse=True)
        # best_program, best_score = max(candidates, key=lambda x: x[1])

        topk_from_alive = candidates#[:beam_width]
        beam = topk_from_alive + finished

        filtered_beam = []
        for p in beam:
            if len(p[0].tokens) == 1 and p[0].tokens[0][0] == -1:
                continue
            filtered_beam.append(p)
        if len(filtered_beam) == 0:
            break

        best_program, best_score = max(filtered_beam, key=lambda x: x[1])
        beam = sorted(beam, key=lambda x: x[1], reverse=True)[:beam_width]

        # update alive + finished splits
        alive  = [(p,s) for (p,s) in beam if not p.finished]
        finished = [(p,s) for (p,s) in beam if p.finished]

    return best_program, best_score


In [None]:
# hide
from MAWM.models.program_encoder import ProgramPredictor, ProgramEncoder

import torch
def loss_fn(z_hat, z, loss_exp= 1):
    return torch.mean(torch.abs(z_hat - z) ** loss_exp) / loss_exp

def score_fn(z, m, predictor= None):
    predictor = ProgramPredictor()
    # import ipdb; ipdb.set_trace()
    scores = []
    with torch.no_grad():
        z_hat = predictor(m)
    
    return [-loss_fn(z_hat[i].unsqueeze(0), z) for i in range(z_hat.size(0))]


In [None]:
#| hide
from MAWM.models.program_embedder import ProgramEmbedder, PRIMITIVE_TEMPLATES
from MAWM.models.program_encoder import ProgramEncoder
from MAWM.core import *
from MAWM.models.program_synthizer import Proposer
import torch
z = torch.randn(1, 32)
device = 'cpu'
grid_size= 7

num_primitives = len(PRIMITIVE_TEMPLATES)
p_embed = ProgramEmbedder(
    num_primitives= num_primitives,
    param_cardinalities= [7, 7],
    max_params_per_primitive= 2,
    d_name= 32,
    d_param= 32,
)

p_encoder = ProgramEncoder(num_primitives, [grid_size, grid_size],2, seq_len=1)
proposer = Proposer(obs_dim= 32,
                    num_prims= num_primitives,
                    max_params= 2,
                    seq_len= 1,
                    prog_emb_dim_x= 32,
                    prog_emb_dim_y= 32,
                    prog_emb_dim_prims= 32)



In [None]:
# lst_res = []
# for i in range(1000):
#     z = torch.randn(1, 32).to(device)
#     res = neural_guided_beam_search(z, p_embed, proposer, p_encoder, score_fn, beam_width=3, topk=4, max_prog_len= 10, lambdas= (0.7, 0.2, 0.01), device= device)
#     lst_res.append(res)

In [None]:
# import numpy as np
# idx = np.random.choice(len(lst_res))
# print(idx)
# lst_res[idx][0]

32


OtherAgentDirection(-1, -1, -1) | CellEmpty(1.0, 0.0) | OtherAgentDirection(3.0,) | GoalAt(3.0, 3.0) | ItemAt(0.0, 1.0) | GoalAt(3.0, 3.0) | CellObstacle(3.0, 3.0) | CellGoal(3.0, 3.0) | GoalAt(3.0, 3.0) | CellAgent(2.0, 4.0) | OtherAgentNear()

In [None]:
# sum([len(lst_res[idx][0]) for idx in range(len(lst_res))]) / len(lst_res)

6.716

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()