# Beam Search Module

> This module handles all aspects of the VAE, including encoding, decoding, and latent space representation.

In [None]:
#| default_exp search.beam_search_batched

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *
from torchvision.utils import save_image
import torch
import os
from torch import nn
import torch.nn.functional as F
import pandas as pd

In [None]:
#| hide
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from MAWM.core import Program, PRIMITIVE_TEMPLATES
from MAWM.models.program_embedder import ProgramEmbedder
from MAWM.models.program_encoder import ProgramEncoder
from MAWM.models.program_synthizer import Proposer

device = 'cpu'
grid_size= 7

num_primitives = len(PRIMITIVE_TEMPLATES)
p_embed = ProgramEmbedder(
    num_primitives= num_primitives,
    param_cardinalities= [7, 7],
    max_params_per_primitive= 2,
    d_name= 32,
    d_param= 32,
).to(device)

p_encoder = ProgramEncoder(num_primitives, [grid_size, grid_size],2, seq_len=5)
proposer = Proposer(obs_dim= 32,
                    num_prims= num_primitives,
                    max_params= 2,
                    seq_len= 5,
                    prog_emb_dim_x= 32,
                    prog_emb_dim_y= 32,
                    prog_emb_dim_prims= 32)

In [None]:
SEQ_LEN = 5
PAD_PRIM = len(PRIMITIVE_TEMPLATES)  # index for padding primitive
PAD_PARAM = -1  # value for padding parameters
MAX_PARAMS = 2  # maximum number of parameters per primitive

def program_to_indices(program):
    tokens = program.tokens[:SEQ_LEN]

    prims = [t[0] for t in tokens]
    params = [list(t[1])[:MAX_PARAMS] for t in tokens]

    # pad params
    for p in params:
        while len(p) < MAX_PARAMS:
            p.append(PAD_PARAM)

    # pad program length
    pad_len = SEQ_LEN - len(tokens)
    prims += [PAD_PRIM] * pad_len
    params += [[PAD_PARAM]*MAX_PARAMS for _ in range(pad_len)]

    return (
        torch.tensor(prims).unsqueeze(0),
        torch.tensor(params).unsqueeze(0)
    )

In [None]:
def batchify_programs(programs, seq_len=SEQ_LEN, max_params=MAX_PARAMS):
    prim_list = []
    param_list = []

    for p in programs:
        prim_ids, param_ids = program_to_indices(p)
        prim_list.append(prim_ids[0])       # (SEQ_LEN)
        param_list.append(param_ids[0])     # (SEQ_LEN, max_params)

    prim_batch = torch.stack(prim_list, dim=0)      # (B, SEQ_LEN)
    param_batch = torch.stack(param_list, dim=0)    # (B, SEQ_LEN, max_params)

    return prim_batch, param_batch


In [None]:
sos_idx = proposer.sos_idx
EOS_IDX = len(PRIMITIVE_TEMPLATES)
zero_params = torch.full((MAX_PARAMS,), -1, device=device)

init_prog = Program(tokens=[(sos_idx, zero_params.tolist())])

In [None]:
a, b = batchify_programs([init_prog, init_prog])
a.shape, b.shape

(torch.Size([2, 5]), torch.Size([2, 5, 2]))

In [None]:
b.dtype

torch.int64

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from MAWM.core import Program, PRIMITIVE_TEMPLATES
from MAWM.models.program_encoder import ProgramEncoder
from MAWM.models.program_synthizer import Proposer
from typing import List, Tuple, Optional
from collections import defaultdict

from typing import List, Tuple, Optional
from collections import defaultdict


def analyze_program_diversity(programs: List[Program], verbose=True):
    """
    Analyze diversity in a set of programs.
    Returns statistics about primitive usage and repetition.
    """
    if not programs:
        return {}
    
    # Count primitive usage across all programs
    prim_usage = defaultdict(int)
    prog_lengths = []
    repeated_prims = []
    
    for prog in programs:
        prog_prims = [p for p, _ in prog.tokens[1:] if p >= 0]  # Skip SOS, filter invalid
        prog_lengths.append(len(prog_prims))
        
        # Count primitives in this program
        local_counts = defaultdict(int)
        for p in prog_prims:
            prim_usage[p] += 1
            local_counts[p] += 1
        
        # Check for repetition within program
        max_repeat = max(local_counts.values()) if local_counts else 0
        repeated_prims.append(max_repeat)
    
    stats = {
        'num_programs': len(programs),
        'unique_primitives': len(prim_usage),
        'avg_length': np.mean(prog_lengths) if prog_lengths else 0,
        'max_repetition': max(repeated_prims) if repeated_prims else 0,
        'avg_repetition': np.mean(repeated_prims) if repeated_prims else 0,
        'prim_distribution': dict(prim_usage)
    }
    
    if verbose:
        print(f"\n=== Program Diversity Analysis ===")
        print(f"Number of programs: {stats['num_programs']}")
        print(f"Unique primitives used: {stats['unique_primitives']}")
        print(f"Average program length: {stats['avg_length']:.2f}")
        print(f"Max same primitive in one program: {stats['max_repetition']}")
        print(f"Avg same primitive repetition: {stats['avg_repetition']:.2f}")
        print(f"\nPrimitive usage distribution:")
        for prim_idx, count in sorted(prim_usage.items(), key=lambda x: x[1], reverse=True)[:10]:
            if prim_idx < len(PRIMITIVE_TEMPLATES):
                prim_name = PRIMITIVE_TEMPLATES[prim_idx][0]
                print(f"  {prim_name} (idx {prim_idx}): {count} times")
    
    return stats


def batchify_programs(programs, seq_len, max_params):
    """Helper to convert list of Programs into batched tensors."""
    batch_size = len(programs)
    prim_ids = torch.full((batch_size, seq_len), -1, dtype=torch.long)
    params = torch.full((batch_size, seq_len, max_params), -1, dtype=torch.float32)
    
    for i, prog in enumerate(programs):
        for j, (prim_idx, param_list) in enumerate(prog.tokens):
            if j >= seq_len:
                break
            prim_ids[i, j] = prim_idx
            for k, p in enumerate(param_list):
                if k < max_params:
                    params[i, j, k] = p
    
    return prim_ids, params


@torch.no_grad()
def parallel_enumerative_search(
    z,
    program_embedder: nn.Module,
    proposer: nn.Module,
    program_encoder: nn.Module,
    score_fn,
    MAX_PARAMS=3,
    beam_width=5,
    topk=6,
    max_prog_len=5,
    grid_size=7,
    lambdas=(0.25, 0.5, 0.25),
    num_param_samples=3,
    device="cuda",
    temperature=1.0,
    diversity_penalty=0.1,
    dedup_programs=True
):
    """
    Parallel enumerative program search (DreamCoder-style).
    
    Key differences from beam search:
    1. Expands ALL programs at each depth level in parallel
    2. Scores all candidates in one batch
    3. Prunes to beam_width after scoring
    4. Much faster due to parallelization
    
    Args:
        temperature: Controls sampling diversity (higher = more diverse)
        diversity_penalty: Penalizes repeated primitives
        dedup_programs: Remove duplicate programs from frontier
    """
    z = z.to(device)
    if z.dim() == 1:
        z = z.unsqueeze(0)
    
    lambda_1, lambda_2, lambda_3 = lambdas
    sos_idx = proposer.sos_idx
    EOS_IDX = len(PRIMITIVE_TEMPLATES)
    zero_params = torch.full((MAX_PARAMS,), -1, device=device)
    
    # Initial frontier: just [SOS]
    init_prog = Program(tokens=[(sos_idx, zero_params.tolist())])
    frontier = [(init_prog, 0.0, 0.0)]  # (program, log_prob, combined_score)
    
    best_program = init_prog
    best_score = -float("inf")
    
    for depth in range(1, max_prog_len + 1):
        # Separate finished and alive programs
        alive = [(p, lp, s) for (p, lp, s) in frontier if not p.finished]
        finished = [(p, lp, s) for (p, lp, s) in frontier if p.finished]
        
        if len(alive) == 0:
            break
        
        # ===== PHASE 1: Batch Proposer Forward Pass =====
        alive_programs = [p for (p, _, _) in alive]
        prev_idx_batch, prev_params_batch = batchify_programs(
            alive_programs, seq_len=max_prog_len, max_params=MAX_PARAMS
        )
        prev_idx_batch = prev_idx_batch.to(device).long()
        prev_params_batch = prev_params_batch.to(device).long()
        
        B = prev_idx_batch.shape[0]
        
        # Get program embeddings
        p_vec = program_embedder(prev_idx_batch, prev_params_batch)
        
        # Replicate z for batch
        z_batch = z.repeat(B, 1)
        
        # Get predictions for all alive programs
        prim_logits_batch, param_pred_batch = proposer.forward_step(z_batch, p_vec)
        
        # Apply temperature for diversity
        prim_logits_batch = prim_logits_batch / temperature
        prim_logprobs_batch = F.log_softmax(prim_logits_batch, dim=-1)
        
        # ===== PHASE 2: Parallel Enumeration =====
        all_candidates = []
        
        for b_i, (prog_parent, parent_logprob, _) in enumerate(alive):
            # Count primitive usage in parent to penalize repetition
            prim_counts = defaultdict(int)
            for token_prim, _ in prog_parent.tokens[1:]:  # Skip SOS
                if token_prim >= 0:
                    prim_counts[token_prim] += 1
            
            # Get top-k primitives for this program
            prim_logprobs = prim_logprobs_batch[b_i].clone()
            
            # Apply diversity penalty to repeated primitives
            for prim_idx, count in prim_counts.items():
                if prim_idx < len(prim_logprobs):
                    prim_logprobs[prim_idx] -= diversity_penalty * count
            
            top_vals, top_idx = torch.topk(prim_logprobs, k=topk)
            
            for k_i in range(topk):
                prim_idx = int(top_idx[k_i].item())
                prim_logp = float(top_vals[k_i].item())
                
                # Handle EOS
                if prim_idx == EOS_IDX:
                    new_prog = prog_parent.extend(prim_idx, [])
                    new_logprob = parent_logprob + prim_logp
                    all_candidates.append((new_prog, new_logprob))
                    continue
                
                # Get arity
                arity = PRIMITIVE_TEMPLATES[prim_idx][1]
                
                # Generate parameter instantiations
                if arity > 0:
                    pred_params = param_pred_batch[b_i][:arity].cpu().numpy()
                    instantiations = []
                    
                    # Increase diversity in parameter sampling
                    for sample_i in range(num_param_samples):
                        # First sample: use predicted parameters
                        if sample_i == 0:
                            inst = [
                                float(round(float(p) * (grid_size - 1)))
                                for p in pred_params
                            ]
                        # Add noise to predicted parameters
                        elif np.random.rand() < 0.5:
                            inst = [
                                float(np.clip(
                                    round(float(p) * (grid_size - 1)) + np.random.randint(-1, 2),
                                    0, grid_size - 1
                                ))
                                for p in pred_params
                            ]
                        # Pure random sampling
                        else:
                            inst = [
                                float(np.random.randint(0, grid_size))
                                for _ in range(arity)
                            ]
                        
                        # Avoid duplicate instantiations
                        if inst not in instantiations:
                            instantiations.append(inst)
                else:
                    instantiations = [[]]
                
                # Create candidate programs
                for inst_params in instantiations:
                    new_prog = prog_parent.extend(prim_idx, inst_params)
                    new_logprob = parent_logprob + prim_logp
                    all_candidates.append((new_prog, new_logprob))
        
        if not all_candidates:
            break
        
        # ===== PHASE 3: Batch Scoring =====
        cand_programs = [p for (p, _) in all_candidates]
        cand_logprobs = [lp for (_, lp) in all_candidates]
        
        # Batchify all candidates
        prim_ids_padded, params_padded = batchify_programs(
            cand_programs, seq_len=max_prog_len, max_params=MAX_PARAMS
        )
        prim_ids_padded = prim_ids_padded.to(device).long()
        params_padded = params_padded.to(device).long()
        
        # Encode all programs in parallel
        prog_emb_batch = program_encoder(prim_ids_padded, params_padded)
        
        # Score all programs in parallel
        z_expanded = z.expand(prog_emb_batch.shape[0], -1)
        scores = score_fn(z_expanded, prog_emb_batch)
        
        # ===== PHASE 4: Combine Scores and Prune =====
        scored_candidates = []
        for idx, (prog, logprob) in enumerate(all_candidates):
            # Combined score: log_prob + similarity - length_penalty
            combined = (
                lambda_1 * logprob +
                lambda_2 * float(scores[idx].item()) -
                lambda_3 * len(prog.tokens)
            )
            scored_candidates.append((prog, logprob, combined))
        
        # Deduplicate programs if requested
        if dedup_programs:
            seen_progs = set()
            unique_candidates = []
            for prog, lp, cscore in scored_candidates:
                prog_str = str(prog.tokens)  # Simple string representation
                if prog_str not in seen_progs:
                    seen_progs.add(prog_str)
                    unique_candidates.append((prog, lp, cscore))
            scored_candidates = unique_candidates
        
        # Sort by combined score
        scored_candidates.sort(key=lambda x: x[2], reverse=True)
        
        # Update best program
        if scored_candidates:
            for prog, lp, cscore in scored_candidates:
                if len(prog.tokens) > 1 or prog.tokens[0][0] != -1:
                    if cscore > best_score:
                        best_program = prog
                        best_score = cscore
                    break
        
        # Prune to beam_width
        frontier = (scored_candidates[:beam_width] + finished)
        frontier.sort(key=lambda x: x[2], reverse=True)
        frontier = frontier[:beam_width]
    
    return best_program, best_score


@torch.no_grad()
def parallel_enumerative_search_batched(
    z_batch,
    program_embedder,
    proposer,
    program_encoder,
    score_fn,
    MAX_PARAMS=3,
    beam_width=5,
    topk=6,
    max_prog_len=5,
    grid_size=7,
    lambdas=(0.25, 0.5, 0.25),
    num_param_samples=3,
    temperature=1.0,
    diversity_penalty=0.1,
    dedup_programs=True,
    device="cuda",
):
    """
    Batched parallel enumerative search.
    Runs one search per batch element.
    
    Returns:
        list of length B, each entry is (best_program, best_score)
    """
    z_batch = z_batch.to(device)
    if z_batch.dim() == 1:
        z_batch = z_batch.unsqueeze(0)
    
    B = z_batch.shape[0]
    results = []
    
    for b in range(B):
        z_single = z_batch[b]
        best_prog, best_score = parallel_enumerative_search(
            z_single,
            program_embedder,
            proposer,
            program_encoder,
            score_fn,
            MAX_PARAMS=MAX_PARAMS,
            beam_width=beam_width,
            topk=topk,
            max_prog_len=max_prog_len,
            grid_size=grid_size,
            lambdas=lambdas,
            num_param_samples=num_param_samples,
            temperature=temperature,
            diversity_penalty=diversity_penalty,
            dedup_programs=dedup_programs,
            device=device,
        )
        results.append((best_prog, best_score))
    
    return results


@torch.no_grad()
def fully_parallel_enumerative_search(
    z_batch,
    program_embedder,
    proposer,
    program_encoder,
    score_fn,
    MAX_PARAMS=2,
    beam_width=5,
    topk=6,
    max_prog_len=5,
    grid_size=7,
    lambdas=(0.25, 0.5, 0.25),
    num_param_samples=3,
    temperature=1.0,
    diversity_penalty=0.1,
    dedup_programs=True,
    device="cuda",
    debug=False,
):
    """
    FULLY parallel version: processes multiple batch elements simultaneously.
    
    This is the fastest version - all batch elements share the same search tree
    and are processed together at each depth level.
    
    Returns:
        list of length B, each entry is (best_program, best_score)
    """
    z_batch = z_batch.to(device)
    if z_batch.dim() == 1:
        z_batch = z_batch.unsqueeze(0)
    
    B = z_batch.shape[0]
    lambda_1, lambda_2, lambda_3 = lambdas
    sos_idx = proposer.sos_idx
    EOS_IDX = len(PRIMITIVE_TEMPLATES)
    zero_params = torch.full((MAX_PARAMS,), -1, device=device)
    
    if debug:
        print(f"=== Search Init ===")
        print(f"SOS_IDX: {sos_idx}, EOS_IDX: {EOS_IDX}")
        print(f"Num primitives: {len(PRIMITIVE_TEMPLATES)}")
        print(f"Batch size: {B}")
    
    # Initialize one frontier per batch element
    init_prog = Program(tokens=[(sos_idx, zero_params.tolist())])
    frontiers = [[(init_prog, 0.0, 0.0)] for _ in range(B)]
    best_programs = [init_prog] * B
    best_scores = [-float("inf")] * B
    
    for depth in range(1, max_prog_len + 1):
        if debug:
            print(f"\n=== Depth {depth} ===")
        
        # Collect all alive programs across all batch elements
        all_alive = []
        batch_indices = []
        
        for b_idx in range(B):
            alive = [(p, lp, s) for (p, lp, s) in frontiers[b_idx] if not p.finished]
            for item in alive:
                all_alive.append(item)
                batch_indices.append(b_idx)
        
        if debug:
            print(f"Total alive programs: {len(all_alive)}")
        
        if len(all_alive) == 0:
            break
        
        # Batch process all alive programs
        alive_programs = [p for (p, _, _) in all_alive]
        prev_idx_batch, prev_params_batch = batchify_programs(
            alive_programs, seq_len=max_prog_len, max_params=MAX_PARAMS
        )
        prev_idx_batch = prev_idx_batch.to(device).long()
        prev_params_batch = prev_params_batch.to(device).long()
        
        # Get embeddings and predictions
        p_vec = program_embedder(prev_idx_batch, prev_params_batch)
        z_expanded = torch.stack([z_batch[b_idx] for b_idx in batch_indices])
        prim_logits_batch, param_pred_batch = proposer.forward_step(z_expanded, p_vec)
        
        # Apply temperature
        prim_logits_batch = prim_logits_batch / temperature
        prim_logprobs_batch = F.log_softmax(prim_logits_batch, dim=-1)
        
        if debug:
            print(f"Logits shape: {prim_logits_batch.shape}")
            print(f"Logits range: [{prim_logits_batch.min().item():.2f}, {prim_logits_batch.max().item():.2f}]")
        
        # Generate candidates per batch element
        batch_candidates = defaultdict(list)
        
        for alive_idx, (prog_parent, parent_logprob, _) in enumerate(all_alive):
            b_idx = batch_indices[alive_idx]
            
            # Count primitive usage for diversity penalty
            prim_counts = defaultdict(int)
            for token_prim, _ in prog_parent.tokens[1:]:
                if token_prim >= 0 and token_prim < EOS_IDX:  # CRITICAL: Only count valid primitives
                    prim_counts[token_prim] += 1
            
            # Apply diversity penalty
            prim_logprobs = prim_logprobs_batch[alive_idx].clone()
            
            # CRITICAL FIX: Mask out invalid indices
            # Only allow primitives in range [0, EOS_IDX]
            valid_mask = torch.zeros_like(prim_logprobs, dtype=torch.bool)
            valid_mask[:EOS_IDX + 1] = True  # Include EOS
            prim_logprobs[~valid_mask] = -float('inf')  # Mask invalid primitives
            
            for prim_idx, count in prim_counts.items():
                if 0 <= prim_idx < len(prim_logprobs):
                    prim_logprobs[prim_idx] -= diversity_penalty * count
            
            # CRITICAL FIX: Ensure topk doesn't exceed valid primitives
            valid_topk = min(topk, EOS_IDX + 1)
            top_vals, top_idx = torch.topk(prim_logprobs, k=valid_topk)
            
            if debug and alive_idx == 0:
                print(f"\nParent program: {prog_parent.tokens}")
                print(f"Top-{valid_topk} primitives: {top_idx.tolist()}")
                print(f"Top-{valid_topk} logprobs: {top_vals.tolist()}")
            
            for k_i in range(valid_topk):
                prim_idx = int(top_idx[k_i].item())
                prim_logp = float(top_vals[k_i].item())
                
                # CRITICAL FIX: Validate primitive index
                if prim_idx < 0 or prim_idx > EOS_IDX:
                    if debug:
                        print(f"WARNING: Invalid prim_idx {prim_idx}, skipping")
                    continue
                
                if prim_idx == EOS_IDX:
                    new_prog = prog_parent.extend(prim_idx, [])
                    new_logprob = parent_logprob + prim_logp
                    batch_candidates[b_idx].append((new_prog, new_logprob))
                    if debug and alive_idx == 0:
                        print(f"  -> EOS selected")
                    continue
                
                arity = PRIMITIVE_TEMPLATES[prim_idx][1]
                
                if debug and alive_idx == 0 and k_i == 0:
                    print(f"  -> Primitive {prim_idx} ({PRIMITIVE_TEMPLATES[prim_idx][0]}) with arity {arity}")
                
                if arity > 0:
                    pred_params = param_pred_batch[alive_idx][:arity].cpu().numpy()
                    instantiations = []
                    for sample_i in range(num_param_samples):
                        if sample_i == 0:
                            inst = [float(round(float(p) * (grid_size - 1))) for p in pred_params]
                        elif np.random.rand() < 0.5:
                            inst = [
                                float(np.clip(
                                    round(float(p) * (grid_size - 1)) + np.random.randint(-1, 2),
                                    0, grid_size - 1
                                ))
                                for p in pred_params
                            ]
                        else:
                            inst = [float(np.random.randint(0, grid_size)) for _ in range(arity)]
                        if inst not in instantiations:
                            instantiations.append(inst)
                else:
                    instantiations = [[]]
                
                for inst_params in instantiations:
                    new_prog = prog_parent.extend(prim_idx, inst_params)
                    new_logprob = parent_logprob + prim_logp
                    batch_candidates[b_idx].append((new_prog, new_logprob))
        
        # Batch score all candidates across all batch elements
        for b_idx in range(B):
            if not batch_candidates[b_idx]:
                if debug:
                    print(f"Batch {b_idx}: No candidates generated")
                continue
            
            cand_programs = [p for (p, _) in batch_candidates[b_idx]]
            cand_logprobs = [lp for (_, lp) in batch_candidates[b_idx]]
            
            if debug:
                print(f"\nBatch {b_idx}: {len(cand_programs)} candidates")
                for i, prog in enumerate(cand_programs[:3]):
                    print(f"  Candidate {i}: {prog.tokens}")
            
            prim_ids, params = batchify_programs(cand_programs, seq_len=max_prog_len, max_params=MAX_PARAMS)
            prim_ids = prim_ids.to(device).long()
            params = params.to(device).long()
            
            prog_emb_batch = program_encoder(prim_ids, params)
            z_single = z_batch[b_idx].unsqueeze(0).expand(prog_emb_batch.shape[0], -1)
            scores = score_fn(z_single, prog_emb_batch)
            
            # Score and prune
            scored = []
            for idx, (prog, logprob) in enumerate(batch_candidates[b_idx]):
                combined = (
                    lambda_1 * logprob +
                    lambda_2 * float(scores[idx].item()) -
                    lambda_3 * len(prog.tokens)
                )
                scored.append((prog, logprob, combined))
            
            scored.sort(key=lambda x: x[2], reverse=True)
            
            # Deduplicate if requested
            if dedup_programs:
                seen = set()
                unique = []
                for prog, lp, cscore in scored:
                    prog_str = str(prog.tokens)
                    if prog_str not in seen:
                        seen.add(prog_str)
                        unique.append((prog, lp, cscore))
                scored = unique
            
            # Update best
            if scored:
                for prog, lp, cscore in scored:
                    # CRITICAL FIX: Validate program has valid primitives
                    has_valid_prims = any(
                        0 <= p < EOS_IDX 
                        for p, _ in prog.tokens[1:]  # Skip SOS
                    )
                    if has_valid_prims and cscore > best_scores[b_idx]:
                        best_programs[b_idx] = prog
                        best_scores[b_idx] = cscore
                        if debug:
                            print(f"  New best for batch {b_idx}: score={cscore:.3f}, prog={prog.tokens[:3]}...")
                        break
            
            # Update frontier
            finished = [(p, lp, s) for (p, lp, s) in frontiers[b_idx] if p.finished]
            frontiers[b_idx] = (scored[:beam_width] + finished)
            frontiers[b_idx].sort(key=lambda x: x[2], reverse=True)
            frontiers[b_idx] = frontiers[b_idx][:beam_width]
    
    return [(best_programs[b], best_scores[b]) for b in range(B)]

In [None]:
# hide
from MAWM.models.program_encoder import ProgramPredictor, ProgramEncoder

import torch
def loss_fn(z_hat, z, loss_exp= 1):
    return torch.mean(torch.abs(z_hat - z) ** loss_exp) / loss_exp

def score_fn(z, m, predictor= None):
    predictor = ProgramPredictor()
    # import ipdb; ipdb.set_trace()
    scores = []
    with torch.no_grad():
        z_hat = predictor(m)
    
    return torch.tensor([-loss_fn(z_hat[i].unsqueeze(0), z) for i in range(z_hat.size(0))])


In [None]:
# predictor = ProgramPredictor()
# # set predictor parameters to stop gradients
# for param in predictor.parameters():
#     param.requires_grad = False

In [None]:
# #| hide
# from MAWM.models.program_embedder import ProgramEmbedder, PRIMITIVE_TEMPLATES
# from MAWM.models.program_encoder import ProgramEncoder
# from MAWM.core import *
# from MAWM.models.program_synthizer import Proposer
# import torch
# z = torch.randn(1, 32)
# device = 'cpu'
# grid_size= 7

# num_primitives = len(PRIMITIVE_TEMPLATES)
# p_embed = ProgramEmbedder(
#     num_primitives= num_primitives,
#     param_cardinalities= [7, 7],
#     max_params_per_primitive= 2,
#     d_name= 32,
#     d_param= 32,
# )

# p_encoder = ProgramEncoder(num_primitives, [grid_size, grid_size],2, seq_len=5)
# proposer = Proposer(obs_dim= 32,
#                     num_prims= num_primitives,
#                     max_params= 2,
#                     seq_len= 5,
#                     prog_emb_dim_x= 32,
#                     prog_emb_dim_y= 32,
#                     prog_emb_dim_prims= 32)



In [None]:
# lst_res = []
# z = torch.randn(32, 32).to(device)
# for i in range(10):
#     res = fully_parallel_enumerative_search(z, p_embed, proposer, p_encoder, score_fn, temperature= 0.7, debug= True, beam_width=3, MAX_PARAMS=2, topk=4, max_prog_len= 5, lambdas= (0.1, 0.89, 0.01), device= device)
#     lst_res.append(res)

=== Search Init ===
SOS_IDX: -1, EOS_IDX: 13
Num primitives: 13
Batch size: 32

=== Depth 1 ===
Total alive programs: 32
Logits shape: torch.Size([32, 14])
Logits range: [-0.16, 0.09]

Parent program: [(-1, [-1, -1])]
Top-4 primitives: [12, 11, 13, 4]
Top-4 logprobs: [-2.5233490467071533, -2.527421712875366, -2.5916450023651123, -2.6133015155792236]
  -> Primitive 12 (OtherAgentDirection) with arity 1
  -> EOS selected

Batch 0: 8 candidates
  Candidate 0: [(-1, [-1, -1]), (12, [3.0])]
  Candidate 1: [(-1, [-1, -1]), (12, [6.0])]
  Candidate 2: [(-1, [-1, -1]), (12, [0.0])]
  New best for batch 0: score=-0.981, prog=[(-1, [-1, -1]), (12, [3.0])]...

Batch 1: 7 candidates
  Candidate 0: [(-1, [-1, -1]), (12, [3.0])]
  Candidate 1: [(-1, [-1, -1]), (12, [4.0])]
  Candidate 2: [(-1, [-1, -1]), (12, [6.0])]
  New best for batch 1: score=-1.030, prog=[(-1, [-1, -1]), (12, [6.0])]...

Batch 2: 8 candidates
  Candidate 0: [(-1, [-1, -1]), (12, [3.0])]
  Candidate 1: [(-1, [-1, -1]), (12, [0.0

In [None]:
# lst_res[1]

[(OtherAgentDirection(-1, -1) | OtherAgentDirection(3.0,),
  -0.9761027997732163),
 (OtherAgentDirection(-1, -1) | OtherAgentDirection(0.0,),
  -1.0502350401878358),
 (OtherAgentDirection(-1, -1) | OtherAgentDirection(3.0,),
  -0.9906164830923081),
 (OtherAgentDirection(-1, -1) | OtherAgentDirection(4.0,),
  -1.1198440313339235),
 (OtherAgentDirection(-1, -1) | OtherAgentDirection(3.0,), -0.785519821047783),
 (OtherAgentDirection(-1, -1) | OtherAgentDirection(3.0,),
  -0.9584705013036727),
 (OtherAgentDirection(-1, -1) | OtherAgentDirection(4.0,),
  -1.0078499525785447),
 (OtherAgentDirection(-1, -1) | OtherAgentDirection(3.0,), -1.023550968170166),
 (OtherAgentDirection(-1, -1) | OtherAgentDirection(3.0,),
  -1.0407267010211945),
 (OtherAgentDirection(-1, -1) | OtherAgentDirection(4.0,),
  -0.9610244685411453),
 (OtherAgentDirection(-1, -1) | OtherAgentDirection(4.0,), -1.066969296336174),
 (OtherAgentDirection(-1, -1) | OtherAgentDirection(3.0,),
  -0.9588128596544266),
 (OtherAgentD

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()