# Vision encoder

> a ConvNet module for percpetion.

In [None]:
#| default_exp models.vision.base

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *

In [None]:
#| export
from torch import nn
from mawm.models.vision.enums import BackboneOutput


class SequenceBackbone(nn.Module):
    def __init__(self):
        """
        collapse T and BS dimensions prior to passing to backbone
        afterwards reshape to original shape
        """
        super().__init__()
        self.output_position_dim = 0

    def _remove_pos_component_for_spatial(self, embeddings):
        """
        remove the position component from spatial embeddings

        Input:
            embeddings: tensor
            (T, BS, Ch, W, H) or
            (BS, Ch, W, H) or
            (T, BS, H) or
            (BS, H)
        """
        og_shape = tuple(embeddings.shape)
        flattened_input = len(og_shape) < 4

        # first reshape to spatial dimension if needed
        if flattened_input:
            spatial_shape = (*embeddings.shape[:-1], *self.output_dim)
            embeddings = embeddings.view(spatial_shape)

        position_channels = self.output_position_dim[0]

        # remove the position dimensions
        if len(embeddings.shape) == 5:
            embeddings = embeddings[:, :, :-position_channels]
        elif len(embeddings.shape) == 4:
            embeddings = embeddings[:, :-position_channels]

        # reflatten tensor if needed
        if flattened_input:
            embeddings = embeddings.view(*og_shape[:-1], -1)

        return embeddings

    def remove_pos_component(self, embeddings):
        """
        remove the position component from embeddings
        Input:
            embeddings: tensor
            (T, BS, Ch, W, H) or
            (BS, Ch, W, H) or
            (T, BS, H) or
            (BS, H)
        """
        if not self.output_position_dim:
            return embeddings

        if isinstance(self.output_dim, int):
            return embeddings[..., : -self.output_position_dim]
        else:
            return self._remove_pos_component_for_spatial(embeddings)

    def forward_multiple(self, x, position=None):
        """
        input:
            x: [T, BS, *] or [BS, T, *]
        output:
            x: [T, BS, *] or [T, BS, *]
        """

        # if no time dimension, just feed it directly to backbone
        if x.dim() == 2 or x.dim() == 4:
            if position is not None:
                output = self.forward(x, position)
            else:
                output = self.forward(x)
            return output

        state = x.flatten(0, 1)
        if position is not None:
            position = position.flatten(0, 1)
            output = self.forward(state, position)
            
        else:
            output = self.forward(state)

        state = output.encodings
        new_shape = x.shape[:2] + state.shape[1:]
        state = state.reshape(new_shape)

        
        return state

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()