In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import OneHotCategorical, Normal
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
from torch.cuda.amp import autocast
from torch.utils.tensorboard import SummaryWriter

from sub_models.functions_losses import SymLogTwoHotLoss

from sub_models.attention_blocks import get_subsequent_mask_with_batch_length, get_subsequent_mask
from sub_models.transformer_model import StochasticTransformerKVCache
from sub_models.attention_blocks import get_vector_mask
from sub_models.attention_blocks import PositionalEncoding1D, AttentionBlock, AttentionBlockKVCache

In [18]:
class StochasticTransformerKVCache2(nn.Module):
    def __init__(self, stoch_dim, action_dim, feat_dim, num_layers, num_heads, max_length, dropout):
        super().__init__()
        self.action_dim = action_dim
        self.feat_dim = feat_dim

        # mix image_embedding and action
        self.stem = nn.Sequential(
            nn.Linear(stoch_dim+action_dim, feat_dim, bias=False),
            nn.LayerNorm(feat_dim),
            nn.ReLU(inplace=True),
            nn.Linear(feat_dim, feat_dim, bias=False),
            nn.LayerNorm(feat_dim)
        )
        self.position_encoding = PositionalEncoding1D(max_length=max_length, embed_dim=feat_dim)
        self.layer_stack = nn.ModuleList([
            AttentionBlockKVCache(feat_dim=feat_dim, hidden_dim=feat_dim*2, num_heads=num_heads, dropout=dropout) for _ in range(num_layers)
        ])
        self.layer_norm = nn.LayerNorm(feat_dim, eps=1e-6)  # TODO: check if this is necessary

    def forward(self, samples, action, mask):
        '''
        Normal forward pass
        '''
        # action is not one hot
        # action = F.one_hot(action.long(), self.action_dim).float() 
        print(mask)
        feats = self.stem(torch.cat([samples, action], dim=-1))
        feats = self.position_encoding(feats)
        feats = self.layer_norm(feats)
        for layer in self.layer_stack:
            print("pre:",feats)
            feats, attn = layer(feats, feats, feats, mask)
            print("post:",feats)
        return feats

    def reset_kv_cache_list(self, batch_size, dtype):
        '''
        Reset self.kv_cache_list
        '''
        self.kv_cache_list = []
        for layer in self.layer_stack:
            self.kv_cache_list.append(torch.zeros(size=(batch_size, 0, self.feat_dim), dtype=dtype, device="cpu"))
            

    def forward_with_kv_cache(self, samples, action,test):
        '''
        Forward pass with kv_cache, cache stored in self.kv_cache_list
        '''
        assert samples.shape[1] == 1
        mask = get_vector_mask(self.kv_cache_list[0].shape[1]+1, samples.device)
        # print(mask)
        # action = F.one_hot(action.long(), self.action_dim).float()
        feats = self.stem(torch.cat([samples, action], dim=-1))
        print("feats1.dim",feats.shape)
        feats = self.position_encoding.forward_with_position(feats, position=self.kv_cache_list[0].shape[1])
        print("feats2.dim",feats.shape)
        feats = self.layer_norm(feats)
        print("feats3.dim",feats.shape)
        for idx, layer in enumerate(self.layer_stack):
            self.kv_cache_list[idx] = torch.cat([self.kv_cache_list[idx], feats], dim=1)
            # print("pre:",self.kv_cache_list[idx])
            print("feats4.dim",feats.shape)
            feats, attn = layer(feats, self.kv_cache_list[idx], self.kv_cache_list[idx], mask)
            # print("post:",feats)

        return feats

    def forward_context(self, samples, action, mask):
        '''
        Normal forward pass
        '''
        # action is not one hot
        # action = F.one_hot(action.long(), self.action_dim).float() 
        # print(mask)
        feats = self.stem(torch.cat([samples, action], dim=-1))
        feats = self.position_encoding(feats)
        feats = self.layer_norm(feats)
        kv_cache_test = feats
        for layer in self.layer_stack:
            # print("pre:",feats)
            feats, attn = layer(feats, feats, feats, mask)
            # print("post:",feats)
        return feats,kv_cache_test


In [29]:
storm_transformer = StochasticTransformerKVCache2(
            stoch_dim=3,
            action_dim=1,
            feat_dim=3,
            num_layers=1,
            num_heads=1,
            max_length=5,
            dropout=0
        )

In [30]:
latent = torch.rand(4,5,3)
action = torch.rand(4,5,1)
temporal_mask = get_subsequent_mask(latent)
print(temporal_mask)

tensor([[[ True, False, False, False, False],
         [ True,  True, False, False, False],
         [ True,  True,  True, False, False],
         [ True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True]]])


In [31]:
# with torch.inference_mode():
#     dist_feat = storm_transformer(latent, action, temporal_mask)
storm_transformer.eval()
dist_feat, kv_test = storm_transformer.forward_context(latent, action, temporal_mask)

In [32]:
# with torch.inference_mode():
#     storm_transformer.reset_kv_cache_list(4,torch.float32)
#     test_list = []
#     for i in range(3):
#         dist_feat_kv = storm_transformer.forward_with_kv_cache(latent[:,i:i+1,:], action[:,i:i+1,:],i)
#         test_list.append(dist_feat_kv)
storm_transformer.reset_kv_cache_list(4,torch.float32)
test_list = []
for i in range(5):
    dist_feat_kv = storm_transformer.forward_with_kv_cache(latent[:,i:i+1,:], action[:,i:i+1,:],i)
    test_list.append(dist_feat_kv)

feats1.dim torch.Size([4, 1, 3])
feats2.dim torch.Size([4, 1, 3])
feats3.dim torch.Size([4, 1, 3])
feats4.dim torch.Size([4, 1, 3])
feats1.dim torch.Size([4, 1, 3])
feats2.dim torch.Size([4, 1, 3])
feats3.dim torch.Size([4, 1, 3])
feats4.dim torch.Size([4, 1, 3])
feats1.dim torch.Size([4, 1, 3])
feats2.dim torch.Size([4, 1, 3])
feats3.dim torch.Size([4, 1, 3])
feats4.dim torch.Size([4, 1, 3])
feats1.dim torch.Size([4, 1, 3])
feats2.dim torch.Size([4, 1, 3])
feats3.dim torch.Size([4, 1, 3])
feats4.dim torch.Size([4, 1, 3])
feats1.dim torch.Size([4, 1, 3])
feats2.dim torch.Size([4, 1, 3])
feats3.dim torch.Size([4, 1, 3])
feats4.dim torch.Size([4, 1, 3])


In [33]:
dist_feat[:,:,:]

tensor([[[ 0.3618,  1.0031, -1.3649],
         [-1.3932,  0.9071,  0.4861],
         [-1.3437,  0.2900,  1.0537],
         [-0.2231,  1.3210, -1.0978],
         [-1.4036,  0.8514,  0.5522]],

        [[-0.6642,  1.4134, -0.7492],
         [-1.4035,  0.8524,  0.5510],
         [-1.2966,  0.1593,  1.1373],
         [-1.4082,  0.8174,  0.5908],
         [-1.1769,  1.2676, -0.0907]],

        [[ 0.1203,  1.1602, -1.2804],
         [-1.1619,  1.2792, -0.1173],
         [-1.4012,  0.5349,  0.8663],
         [-0.6396,  1.4121, -0.7725],
         [-1.3687,  0.9926,  0.3760]],

        [[ 0.2109,  1.1056, -1.3165],
         [-1.1716,  1.2717, -0.1001],
         [-1.3078,  0.1879,  1.1199],
         [-1.3728,  0.9805,  0.3924],
         [-1.3350,  1.0716,  0.2634]]], grad_fn=<SliceBackward0>)

In [34]:
result = torch.cat(test_list,dim=1)
result[:,:,:]

tensor([[[ 0.3618,  1.0031, -1.3649],
         [-1.3932,  0.9071,  0.4861],
         [-1.3437,  0.2900,  1.0537],
         [-0.2231,  1.3210, -1.0978],
         [-1.4036,  0.8514,  0.5522]],

        [[-0.6642,  1.4134, -0.7492],
         [-1.4035,  0.8524,  0.5510],
         [-1.2966,  0.1593,  1.1373],
         [-1.4082,  0.8174,  0.5908],
         [-1.1769,  1.2676, -0.0907]],

        [[ 0.1203,  1.1602, -1.2804],
         [-1.1619,  1.2792, -0.1173],
         [-1.4012,  0.5349,  0.8663],
         [-0.6396,  1.4121, -0.7725],
         [-1.3687,  0.9926,  0.3760]],

        [[ 0.2109,  1.1056, -1.3165],
         [-1.1716,  1.2717, -0.1001],
         [-1.3078,  0.1879,  1.1199],
         [-1.3728,  0.9805,  0.3924],
         [-1.3350,  1.0716,  0.2634]]], grad_fn=<SliceBackward0>)

In [35]:
storm_transformer.kv_cache_list

[tensor([[[ 0.0911,  1.1767, -1.2677],
          [-1.2728,  1.1702,  0.1027],
          [-1.4135,  0.7461,  0.6673],
          [-0.3324,  1.3566, -1.0242],
          [-1.2743,  1.1683,  0.1059]],
 
         [[-0.6688,  1.4135, -0.7447],
          [-1.2728,  1.1702,  0.1027],
          [-1.4117,  0.6334,  0.7784],
          [-1.2633,  1.1821,  0.0812],
          [-0.9749,  1.3747, -0.3997]],
 
         [[-0.1154,  1.2784, -1.1630],
          [-1.0452,  1.3476, -0.3024],
          [-1.3808,  0.9552,  0.4256],
          [-0.6298,  1.4115, -0.7817],
          [-1.2110,  1.2380, -0.0270]],
 
         [[-0.0415,  1.2450, -1.2035],
          [-1.0557,  1.3428, -0.2871],
          [-1.4117,  0.6334,  0.7784],
          [-1.2049,  1.2436, -0.0387],
          [-1.1499,  1.2879, -0.1380]]], grad_fn=<CatBackward0>)]

In [10]:
kv_test[:,:,:]

tensor([[[ 1.3687, -0.3761, -0.9926],
         [ 1.2716, -0.0997, -1.1718],
         [ 1.1253,  0.1792, -1.3045],
         [ 1.3443, -0.2919, -1.0524],
         [ 0.2353,  1.0900, -1.3253]],

        [[ 1.3525, -0.3185, -1.0341],
         [ 1.2684, -0.0925, -1.1759],
         [ 1.1253,  0.1792, -1.3045],
         [ 1.3443, -0.2919, -1.0524],
         [ 0.2310,  1.0928, -1.3238]],

        [[ 1.3687, -0.3761, -0.9926],
         [ 1.2622, -0.0786, -1.1835],
         [ 1.1253,  0.1792, -1.3045],
         [ 1.3443, -0.2919, -1.0524],
         [ 0.2310,  1.0928, -1.3238]],

        [[ 1.3687, -0.3761, -0.9926],
         [ 1.2716, -0.0997, -1.1718],
         [ 1.1253,  0.1792, -1.3045],
         [ 1.3443, -0.2919, -1.0524],
         [ 0.1941,  1.1161, -1.3102]]], grad_fn=<SliceBackward0>)