### Извлечение признаков из VPT

In [5]:
from minestudio.models import VPTPolicy
import torch.nn as nn
import torch

In [2]:
model = VPTPolicy.from_pretrained("CraftJarvis/MineStudio_VPT.rl_from_early_game_2x").to("cuda")
model.eval()

VPTPolicy(
  (value_head): ScaledMSEHead(
    (linear): Linear(in_features=2048, out_features=1, bias=True)
    (normalizer): NormalizeEwma()
  )
  (pi_head): DictActionHead(
    (buttons): CategoricalActionHead(
      (linear_layer): Linear(in_features=2048, out_features=8641, bias=True)
    )
    (camera): CategoricalActionHead(
      (linear_layer): Linear(in_features=2048, out_features=121, bias=True)
    )
  )
  (net): MinecraftPolicy(
    (img_preprocess): ImgPreprocessing()
    (img_process): ImgObsProcess(
      (cnn): ImpalaCNN(
        (stacks): ModuleList(
          (0): CnnDownStack(
            (firstconv): FanInInitReLULayer(
              (layer): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            )
            (n): GroupNorm(1, 128, eps=1e-05, affine=True)
            (blocks): ModuleList(
              (0-1): 2 x CnnBasicBlock(
                (conv0): FanInInitReLULayer(
                  (norm): GroupNorm(1, 128, eps=1e-05, affine=True)
   

Имплементация VPTPolicy такова, что есть несколько модулей у нашей model:
* model.net - сам VPT для извлечения признаков, генерация скрытых представлений для pi и value голов. Важно отметить, что генерируемые латенты для pi и value одинаковые. Также модель кэширует предыдущие состояния для эффективного подсчёта в инфиренсе.
* model.pi_head - модуль для генерации действий
* model.value_head - модуль для генерации ценности

Вот кусок кода, нас интересует метод forward

```python
class VPTPolicy(MinePolicy, PyTorchModelHubMixin):
    def forward(self, input, state_in, **kwargs):
        """Forward pass of the VPTPolicy.

        Takes observations and recurrent state, passes them through the underlying
        `MinecraftPolicy` network, and then through policy and value heads.

        :param input: Dictionary of input observations, expected to contain "image".
                      The "image" tensor should have shape (B, T, H, W, C) or similar.
        :type input: Dict[str, torch.Tensor]
        :param state_in: Input recurrent state. If None, an initial state is generated.
        :type state_in: Optional[List[torch.Tensor]]
        :param kwargs: Additional keyword arguments (not directly used in this method but part of signature).
        :returns: A tuple containing:
            - latents (Dict[str, torch.Tensor]): Dictionary with 'pi_logits' and 'vpred'.
            - state_out (List[torch.Tensor]): Output recurrent state.
        :rtype: Tuple[Dict[str, torch.Tensor], List[torch.Tensor]]
        """
        B, T = input["image"].shape[:2]
        first = torch.tensor([[False]], device=self.device).repeat(B, T)
        state_in = self.initial_state(B) if state_in is None else state_in

        #input: 1, 128, 128, 128, 3
        #first: 1, 128
        # state_in[0]: 1, 1, 1, 128
        # state_in[1]: 1, 1, 128, 128
        try:
            (pi_h, v_h), state_out = self.net(input, state_in, context={"first": first})
        except Exception as e:
            import ray
            ray.util.pdb.set_trace()
        pi_logits = self.pi_head(pi_h)
        vpred = self.value_head(v_h)
        latents = {'pi_logits': pi_logits, 'vpred': vpred}
        return latents, state_out
```

А вот метод forward у самой model.net

```python
class MinecraftPolicy(nn.Module):
        def forward(self, ob, state_in, context):
        """Forward pass of the MinecraftPolicy.

        Processes image observations, passes them through recurrent layers, and produces
        latent representations.

        :param ob: Dictionary of observations, expected to contain "image".
        :type ob: Dict[str, torch.Tensor]
        :param state_in: Input recurrent state.
        :type state_in: Any # Type depends on recurrence_type
        :param context: Context dictionary, expected to contain "first" (a tensor indicating episode starts).
        :type context: Dict[str, torch.Tensor]
        :returns: A tuple containing:
            - pi_latent_or_tuple (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]):
                If `single_output` is True, this is a single tensor for both policy and value.
                Otherwise, it's a tuple (pi_latent, vf_latent).
            - state_out (Any): Output recurrent state.
        :rtype: Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], Any]
        """
        first = context["first"]
        x = self.img_preprocess(ob["image"])
        x = self.img_process(x)

        if self.diff_obs_process:
            processed_obs = self.diff_obs_process(ob["diff_goal"])
            x = processed_obs + x

        if self.pre_lstm_ln is not None:
            x = self.pre_lstm_ln(x)

        if self.recurrent_layer is not None:
            x, state_out = self.recurrent_layer(x, first, state_in)
        else:
            state_out = state_in

        x = F.relu(x, inplace=False)

        x = self.lastlayer(x)
        x = self.final_ln(x)
        pi_latent = vf_latent = x
        if self.single_output:
            return pi_latent, state_out
        return (pi_latent, vf_latent), state_out
```

pi_latent и vf_latent одинаковые, входные изображения должны быть вида (B, T, H, W, C). Попробуем вогнать пробный тензор

In [3]:
B = 2
T = 129
input = {"image": torch.zeros(B, T, 128, 128, 3, device="cuda")}
first = torch.zeros((B, T), dtype=torch.bool, device="cuda")
context = {"first": first}
state_in = model.net.initial_state(B) 
latents, state_out = model.net(input, state_in, {"first": first})
print(latents[0].shape)

torch.Size([2, 129, 2048])


In [4]:
print(len(state_out))
for i, s in enumerate(state_out):
    print(i, s.dtype, tuple(s.shape), s.min().item() if s.is_floating_point() else "", s.max().item() if s.is_floating_point() else "")

12
0 torch.bool (2, 1, 128)  
1 torch.float32 (2, 128, 2048) -79.74006652832031 81.1694564819336
2 torch.float32 (2, 128, 2048) -13.02964973449707 20.191967010498047
3 torch.bool (2, 1, 128)  
4 torch.float32 (2, 128, 2048) -80.62519073486328 112.31578826904297
5 torch.float32 (2, 128, 2048) -5.20977783203125 4.960232734680176
6 torch.bool (2, 1, 128)  
7 torch.float32 (2, 128, 2048) -68.98873138427734 148.8136749267578
8 torch.float32 (2, 128, 2048) -4.444644451141357 4.461331367492676
9 torch.bool (2, 1, 128)  
10 torch.float32 (2, 128, 2048) -804.27001953125 1025.3935546875
11 torch.float32 (2, 128, 2048) -10.492900848388672 10.701163291931152


Помимо признаков выводится также обновлённый кэш, который подаётся в VPT

То есть кэш имеет фиксированный размер контекста (T=128), однако на вход мы можем подавать любую последовательность, он посчитает латенты для каждого элемента

Теперь имплементируем простейший класс для извлечения латентов

In [None]:
class VPTFeatureExtractor(nn.Module):
    def __init__(self, model_name : str, device : str = "cuda", eval_mode : bool = True, *args, **kwargs):
        super().__init__(*args, **kwargs)
        policy = VPTPolicy.from_pretrained(model_name).to(device)
        self.device = device
        self.net = policy.net.to(device)
        if eval_mode:
            self.net.eval()
     
    @torch.no_grad()
    def init_state(self, batch_size: int):
        st = self.net.initial_state(batch_size)
        return [s.to(self.device) for s in st] if st is not None else None
            
    @torch.no_grad()
    def forward(self, obs : dict, state_in = None, context = None, pooling_mode = None):
        """Extract latent features from the VPT model"""
        if "image" not in obs:
            raise KeyError('Obs must contain "image" key')
        #x dim: (B, T, H, W, C)
        x = obs["image"]
        if not torch.is_tensor(x):
            x = torch.as_tensor(x)
        if x.dim() != 5:
            raise ValueError(f"image must be in a 5D shape (B, T, H, W, C) but got {x.shape}")
        
        B, T = x.shape[:2]
        if context is None:
            first = torch.zeros((B, T), dtype=torch.bool, device=self.device)
            context = {"first" : first}
        else:
            context = dict(context)
            context["first"] = context["first"].to(self.device)
        if state_in is None:
            state_in = self.init_state(B)
        else:
            state_in = [s.to(self.device) for s in state_in]
        
        (latents, _), state_out = self.net({"image" : x}, state_in, context)
        latents = self._pooling(latents, pooling_mode)
        return latents, state_out
    
    def _pooling(self, latents, pooling):
      if pooling is None or pooling == 'none':
        return latents,  #B, T, D
      if pooling == 'last':
        return latents[:, -1] #B, D
      if pooling == 'mean':
        return latents.mean(dim=1) #B, D
      return ValueError(f"pooling can be either none | last | mean but {pooling} was given")

In [7]:
extractor = VPTFeatureExtractor("CraftJarvis/MineStudio_VPT.rl_from_early_game_2x", device="cuda")

obs = {"image": torch.zeros(2, 16, 128, 128, 3, device="cuda", dtype=torch.uint8)}
pi_h, state = extractor(obs, state_in=None, context=None)
print(pi_h.shape)


torch.Size([2, 16, 2048])


TODO: добавить пулинг с разными режимами: возвращать признаки как есть, возвращать последний токен, возвращать среднее токенов