# Program Encoder

> Get program rcontinuous representation.

In [None]:
#| default_exp models.program_encoder

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *

## Program Encoder

In [None]:
#|export
import torch
import torch.nn as nn
import torch.nn.functional as F
from MAWM.models.dense import DenseModel
from MAWM.models.program_embedder import ProgramEmbedder
from MAWM.core import PRIMITIVE_TEMPLATES

class ProgramEncoder(nn.Module):
    def __init__(
        self,
        num_primitives,
        param_cardinalities,   # list: for each slot, how many discrete values possible
        max_params_per_primitive,
        seq_len=4,
        d_name=32,
        d_param=32,
        output_dim=256,
        model_info={ 'layers': 3,'node_size': 128,'activation': nn.ReLU,'dist': None}
    ):
        super().__init__()
        
        self.program_embedder = ProgramEmbedder(
            num_primitives= num_primitives,
            param_cardinalities= param_cardinalities,
            max_params_per_primitive= max_params_per_primitive,
            d_name= d_name,
            d_param= d_param,
        )

        self.fuse = nn.Linear(seq_len * (d_name + d_param + d_param), model_info['node_size'])
        self.program_mlp = DenseModel(output_shape= (output_dim,), input_size=model_info['node_size'], info= model_info)

    def forward(self, primitive_ids, param_ids):
        """
        primitive_ids: LongTensor of shape [B, L]
        param_ids: LongTensor of shape [B, L, max_params] with -1 for missing parameters
        """
        combined_B_L_D = self.program_embedder(primitive_ids, param_ids) # shape: [B, L, D]
        B, L, D = combined_B_L_D.shape
        combined_B_LD = combined_B_L_D.view(B, L * D)  # Flatten to [B, D]
        combined_B_LD = self.fuse(combined_B_LD)  # shape: [B, node_size]
        primitive_vec = self.program_mlp(combined_B_LD)# shape: [B, output_dim]
        
        return primitive_vec

In [None]:
#| hide
pe = ProgramEncoder(num_primitives= len(PRIMITIVE_TEMPLATES),
                    param_cardinalities= [7, 7],
                    seq_len= 3,
                    max_params_per_primitive= 2)

In [None]:
#| hide
pe

ProgramEncoder(
  (program_embedder): ProgramEmbedder(
    (name_embed): Embedding(6, 32)
    (param_embeds): ModuleList(
      (0-1): 2 x Embedding(7, 32)
    )
  )
  (fuse): Linear(in_features=288, out_features=128, bias=True)
  (program_mlp): DenseModel(
    (model): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): ReLU()
      (4): Linear(in_features=128, out_features=128, bias=True)
      (5): ReLU()
      (6): Linear(in_features=128, out_features=256, bias=True)
    )
  )
)

In [None]:
#| hide
from MAWM.models.program_embedder import batchify_programs, get_indices
from MAWM.core import Program, PRIMITIVE_TEMPLATES
batch_programs = [
    Program(tokens= [(0, [1.0, 2.0]), (4, []), (1, [3.0, 4.0])]), # L=3
    Program(tokens= [(2, [5.0]), (3, [6.0, 5.0])]) # L=2
]


input_prim, input_params = batchify_programs(batch_programs)


In [None]:
input_prim, input_params

(tensor([[ 0,  4,  1],
         [ 2,  3, -1]]),
 tensor([[[ 1,  2],
          [-1, -1],
          [ 3,  4]],
 
         [[ 5, -1],
          [ 6,  5],
          [-1, -1]]]))

In [None]:
pe(input_prim, input_params).shape

torch.Size([2, 256])

In [None]:
#|export
import torch
import torch.nn as nn
import torch.nn.functional as F
from MAWM.models.dense import DenseModel
from MAWM.models.program_embedder import ProgramEmbedder
from MAWM.core import PRIMITIVE_TEMPLATES

class ProgramPredictor(nn.Module):
    def __init__(
        self,
        output_dim=32,
        model_info={ 'layers': 3,'node_size': 256,'activation': nn.ReLU,'dist': None}
    ):
        super().__init__()
        
        self.predictor = DenseModel(output_shape= (output_dim,), input_size=model_info['node_size'], info= model_info)

    def forward(self, x):
        """
        primitive_ids: LongTensor of shape [B, L]
        param_ids: LongTensor of shape [B, L, max_params] with -1 for missing parameters
        """
        z_hat = self.predictor(x)# shape: [B, output_dim]
        
        return z_hat

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()