# Embedding Module

> Embedding works by representing programs as fixed-size tensors of dimension dim, such that each program element is one-third of the values of this tensor.

In [None]:
#| default_exp models.program.embedder

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *

## Tokenization

In [None]:
#| hide
from mawm.core import Program, PRIMITIVE_TEMPLATES

p = Program(tokens= [(0, [1.0, 2.0]), (4, []), (1, [3.0, 4.0])])

{'AgentAt': 0, 'GoalAt': 1, 'ObstacleAt': 2, 'ItemAt': 3, 'Near': 4, 'CanMove': 5}


In [None]:
len(PRIMITIVE_TEMPLATES)

In [None]:
#| export
from mawm.core import Program, PRIMITIVE_TEMPLATES
import torch
import torch.nn as nn
import torch.nn.functional as F

MAX_PARAMS = 2

def get_indices(program, max_params=MAX_PARAMS, device="cpu", padding_vals=[-1, -1]):
    
    L = len(program.tokens)
    if L == 0:
        # Handle empty program as requested
        prim_ids = torch.zeros((1, 0), dtype=torch.long, device=device)
        param_tensor = torch.zeros((1, 0, max_params), dtype=torch.long, device=device)
        prim_ids.add_(padding_vals[0])
        param_tensor.add_(padding_vals[1])
        return prim_ids, param_tensor

    prim_ids_list = []
    params_list = []
    for (prim_idx, params) in program.tokens:
        prim_ids_list.append(int(prim_idx))
        p = list(params)[:max_params]
        if len(p) < max_params:
            # Padding parameters with -1
            p = p + [padding_vals[1]] * (max_params - len(p))
        params_list.append(p)
        
    prim_ids = torch.tensor([prim_ids_list], dtype=torch.long, device=device)          # (1, L)
    param_tensor = torch.tensor([params_list], dtype=torch.long, device=device)   # (1, L, max_params)
    return prim_ids, param_tensor



# final_batch_prim_ids shape: (2, 3)
# final_batch_param_tensor shape: (2, 3, 2)

In [None]:
#| export
def batchify_programs(batch_programs, padding_vals=[-1, -1]):
    all_prim_tensors = []
    all_param_tensors = []
    max_len = 0

    for program in batch_programs:
        prim_ids, param_tensor = get_indices(program, padding_vals=padding_vals)
        all_prim_tensors.append(prim_ids.squeeze(0))     
        all_param_tensors.append(param_tensor.squeeze(0))
        max_len = max(max_len, prim_ids.size(1))

    # USE EOS AS PAD
    PAD_PRIM = padding_vals[0]   # <---- IMPORTANT CHANGE
    PAD_PARAM = padding_vals[1]       # parameters can remain -1

    padded_prim_ids = []
    padded_param_tensors = []

    for prim_t, param_t in zip(all_prim_tensors, all_param_tensors):
        L = prim_t.size(0)
        pad_len = max_len - L
        
        padded_prim_ids.append(F.pad(prim_t, (0, pad_len), value=PAD_PRIM))   # <-- EOS PAD

        padded_param_tensors.append(F.pad(param_t, (0, 0, 0, pad_len), value=PAD_PARAM))

    batch_prim_ids = torch.stack(padded_prim_ids, dim=0)
    batch_param_tensor = torch.stack(padded_param_tensors, dim=0)
    return batch_prim_ids, batch_param_tensor


In [None]:
# #| export
# def batchify_programs(batch_programs):
#     # --- Correct Batching Logic ---
#     all_prim_tensors = []
#     all_param_tensors = []
#     max_len = 0

#     # 1. Get individual tensors and find max_len
#     for program in batch_programs:
#         prim_ids, param_tensor = get_indices(program)
#         all_prim_tensors.append(prim_ids.squeeze(0))     # Remove the (1) batch dim: (L)
#         all_param_tensors.append(param_tensor.squeeze(0)) # Remove the (1) batch dim: (L, max_params)
#         max_len = max(max_len, prim_ids.size(1))

#     # 2. Pad and Stack
#     padded_prim_ids = []
#     padded_param_tensors = []
#     PAD_VALUE = -1 # Use -1 for padding as in your get_indices function

#     for prim_t, param_t in zip(all_prim_tensors, all_param_tensors):
#         L = prim_t.size(0)
        
#         # Pad Prim IDs: (L) -> (L_max)
#         pad_len = max_len - L
#         padded_prim_ids.append(F.pad(prim_t, (0, pad_len), value=PAD_VALUE))
        
#         # Pad Param Tensor: (L, max_p) -> (L_max, max_p)
#         # F.pad takes (padding_left, padding_right, padding_top, padding_bottom, ...)
#         padded_param_tensors.append(F.pad(param_t, (0, 0, 0, pad_len), value=PAD_VALUE))

#     # 3. Concatenate (Stack)
#     batch_prim_ids = torch.stack(padded_prim_ids, dim=0)    # (B, L_max)
#     batch_param_tensor = torch.stack(padded_param_tensors, dim=0) # (B, L_max, max_params)
#     return batch_prim_ids, batch_param_tensor


In [None]:
batch_programs = [
    Program(tokens= [(0, [1.0, 2.0]), (4, []), (1, [3.0, 4.0]),(5, [1]) ]), # L=4
    Program(tokens= [(2, [5.0]), (3, [6.0, 5.0])]) # L=2
]


In [None]:
final_batch_prim_ids, final_batch_param_tensor = batchify_programs(batch_programs)

In [None]:
final_batch_param_tensor, final_batch_prim_ids

(tensor([[[ 1,  2],
          [-1, -1],
          [ 3,  4],
          [ 1, -1]],
 
         [[ 5, -1],
          [ 6,  5],
          [-1, -1],
          [-1, -1]]]),
 tensor([[ 0,  4,  1,  5],
         [ 2,  3, -1, -1]]))

In [None]:
final_batch_param_tensor.shape, final_batch_prim_ids.shape

(torch.Size([2, 4, 2]), torch.Size([2, 4]))

In [None]:
final_batch_prim_ids[0]

tensor([0, 4, 1])

In [None]:
final_batch_param_tensor.reshape(2, -1)

tensor([[ 1,  2, -1, -1,  3,  4,  1, -1],
        [ 5, -1,  6,  5, -1, -1, -1, -1]])

In [None]:
final_batch_param_tensor.reshape(2, -1).shape

torch.Size([2, 8])

## Program Embedder

In [None]:
#| export
import torch
import torch.nn as nn

class ProgramEmbedder(nn.Module):
    "Embeds a program into a fixed-size vector."
    def __init__(
        self,
        num_primitives,
        param_cardinalities,   # list: for each slot, how many discrete values possible
        max_params_per_primitive,
        d_name=32,
        d_param=32,
    ):
        super().__init__()

        self.num_primitives = num_primitives
        self.max_params = max_params_per_primitive
        self.empty_name = nn.Parameter(torch.zeros(d_name))
        self.empty_param = nn.Parameter(torch.zeros(d_param))
        
        self.name_embed = nn.Embedding(num_primitives + 1, d_name)
        self.param_embeds = nn.ModuleList([
            nn.Embedding(card, d_param) for card in param_cardinalities
        ])

    def forward(self, prim_ids, params_ids):
        """
        prim_ids: (B, L) LongTensor of primitive IDs
        params_ids: (B, L, max_params_per_primitive) LongTensor of parameter IDs
        returns: (B, L, D) Tensor of embedded programs
        """
        B, L = prim_ids.shape
        d_name = self.name_embed.embedding_dim
        d_param = self.param_embeds[0].embedding_dim
        D = d_name + self.max_params * d_param

        mask = (prim_ids == -1)
        name_embeds_B_L_D = self.name_embed(torch.clamp(prim_ids, min=0))  # (B, L, d_name)
        name_embeds_B_L_D[mask] = self.empty_name
        
        param_vecs = []
        for slot, embed in enumerate(self.param_embeds):
            slot_vals = params_ids[:, :, slot]  # [N]
            mask = (slot_vals == -1)

            slot_embed = embed(torch.clamp(slot_vals, min=0))
            slot_embed[mask] = self.empty_param

            param_vecs.append(slot_embed)

        params_emb_stacked = torch.stack(param_vecs, dim=2) # (B, L, num_params, d)
        params_vect_embeds = params_emb_stacked.view(B, L, -1) # (B, L, num_params * d)

        program_embed_B_L_D = torch.cat([name_embeds_B_L_D, params_vect_embeds], dim=-1) # (B, L, D)

        return program_embed_B_L_D
    

In [None]:
#| hide
emb = ProgramEmbedder(
    num_primitives= len(PRIMITIVE_TEMPLATES),
    param_cardinalities= [7, 7],
    max_params_per_primitive= 2,
)
emb(final_batch_prim_ids, final_batch_param_tensor).shape

torch.Size([2, 4, 96])

In [None]:
a = final_batch_prim_ids
mask = (a == -1)
mask.shape

torch.Size([2, 4])

In [None]:
a[mask]

tensor([-1, -1])

In [None]:
emb_layer = nn.Embedding(6, 32)

prim_vec = emb_layer(a.clamp(min=0))
prim_vec.shape

torch.Size([2, 4, 32])

In [None]:
prim_vec

tensor([[[ 1.5241,  0.6214, -0.2335,  0.8660, -2.6045, -0.3754, -0.1755,
           0.6233,  0.0270, -0.8318, -0.6073,  0.0662,  2.3057,  0.1164,
          -0.2514, -0.3850, -0.4884, -2.8957, -1.4303, -0.8880, -0.1274,
          -0.9298,  0.6627,  0.0627,  1.1778, -1.2716, -1.0698,  1.3303,
          -1.4972, -0.7821, -0.6830,  0.5980],
         [-0.9660,  0.6694, -0.2911, -1.2771,  0.3612, -1.9370,  0.7846,
           0.1405, -0.3342,  0.1671,  1.1493,  0.4928,  0.0317, -1.3324,
           1.0530, -0.6200,  0.6001,  0.5116, -0.2895,  0.1138,  0.5460,
           0.5285, -0.8443,  1.0123, -0.0393, -0.2357,  0.0041,  0.7768,
          -0.4302,  0.5859,  0.4848, -0.0973],
         [-2.2233,  1.4776,  0.9002,  1.4389,  1.2548,  0.4807, -0.5116,
          -0.9130,  0.6739, -0.4171,  0.3680,  1.0555, -0.1830, -0.1053,
           1.0511,  0.9505, -0.8490, -0.2548, -1.8161,  0.5618,  0.6206,
          -0.3532, -1.2162, -0.0491,  1.1764, -0.4762, -1.3392, -0.3072,
          -0.0452, -1.1065, -0

In [None]:
prim_vec[mask] = nn.Parameter(torch.zeros(32))

In [None]:
prim_vec

tensor([[[ 1.5241,  0.6214, -0.2335,  0.8660, -2.6045, -0.3754, -0.1755,
           0.6233,  0.0270, -0.8318, -0.6073,  0.0662,  2.3057,  0.1164,
          -0.2514, -0.3850, -0.4884, -2.8957, -1.4303, -0.8880, -0.1274,
          -0.9298,  0.6627,  0.0627,  1.1778, -1.2716, -1.0698,  1.3303,
          -1.4972, -0.7821, -0.6830,  0.5980],
         [-0.9660,  0.6694, -0.2911, -1.2771,  0.3612, -1.9370,  0.7846,
           0.1405, -0.3342,  0.1671,  1.1493,  0.4928,  0.0317, -1.3324,
           1.0530, -0.6200,  0.6001,  0.5116, -0.2895,  0.1138,  0.5460,
           0.5285, -0.8443,  1.0123, -0.0393, -0.2357,  0.0041,  0.7768,
          -0.4302,  0.5859,  0.4848, -0.0973],
         [-2.2233,  1.4776,  0.9002,  1.4389,  1.2548,  0.4807, -0.5116,
          -0.9130,  0.6739, -0.4171,  0.3680,  1.0555, -0.1830, -0.1053,
           1.0511,  0.9505, -0.8490, -0.2548, -1.8161,  0.5618,  0.6206,
          -0.3532, -1.2162, -0.0491,  1.1764, -0.4762, -1.3392, -0.3072,
          -0.0452, -1.1065, -0

In [None]:
layers_embeds = nn.ModuleList([
            nn.Embedding(card, 32) for card in [7, 7]
])

In [None]:
for slot, embed in enumerate(layers_embeds):
    print(slot, embed)

0 Embedding(7, 32)
1 Embedding(7, 32)


In [None]:
final_batch_param_tensor

tensor([[[ 1,  2],
         [-1, -1],
         [ 3,  4],
         [ 1, -1]],

        [[ 5, -1],
         [ 6,  5],
         [-1, -1],
         [-1, -1]]])

In [None]:
b = final_batch_param_tensor
b.shape

torch.Size([2, 4, 2])

In [None]:
final_batch_param_tensor[:, 2].clamp(0)

tensor([[3, 4],
        [0, 0]])

In [None]:
params_emb = []
for slot, embed in enumerate(layers_embeds):
    print(slot)
    param = final_batch_param_tensor[:, :, slot]
    mask = (param == -1)
    inp = param.clamp(0)
    print(inp)
    em = embed(inp)
    em[mask] = nn.Parameter(torch.zeros(32))
    params_emb.append(em)

0
tensor([[1, 0, 3, 1],
        [5, 6, 0, 0]])
1
tensor([[2, 0, 4, 0],
        [0, 5, 0, 0]])


In [None]:
params_emb[0].shape, params_emb[1].shape

(torch.Size([2, 4, 32]), torch.Size([2, 4, 32]))

In [None]:
torch.stack(params_emb, dim=2).shape

torch.Size([2, 4, 2, 32])

In [None]:
params_emb_stacked = torch.stack(params_emb, dim=2)

In [None]:
prim_vec.shape

torch.Size([2, 4, 32])

In [None]:
params_emb_stacked[:, 0, :, :].shape

torch.Size([2, 2, 32])

In [None]:
params_emb_stacked[:, 0, :, :].view(params_emb_stacked.shape[0], -1).shape

torch.Size([2, 64])

In [None]:
prim_vec[0, :].shape

torch.Size([3, 32])

In [None]:
prim_vec[:, 0, :].shape

torch.Size([2, 32])

In [None]:
prim_vec.shape

torch.Size([2, 4, 32])

In [None]:
# xy_embed_p0 = params_emb_stacked[:, 0, :, :] # (B, num_params, d)
# prim_embed_p0 = prim_vec[:, 0, :]               # (B, d)
torch.stack([torch.cat([prim_vec[:, i, :], params_emb_stacked[:, i, :, :].view(params_emb_stacked.shape[0], -1)], dim=-1) for i in range(prim_vec.shape[1])]).shape

torch.Size([4, 2, 96])

In [None]:
out = torch.cat([prim_vec[:, 0, :], params_emb_stacked[:, 0, :, :].view(params_emb_stacked.shape[0], -1)], dim=-1)
out.shape

torch.Size([2, 96])

In [None]:
len(params_emb)

2

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()