# Beam Search Module

> This module handles all aspects of the VAE, including encoding, decoding, and latent space representation.

In [None]:
#| default_exp search.evo

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *
from torchvision.utils import save_image
import torch
import os
from torch import nn
import torch.nn.functional as F
import pandas as pd

In [None]:
#| export
import torch.nn.functional as F
import numpy as np
from MAWM.core import Program, PRIMITIVE_TEMPLATES

@torch.no_grad()
def neural_guided_beam_search_corrected(
    z_obs: torch.FloatTensor,         # Sender's observation (for proposer)
    z_receiver: torch.FloatTensor,     # Receiver's current state
    z_target: torch.FloatTensor,       # Target next state
    a_t: torch.FloatTensor,            # Action taken
    r_target: torch.FloatTensor,       # Target reward
    proposer,
    encoder,
    world_model,
    reward_model,
    score_fn,
    batchify_programs,
    MAX_PARAMS: int = 3,
    beam_width: int = 5,
    topk: int = 6,
    max_prog_len: int = 5,
    grid_size: int = 5,
    device="cuda"
):
    """
    Neural-guided beam search for program inference.
    
    Finds program P that best explains: z_receiver + action + P -> z_target, r_target
    """
    # Ensure correct shapes
    z_obs = z_obs.to(device).unsqueeze(0) if z_obs.dim() == 1 else z_obs.to(device)
    z_receiver = z_receiver.to(device).unsqueeze(0) if z_receiver.dim() == 1 else z_receiver.to(device)
    z_target = z_target.to(device).unsqueeze(0) if z_target.dim() == 1 else z_target.to(device)
    a_t = a_t.to(device).unsqueeze(0) if a_t.dim() == 1 else a_t.to(device)
    r_target = r_target.to(device).unsqueeze(0) if r_target.dim() == 0 else r_target.to(device)

    sos_idx = proposer.sos_idx
    zero_params = torch.zeros((MAX_PARAMS,), device=device, dtype=torch.float32)

    # Beam entries: (Program, prev_prim_idx, prev_params, cumulative_score)
    beam = [(Program(), sos_idx, zero_params, 0.0)]
    
    best_program = Program()
    best_energy = float("inf")

    # with torch.no_grad():
    for depth in range(1, max_prog_len + 1):
        all_candidates = []
        
        # Batch proposer forward pass across beam
        Bbeam = len(beam)
        prev_idx_batch = torch.tensor(
            [entry[1] for entry in beam], dtype=torch.long, device=device
        )
        prev_params_batch = torch.stack([entry[2] for entry in beam], dim=0)
        z_obs_batch = z_obs.repeat(Bbeam, 1)
        
        prim_logits_batch, param_pred_batch = proposer.forward_step(
            z_obs_batch, prev_idx_batch, prev_params_batch
        )
        
        # Expand top-k primitives for each beam entry
        expansions = []
        for parent_i, (prog_parent, prev_idx, prev_params, parent_score) in enumerate(beam):
            probs = F.softmax(prim_logits_batch[parent_i], dim=-1)
            top_vals, top_idx = torch.topk(probs, k=min(topk, probs.size(0)), dim=-1)
            
            # In beam search expansion:
            for k_i in range(top_idx.size(0)):
                prim_idx = int(top_idx[k_i].item())
                arity = PRIMITIVE_TEMPLATES[prim_idx][1]
                
                if arity > 0:
                    pred_params = param_pred_batch[parent_i][:arity].cpu().numpy()
                    
                    # Sample K different parameter instantiations
                    for _ in range(3):  # Try 3 different parameter settings
                        if np.random.rand() < 0.7:  # 70% use prediction
                            inst_params = []
                            for pval in pred_params:
                                scaled = float(pval) * (grid_size - 1)
                                inst_params.append(float(round(scaled)))
                        else:  # 30% random
                            inst_params = [float(np.random.randint(0, grid_size)) 
                                            for _ in range(arity)]
                        
                        new_prog = prog_parent.extend(prim_idx, inst_params)
                        expansions.append((parent_i, new_prog, prim_idx, inst_params, parent_score))
                else:
                    new_prog = prog_parent.extend(prim_idx, [])
                    expansions.append((parent_i, new_prog, prim_idx, [], parent_score))
        
        if len(expansions) == 0:
            break
        
        # Batch-encode all expanded programs
        cand_programs = [e[1] for e in expansions]
        prim_ids_list, param_tensors_list, L_list = [], [], []
        
        for prog in cand_programs:
            prim_ids, param_tensor = batchify_programs(
                prog, max_params=MAX_PARAMS, device=device
            )
            prim_ids_list.append(prim_ids)
            param_tensors_list.append(param_tensor)
            L_list.append(prim_ids.shape[1])
        
        L_max = max(L_list) if L_list else 0
        Nc = len(cand_programs)
        
        # Pad to max length
        prim_ids_padded = torch.zeros((Nc, L_max), dtype=torch.long, device=device)
        param_padded = torch.zeros((Nc, L_max, MAX_PARAMS), dtype=torch.float32, device=device)
        
        for i_p, (prim_ids, param_t) in enumerate(zip(prim_ids_list, param_tensors_list)):
            Li = prim_ids.shape[1]
            if Li > 0:
                prim_ids_padded[i_p, :Li] = prim_ids.squeeze(0)
                param_padded[i_p, :Li, :] = param_t.squeeze(0)
        
        # Get program embeddings
        prog_emb_batch = encoder(prim_ids_padded, param_padded)  # (Nc, MSG_DIM)
        
        # Repeat receiver state and action
        z_b = z_receiver.repeat(Nc, 1)
        a_b = a_t.repeat(Nc, 1)
        z_target_b = z_target.repeat(Nc, 1)
        r_target_b = r_target.repeat(Nc)
        
        # Evaluate world model
        z_pred_b = world_model(z_b, a_b, prog_emb_batch)
        reward_dist = reward_model(z_b, a_b)
        
        # Compute energies
        energies = score_fn(
            z_pred_b, z_target_b, reward_dist, r_target_b
        )
        
        # Create candidate beam entries
        for idx_e, entry in enumerate(expansions):
            parent_i, new_prog, prim_idx, inst_params, parent_score = entry
            energy = float(energies[idx_e].item())
            new_score = parent_score - energy  # Higher is better
            
            # Prepare prev_params for next step (normalized to [0,1])
            prev_params_next = [p / (grid_size - 1) for p in inst_params][:MAX_PARAMS]
            if len(prev_params_next) < MAX_PARAMS:
                prev_params_next += [0.0] * (MAX_PARAMS - len(prev_params_next))
            prev_params_tensor = torch.tensor(
                prev_params_next, dtype=torch.float32, device=device
            )
            
            all_candidates.append((new_prog, prim_idx, prev_params_tensor, new_score))
        
        # Prune to beam width
        if len(all_candidates) == 0:
            break
        
        all_candidates.sort(key=lambda x: x[3], reverse=True)
        beam = all_candidates[:beam_width]
        
        # Track best program
        for prog_cand, pidx, pparams, sc in beam:
            energy_here = -sc
            if energy_here < best_energy:
                best_energy = energy_here
                best_program = prog_cand

    return best_program, best_energy

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()