# World Model

> World model (Predictor).

In [None]:
#| default_exp models.worldmodel

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *

In [None]:
#| export
import torch
import torch.nn as nn
from MAWM.models.dense import DenseModel


In [None]:
#| export
from MAWM.models.dense import DenseModel
class WorldModel(nn.Module):
    def __init__(
        self,
        obs_dim: int,
        msg_dim: int,
        action_dim: int,
        output_dim: int,
        model_info:dict={'layers': 3,'node_size': 128,'activation': nn.ReLU,'dist': None}
    ):
        super().__init__()

        self.message_proj = nn.Linear(msg_dim, model_info['node_size'], bias=True) #[B, D]
        self.action_proj = nn.Linear(action_dim, model_info['node_size'], bias=True) # [B, D]
        self.obs_proj = nn.Linear(obs_dim, model_info['node_size'], bias=True) # [B, D]

        self.fuse = nn.Linear(model_info['node_size'] * 3, model_info['node_size'])
        self.wm = DenseModel(output_shape= (output_dim,), input_size=model_info['node_size'], info= model_info)

    def forward(self, z, action, msg):
        msg = self.message_proj(msg)
        act = self.action_proj(action)
        z = self.obs_proj(z)

        out = self.fuse(torch.cat([z, act, msg], dim= -1))
        out = self.wm(out)
        return out


In [None]:
#| hide
z = torch.randn(4, 32)
action = torch.randn(4, 1)
msg = torch.randn(4, 256)
model = WorldModel(obs_dim=32, msg_dim=256, action_dim=1, output_dim=32, 
           model_info={ 'layers': 3,'node_size': 128,'activation': nn.ReLU,'dist': None})
out = model(z, action, msg)
out.shape

torch.Size([4, 32])

In [None]:
#| export
class RewardModel(nn.Module):
    def __init__(
        self,
        obs_dim: int,
        msg_dim: int,
        action_dim: int,
        output_dim: int,
        model_info:dict= {'layers': 3,'node_size': 128,'activation': nn.ReLU,'dist': None}
    ):
        super().__init__()

        self.message_proj = nn.Linear(msg_dim, model_info['node_size'], bias=True) #[B, D]
        self.action_proj = nn.Linear(action_dim, model_info['node_size'], bias=True) # [B, D]
        self.obs_proj = nn.Linear(obs_dim, model_info['node_size'], bias=True) # [B, D]

        self.fuse = nn.Linear(model_info['node_size'] * 3, model_info['node_size']) # [B, D]
        self.rm = DenseModel(output_shape= (output_dim,), input_size=model_info['node_size'], info= model_info) # [B, 1]

    def forward(self, z, action, msg):
        msg = self.message_proj(msg)
        act = self.action_proj(action)
        z = self.obs_proj(z)

        out = self.fuse(torch.cat([z, act, msg], dim= -1))
        out = self.rm(out)
        return out

In [None]:
#| hide
z = torch.randn(4, 32)
action = torch.randn(4, 1)
msg = torch.randn(4, 256)
model = RewardModel(obs_dim=32, msg_dim=256, action_dim=1, output_dim=1, 
           model_info={ 'layers': 3,'node_size': 128,'activation': nn.ReLU,'dist': 'binary'})
out = model(z, action, msg)
out.sample()

tensor([[0.],
        [1.],
        [0.],
        [0.]])

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()