In [61]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import OneHotCategorical, Normal
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
from torch.cuda.amp import autocast
from torch.utils.tensorboard import SummaryWriter

from sub_models.functions_losses import SymLogTwoHotLoss

from sub_models.attention_blocks import get_subsequent_mask_with_batch_length, get_subsequent_mask
from sub_models.transformer_model import StochasticTransformerKVCache
from sub_models.attention_blocks import get_vector_mask
from sub_models.attention_blocks import PositionalEncoding1D, AttentionBlock, AttentionBlockKVCache

In [72]:
class StochasticTransformerKVCache2(nn.Module):
    def __init__(self, stoch_dim, action_dim, feat_dim, num_layers, num_heads, max_length, dropout):
        super().__init__()
        self.action_dim = action_dim
        self.feat_dim = feat_dim

        # mix image_embedding and action
        self.stem = nn.Sequential(
            nn.Linear(stoch_dim+action_dim, feat_dim, bias=False),
            nn.LayerNorm(feat_dim),
            nn.ReLU(inplace=True),
            nn.Linear(feat_dim, feat_dim, bias=False),
            nn.LayerNorm(feat_dim)
        )
        self.position_encoding = PositionalEncoding1D(max_length=max_length, embed_dim=feat_dim)
        self.layer_stack = nn.ModuleList([
            AttentionBlockKVCache(feat_dim=feat_dim, hidden_dim=feat_dim*2, num_heads=num_heads, dropout=dropout) for _ in range(num_layers)
        ])
        self.layer_norm = nn.LayerNorm(feat_dim, eps=1e-6)  # TODO: check if this is necessary

    def forward(self, samples, action, mask):
        '''
        Normal forward pass
        '''
        # action is not one hot
        # action = F.one_hot(action.long(), self.action_dim).float() 
        print(mask)
        feats = self.stem(torch.cat([samples, action], dim=-1))
        feats = self.position_encoding(feats)
        feats = self.layer_norm(feats)
        for layer in self.layer_stack:
            print("pre:",feats)
            feats, attn = layer(feats, feats, feats, mask)
            print("post:",feats)
        return feats

    def reset_kv_cache_list(self, batch_size, dtype):
        '''
        Reset self.kv_cache_list
        '''
        self.kv_cache_list = []
        for layer in self.layer_stack:
            self.kv_cache_list.append(torch.zeros(size=(batch_size, 0, self.feat_dim), dtype=dtype, device="cpu"))
            

    def forward_with_kv_cache(self, samples, action,test):
        '''
        Forward pass with kv_cache, cache stored in self.kv_cache_list
        '''
        assert samples.shape[1] == 1
        mask = get_vector_mask(self.kv_cache_list[0].shape[1]+1, samples.device)
        # print(mask)
        # action = F.one_hot(action.long(), self.action_dim).float()
        feats = self.stem(torch.cat([samples, action], dim=-1))
        feats = self.position_encoding.forward_with_position(feats, position=self.kv_cache_list[0].shape[1])
        feats = self.layer_norm(feats)
        for idx, layer in enumerate(self.layer_stack):
            self.kv_cache_list[idx] = torch.cat([self.kv_cache_list[idx], feats], dim=1)
            print("pre:",self.kv_cache_list[idx])
            print("pres",feats)
            feats, attn = layer(feats, self.kv_cache_list[idx], self.kv_cache_list[idx], mask)
            print("post:",feats)

        return feats


In [73]:
storm_transformer = StochasticTransformerKVCache2(
            stoch_dim=3,
            action_dim=1,
            feat_dim=3,
            num_layers=1,
            num_heads=1,
            max_length=3,
            dropout=0
        )

In [74]:
latent = torch.rand(4,3,3)
action = torch.rand(4,3,1)
temporal_mask = get_subsequent_mask(latent)
print(temporal_mask)

tensor([[[ True, False, False],
         [ True,  True, False],
         [ True,  True,  True]]])


In [83]:
# with torch.inference_mode():
#     dist_feat = storm_transformer(latent, action, temporal_mask)
storm_transformer.eval()
dist_feat = storm_transformer(latent, action, temporal_mask)

tensor([[[ True, False, False],
         [ True,  True, False],
         [ True,  True,  True]]])
pre: tensor([[[ 1.1645,  0.1126, -1.2772],
         [ 0.2367, -1.3258,  1.0891],
         [ 1.1531, -1.2856,  0.1325]],

        [[ 1.1580,  0.1241, -1.2821],
         [-0.3539, -1.0088,  1.3627],
         [ 1.2158, -1.2335,  0.0177]],

        [[-0.5184,  1.3987, -0.8803],
         [-0.0443, -1.2020,  1.2463],
         [ 0.9855, -1.3711,  0.3856]],

        [[ 1.1651,  0.1117, -1.2768],
         [-0.6524, -0.7605,  1.4128],
         [ 0.9855, -1.3711,  0.3856]]], grad_fn=<NativeLayerNormBackward0>)
post: tensor([[[ 1.0269,  0.3286, -1.3555],
         [ 0.2449, -1.3287,  1.0838],
         [ 1.2687, -1.1754, -0.0933]],

        [[ 1.0212,  0.3366, -1.3579],
         [-0.2319, -1.0922,  1.3241],
         [ 1.3218, -1.0963, -0.2256]],

        [[-0.6127,  1.4102, -0.7975],
         [ 0.0307, -1.2398,  1.2091],
         [ 1.0552, -1.3430,  0.2879]],

        [[ 1.0274,  0.3280, -1.3553],
     

In [84]:
# with torch.inference_mode():
#     storm_transformer.reset_kv_cache_list(4,torch.float32)
#     test_list = []
#     for i in range(3):
#         dist_feat_kv = storm_transformer.forward_with_kv_cache(latent[:,i:i+1,:], action[:,i:i+1,:],i)
#         test_list.append(dist_feat_kv)
storm_transformer.reset_kv_cache_list(4,torch.float32)
test_list = []
for i in range(3):
    dist_feat_kv = storm_transformer.forward_with_kv_cache(latent[:,i:i+1,:], action[:,i:i+1,:],i)
    test_list.append(dist_feat_kv)

pre: tensor([[[ 1.1645,  0.1126, -1.2772]],

        [[ 1.1580,  0.1241, -1.2821]],

        [[-0.5184,  1.3987, -0.8803]],

        [[ 1.1651,  0.1117, -1.2768]]], grad_fn=<CatBackward0>)
pres tensor([[[ 1.1645,  0.1126, -1.2772]],

        [[ 1.1580,  0.1241, -1.2821]],

        [[-0.5184,  1.3987, -0.8803]],

        [[ 1.1651,  0.1117, -1.2768]]], grad_fn=<NativeLayerNormBackward0>)
post: tensor([[[ 1.0269,  0.3286, -1.3555]],

        [[ 1.0212,  0.3366, -1.3579]],

        [[-0.6127,  1.4102, -0.7975]],

        [[ 1.0274,  0.3280, -1.3553]]], grad_fn=<NativeLayerNormBackward0>)
pre: tensor([[[ 1.1645,  0.1126, -1.2772],
         [ 0.2367, -1.3258,  1.0891]],

        [[ 1.1580,  0.1241, -1.2821],
         [-0.3539, -1.0088,  1.3627]],

        [[-0.5184,  1.3987, -0.8803],
         [-0.0443, -1.2020,  1.2463]],

        [[ 1.1651,  0.1117, -1.2768],
         [-0.6524, -0.7605,  1.4128]]], grad_fn=<CatBackward0>)
pres tensor([[[ 0.2367, -1.3258,  1.0891]],

        [[-0.3539, -1.

In [85]:
dist_feat[:,:,:]

tensor([[[ 1.0269,  0.3286, -1.3555],
         [ 0.2449, -1.3287,  1.0838],
         [ 1.2687, -1.1754, -0.0933]],

        [[ 1.0212,  0.3366, -1.3579],
         [-0.2319, -1.0922,  1.3241],
         [ 1.3218, -1.0963, -0.2256]],

        [[-0.6127,  1.4102, -0.7975],
         [ 0.0307, -1.2398,  1.2091],
         [ 1.0552, -1.3430,  0.2879]],

        [[ 1.0274,  0.3280, -1.3553],
         [-0.4938, -0.9007,  1.3946],
         [ 1.0624, -1.3396,  0.2772]]], grad_fn=<SliceBackward0>)

In [86]:
result = torch.cat(test_list,dim=1)
result[:,:,:]

tensor([[[ 1.0269,  0.3286, -1.3555],
         [ 0.2449, -1.3287,  1.0838],
         [ 1.2687, -1.1754, -0.0933]],

        [[ 1.0212,  0.3366, -1.3579],
         [-0.2319, -1.0922,  1.3241],
         [ 1.3218, -1.0963, -0.2256]],

        [[-0.6127,  1.4102, -0.7975],
         [ 0.0307, -1.2398,  1.2091],
         [ 1.0552, -1.3430,  0.2879]],

        [[ 1.0274,  0.3280, -1.3553],
         [-0.4938, -0.9007,  1.3946],
         [ 1.0624, -1.3396,  0.2772]]], grad_fn=<SliceBackward0>)