# core

> Fill in a module description here

In [None]:
#| default_exp core

In [None]:
#| hide
from nbdev.showdoc import *



- Collect random trajectories from the environment.
- Train an encoder using VAE on the collected trajectories.
- End-to-end training of the world model and communication module using the trained encoder.

In [None]:
#| export
import importlib
def get_cls(module_name, class_name):
    module = importlib.import_module(module_name)
    return getattr(module, class_name)

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from typing import List, Tuple
Z_DIM = 32
ACTION_DIM = 1
PROG_HID = 128
PRIM_EMB = 32
PARAM_EMB = 16
PROG_RNN_HID = 128
MSG_DIM = 32#23
MAX_PARAMS = 2        # maximum params per primitive; everything is padded to this
GRID_SIZE = 5
BEAM_WIDTH = 5
PROP_TOPK = 6
MAX_PROG_LEN = 5
LAMBDA_Z = 1.0
LAMBDA_R = 1.0
LEARNING_RATE = 1e-4


In [None]:
#| export
# -------------------------
# Primitive templates (name, arity)
# -------------------------
PRIMITIVE_TEMPLATES = [
    ("CellEmpty", 2),        # cell (i, j) is empty
    ("CellObstacle", 2),
    ("CellItem", 2),
    ("CellGoal", 2),
    ("CellAgent", 2),
    # ("AgentAt", 2),
    ("GoalAt", 2),
    # ("ObstacleAt", 2),
    ("ItemAt", 2),
    ("Near", 0),       # boolean style (no params)
    ("SeeGoal", 0),
    ("CanMove", 1),    # direction (0..3)
    ("OtherAgentAt", 2),
    ("OtherAgentNear", 0),
    ("OtherAgentDirection", 1)

]
PRIM_NAME_TO_IDX = {name: i for i, (name, ar) in enumerate(PRIMITIVE_TEMPLATES)}
NUM_PRIMS = len(PRIMITIVE_TEMPLATES)
print(PRIM_NAME_TO_IDX)

{'CellEmpty': 0, 'CellObstacle': 1, 'CellItem': 2, 'CellGoal': 3, 'CellAgent': 4, 'GoalAt': 5, 'ItemAt': 6, 'Near': 7, 'SeeGoal': 8, 'CanMove': 9, 'OtherAgentAt': 10, 'OtherAgentNear': 11, 'OtherAgentDirection': 12}


In [None]:
#| export
# -------------------------
# Program representation
# -------------------------
class Program:
    def __init__(self, tokens: List[Tuple[int, List[float]]] = None, finished: bool = False):
        # tokens: list of (prim_idx, params_list)
        self.tokens = tokens or []
        self.EOS_IDX = len(PRIMITIVE_TEMPLATES)
        self.finished = finished


    # def extend(self, prim_idx: int, params: List[float]):
    #     return Program(self.tokens + [(int(prim_idx), [float(p) for p in params])])
    def extend(self, prim_idx, params):
        if self.finished:
            return self  # don't extend finished programs
        if prim_idx == self.EOS_IDX:
            return Program(self.tokens, finished=True)
        return Program(self.tokens + [(prim_idx, params)], finished=False)


    def __len__(self):
        return len(self.tokens)

    def __repr__(self):
        if len(self.tokens) == 0:
            return "<EMPTY>"
        elif len(self.tokens) == 1 and self.tokens[0][0] == -1:
            return "<BOP>"
        
        toks = []
        for pidx, params in self.tokens:
            name = PRIMITIVE_TEMPLATES[pidx][0]
            toks.append(f"{name}{tuple(params)}")
        return " | ".join(toks)

In [None]:
P = Program(tokens=[(0, [1.0, 2.0]), (4, []), (1, [3.0, 4.0])])
P

AgentAt(1.0, 2.0) | Near() | GoalAt(3.0, 4.0)

In [None]:
len(P)

3

In [None]:
P = P.extend(2, [5.0, 6.0])
P

AgentAt(1.0, 2.0) | Near() | GoalAt(3.0, 4.0) | ObstacleAt(5.0, 6.0)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()