In [1]:
import torch
import torch.nn.functional as F


def resize_1d_pos_embedding(posemb_ori: torch.Tensor, 
                            new_length: int,
                            interpolation: str = 'linear', 
                            antialias: bool = True,
                            verbose: bool = False):
    """
    Resize a 1D positional embedding of shape [T, C] to new length.
    posemb: Tensor of shape [T, C]
    new_length: int, the target number of tokens
    """
    old_length, dim = posemb_ori.shape
    if old_length == new_length:
        return posemb_ori

    posemb = posemb_ori[:-1,:].unsqueeze(0).permute(0, 2, 1)  # (1, C, T)
    posemb = F.interpolate(posemb, size=new_length-1, mode=interpolation,)
    posemb = posemb.permute(0, 2, 1).squeeze(0)  # (T_new, C)
    posemb = torch.cat([posemb, posemb_ori[-1:,:]], dim=0)
    if verbose:
        print(f"Resized position embedding from length {old_length} to {new_length}.")

    return posemb

def update_ckpt_pos_embedding(ckpt_path, save_path, new_len, key='positional_embedding'):
    ckpt = torch.load(ckpt_path, map_location='cpu')
    posemb = ckpt['model']['action_model']['net.positional_embedding']  # shape: [T, C]
    resized = resize_1d_pos_embedding(posemb, new_len)
    print(resized.shape)
    ckpt['model']['action_model']['net.positional_embedding'] = resized
    torch.save(ckpt, save_path)
    print(f"Saved resized checkpoint to {save_path}")
    
update_ckpt_pos_embedding(
    ckpt_path='./checkpoints/step-050000-epoch-02-loss=0.0504.pt',
    save_path='./reshape_embedding_step2.pt',
    new_len=7+1+1,
)

In [2]:
import torch
ckpt = torch.load('./reshape_embedding_step2.pt', map_location='cpu')
print(ckpt['model']['action_model']['net.positional_embedding'].shape)

torch.Size([9, 768])
