## New LDS tokenizer class

Based on the matrix and set up a tokenizer.

Observations: seems this method requires more tokens in the vocab

In [83]:

import logging
from typing import ClassVar

import numpy as np
from scipy.fft import dct
from scipy.fft import idct
from tokenizers import ByteLevelBPETokenizer
from tokenizers.trainers import BpeTrainer
from transformers import PreTrainedTokenizerFast
from transformers.processing_utils import ProcessorMixin

class LDSActionProcessor(ProcessorMixin):
    attributes: ClassVar[list[str]] = ["bpe_tokenizer"]
    bpe_tokenizer_class: str = "AutoTokenizer"

    def __init__(
        self,
        bpe_tokenizer: PreTrainedTokenizerFast,
        scale: float = 10,
        vocab_size: int = 1024,
        min_token: int = 0,
        *,
        action_dim: int | None = None,
        time_horizon: int | None = None,
    ):
        self.scale = scale # TODO: need tuning
        self.vocab_size = vocab_size
        self.min_token = min_token

        # Action horizon and dimension needed during decoding. These can be specified
        # in three ways (in order of priority):
        # 1. passed in as kwargs to decode()
        # 2. in the constructor
        # 3. cached from the last time decode() was called
        self.time_horizon = time_horizon
        self.action_dim = action_dim
        self.called_time_horizon = time_horizon
        self.called_action_dim = action_dim

        super().__init__(bpe_tokenizer)

    @staticmethod
    def get_A_mat(action_chunk: np.array) -> np.array:
        # action chunk: normalized action chunk
        assert action_chunk.ndim == 2, "Only 2 dimensions supported: [timesteps, action_dim]"

        X = action_chunk[:-1, :]  # x_t
        Y = action_chunk[1:, :]  # x_{t+1}
        A, residuals, rank, s = np.linalg.lstsq(X, Y, rcond=None)
        return A            

    def __call__(self, action_chunk: np.array) -> np.array:
        # action chunk: normalized action chunk
        assert action_chunk.ndim <= 3, "Only 3 dimensions supported: [batch, timesteps, action_dim]"
        if action_chunk.ndim == 2:
            action_chunk = action_chunk[None, ...]

        # Cache the time horizon and action dimension for decoding
        self.called_time_horizon = action_chunk.shape[-2]
        self.called_action_dim = action_chunk.shape[-1]

        ##################################################################
        ##### MOD: using LDS to directly approximate the action chunk #####
        LDS_coeffs = []
        for i in range(action_chunk.shape[0]): # for each batch
            # Solve for A using least squares: A = (X^T * X)^(-1) * X^T * Y
            LDS_coeff = self.get_A_mat(action_chunk[i])
            # do the rounding to convert LDS_coeff to integer
            LDS_coeff = np.around(LDS_coeff * self.scale)
            LDS_coeffs.append(LDS_coeff)
        ######################## End of MOD ###########################

        tokens = []
        for elem in LDS_coeffs:
            token_str = "".join(map(chr, np.maximum(elem.flatten() - self.min_token, 0).astype(int)))
            tokens.append(self.bpe_tokenizer(token_str)["input_ids"])
        return tokens

    def decode(
        self,
        tokens: list[list[int]],
        *,
        time_horizon: int | None = None,
        initial_norm_action: list[np.array],
        action_dim: int | None = None,
    ) -> np.array:
        self.time_horizon = time_horizon or self.time_horizon or self.called_time_horizon
        self.action_dim = action_dim or self.action_dim or self.called_action_dim

        # Cache the time horizon and action dimension for the next call
        self.called_time_horizon = self.time_horizon
        self.called_action_dim = self.action_dim

        assert (
            self.time_horizon is not None and self.action_dim is not None
        ), "Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."

        decoded_actions = []
        for batch_idx, token in zip(range(len(tokens)), tokens): # for each batch
            try:
                decoded_tokens = self.bpe_tokenizer.decode(token)
                
                decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.min_token
                decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
                assert (
                    decoded_dct_coeff.shape
                    == (
                        self.action_dim,
                        self.action_dim,
                    )
                ), f"Decoded LDS coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
            except Exception as e:
                print(f"Error decoding tokens: {e}")
                print(f"Tokens: {token}")
                decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim))
            # decoded_actions.append(idct(decoded_dct_coeff / self.scale, axis=0, norm="ortho"))
            
            # original LDS A matrix
            A_mat = decoded_dct_coeff/self.scale

            # rollout the get the action
            A_cum = A_mat
            decoded_actions_norm = [initial_norm_action[batch_idx]]
            for _ in range(self.time_horizon-1):
                decoded_actions_norm.append(np.matmul(A_cum, decoded_actions_norm[-1])) # A * x_{t}
                A_cum = np.matmul(A_cum, A_mat) # A_{t+1} = A_t * A

            decoded_actions.append(np.stack(decoded_actions_norm))
                
        return np.stack(decoded_actions)

    @classmethod
    def fit(
        cls,
        action_data: list[np.array],
        scale: float = 4, # DCT using 10, but 4705 too many for vocab
        vocab_size: int = 2048, # TODO: 1024
        *,
        time_horizon: int | None = None,
        action_dim: int | None = None,
    ) -> "LDSActionProcessor":
        # # Run DCT over all inputs
        # dct_tokens = [dct(a, axis=0, norm="ortho").flatten() for a in action_data]

        ############## Mod for the DCT #################
        lds_tokens = [LDSActionProcessor.get_A_mat(a).flatten() for a in action_data]
        ############## End of Mod #####################

        # Quantize and find min token
        max_token = int(np.around(np.concatenate(lds_tokens) * scale).max())
        min_token = int(np.around(np.concatenate(lds_tokens) * scale).min())
        min_vocab_size = max_token - min_token

        assert (
            min_vocab_size <= vocab_size
        ), f"Vocab size {vocab_size} is too small for the range of tokens {min_vocab_size}"
        if min_vocab_size + 100 > vocab_size:
            logging.warning(
                f"Initial alphabet size {min_vocab_size} is almost as large as the vocab"
                f"size {vocab_size}, consider increasing vocab size"
            )

        # Make token iterator for BPE training
        def _token_iter():
            for tokens in lds_tokens:
                rounded_tokens = np.around(tokens * scale) - min_token
                rounded_tokens = rounded_tokens.astype(int)
                string = "".join(map(chr, rounded_tokens))
                yield string

        # Train BPE tokenizer
        bpe = ByteLevelBPETokenizer()

        # Set up the entire range of possible tokens as the initial alphabet
        alphabet = [chr(i) for i in range(max_token - min_token + 1)]
        trainer = BpeTrainer(
            vocab_size=vocab_size,
            min_frequency=2,
            show_progress=True,
            special_tokens=[],
            initial_alphabet=alphabet,
            max_token_length=10000,
        )

        # Train the inner tokenizer (don't use ByteLevelBPETokenizer.train_from_iterator()
        # because it doesn't support custom alphabets)
        bpe._tokenizer.train_from_iterator(_token_iter(), trainer=trainer)

        return cls(
            PreTrainedTokenizerFast(tokenizer_object=bpe, clean_up_tokenization_spaces=False),
            scale=scale,
            vocab_size=vocab_size,
            min_token=min_token,
            time_horizon=time_horizon,
            action_dim=action_dim,
        )


## Dataloader of robomimic sim dataset
Test the input output error side by side with LDS-Tokenizer

In [30]:
import numpy as np
import matplotlib.pyplot as plt
# import seaborn as sns
import h5py

def list_action_data(file_path_list, chunk_length=20, num_steps=10):
    ## read from list of file_path, get each file_path, return list[np.array]
    # chunk_length: the length of the action chunk
    # num_steps: the starting idx in demo data to sample for action chunk

    action_data_list_all = []

    ## Read state and action data
    demo_state = []
    demo_action = []

    for file_path in file_path_list:
        with h5py.File(file_path, 'r') as f:
            f = f['data']
            print(f.keys())
            print(f['demo_0'].keys())
            for demo in f.keys():
                demo_state.append(f[demo]['states'][:])
                demo_action.append(f[demo]['actions'][:])
    
    # create action chunk from each demo actions
    # TODO: sample different starting idx for each demo with num_steps
    for demo_idx in range(len(demo_action)):
        demo = demo_action[demo_idx]
        demo_length = len(demo)
        for i in range(demo_length // num_steps):
            action_chunk = demo[i*num_steps:i*num_steps+chunk_length]
            action_data_list_all.append(action_chunk)
    return action_data_list_all

def get_normalized_action_stats(action_data_list):
    # get normalization stats for the action data
    action_data_array = np.concatenate(action_data_list, axis=0)
    print(f"normalized action data shape: {action_data_array.shape}")
    mean, std = np.mean(action_data_array, axis=0), np.std(action_data_array, axis=0)
    return mean, std

def normalize_action_data(action_data_list, mean, std):
    # normalize the action data to [0, 1], where action_data_list = [np.array, np.array, ...]
    for data in action_data_list:
        data = (data - mean) / std
    return action_data_list

file_path = ['/srv/rl2-lab/flash7/zhenyang/data/robomimic-sim/low_dim_v141.hdf5']
action_data_list = list_action_data(file_path)

<KeysViewHDF5 ['demo_0', 'demo_1', 'demo_10', 'demo_100', 'demo_101', 'demo_102', 'demo_103', 'demo_104', 'demo_105', 'demo_106', 'demo_107', 'demo_108', 'demo_109', 'demo_11', 'demo_110', 'demo_111', 'demo_112', 'demo_113', 'demo_114', 'demo_115', 'demo_116', 'demo_117', 'demo_118', 'demo_119', 'demo_12', 'demo_120', 'demo_121', 'demo_122', 'demo_123', 'demo_124', 'demo_125', 'demo_126', 'demo_127', 'demo_128', 'demo_129', 'demo_13', 'demo_130', 'demo_131', 'demo_132', 'demo_133', 'demo_134', 'demo_135', 'demo_136', 'demo_137', 'demo_138', 'demo_139', 'demo_14', 'demo_140', 'demo_141', 'demo_142', 'demo_143', 'demo_144', 'demo_145', 'demo_146', 'demo_147', 'demo_148', 'demo_149', 'demo_15', 'demo_150', 'demo_151', 'demo_152', 'demo_153', 'demo_154', 'demo_155', 'demo_156', 'demo_157', 'demo_158', 'demo_159', 'demo_16', 'demo_160', 'demo_161', 'demo_162', 'demo_163', 'demo_164', 'demo_165', 'demo_166', 'demo_167', 'demo_168', 'demo_169', 'demo_17', 'demo_170', 'demo_171', 'demo_172', '

## Train the tokenizer using the robomimic sim dataset

In [88]:

## Normalize the action data
mean, std = get_normalized_action_stats(action_data_list)
action_data_list_norm = normalize_action_data(action_data_list, mean, std)

print(f"shape of action_data_list_norm: {action_data_list_norm[0].shape}")

## Train the tokenizer
tokenizer = LDSActionProcessor.fit(action_data_list_norm)

## Test the reconstruction error (using training set)
def compare_gt_pred(gt_action, mean, std, tokenizer):
    """
    take an original action, then we normalize and get the tokens
    then recover the action from the tokens, and compare the original action and the recovered action
    """
    gt_action_norm = (gt_action - mean) / std
    tokens = tokenizer(gt_action_norm) # list[list[int]]
    print(f"tokens shape: {len(tokens[0])}")
    recovered_action_norm = tokenizer.decode(tokens, initial_norm_action=[gt_action_norm[0]])
    recovered_action = recovered_action_norm * std + mean
    
    reconstr_error = np.linalg.norm(gt_action - recovered_action[0, :gt_action.shape[0]], axis=0) # in case action dim < 0
    return reconstr_error

## Test the overfitting result
reconstr_error_list = []
for i in range(len(action_data_list)):
    reconstr_error = compare_gt_pred(action_data_list[i], mean, std, tokenizer)
    print(f"reconstr_error: {reconstr_error}")
    reconstr_error_list.append(reconstr_error)

print(f"reconstr_error_list shape: {np.array(reconstr_error_list).shape}")
print(f"reconstr_error average across all actions: {np.mean(reconstr_error_list, axis=0)}")



normalized action data shape: (16386, 7)
shape of action_data_list_norm: (20, 7)



tokens shape: 22
reconstr_error: [4.77226169e+06 7.14765295e+05 3.45996463e+07 6.41486834e+05
 1.88314140e+06 4.89313924e+06 2.51874000e+08]
tokens shape: 26
reconstr_error: [ 1.47314797  0.54527374  2.65337147  0.14412243  0.31054542  1.54525343
 33.9172244 ]
tokens shape: 32
reconstr_error: [ 47.09911682   2.88010345  67.23805298   1.28372719   9.86761245
   9.55550205 204.88929077]
tokens shape: 25
reconstr_error: [1.37051589 0.3238242  0.54095953 0.03603334 0.89985195 0.34173843
 3.99723723]
tokens shape: 27
reconstr_error: [5.79987766e+13 9.03734133e+12 7.58731419e+13 7.31749723e+12
 2.57115956e+12 1.63952437e+13 1.21967748e+14]
tokens shape: 27
reconstr_error: [1.42238177e+08 8.82559474e+07 5.46715644e+08 6.77022061e+06
 1.10069822e+08 7.20384119e+08 3.35299187e+09]
tokens shape: 27
reconstr_error: [4.62821207e-01 6.85239087e-01 9.81161146e+01 1.95061780e+01
 2.96539578e+01 1.91779343e+02 1.698244

  s = (x.conj() * x).real


tokens shape: 34
reconstr_error: [1.98116907e+12 3.74361214e+11 7.66966987e+12 1.27845124e+11
 5.17524571e+11 1.87212374e+13 2.50969179e+13]
tokens shape: 29
reconstr_error: [ 1.42688204  0.47381096  1.35356264  0.09218642  0.21696319  1.32314472
 11.45184944]
tokens shape: 29
reconstr_error: [2.97396317e+00 8.99609505e-01 1.45417342e+01 7.91261948e-02
 3.79051964e+00 3.77017633e+00 2.40566311e+03]
tokens shape: 25
reconstr_error: [5.96967195e+14 1.63390790e+15 1.16679970e+16 3.20273720e+15
 4.81247445e+15 9.67948296e+15 6.86174340e+16]
tokens shape: 27
reconstr_error: [9.91938429e+10 7.23332801e+08 4.52737016e+10 3.01126801e+09
 3.56146426e+09 1.66733972e+11 1.43821158e+11]
tokens shape: 25
reconstr_error: [1.44034858e+00 8.69355984e-01 1.28399321e+00 6.19588184e-02
 6.47346040e-01 1.77474466e-01 8.80617107e+01]
tokens shape: 29
reconstr_error: [1.63381821e+03 4.42261780e-01 1.08552306e+04 1.42548995e+03
 1.04663383e+03 2.07521931e+04 7.23968144e+05]
tokens shape: 24
reconstr_error: [