# core

> Fill in a module description here

In [None]:
#| default_exp core

In [None]:
#| hide
from nbdev.showdoc import *



- Collect random trajectories from the environment.
- Train an encoder using VAE on the collected trajectories.
- End-to-end training of the world model and communication module using the trained encoder.

In [None]:
#| export
import importlib
def get_cls(module_name, class_name):
    module = importlib.import_module(module_name)
    return getattr(module, class_name)

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from typing import List, Tuple
Z_DIM = 32
ACTION_DIM = 1
PROG_HID = 128
PRIM_EMB = 32
PARAM_EMB = 16
PROG_RNN_HID = 128
MSG_DIM = 32#23
MAX_PARAMS = 2        # maximum params per primitive; everything is padded to this
GRID_SIZE = 5
BEAM_WIDTH = 5
PROP_TOPK = 6
MAX_PROG_LEN = 5
LAMBDA_Z = 1.0
LAMBDA_R = 1.0
LEARNING_RATE = 1e-4


In [None]:
#| export
# -------------------------
# Primitive templates (name, arity)
# -------------------------
PRIMITIVE_TEMPLATES = [
    ("CellEmpty", 2),        # cell (i, j) is empty
    ("CellObstacle", 2),
    ("CellItem", 2),
    ("CellGoal", 2),
    ("CellAgent", 2),
    # ("AgentAt", 2),
    ("GoalAt", 2),
    # ("ObstacleAt", 2),
    ("ItemAt", 2),
    ("Near", 0),       # boolean style (no params)
    ("SeeGoal", 0),
    ("CanMove", 1),    # direction (0..3)
    ("OtherAgentAt", 2),
    ("OtherAgentNear", 0),
    ("OtherAgentDirection", 1)

]
PRIM_NAME_TO_IDX = {name: i for i, (name, ar) in enumerate(PRIMITIVE_TEMPLATES)}
NUM_PRIMS = len(PRIMITIVE_TEMPLATES)
# print(PRIM_NAME_TO_IDX)

{'CellEmpty': 0, 'CellObstacle': 1, 'CellItem': 2, 'CellGoal': 3, 'CellAgent': 4, 'GoalAt': 5, 'ItemAt': 6, 'Near': 7, 'SeeGoal': 8, 'CanMove': 9, 'OtherAgentAt': 10, 'OtherAgentNear': 11, 'OtherAgentDirection': 12}


In [None]:
#| export
# -------------------------
# Program representation
# -------------------------
class Program:
    def __init__(self, tokens: List[Tuple[int, List[float]]] = None, finished: bool = False):
        # tokens: list of (prim_idx, params_list)
        self.tokens = tokens or []
        self.EOS_IDX = len(PRIMITIVE_TEMPLATES)
        self.finished = finished


    # def extend(self, prim_idx: int, params: List[float]):
    #     return Program(self.tokens + [(int(prim_idx), [float(p) for p in params])])
    def extend(self, prim_idx, params):
        if self.finished:
            return self  # don't extend finished programs
        if prim_idx == self.EOS_IDX:
            return Program(self.tokens, finished=True)
        return Program(self.tokens + [(prim_idx, params)], finished=False)


    def __len__(self):
        return len(self.tokens)

    def __repr__(self):
        if len(self.tokens) == 0:
            return "<EMPTY>"
        elif len(self.tokens) == 1 and self.tokens[0][0] == -1:
            return "<BOP>"
        
        toks = []
        for pidx, params in self.tokens:
            name = PRIMITIVE_TEMPLATES[pidx][0]
            toks.append(f"{name}{tuple(params)}")
        return " | ".join(toks)

In [None]:
P = Program(tokens=[(0, [1.0, 2.0]), (4, []), (1, [3.0, 4.0])])
P

AgentAt(1.0, 2.0) | Near() | GoalAt(3.0, 4.0)

In [None]:
len(P)

3

In [None]:
P = P.extend(2, [5.0, 6.0])
P

AgentAt(1.0, 2.0) | Near() | GoalAt(3.0, 4.0) | ObstacleAt(5.0, 6.0)

In [None]:
#| export
import argparse
import dataclasses
from dataclasses import dataclass
from enum import Enum
from typing import Any, Iterable, Tuple, Union, cast, List

from omegaconf import OmegaConf

In [None]:
#| export
DataClass = Any
DataClassType = Any
class DataclassArgParser(argparse.ArgumentParser):
    """A class for handling dataclasses and argument parsing.
    Closely based on Hugging Face's HfArgumentParser class,
    extended to support recursive dataclasses.
    """

    def __init__(
        self,
        dataclass_types: Union[DataClassType, Iterable[DataClassType]],
        **kwargs,
    ):
        """
        Args:
            dataclass_types:
                Dataclass type, or list of dataclass types for which we will
                "fill" instances with the parsed args.
            kwargs:
                (Optional) Passed to `argparse.ArgumentParser()` in the regular
                way.
        """
        super().__init__(**kwargs)
        if dataclasses.is_dataclass(dataclass_types):
            dataclass_types = cast(DataClassType, dataclass_types)
            dataclass_types = [dataclass_types]
        self.dataclass_types = cast(Iterable[DataClassType], dataclass_types)
        for dtype in self.dataclass_types:
            self._add_dataclass_arguments(dtype)

    def _add_dataclass_arguments(self, dtype: DataClassType):
        for f in dataclasses.fields(dtype):
            field_name = f"--{f.name}"
            kwargs = dict(f.metadata).copy()
            typestring = str(f.type)
            for x in (int, float, str):
                if typestring == f"typing.Union[{x.__name__}, NoneType]":
                    f.type = x
            if isinstance(f.type, type) and issubclass(f.type, Enum):
                kwargs["choices"] = list(f.type)
                kwargs["type"] = f.type
                if f.default is not dataclasses.MISSING:
                    kwargs["default"] = f.default
            elif f.type is bool:
                kwargs["action"] = "store_false" if f.default is True else "store_true"
                if f.default is True:
                    field_name = f"--no-{f.name}"
                    kwargs["dest"] = f.name
            elif dataclasses.is_dataclass(f.type):
                self._add_dataclass_arguments(f.type)
            else:
                kwargs["type"] = f.type
                if f.default is not dataclasses.MISSING:
                    kwargs["default"] = f.default
                else:
                    kwargs["required"] = True
            self.add_argument(field_name, **kwargs)

    def parse_args_into_dataclasses(
        self,
        args=None,
    ) -> Tuple[DataClass, ...]:
        """
        Parse command-line args into instances of the specified dataclass
        types.  This relies on argparse's `ArgumentParser.parse_known_args`.
        See the doc at:
        docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args
        Args:
            args:
                List of strings to parse. The default is taken from sys.argv.
                (same as argparse.ArgumentParser)
        Returns:
            Tuple consisting of:
                - the dataclass instances in the same order as they
                  were passed to the initializer.abspath
                - if applicable, an additional namespace for more
                  (non-dataclass backed) arguments added to the parser
                  after initialization.
                - The potential list of remaining argument strings.
                  (same as argparse.ArgumentParser.parse_known_args)
        """
        namespace, unknown = self.parse_known_args(args=args)
        outputs = []

        for dtype in self.dataclass_types:
            outputs.append(self._populate_dataclass(dtype, namespace))
        if len(namespace.__dict__) > 0:
            # additional namespace.
            outputs.append(namespace)
        if len(unknown) > 0:
            outputs.append(unknown)
        return tuple(outputs)

    @staticmethod
    def _populate_dataclass(dtype: DataClassType, namespace: argparse.Namespace):
        keys = {f.name for f in dataclasses.fields(dtype)}
        inputs = {k: v for k, v in vars(namespace).items() if k in keys}
        for k in keys:
            delattr(namespace, k)
        sub_dataclasses = {
            f.name: f.type
            for f in dataclasses.fields(dtype)
            if dataclasses.is_dataclass(f.type)
        }
        for k, s in sub_dataclasses.items():
            inputs[k] = DataclassArgParser._populate_dataclass(s, namespace)
        obj = dtype(**inputs)
        return obj

    @staticmethod
    def _populate_dataclass_from_dict(dtype: DataClassType, d: dict):
        d = DataclassArgParser.legacy_transform_dict(d.copy())
        keys = {f.name for f in dataclasses.fields(dtype)}
        inputs = {k: v for k, v in d.items() if k in keys}
        for k in keys:
            if k in d:
                del d[k]
        sub_dataclasses = {
            f.name: f.type
            for f in dataclasses.fields(dtype)
            if dataclasses.is_dataclass(f.type)
        }
        for k, s in sub_dataclasses.items():
            if k not in inputs:
                v = {}
            else:
                v = inputs[k]
            inputs[k] = DataclassArgParser._populate_dataclass_from_dict(s, v)
        obj = dtype(**inputs)
        return obj

    @staticmethod
    def _populate_dataclass_from_flat_dict(dtype: DataClassType, d: dict):
        d = DataclassArgParser.legacy_transform_dict(d.copy())
        keys = {f.name for f in dataclasses.fields(dtype)}
        inputs = {k: v for k, v in d.items() if k in keys}
        for k in keys:
            if k in d:
                del d[k]
        sub_dataclasses = {
            f.name: f.type
            for f in dataclasses.fields(dtype)
            if dataclasses.is_dataclass(f.type)
        }
        for k, s in sub_dataclasses.items():
            inputs[k] = DataclassArgParser._populate_dataclass_from_dict(s, d)
        obj = dtype(**inputs)
        return obj

    @staticmethod
    def legacy_transform_dict(d: dict):
        """Transforms the dictionary to an older version of the dataclasses"""
        key_mapping = {
            "training_config": "training",
            "model_config": "model",
            "cost_config": "cost",
        }
        nd = {}
        for k in d:
            if k in key_mapping:
                nd[key_mapping[k]] = d[k]
            else:
                nd[k] = d[k]
        return nd


In [None]:
#| export
def omegaconf_parse(cls):
    parser = argparse.ArgumentParser(fromfile_prefix_chars="@")
    parser.add_argument(
        "--configs",
        nargs="*",
        default=[],
        help="Configs to load",
    )
    parser.add_argument(
        "--values",
        nargs="*",
        default=[],
        help="Dot values to change configs",
    )
    args, _unknown = parser.parse_known_args()

    return omegaconf_parse_files_vals(cls, args.configs, args.values)


def omegaconf_parse_files_vals(cls, files_paths: List[str], dotlist: List[str]):
    configs = [OmegaConf.structured(cls)]
    for path in files_paths:
        configs.append(OmegaConf.load(path))
    configs.append(OmegaConf.from_dotlist(dotlist))
    omega_config = OmegaConf.merge(*configs)
    res = cls.parse_from_dict(OmegaConf.to_container(omega_config))
    return res


def combine_cli_dict(cls, c_dict):
    """A function to load cli configs and merge them with a dictionary"""
    config_base = cls.parse_from_command_line()
    return combine_dataclass_dict(config_base, c_dict)


def combine_dataclass_dict(dcls, c_dict):
    """Combines the parameters in an instantiated dataclass with the dictionary."""
    config = OmegaConf.create(dataclasses.asdict(dcls))
    for k, v in c_dict.items():
        OmegaConf.update(config, k, v)
    return dcls.parse_from_dict(OmegaConf.to_container(config))

In [None]:
#| export
import argparse
import dataclasses
from dataclasses import dataclass
from enum import Enum
from typing import Any, Iterable, Tuple, Union, cast, List

from omegaconf import OmegaConf

DataClass = Any
DataClassType = Any


@dataclass
class ConfigBase:
    """Base class that should handle parsing from command line,
    json, dicts.
    """

    @classmethod
    def parse_from_command_line(cls):
        return omegaconf_parse(cls)

    @classmethod
    def parse_from_file(cls, path: str):
        oc = OmegaConf.load(path)
        return cls.parse_from_dict(OmegaConf.to_container(oc))

    @classmethod
    def parse_from_command_line_deprecated(cls):
        result = DataclassArgParser(
            cls, fromfile_prefix_chars="@"
        ).parse_args_into_dataclasses()
        if len(result) > 1:
            raise RuntimeError(
                f"The following arguments were not recognized: {result[1:]}"
            )
        return result[0]

    @classmethod
    def parse_from_dict(cls, inputs):
        return DataclassArgParser._populate_dataclass_from_dict(cls, inputs.copy())

    @classmethod
    def parse_from_flat_dict(cls, inputs):
        return DataclassArgParser._populate_dataclass_from_flat_dict(cls, inputs.copy())

    def save(self, path: str):
        with open(path, "w") as f:
            OmegaConf.save(config=self, f=f)




Let me reiterate the exact setting. I have two agents, each with partial local observation as an RGB image, representing the 7*7 grid around it. We derive a message from this observation by parsing the grid information, resulting in a one-hot encoded message of shape (5*7*7), where 5 is the channel number representing the number of object types in the environment (wall, goal cell, ...). Each agent can send and receive messages. They both have an image encoder and a message encoder, which are trained in an SSL manner using LEJEPA (or vicreg). The training paradigm is centralized, but the execution has to be decentralized. The following is an example of the forward pass regarding the message:


C = self.msg_encoder(msg)

z_sender = self.model.backbone(obs_sender, position = pos_sender)

sigreg_img = self.sigreg(proj_z)

sigreg_msg = self.sigreg(proj_c)

sigreg_loss = 0.5 * (sigreg_img + sigreg_msg )

lejepa_loss = (1- self.lambda_) * inv_loss + self.lambda_ * sigreg_loss 


Now, to learn the dynamics model, we have two more models:

an image encoder (same as above) and a predictor, which predicts T-1 steps ahead in latent space:

Z0, Z = self.model(x= obs, pos= pos, actions= act, msgs= C, T= act.size(1)-1)

vicreg_loss = self.vicreg(Z0, Z, mask= mask)


The training pipeline optimizes:

lejepa_loss + vicreg_loss['total_loss']


Now I want to focus on how this model can be leveraged to plan with the discrete CEM 

(an ideally allow communication between agents in the planning)

ok, let's discard the generative modeling part for now. If we allowed agents to share only the very first discrete message in the very first environment interaction, and also share the actions taken during planning, can we not rely on the predictor to "recreate" an imagined series of messages.

The mechanism is like this, in the training phase, along with training jepa for good encoders and prediction of the next latent, we train two other predictors, one to predict the latent of observation from the message, and one for the other way around. At planning time, we have the predictor to get z from the message, so we can obtain ~ z0 with no access to the other agent's observation. 

Now, what happens in planning is as follows:

t = 0
agent j send message (grid encoded as 7*7, so agent send only 49 integer numbers as 2d array. 

Agent I encode the message using the message encoder. Agent, I apply one_hot encoding to obtain the msg of shape 5*7*7.  

Agent I Use the encoded message, h,  of j to condition the predictor (along with its action) to get z_{1} such as: z_{1}^i = f(z_0^i, a,_0^i, h_0^j)

finally, get the z_0 of agent j by using the observation predictor (at agent i) as:
z_0^j = obsPred(h^j)

agent j communicate its action ALREADY, so we have the action of t=0 already, then agent i uses its dynamics model to forward the agent's j latent to get z_1^j as:
z_1^j = f(z_0^j, a_0^j, msgEncoder(msg_1^i))

t = 1

get the encoded message of agent j by using the message predictor (which already learned to predict the latent of the message from the latent of the observation):
h_1^j = msgPred(z_1^j)
use the dynamics model again but now for the agent i himself:
z_{2} such as: z_{1}^i = f(z_1^i, a,_1^i, h_1^j)

Note: during training all models are shared across the two agents, to simplify things 


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()