In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import OneHotCategorical, Normal
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
from torch.cuda.amp import autocast
from torch.utils.tensorboard import SummaryWriter

from sub_models.functions_losses import SymLogTwoHotLoss

from sub_models.attention_blocks import get_subsequent_mask_with_batch_length, get_subsequent_mask
from sub_models.transformer_model import StochasticTransformerKVCache
from sub_models.attention_blocks import get_vector_mask
from sub_models.attention_blocks import PositionalEncoding1D, AttentionBlock, AttentionBlockKVCache

In [14]:
class StochasticTransformerKVCache2(nn.Module):
    def __init__(self, stoch_dim, action_dim, feat_dim, num_layers, num_heads, max_length, dropout):
        super().__init__()
        self.action_dim = action_dim
        self.feat_dim = feat_dim

        # mix image_embedding and action
        self.stem = nn.Sequential(
            nn.Linear(stoch_dim+action_dim, feat_dim, bias=False),
            nn.LayerNorm(feat_dim),
            nn.ReLU(inplace=True),
            nn.Linear(feat_dim, feat_dim, bias=False),
            nn.LayerNorm(feat_dim)
        )
        self.position_encoding = PositionalEncoding1D(max_length=max_length, embed_dim=feat_dim)
        self.layer_stack = nn.ModuleList([
            AttentionBlockKVCache(feat_dim=feat_dim, hidden_dim=feat_dim*2, num_heads=num_heads, dropout=dropout) for _ in range(num_layers)
        ])
        self.layer_norm = nn.LayerNorm(feat_dim, eps=1e-6)  # TODO: check if this is necessary

    def forward(self, samples, action, mask):
        '''
        Normal forward pass
        '''
        # action is not one hot
        # action = F.one_hot(action.long(), self.action_dim).float() 
        print(mask)
        feats = self.stem(torch.cat([samples, action], dim=-1))
        feats = self.position_encoding(feats)
        feats = self.layer_norm(feats)
        for layer in self.layer_stack:
            print("pre:",feats)
            feats, attn = layer(feats, feats, feats, mask)
            print("post:",feats)
        return feats

    def reset_kv_cache_list(self, batch_size, dtype):
        '''
        Reset self.kv_cache_list
        '''
        self.kv_cache_list = []
        for layer in self.layer_stack:
            self.kv_cache_list.append(torch.zeros(size=(batch_size, 0, self.feat_dim), dtype=dtype, device="cpu"))
            

    def forward_with_kv_cache(self, samples, action,test):
        '''
        Forward pass with kv_cache, cache stored in self.kv_cache_list
        '''
        assert samples.shape[1] == 1
        mask = get_vector_mask(self.kv_cache_list[0].shape[1]+1, samples.device)
        # print(mask)
        # action = F.one_hot(action.long(), self.action_dim).float()
        feats = self.stem(torch.cat([samples, action], dim=-1))
        feats = self.position_encoding.forward_with_position(feats, position=self.kv_cache_list[0].shape[1])
        feats = self.layer_norm(feats)
        for idx, layer in enumerate(self.layer_stack):
            self.kv_cache_list[idx] = torch.cat([self.kv_cache_list[idx], feats], dim=1)
            # print("pre:",self.kv_cache_list[idx])
            # print("pres",feats)
            feats, attn = layer(feats, self.kv_cache_list[idx], self.kv_cache_list[idx], mask)
            # print("post:",feats)

        return feats

    def forward_context(self, samples, action, mask):
        '''
        Normal forward pass
        '''
        # action is not one hot
        # action = F.one_hot(action.long(), self.action_dim).float() 
        # print(mask)
        feats = self.stem(torch.cat([samples, action], dim=-1))
        feats = self.position_encoding(feats)
        feats = self.layer_norm(feats)
        kv_cache_test = feats
        for layer in self.layer_stack:
            # print("pre:",feats)
            feats, attn = layer(feats, feats, feats, mask)
            # print("post:",feats)
        return feats,kv_cache_test


In [15]:
storm_transformer = StochasticTransformerKVCache2(
            stoch_dim=3,
            action_dim=1,
            feat_dim=3,
            num_layers=1,
            num_heads=1,
            max_length=5,
            dropout=0
        )

In [16]:
latent = torch.rand(4,5,3)
action = torch.rand(4,5,1)
temporal_mask = get_subsequent_mask(latent)
print(temporal_mask)

tensor([[[ True, False, False, False, False],
         [ True,  True, False, False, False],
         [ True,  True,  True, False, False],
         [ True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True]]])


In [17]:
# with torch.inference_mode():
#     dist_feat = storm_transformer(latent, action, temporal_mask)
storm_transformer.eval()
dist_feat, kv_test = storm_transformer.forward_context(latent, action, temporal_mask)

In [18]:
# with torch.inference_mode():
#     storm_transformer.reset_kv_cache_list(4,torch.float32)
#     test_list = []
#     for i in range(3):
#         dist_feat_kv = storm_transformer.forward_with_kv_cache(latent[:,i:i+1,:], action[:,i:i+1,:],i)
#         test_list.append(dist_feat_kv)
storm_transformer.reset_kv_cache_list(4,torch.float32)
test_list = []
for i in range(5):
    dist_feat_kv = storm_transformer.forward_with_kv_cache(latent[:,i:i+1,:], action[:,i:i+1,:],i)
    test_list.append(dist_feat_kv)

In [19]:
dist_feat[:,:,:]

tensor([[[ 1.3835, -0.4380, -0.9455],
         [ 0.9117,  0.4804, -1.3921],
         [ 1.4132, -0.6600, -0.7532],
         [ 0.6356,  0.7763, -1.4119],
         [ 0.9831,  0.3889, -1.3720]],

        [[ 1.3835, -0.4380, -0.9455],
         [ 1.2935, -1.1419, -0.1516],
         [ 1.4102, -0.6135, -0.7968],
         [ 0.4802,  0.9119, -1.3921],
         [ 0.8993,  0.4956, -1.3949]],

        [[ 1.3835, -0.4380, -0.9455],
         [ 0.9117,  0.4804, -1.3921],
         [ 1.4132, -0.6600, -0.7532],
         [ 0.6356,  0.7763, -1.4119],
         [ 0.9831,  0.3889, -1.3720]],

        [[ 1.3835, -0.4380, -0.9455],
         [ 1.2362, -0.0232, -1.2130],
         [ 1.4136, -0.6722, -0.7414],
         [ 0.6218,  0.7891, -1.4109],
         [ 0.9802,  0.3927, -1.3729]]], grad_fn=<SliceBackward0>)

In [20]:
result = torch.cat(test_list,dim=1)
result[:,:,:]

tensor([[[ 1.3835, -0.4380, -0.9455],
         [ 0.9117,  0.4804, -1.3921],
         [ 1.4132, -0.6600, -0.7532],
         [ 0.6356,  0.7763, -1.4119],
         [ 0.9831,  0.3889, -1.3720]],

        [[ 1.3835, -0.4380, -0.9455],
         [ 1.2935, -1.1419, -0.1516],
         [ 1.4102, -0.6135, -0.7968],
         [ 0.4802,  0.9119, -1.3921],
         [ 0.8993,  0.4956, -1.3949]],

        [[ 1.3835, -0.4380, -0.9455],
         [ 0.9117,  0.4804, -1.3921],
         [ 1.4132, -0.6600, -0.7532],
         [ 0.6356,  0.7763, -1.4119],
         [ 0.9831,  0.3889, -1.3720]],

        [[ 1.3835, -0.4380, -0.9455],
         [ 1.2362, -0.0232, -1.2130],
         [ 1.4136, -0.6722, -0.7414],
         [ 0.6218,  0.7891, -1.4109],
         [ 0.9802,  0.3927, -1.3729]]], grad_fn=<SliceBackward0>)

In [21]:
storm_transformer.kv_cache_list

[tensor([[[ 1.1871,  0.0721, -1.2592],
          [ 0.4294,  0.9522, -1.3816],
          [ 1.3325, -0.2559, -1.0766],
          [ 0.1741,  1.1284, -1.3025],
          [ 0.5249,  0.8748, -1.3997]],
 
         [[ 1.1871,  0.0721, -1.2592],
          [ 1.3714, -0.9846, -0.3868],
          [ 1.3325, -0.2559, -1.0766],
          [ 0.1741,  1.1284, -1.3025],
          [ 0.5249,  0.8748, -1.3997]],
 
         [[ 1.1871,  0.0721, -1.2592],
          [ 0.4294,  0.9522, -1.3816],
          [ 1.3325, -0.2559, -1.0766],
          [ 0.1741,  1.1284, -1.3025],
          [ 0.5249,  0.8748, -1.3997]],
 
         [[ 1.1871,  0.0721, -1.2592],
          [ 0.8712,  0.5291, -1.4004],
          [ 1.3325, -0.2559, -1.0766],
          [ 0.1741,  1.1284, -1.3025],
          [ 0.5249,  0.8748, -1.3997]]], grad_fn=<CatBackward0>)]

In [22]:
kv_test[:,:,:]

tensor([[[ 1.1871,  0.0721, -1.2592],
         [ 0.4294,  0.9522, -1.3816],
         [ 1.3325, -0.2559, -1.0766],
         [ 0.1741,  1.1284, -1.3025],
         [ 0.5249,  0.8748, -1.3997]],

        [[ 1.1871,  0.0721, -1.2592],
         [ 1.3714, -0.9846, -0.3868],
         [ 1.3325, -0.2559, -1.0766],
         [ 0.1741,  1.1284, -1.3025],
         [ 0.5249,  0.8748, -1.3997]],

        [[ 1.1871,  0.0721, -1.2592],
         [ 0.4294,  0.9522, -1.3816],
         [ 1.3325, -0.2559, -1.0766],
         [ 0.1741,  1.1284, -1.3025],
         [ 0.5249,  0.8748, -1.3997]],

        [[ 1.1871,  0.0721, -1.2592],
         [ 0.8712,  0.5291, -1.4004],
         [ 1.3325, -0.2559, -1.0766],
         [ 0.1741,  1.1284, -1.3025],
         [ 0.5249,  0.8748, -1.3997]]], grad_fn=<SliceBackward0>)