# Beam Search Module

> This module handles all aspects of the VAE, including encoding, decoding, and latent space representation.

In [None]:
#| default_exp search.evo

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *
from torchvision.utils import save_image
import torch
import os
from torch import nn
import torch.nn.functional as F
import pandas as pd

In [None]:
#| hide
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from MAWM.core import Program, PRIMITIVE_TEMPLATES
from MAWM.models.program_embedder import ProgramEmbedder
from MAWM.models.program_encoder import ProgramEncoder
from MAWM.models.program_synthizer import Proposer

device = 'cpu'
grid_size= 7

num_primitives = len(PRIMITIVE_TEMPLATES)
p_embed = ProgramEmbedder(
    num_primitives= num_primitives,
    param_cardinalities= [7, 7],
    max_params_per_primitive= 2,
    d_name= 32,
    d_param= 32,
).to(device)

p_encoder = ProgramEncoder(num_primitives, [grid_size, grid_size],2, seq_len=5)
proposer = Proposer(obs_dim= 32,
                    num_prims= num_primitives,
                    max_params= 2,
                    seq_len= 5,
                    prog_emb_dim_x= 32,
                    prog_emb_dim_y= 32,
                    prog_emb_dim_prims= 32)

{'CellEmpty': 0, 'CellObstacle': 1, 'CellItem': 2, 'CellGoal': 3, 'CellAgent': 4, 'GoalAt': 5, 'ItemAt': 6, 'Near': 7, 'SeeGoal': 8, 'CanMove': 9, 'OtherAgentAt': 10, 'OtherAgentNear': 11, 'OtherAgentDirection': 12}


In [None]:
SEQ_LEN = 5
PAD_PRIM = len(PRIMITIVE_TEMPLATES)  # index for padding primitive
PAD_PARAM = -1  # value for padding parameters
MAX_PARAMS = 2  # maximum number of parameters per primitive

def program_to_indices(program):
    tokens = program.tokens[:SEQ_LEN]

    prims = [t[0] for t in tokens]
    params = [list(t[1])[:MAX_PARAMS] for t in tokens]

    # pad params
    for p in params:
        while len(p) < MAX_PARAMS:
            p.append(PAD_PARAM)

    # pad program length
    pad_len = SEQ_LEN - len(tokens)
    prims += [PAD_PRIM] * pad_len
    params += [[PAD_PARAM]*MAX_PARAMS for _ in range(pad_len)]

    return (
        torch.tensor(prims).unsqueeze(0),
        torch.tensor(params).unsqueeze(0)
    )

In [None]:
def batchify_programs(programs, seq_len=SEQ_LEN, max_params=MAX_PARAMS):
    prim_list = []
    param_list = []

    for p in programs:
        prim_ids, param_ids = program_to_indices(p)
        prim_list.append(prim_ids[0])       # (SEQ_LEN)
        param_list.append(param_ids[0])     # (SEQ_LEN, max_params)

    prim_batch = torch.stack(prim_list, dim=0)      # (B, SEQ_LEN)
    param_batch = torch.stack(param_list, dim=0)    # (B, SEQ_LEN, max_params)

    return prim_batch, param_batch


In [None]:
#| export
import torch
import torch.nn as nn
import numpy as np
from typing import List, Tuple, Callable, Optional
from MAWM.core import Program, PRIMITIVE_TEMPLATES
# required helper: batchify_programs should exist in your codebase

@torch.no_grad()
def evolutionary_program_search(
    z: torch.Tensor,
    program_embedder: nn.Module,
    proposer: Optional[nn.Module],
    program_encoder: nn.Module,
    score_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    # evolution hyperparams
    pop_size: int = 128,
    num_generations: int = 40,
    max_params: int = 3,
    max_prog_len: int = 5,
    grid_size: int = 7,
    elite_fraction: float = 0.05,
    tournament_k: int = 3,
    mutation_rate: float = 0.35,
    insertion_prob: float = 0.12,
    deletion_prob: float = 0.12,
    crossover_rate: float = 0.2,
    use_proposer_guidance: bool = False,
    proposer_lambda: float = 0.5,   # weight for proposer log-prob in fitness (only primitives)
    length_penalty: float = 0.25,   # lambda_3 in your earlier code
    device: str = "cuda",
    seed: Optional[int] = None,
):
    """
    Evolutionary search for programs.
    Returns: (best_program: Program, best_fitness: float)

    - z: single observation latent (D,) or (1, D)
    - program_embedder: used to produce embeddings for proposer guidance if needed (not required here)
    - proposer: optional proposer module. If use_proposer_guidance == True, proposer must be not None.
    - program_encoder & score_fn: used to compute task score for candidates (batched)
    """

    if seed is not None:
        np.random.seed(seed)
        torch.manual_seed(seed)

    # ensure z shape (1, D)
    z = z.to(device)
    if z.dim() == 1:
        z = z.unsqueeze(0)

    num_prims = len(PRIMITIVE_TEMPLATES)
    lambda_prop = proposer_lambda
    lambda_score = 1.0  # we directly use score_fn output; scale with proposer_lambda if needed
    lambda_len = length_penalty

    POP_SIZE = pop_size
    ELITE_COUNT = max(1, int(np.ceil(POP_SIZE * elite_fraction)))

    # ---------------- utilities ----------------
    def random_program() -> Program:
        L = np.random.randint(1, max_prog_len + 1)
        tokens = []
        for _ in range(L):
            prim = int(np.random.randint(0, num_prims))
            arity = PRIMITIVE_TEMPLATES[prim][1]
            params = [int(np.random.randint(0, grid_size)) for __ in range(arity)]
            tokens.append((prim, params))
        return Program(tokens=tokens)

    def copy_program(p: Program) -> Program:
        return Program(tokens=[(int(t[0]), [int(x) for x in t[1]]) for t in p.tokens])

    def mutate_program(program: Program) -> Program:
        p = copy_program(program)
        if len(p.tokens) == 0:
            return random_program()

        # pick a position to mutate
        i = int(np.random.randint(0, len(p.tokens)))
        prim, params = p.tokens[i]

        # mutate primitive with some prob
        if np.random.rand() < 0.4:
            new_prim = int(np.random.randint(0, num_prims))
            prim = new_prim
            # regenerate params for new arity
            arity = PRIMITIVE_TEMPLATES[prim][1]
            params = [int(np.random.randint(0, grid_size)) for _ in range(arity)]
        else:
            # mutate parameters individually
            arity = PRIMITIVE_TEMPLATES[prim][1]
            new_params = []
            for k in range(arity):
                if np.random.rand() < 0.35:
                    new_params.append(int(np.random.randint(0, grid_size)))
                else:
                    # keep existing if present, else random
                    new_params.append(int(params[k]) if k < len(params) else int(np.random.randint(0, grid_size)))
            params = new_params

        p.tokens[i] = (prim, params)

        # insertion
        if np.random.rand() < insertion_prob and len(p.tokens) < max_prog_len:
            new_prim = int(np.random.randint(0, num_prims))
            new_arity = PRIMITIVE_TEMPLATES[new_prim][1]
            new_params = [int(np.random.randint(0, grid_size)) for _ in range(new_arity)]
            insert_pos = int(np.random.randint(0, len(p.tokens)+1))
            p.tokens.insert(insert_pos, (new_prim, new_params))

        # deletion
        if np.random.rand() < deletion_prob and len(p.tokens) > 1:
            del_pos = int(np.random.randint(0, len(p.tokens)))
            p.tokens.pop(del_pos)

        return p

    def crossover_programs(a: Program, b: Program) -> Program:
        pa = copy_program(a)
        pb = copy_program(b)
        if len(pa.tokens) == 0 or len(pb.tokens) == 0:
            return random_program()
        # one-point crossover: choose split points (allow 0 to len)
        i = int(np.random.randint(0, len(pa.tokens) + 1))
        j = int(np.random.randint(0, len(pb.tokens) + 1))
        new_tokens = pa.tokens[:i] + pb.tokens[j:]
        # clip to max length
        new_tokens = new_tokens[:max_prog_len] if len(new_tokens) > max_prog_len else new_tokens
        if len(new_tokens) == 0:
            return random_program()
        return Program(tokens=[(int(t[0]), [int(x) for x in t[1]]) for t in new_tokens])

    # proposer's primitive log-prob (autoregressive; only primitives)
    def proposer_prim_logprob(program: Program) -> float:
        """
        Compute sum of log-probabilities of primitives under the proposer, evaluated autoregressively.
        NOTE: This function only accounts for primitive logits, not parameter probabilities.
        It's somewhat expensive if done for many programs individually; we will vectorize when used.
        """
        if proposer is None:
            return 0.0

        # We'll compute sequentially: for t in [0..len-1], form prefix embedding, call proposer.forward_step(z, p_vec)
        # Use program_embedder to build p_vec for each prefix.
        # To avoid complexity here, we will generate prim_ids for prefixes and reuse program_embedder.
        prim_ids = []
        params_ids = []
        prefixes = []
        for t in range(len(program.tokens)):
            prefix = [ (program.tokens[i][0], program.tokens[i][1]) for i in range(t) ]  # prefix before token t
            prefixes.append(prefix)

        # build Programs for prefixes, then batchify
        prefix_programs = []
        for pref in prefixes:
            if len(pref) == 0:
                prefix_programs.append(Program(tokens=[(-1, [-1]*max_params)]))
            else:
                prefix_programs.append(Program(tokens=pref))

        # we rely on batchify_programs to produce prim_ids and param ids
        prim_ids_padded, params_padded = batchify_programs(prefix_programs, seq_len=max_prog_len, max_params=max_params)
        prim_ids_padded = prim_ids_padded.to(torch.long).to(device)
        params_padded = params_padded.to(torch.long).to(device)

        # program_embedder -> p_vec
        p_vec = program_embedder(prim_ids_padded, params_padded)  # shape (P, L, prog_embed_dim)
        # Repeat z for batch
        z_batch = z.repeat(p_vec.shape[0], 1)

        prim_logits_batch, _ = proposer.forward_step(z_batch, p_vec)
        # for each prefix i, the primitive we want is program.tokens[i][0]
        logp_sum = 0.0
        for i, token in enumerate(program.tokens):
            prim_idx = int(token[0])
            logits = prim_logits_batch[i]  # shape num_prims+1
            logprob = torch.log_softmax(logits, dim=-1)[prim_idx].item()
            logp_sum += float(logprob)
        return logp_sum

    # Vectorized proposer primitive logprobs (for a list of programs)
    def proposer_prim_logprob_batch(programs: List[Program]) -> np.ndarray:
        if proposer is None:
            return np.zeros(len(programs), dtype=float)

        # For each program, produce prefixes -> collect all prefixes of all programs
        prefixes_all = []
        indices = []  # for mapping
        for pi, prog in enumerate(programs):
            for t in range(len(prog.tokens)):
                prefix = prog.tokens[:t]  # prefix tokens
                if len(prefix) == 0:
                    prefixes_all.append(Program(tokens=[(-1, [-1]*max_params)]))
                else:
                    prefixes_all.append(Program(tokens=[(int(x[0]), [int(y) for y in x[1]]) for x in prefix]))
                indices.append(pi)

        if len(prefixes_all) == 0:
            return np.zeros(len(programs), dtype=float)

        prim_ids_padded, params_padded = batchify_programs(prefixes_all, seq_len=max_prog_len, max_params=max_params)
        prim_ids_padded = prim_ids_padded.to(torch.long).to(device)
        params_padded = params_padded.to(torch.long).to(device)

        p_vec = program_embedder(prim_ids_padded, params_padded)  # (Pfx, L, dim)
        z_rep = z.repeat(p_vec.shape[0], 1)
        prim_logits_batch, _ = proposer.forward_step(z_rep, p_vec)

        # compute log softmax
        prim_logprobs = torch.log_softmax(prim_logits_batch, dim=-1).cpu().numpy()  # shape (Pfx, num_prims+1)
        # for each prefix entry k, we need the probability of the next primitive in the corresponding program
        sums = np.zeros(len(programs), dtype=float)
        ptr = 0
        for k, prog in enumerate(programs):
            # sum log-probs for each token in prog
            for t in range(len(prog.tokens)):
                # prefix index in flattened prefixes_all is ptr
                prim_idx = int(prog.tokens[t][0])
                sums[k] += float(prim_logprobs[ptr, prim_idx])
                ptr += 1
        return sums

    # ---------------- initialize population ----------------
    population: List[Program] = [random_program() for _ in range(POP_SIZE)]

    # ---------------- fitness evaluation ----------------
    def evaluate_population(programs: List[Program]) -> np.ndarray:
        """Return fitness array of shape (len(programs),). Higher = better."""
        if len(programs) == 0:
            return np.array([])

        # 1) task score via program_encoder + score_fn (vectorized)
        prim_ids_padded, params_padded = batchify_programs(programs, seq_len=max_prog_len, max_params=max_params)
        prim_ids_padded = prim_ids_padded.to(torch.long).to(device)
        params_padded = params_padded.to(torch.long).to(device)
        prog_emb = program_encoder(prim_ids_padded, params_padded)  # shape (P, emb_dim)
        scores_tensor = score_fn(z, prog_emb)  # expected shape (1, P) or (P,)
        # convert to numpy 1d
        scores = scores_tensor.detach().cpu().squeeze()
        if scores.ndim == 0:
            scores = scores.unsqueeze(0)
        scores = scores.cpu().numpy()
        # ensure length match
        if scores.shape[0] != len(programs):
            # if score_fn returned shape (P,) but squeezed weirdly
            scores = np.array([float(scores[i]) for i in range(len(programs))])

        # 2) length penalty
        lengths = np.array([len(p.tokens) for p in programs], dtype=float)
        len_pen = - lambda_len * lengths

        # 3) optional proposer primitive logprobs
        if use_proposer_guidance and proposer is not None:
            prop_logp_arr = proposer_prim_logprob_batch(programs)  # shape (P,)
            fitness = lambda_score * scores + lambda_prop * prop_logp_arr + len_pen
        else:
            fitness = lambda_score * scores + len_pen

        return fitness

    # ---------------- evolutionary loop ----------------
    for gen in range(num_generations):
        # evaluate
        fitness = evaluate_population(population)
        # keep elites
        elite_idx = np.argsort(fitness)[::-1][:ELITE_COUNT]
        elites = [copy_program(population[i]) for i in elite_idx]

        # prepare selection probabilities (tournament selection implemented below)
        new_population: List[Program] = elites.copy()

        # create offspring until population full
        while len(new_population) < POP_SIZE:
            # tournament selection
            contenders = np.random.choice(POP_SIZE, size=tournament_k, replace=False)
            winner_idx = int(contenders[np.argmax(fitness[contenders])])
            parent = population[winner_idx]

            # second parent for crossover (another tournament)
            if np.random.rand() < crossover_rate:
                contenders2 = np.random.choice(POP_SIZE, size=tournament_k, replace=False)
                winner2 = int(contenders2[np.argmax(fitness[contenders2])])
                parent2 = population[winner2]
                child = crossover_programs(parent, parent2)
            else:
                child = copy_program(parent)

            # mutation
            if np.random.rand() < mutation_rate:
                child = mutate_program(child)

            new_population.append(child)

        population = new_population

        # (optional) quick tightening: keep population unique by hashing tokens (not strictly necessary)
        # We do not remove duplicates to maintain diversity.

    # final evaluation and best pick
    final_fitness = evaluate_population(population)
    best_idx = int(np.argmax(final_fitness))
    return population[best_idx], float(final_fitness[best_idx])


In [None]:
# hide
from MAWM.models.program_encoder import ProgramPredictor, ProgramEncoder

import torch
def loss_fn(z_hat, z, loss_exp= 1):
    return torch.mean(torch.abs(z_hat - z) ** loss_exp) / loss_exp

def score_fn(z, m, predictor= None):
    predictor = ProgramPredictor()
    # import ipdb; ipdb.set_trace()
    scores = []
    with torch.no_grad():
        z_hat = predictor(m)
    
    return torch.tensor([-loss_fn(z_hat[i].unsqueeze(0), z) for i in range(z_hat.size(0))])


In [None]:
# best_prog, best_score = evolutionary_program_search(
#     z=torch.randn(32),  # torch tensor (D,)
#     program_embedder=p_embed,
#     proposer=proposer,               # or None
#     program_encoder=p_encoder,
#     score_fn=score_fn,
#     pop_size=128,
#     num_generations=50,
#     max_prog_len=5,
#     grid_size=7,
#     use_proposer_guidance=True,      # optional
#     proposer_lambda=0.5,
#     length_penalty=0.25,
#     device="cpu",
#     seed=42,
# )


In [None]:
# lst_res = []
# z = torch.randn(32).to(device)
# for i in range(10000):
#     res = evolutionary_program_search(z=z,  # torch tensor (D,)
#                                     program_embedder=p_embed,
#                                     proposer=proposer,               # or None
#                                     program_encoder=p_encoder,
#                                     score_fn=score_fn,
#                                     pop_size=128,
#                                     num_generations=50,
#                                     max_prog_len=5,
#                                     grid_size=7,
#                                     use_proposer_guidance=True,      # optional
#                                     proposer_lambda=0.5,
#                                     length_penalty=0.25,
#                                     device="cpu",
#                                     seed=42,
#                                     )
#     lst_res.append(res)

KeyboardInterrupt: 

In [None]:
# print("best:", best_prog, best_score)


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()