# Vision encoder

> a ConvNet module for percpetion.

In [None]:
#| default_exp models.vision.enums

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *

In [None]:
#| export
from mawm.cfg.pldm import ConfigBase
from dataclasses import dataclass
from typing import Optional
import torch


@dataclass
class BackboneConfig(ConfigBase):
    arch: str = "menet5"
    backbone_subclass: str = "a"
    backbone_width_factor: int = 1
    backbone_mlp: Optional[str] = None  # mlp to slap on top of backbone
    backbone_norm: str = "batch_norm"
    backbone_pool: str = "avg_pool"
    backbone_final_fc: bool = True
    channels: int = 1
    input_dim: Optional[int] = None  # if it's none, we assume it's image.
    position_dim: Optional[int] = None
    position_encoder_arch: Optional[str] = None
    fc_output_dim: Optional[int] = None  # if it's none, it will be a spatial output
    final_ln: bool = False




The backbone is the vision encoder. The config of the backbone includes the follwoing parameters:
- `arch`: The string indicating the class name of the backbone architecture.
- `backbone_subclass`: used inside the model architecture to access the configuration of the layers.
-  `backbone_width_factor`: A multiplicative factor to scale the width of the backbone layers.
-  `backbone_mlp`: Stacking a projector MLP on top of the backbone to project the features to a desired dimension.
-  `backbone_nor`: The type of normalization layer to use in the backbone.
-  `backbone_pool`: The type of pooling layer to use in the backbone.
-  `backbone_final_fc`: .
-  `channels`: The number of input channels of the images.
-  `input_dim`: The spatial dimensions of the input images.
-  `position_dim`: The dimension of the position input.
-  `position_encoder_arch`: The architecture of the position encoder.

In [None]:
#| export
class BackboneOutput:
    def __init__(
        self,
        encodings: torch.Tensor,
        obs_component: Optional[torch.Tensor] = None,
        pos_component: Optional[torch.Tensor] = None,
    ):
        self.encodings = encodings
        self._obs_component = obs_component
        self.pos_component = pos_component

    @property
    def obs_component(self):
        return (
            self._obs_component if self._obs_component is not None else self.encodings
        )

    @obs_component.setter
    def obs_component(self, value: Optional[torch.Tensor]):
        self._obs_component = value

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()