In [1]:
import math
import numpy as np
import torch
import torch.nn as nn   
from functorch import jacrev, vmap
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv,ResGatedGraphConv,SAGEConv, GatedGraphConv
import torch.nn as nn 
from OnlineAdaptation.modules.vq_torch.vq_quantize import VectorQuantize
import torch.optim as optim

In [2]:
def get_activation(act_name):
    if act_name == "elu":
        return nn.ELU()
    elif act_name == "selu":
        return nn.SELU()
    elif act_name == "relu":
        return nn.ReLU()
    elif act_name == "crelu":
        return nn.ReLU()
    elif act_name == "lrelu":
        return nn.LeakyReLU()
    elif act_name == "tanh":
        return nn.Tanh()
    elif act_name == "sigmoid":
        return nn.Sigmoid()
    elif act_name == "identity":
        return nn.Identity()
    else:
        print("invalid activation function!")
        return None

In [3]:
body_index = {
    '0': [0,1,2,3,4,5,6,7,8,9,10,11,48,49,50,51],
}

hip_index = {
    '1': [12,24,36,48],
    '2': [15,27,39,49],
    '3': [18,30,42,50],
    '4': [21,33,45,51],
}
thigh_index = {
    '5': [13,25,37,48],
    '6': [16,28,40,49],
    '7': [19,31,43,50],
    '8': [22,34,46,51],
}
calf_index = {
    '9': [14,26,38,48],
    '10': [17,29,41,49],
    '11': [20,32,44,50],
    '12': [23,35,47,51],
}

edge_index = [
    [0,1],[0,2],[0,3],[0,4],[0,5],[0,6],[0,7],[0,8],[0,9],[0,10],[0,11],[0,12],
    [1,5],[5,9],
    [2,6],[6,10],
    [3,7],[7,11],
    [4,8],[8,12],
]

def mlp(input_dim, out_dim, hidden_sizes, activations):
    layers = []
    prev_h = input_dim
    for h in hidden_sizes:
        layers.append(nn.Linear(prev_h, h))
        layers.append(activations)
        prev_h = h
    layers.append(nn.Linear(prev_h, out_dim))
    return nn.Sequential(*layers)


In [4]:
class GraphEncoder(nn.Module):
    def __init__(self,
                 num_obs,
                 num_history,
                 num_latent,
                 activation = 'relu',):
        super(GraphEncoder, self).__init__()
        self.num_obs = num_obs
        self.num_latent = num_latent
        activation_fn = get_activation(activation)
        # graph info
        node_base = torch.tensor(list(body_index.values()),dtype=torch.long).squeeze()
        node_hip = torch.stack([torch.tensor(list(hip_index.values()),dtype=torch.long)],dim=0).squeeze()
        node_thigh = torch.stack([torch.tensor(list(thigh_index.values()),dtype=torch.long)],dim=0).squeeze()
        node_calf = torch.stack([torch.tensor(list(calf_index.values()),dtype=torch.long)],dim=0).squeeze()

        self.node_base = nn.Parameter(node_base, requires_grad=False)
        self.node_hip = nn.Parameter(node_hip, requires_grad=False)
        self.node_thigh = nn.Parameter(node_thigh, requires_grad=False)
        self.node_calf = nn.Parameter(node_calf, requires_grad=False) 

        self.edge = torch.as_tensor(edge_index, dtype=torch.long).contiguous().t()

        # build feature extractor for base, hip, thigh and calf
        base_input_size = num_history * len(list(body_index.values())[0])
        hip_input_size = num_history * len(list(hip_index.values())[0])
        thigh_input_size = num_history * len(list(thigh_index.values())[0])
        calf_input_size = num_history * len(list(calf_index.values())[0])
        self.base_net = mlp(base_input_size, 2 * num_latent, [128, 64], activation_fn)
        self.hip_net = mlp(hip_input_size, 2 * num_latent, [128, 64], activation_fn)
        self.thigh_net = mlp(thigh_input_size, 2 * num_latent, [128, 64], activation_fn)
        self.calf_net = mlp(calf_input_size, 2* num_latent, [128, 64], activation_fn)

        # build graph net 
        self.gn = ResGatedGraphConv(in_channels= 2* num_latent,out_channels=2* num_latent)
        self.gn2 = ResGatedGraphConv(in_channels= 2* num_latent,out_channels= num_latent)
        self.act = activation_fn
    
    def _history2node(self,obs_history):
        # obs_history.shape = (bz, n_histroy, n_obs)
        base = obs_history[:,:,self.node_base].unsqueeze(1) # (bz, n_history, n_base)
        hip = obs_history[:,:,self.node_hip].permute(0,2,1,3) # (bz, n_history, 4, 4)
        thigh = obs_history[:,:,self.node_thigh].permute(0,2,1,3)
        calf = obs_history[:,:,self.node_calf].permute(0,2,1,3) 
        base = self.base_net(base.flatten(-2,-1))
        hip = self.hip_net(hip.flatten(-2,-1))
        thigh = self.thigh_net(thigh.flatten(-2,-1))
        calf = self.calf_net(calf.flatten(-2,-1))
        return torch.cat([base,hip,thigh,calf],dim=1) # (bz, n_node,num_latent)
    
    def forward(self,obs_history):
        nodes = self._history2node(obs_history)
        nodes = self.gn(nodes,self.edge)
        nodes = self.act(nodes)
        nodes = self.gn2(nodes,self.edge)
        return nodes

    
class GraphActor(nn.Module):
    def __init__(self, 
                 num_obs,
                 num_latent,
                 num_actions,
                 activation = 'elu',
                 actor_hidden_dims = [512, 256, 128]):
        super().__init__()
        self.num_latent = num_latent
        # graph info
        node_base = torch.tensor(list(body_index.values()),dtype=torch.long).squeeze()
        node_hip = torch.stack([torch.tensor(list(hip_index.values()),dtype=torch.long)],dim=0).squeeze()
        node_thigh = torch.stack([torch.tensor(list(thigh_index.values()),dtype=torch.long)],dim=0).squeeze()
        node_calf = torch.stack([torch.tensor(list(calf_index.values()),dtype=torch.long)],dim=0).squeeze()
        activation_fn = get_activation(activation)
        self.node_base = nn.Parameter(node_base, requires_grad=False)
        self.node_hip = nn.Parameter(node_hip, requires_grad=False)
        self.node_thigh = nn.Parameter(node_thigh, requires_grad=False)
        self.node_calf = nn.Parameter(node_calf, requires_grad=False) 
        self.edge = torch.as_tensor(edge_index, dtype=torch.long).contiguous().t()
        # Pipeline
        # obs-> node, concat with latent -> node latent
        # node latent -> node level policy -> node action
        base_input_size =  len(list(body_index.values())[0])
        hip_input_size = len(list(hip_index.values())[0])
        thigh_input_size =len(list(thigh_index.values())[0])
        calf_input_size = len(list(calf_index.values())[0])
        self.base_net = mlp(base_input_size, num_latent, [128], activation_fn)
        self.hip_net = mlp(hip_input_size, num_latent, [128], activation_fn)
        self.thigh_net = mlp(thigh_input_size, num_latent, [128], activation_fn)
        self.calf_net = mlp(calf_input_size, num_latent, [128], activation_fn)

        # graph neural network 
        self.gn = ResGatedGraphConv(in_channels=2 * num_latent,
                                    out_channels=2 * num_latent)
        self.act = activation_fn
        self.gn2 = ResGatedGraphConv(in_channels=2 * num_latent,
                                    out_channels=num_latent)
        
        # graph policy : (base, hip, thigh, calf) -> leg_control 
        self.leg_policy = mlp(4*num_latent, 3, [256,128], activation_fn)
        self.FL_Leg = nn.Parameter(torch.tensor([0,1,5,9],dtype=torch.long), requires_grad=False)
        self.FR_Leg = nn.Parameter(torch.tensor([0,2,6,10],dtype=torch.long), requires_grad=False)
        self.RL_Leg = nn.Parameter(torch.tensor([0,3,7,11],dtype=torch.long), requires_grad=False)
        self.RR_Leg = nn.Parameter(torch.tensor([0,4,8,12],dtype=torch.long), requires_grad=False)

    def _obs2node(self, obs):
        # obs.shape = (bz, n_obs)
        base = obs[:,self.node_base].unsqueeze(1) # (bz, 1, n_base)
        hip = obs[:,self.node_hip]# (bz, 4, 4)
        thigh = obs[:,self.node_thigh]
        calf = obs[:,self.node_calf]
        base = self.base_net(base)
        hip = self.hip_net(hip)
        thigh = self.thigh_net(thigh)
        calf = self.calf_net(calf)
        return torch.cat([base,hip,thigh,calf],dim=1) # (bz, n_node,num_latent)
    def forward(self, obs, latent):
        # obs.shape = (bz, n_obs)
        # latent.shape = (bz,n_node, num_latent)
        obs_nodes = self._obs2node(obs) # (bz, n_node, num_latent) 
        nodes_latent = torch.cat([obs_nodes,latent],dim=-1) # (bz, n_node, 2*num_latent)
        nodes_latent = self.gn(nodes_latent, self.edge) # (bz, n_node, 2*num_latent)
        nodes_latent = self.act(nodes_latent)
        nodes_latent = self.gn2(nodes_latent, self.edge) # (bz, n_node, num_latent) 
        FL_Leg_latent = nodes_latent[:,self.FL_Leg,:].reshape(-1,4*self.num_latent)
        FR_Leg_latent = nodes_latent[:,self.FR_Leg,:].reshape(-1,4*self.num_latent)
        RL_Leg_latent = nodes_latent[:,self.RL_Leg,:].reshape(-1,4*self.num_latent)
        RR_Leg_latent = nodes_latent[:,self.RR_Leg,:].reshape(-1,4*self.num_latent)
        FL_Leg_action = self.leg_policy(FL_Leg_latent)
        FR_Leg_action = self.leg_policy(FR_Leg_latent)
        RL_Leg_action = self.leg_policy(RL_Leg_latent)
        RR_Leg_action = self.leg_policy(RR_Leg_latent)
        return torch.cat([FL_Leg_action,FR_Leg_action,RL_Leg_action,RR_Leg_action],dim=1) # (bz, 12)

class GraphForward(nn.Module):
    def __init__(self, 
                 num_obs,
                 num_latent,
                 num_actions,
                 activation = 'elu',
                 actor_hidden_dims = [512, 256, 128]):
        super().__init__()
        """
        obs + action + latent -> next_obs 
        obs -> node 
        action -> node
        """
        self.num_latent = num_latent
        # graph info
        node_base = torch.tensor(list(body_index.values()),dtype=torch.long).squeeze()
        node_hip = torch.stack([torch.tensor(list(hip_index.values()),dtype=torch.long)],dim=0).squeeze()
        node_thigh = torch.stack([torch.tensor(list(thigh_index.values()),dtype=torch.long)],dim=0).squeeze()
        node_calf = torch.stack([torch.tensor(list(calf_index.values()),dtype=torch.long)],dim=0).squeeze()
        
        node_hip_action = torch.tensor([0, 3, 6, 9],dtype=torch.long)
        node_thigh_action = torch.tensor([1, 4, 7, 10],dtype=torch.long)
        node_calf_action = torch.tensor([2, 5, 8, 11],dtype=torch.long)
        self.node_hip_action = nn.Parameter(node_hip_action, requires_grad=False)
        self.node_thigh_action = nn.Parameter(node_thigh_action, requires_grad=False)
        self.node_calf_action = nn.Parameter(node_calf_action, requires_grad=False)

        activation_fn = get_activation(activation)
        self.node_base = nn.Parameter(node_base, requires_grad=False)
        self.node_hip = nn.Parameter(node_hip, requires_grad=False)
        self.node_thigh = nn.Parameter(node_thigh, requires_grad=False)
        self.node_calf = nn.Parameter(node_calf, requires_grad=False) 
        self.edge = nn.Parameter(torch.as_tensor(edge_index, dtype=torch.long).contiguous().t(),requires_grad=False)
        # pipeline 
        base_input_size =  len(list(body_index.values())[0])
        hip_input_size = len(list(hip_index.values())[0]) + 1
        thigh_input_size =len(list(thigh_index.values())[0]) + 1 
        calf_input_size = len(list(calf_index.values())[0]) + 1 
        self.base_net = mlp(base_input_size, num_latent, [128], activation_fn)
        self.hip_net = mlp(hip_input_size, num_latent, [128], activation_fn)
        self.thigh_net = mlp(thigh_input_size, num_latent, [128], activation_fn)
        self.calf_net = mlp(calf_input_size, num_latent, [128], activation_fn)
        # graph neural network 
        self.gn = ResGatedGraphConv(in_channels=2 * num_latent,
                                    out_channels=2 * num_latent)
        self.act = activation_fn
        self.gn2 = ResGatedGraphConv(in_channels=2 * num_latent,
                                    out_channels=num_latent)
        # decoder 
        ## base_vel, base_ang, project_gravity, cmd
        ## dof_pos,vel,actions, contact
        self.base_decoder = mlp(num_latent, 3 + 3 + 3 + 3, [128], activation_fn)
        self.leg_decoder = mlp(num_latent * 4 , 3 + 3 + 3 + 1, [128], activation_fn)
        self.FL_Leg = nn.Parameter(torch.tensor([0,1,5,9],dtype=torch.long), requires_grad=False)
        self.FR_Leg = nn.Parameter(torch.tensor([0,2,6,10],dtype=torch.long), requires_grad=False)
        self.RL_Leg = nn.Parameter(torch.tensor([0,3,7,11],dtype=torch.long), requires_grad=False)
        self.RR_Leg = nn.Parameter(torch.tensor([0,4,8,12],dtype=torch.long), requires_grad=False)


    def _obsaction2node(self,obs,action):
        base = obs[:,self.node_base].unsqueeze(1) # (bz, 1, n_base)
        hip = obs[:,self.node_hip]# (bz, 4, 4)
        thigh = obs[:,self.node_thigh]
        calf = obs[:,self.node_calf]
        hip_action = action[:,self.node_hip_action].unsqueeze(-1)
        thigh_action = action[:,self.node_thigh_action].unsqueeze(-1)
        calf_action = action[:,self.node_calf_action].unsqueeze(-1)
        base = self.base_net(base)
        hip = self.hip_net(torch.cat([hip,hip_action],dim=-1))
        thigh = self.hip_net(torch.cat([thigh,thigh_action],dim=-1))
        calf = self.calf_net(torch.cat([calf,calf_action],dim=-1))
        node = torch.cat([base,hip,thigh,calf],dim=1)
        return node # shape (bz, 13, 4)

    def forward(self,obs,action,latent):
        obsaction_node = self._obsaction2node(obs,action) 
        nodes_latent = torch.cat([obsaction_node,latent],dim=-1) # (bz, n_node, 2*num_latent)
        nodes_latent = self.gn(nodes_latent, self.edge) # (bz, n_node, 2*num_latent)
        nodes_latent = self.act(nodes_latent)
        nodes_latent = self.gn2(nodes_latent, self.edge) # (bz, n_node, num_latent) 

        Base_latent = nodes_latent[:,0:1,:].reshape(-1,self.num_latent)
        FL_Leg_latent = nodes_latent[:,self.FL_Leg,:].reshape(-1,4*self.num_latent)
        FR_Leg_latent = nodes_latent[:,self.FR_Leg,:].reshape(-1,4*self.num_latent)
        RL_Leg_latent = nodes_latent[:,self.RL_Leg,:].reshape(-1,4*self.num_latent)
        RR_Leg_latent = nodes_latent[:,self.RR_Leg,:].reshape(-1,4*self.num_latent)

        base_decoded = self.base_decoder(Base_latent) # (bz,12)
        FL_Leg_decoded = self.leg_decoder(FL_Leg_latent) # (bz, 4) 
        FR_Leg_decoded = self.leg_decoder(FR_Leg_latent)
        RL_Leg_decoded = self.leg_decoder(RL_Leg_latent)
        RR_Leg_decoded = self.leg_decoder(RR_Leg_latent)
        decoded_pos = torch.cat([FL_Leg_decoded[:,0:3],FR_Leg_decoded[:,0:3],RL_Leg_decoded[:,0:3],RR_Leg_decoded[:,0:3]],dim=-1) # (bz,12)
        decoded_vel = torch.cat([FL_Leg_decoded[:,3:6],FR_Leg_decoded[:,3:6],RL_Leg_decoded[:,3:6],RR_Leg_decoded[:,3:6]],dim=-1)
        decoded_act = torch.cat([FL_Leg_decoded[:,6:9],FR_Leg_decoded[:,6:9],RL_Leg_decoded[:,6:9],RR_Leg_decoded[:,6:9]],dim=-1)
        decoded_contact = torch.cat([FL_Leg_decoded[:,9:10],FR_Leg_decoded[:,9:10],RL_Leg_decoded[:,9:10],RR_Leg_decoded[:,9:10]],dim=-1)
        decoded = torch.cat((base_decoded,decoded_pos,decoded_vel,decoded_act,decoded_contact),dim=-1)
        return decoded

In [6]:
obs_history = torch.randn(2,10,52)
obs = torch.randn(2,52)

obs_history_dim = 10 * 52 

vq = VectorQuantize(
    dim = 256,
    codebook_dim = 32,                  # a number of papers have shown smaller codebook dimension to be acceptable
    heads = 8,                          # number of heads to vector quantize, codebook shared across all heads
    separate_codebook_per_head = True,  # whether to have a separate codebook per head. False would mean 1 shared codebook
    codebook_size = 32,
)

## VQ Test

In [5]:
from functools import partial

import torch
from torch import nn, einsum
import torch.nn.functional as F
import torch.distributed as distributed
from torch.optim import Optimizer
from torch.cuda.amp import autocast

from einops import rearrange, repeat, reduce, pack, unpack

from typing import Callable

def batched_embedding(indices, embeds):
    batch, dim = indices.shape[1], embeds.shape[-1]
    indices = repeat(indices, 'h b n -> h b n d', d = dim)
    embeds = repeat(embeds, 'h c d -> h b c d', b = batch)
    return embeds.gather(2, indices)

def ema_inplace(old, new, decay):
    old.mul_(decay).add_(new * (1 - decay))

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def uniform_init(*shape):
    t = torch.empty(shape)
    nn.init.kaiming_uniform_(t)
    return t
def pack_one(t, pattern):
    return pack([t], pattern)

def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]
def l2norm(t):
    return F.normalize(t, p = 2, dim = -1)
def sample_vectors(samples, num):
    num_samples, device = samples.shape[0], samples.device
    if num_samples >= num:
        indices = torch.randperm(num_samples, device = device)[:num]
    else:
        indices = torch.randint(0, num_samples, (num,), device = device)

    return samples[indices]

def batched_sample_vectors(samples, num):
    return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim = 0)], dim = 0)

def orthogonal_loss_fn(t):
    # eq (2) from https://arxiv.org/abs/2112.00384
    #! t.shape (n_book, n_codebook, dim)
    h, n = t.shape[:2]
    normed_codes = l2norm(t)
    cosine_sim = einsum('h i d, h j d -> h i j', normed_codes, normed_codes)
    return (cosine_sim ** 2).sum() / (h * n ** 2) - (1 / n)

def orthogonal_loss_fn_with_mask(t, mask):
    #! t.shape (n_book, n_codebook, dim)
    #! mask.shape (n_book, n_codebook)
    # perform othogonal loss on masked codes where mask==1
    h, n = t.shape[:2]
    normed_codes = l2norm(t)
    cosine_sim = einsum('h i d, h j d -> h i j', normed_codes, normed_codes)
    print("Check: ", cosine_sim, mask)
    cosine_sim = cosine_sim * mask.unsqueeze(-1) * mask.unsqueeze(-2)
    print("Check: ", cosine_sim)
    n_effective = mask.sum(dim = -1)[0]
    print("Check: ", n_effective)
    return (cosine_sim ** 2).sum() / (h * n_effective ** 2) - (1 / n_effective)
    



def cdist(x, y):
    x2 = reduce(x ** 2, 'b n d -> b n', 'sum')
    y2 = reduce(y ** 2, 'b n d -> b n', 'sum')
    xy = einsum('b i d, b j d -> b i j', x, y) * -2
    return (rearrange(x2, 'b i -> b i 1') + rearrange(y2, 'b j -> b 1 j') + xy).sqrt()
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

def laplace_smoothing(x, n_categories, eps = 1e-5, dim = -1):
    # x.shape = (n_book, n_code)
    denom = x.sum(dim = dim, keepdim = True)
    return (x + eps) / (denom + n_categories * eps) # (n_book,n_code), n_code 的概率

def simple_sample(logits,mask,temperature=1,deterministic=True,dim = -1):
    # loigits.shape = (n_head, bz, n_codebook)
    # mask.shape = (n_head, n_codebook)
    # sample logits with argmax from the mask==1 
    dtype, size = logits.dtype, logits.shape[dim]
    if deterministic:
        sampling_logits = logits
    else:
        sampling_logits = (logits / temperature) + gumbel_noise(logits)
    if mask is not None:
        mask = repeat(mask, 'h c -> h b c', b = logits.shape[1])
        sampling_logits[~mask] = -1e10
    ind = sampling_logits.argmax(dim = dim)
    one_hot = F.one_hot(ind, size).type(dtype)
    return ind, one_hot

def gumbel_sample(
    logits,
    temperature = 1.,
    stochastic = False,
    straight_through = False,
    reinmax = False,
    dim = -1,
    training = True
):
    dtype, size = logits.dtype, logits.shape[dim]

    if training and stochastic and temperature > 0:
        sampling_logits = (logits / temperature) + gumbel_noise(logits)
    else:
        sampling_logits = logits

    ind = sampling_logits.argmax(dim = dim)
    one_hot = F.one_hot(ind, size).type(dtype)

    assert not (reinmax and not straight_through), 'reinmax can only be turned on if using straight through gumbel softmax'

    if not straight_through or temperature <= 0. or not training:
        return ind, one_hot

    # use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612
    # algorithm 2

    if reinmax:
        π0 = logits.softmax(dim = dim)
        π1 = (one_hot + (logits / temperature).softmax(dim = dim)) / 2
        π1 = ((log(π1) - logits).detach() + logits).softmax(dim = 1)
        π2 = 2 * π1 - 0.5 * π0
        one_hot = π2 - π2.detach() + one_hot
    else:
        π1 = (logits / temperature).softmax(dim = dim)
        one_hot = one_hot + π1 - π1.detach()

    return ind, one_hot
class CodeBook(nn.Module):
    def __init__(self,
                 dim, # 这个 dim 就是分化好的
                 num_codebooks,
                 codebook_size,
                    ema_update = True,
                    decay = 0.8,
                    eps = 1e-5,
                    threshold_ema_dead_code = 2
                 ) -> None:
        super().__init__()
        self.codebook_size = codebook_size
        self.num_codebooks = num_codebooks
        init_fn = uniform_init # 可能可以尝试有 pretrain 的初始化 
        self.gumbel_sample = gumbel_sample
        embed = init_fn(num_codebooks,codebook_size, dim)
        self.sample_codebook_temp = 1.0
        self.ema_update = ema_update
        self.decay = decay
        self.eps = eps
        self.threshold_ema_dead_code = threshold_ema_dead_code 
        self.reset_cluster_size = threshold_ema_dead_code
        self.sample_fn = batched_sample_vectors

        #! 是不是统计一下每个 embed 的选取个数, 可以查看 dead code
        self.register_buffer('cluster_size', torch.zeros(num_codebooks, codebook_size))
        self.register_buffer('embed_avg', embed.clone())
        self.embed = nn.Parameter(embed)

        #! 算法需求
        self.register_buffer('batch_mean', None)
        self.register_buffer('batch_variance', None)

        self.register_buffer('codebook_mean_needs_init', torch.Tensor([True]))
        self.register_buffer('codebook_mean', torch.empty(num_codebooks, 1, dim))
        self.register_buffer('codebook_variance_needs_init', torch.Tensor([True]))
        self.register_buffer('codebook_variance', torch.empty(num_codebooks, 1, dim))
    
    def replace(self, batch_samples, batch_mask):
        for ind, (samples, mask) in enumerate(zip(batch_samples.unbind(dim = 0), batch_mask.unbind(dim = 0))):
            if not torch.any(mask):
                continue

            sampled = self.sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item())
            sampled = rearrange(sampled, '1 ... -> ...')
            
            self.embed.data[ind][mask] = sampled

            self.cluster_size.data[ind][mask] = self.reset_cluster_size
            self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size

    def expire_codes_(self, batch_samples):
        if self.threshold_ema_dead_code == 0:
            return

        expired_codes = self.cluster_size < self.threshold_ema_dead_code

        if not torch.any(expired_codes):
            return

        batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d')
        self.replace(batch_samples, batch_mask = expired_codes)


    @torch.jit.ignore
    def update_with_decay(self, buffer_name, new_value, decay):
        old_value = getattr(self, buffer_name)

        needs_init = getattr(self, buffer_name + "_needs_init", False)

        if needs_init:
            self.register_buffer(buffer_name + "_needs_init", torch.Tensor([False]))

        if not exists(old_value) or needs_init:
            self.register_buffer(buffer_name, new_value.detach())

            return

        value = old_value * decay + new_value.detach() * (1 - decay)
        self.register_buffer(buffer_name, value)
    
    @autocast(enabled = False)
    def forward(
        self,
        x,
        sample_codebook_temp = None,
        mask = None,
        freeze_codebook = False
    ):
        # x.shape = 'h b n d', head 数量前置, b 是 batch size, n 感觉是辅助的 ?? d 是 dim
        needs_codebook_dim = x.ndim < 4
        x = x.float()
        if needs_codebook_dim:
            x = rearrange(x, 'h ... -> h 1 ...')
        flatten,ps = pack_one(x,'h * d') # 这步的操作类似 从 h,b,n,d -> h, b*n,d
        embed = self.embed
        dist = -cdist(flatten, embed) # 越大相似度越高 

        embed_ind, embed_onehot = self.gumbel_sample(dist, dim = -1, temperature = self.sample_codebook_temp, training = self.training)
        embed_ind = unpack_one(embed_ind, ps, 'h *')
        if self.training:
            unpacked_onehot = unpack_one(embed_onehot, ps, 'h * c')
            quantize = einsum('h b n c, h c d -> h b n d', unpacked_onehot, embed)
        else:
            quantize = batched_embedding(embed_ind, embed)

        
        if self.training and self.ema_update:
            cluster_size = embed_onehot.sum(dim = 1) 
            ema_inplace(self.cluster_size, cluster_size, self.decay)
            embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot) # 选择出来哪些 emb 被选中了 
            ema_inplace(self.embed_avg, embed_sum, self.decay)
            cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum(dim = -1, keepdim = True)
            embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1')
            self.embed.data.copy_(embed_normalized)
            self.expire_codes_(x)
        
        dist = unpack_one(dist, ps, 'h * d')
        if needs_codebook_dim:
            quantize = rearrange(quantize, 'h 1 n d -> h n d')
            dist = rearrange(dist, 'h 1 n d -> h n d')
            embed_ind = rearrange(embed_ind, 'h 1 n -> h n')
        return quantize, embed_ind, dist
        
class SimpleCodeBook(nn.Module):
    def __init__(self,
                 dim, # 这个 dim 就是分化好的
                 num_codebooks,
                 codebook_size,
                    ema_update = True,
                    decay = 0.8,
                    eps = 1e-5,
                    threshold_ema_dead_code = 2
                 ) -> None:
        super().__init__()
        self.codebook_size = codebook_size
        self.num_codebooks = num_codebooks
        init_fn = uniform_init # 可能可以尝试有 pretrain 的初始化 
        self.sample_method = simple_sample
        embed = init_fn(num_codebooks,codebook_size, dim)
        self.sample_codebook_temp = 1.0
        self.ema_update = ema_update
        self.decay = decay
        self.eps = eps
        self.sample_fn = batched_sample_vectors

        #! 是不是统计一下每个 embed 的选取个数, 可以查看 dead code
        self.register_buffer('cluster_size', torch.zeros(num_codebooks, codebook_size))
        self.register_buffer('valid_codebook', torch.ones(num_codebooks, codebook_size,dtype=torch.bool))
        self.register_buffer('embed_avg', embed.clone())
        self.embed = nn.Parameter(embed)

        #! 算法需求
        self.register_buffer('batch_mean', None)
        self.register_buffer('batch_variance', None)

        self.register_buffer('codebook_mean_needs_init', torch.Tensor([True]))
        self.register_buffer('codebook_mean', torch.empty(num_codebooks, 1, dim))
        self.register_buffer('codebook_variance_needs_init', torch.Tensor([True]))
        self.register_buffer('codebook_variance', torch.empty(num_codebooks, 1, dim))
    
    def unmask_all(self):
        self.valid_codebook.data.copy_(torch.ones_like(self.valid_codebook,dtype=torch.bool))
    def mask_percentage(self,percentage):
        assert percentage >= 0 and percentage <= 1
        valid_num = int(self.codebook_size * (1-percentage))
        self.valid_codebook.data.copy_(torch.ones_like(self.valid_codebook,dtype=torch.bool))
        self.valid_codebook.data[:,valid_num:] = 0
    def random_mask(self):
        # 对每个头的 codebook 随机 mask 一些
        mask = torch.randint(low=0,high = 2, size = (self.codebook_size,),dtype=torch.bool)
        self.valid_codebook.data[:,mask] = 0
    @torch.jit.ignore
    def update_with_decay(self, buffer_name, new_value, decay):
        old_value = getattr(self, buffer_name)

        needs_init = getattr(self, buffer_name + "_needs_init", False)

        if needs_init:
            self.register_buffer(buffer_name + "_needs_init", torch.Tensor([False]))

        if not exists(old_value) or needs_init:
            self.register_buffer(buffer_name, new_value.detach())

            return

        value = old_value * decay + new_value.detach() * (1 - decay)
        self.register_buffer(buffer_name, value)
    
    @autocast(enabled = False)
    def forward(
        self,
        x
    ):
        # x.shape = 'h b d', head 数量前置, b 是 batch size, d 是 dim
        x = x.float()
        
        embed = self.embed
        dist = -cdist(x, embed) # 越大相似度越高 

        embed_ind, embed_onehot = self.sample_method(logits = dist,
                                                     mask = self.valid_codebook,
                                                     temperature=1.0,
                                                     deterministic=not self.training,
                                                     dim = -1)
        if self.training:
            quantize = einsum('h b c, h c d -> h b d', embed_onehot, embed)
        else:
            batch, dim = embed_ind.shape[1], embed.shape[-1]
            indices = repeat(embed_ind, 'h b -> h b d', d = dim)
            repeated_embeds = repeat(embed, 'h c d -> h b c d', b = batch)
            quantize = repeated_embeds.gather(2, indices)

        
        if self.training and self.ema_update:
            #! onehot.shape = (n_book,bz,n_code)-> n_code 激活的数量
            cluster_size = embed_onehot.sum(dim = 1) 
            ema_inplace(self.cluster_size, cluster_size, self.decay) #! 统计每个 code 的激活数量均值
            embed_sum = einsum('h b d, h b c -> h c d', x, embed_onehot) # 选择出来哪些 emb 被选中了 
            ema_inplace(self.embed_avg, embed_sum, self.decay) #! 统计用来激活 code 的emb 原来长什么样

            #! 这个平滑挺有趣的, 因为有的 code 没有被采样到,但是不能直接 + 1, 所以先算概率(Laplace平滑), 再乘上总数, 相当于略微的平分
            cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum(dim = -1, keepdim = True)
            #! 这一步是缓慢更新 emb
            embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1')
            self.embed.data.copy_(embed_normalized)
        
        return quantize, embed_ind, dist

class VectorQuantize(nn.Module):
    def __init__(self,
                 input_dim, n_head, codebook_size,
                 commitment_weight = 1.,
                 orthogonal_reg_weight = 0.,
                 orthogonal_reg_active_codes_only = False):
        super().__init__()
        self.input_dim = input_dim 
        self.n_head = n_head
        self.codebook_size = codebook_size
        self.commitment_weight = commitment_weight
        self.orthogonal_reg_weight = orthogonal_reg_weight
        self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only 

        codebook_dim = input_dim // n_head # 这个是每个 head 的维度
        assert codebook_dim * n_head == input_dim, 'input_dim must be divisible by n_head'
        self._codebook = SimpleCodeBook(
            dim = codebook_dim,
            num_codebooks = n_head,
            codebook_size = codebook_size
        )
    
    def forward(self,x):
        # x.shape = 'bz, d',
        # x = rearrange(x,'bz d -> bz h d')
        ein_rhs_eq = 'h b d'
        x = rearrange(input, f"b (h d) -> {ein_rhs_eq}", h = self.n_head)
        # quantize
        quantize, embed_ind, distances = self._codebook(x)
        if self.training:
            # straight through
            commit_quantize = quantize
            quantize = x + (quantize - x).detach()

        embed_ind = rearrange(embed_ind, 'h b -> b h', h = self.n_head)
        

        loss = torch.tensor([0.], device = x.device, requires_grad = self.training)

        if self.training:
            #! commit loss
            commit_loss = F.mse_loss(commit_quantize, x)
            loss = loss + commit_loss * self.commitment_weight
            #! othogonal loss 
            if self.orthogonal_reg_weight > 0:
                orthogonal_reg_loss = orthogonal_loss_fn_with_mask(self._codebook.embed, self._codebook.valid_codebook)
                loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight


        quantize = rearrange(quantize, 'h b d -> b (h d)', h = self.n_head)
        
        return quantize, embed_ind, loss

    def get_info(self,x):
        ein_rhs_eq = 'h b d'
        x = rearrange(input, f"b (h d) -> {ein_rhs_eq}", h = self.n_head)
        # quantize
        quantize, embed_ind, distances = self._codebook(x)
        embed_ind = rearrange(embed_ind, 'h b -> b h', h = self.n_head)
        quantize = rearrange(quantize, 'h b d -> b (h d)', h = self.n_head)
        distances = rearrange(distances, 'h b d -> b h d')
        return quantize, embed_ind, distances


        
VQ = VectorQuantize(
    input_dim=8,
    n_head=4,
    codebook_size=3,
)

In [10]:



def get_activation(act_name):
    if act_name == "elu":
        return nn.ELU()
    elif act_name == "selu":
        return nn.SELU()
    elif act_name == "relu":
        return nn.ReLU()
    elif act_name == "crelu":
        return nn.ReLU()
    elif act_name == "lrelu":
        return nn.LeakyReLU()
    elif act_name == "tanh":
        return nn.Tanh()
    elif act_name == "sigmoid":
        return nn.Sigmoid()
    elif act_name == "identity":
        return nn.Identity()
    elif act_name == "elephant":
        return Elephant(0.2,4)
    else:
        print("invalid activation function!")
        return None

In [11]:
net = nn.Sequential(
    nn.Linear(10, 10),
    get_activation('elephant'),
    nn.Linear(10, 10),
)

In [12]:
x  = torch.randn(2,10)

In [13]:
res = net(x)

In [13]:
sz = (1,32,32)
t = torch.empty(sz)
nn.init.orthogonal_(t)
print(t)


tensor([[[ 0.0133, -0.0153,  0.0105,  ..., -0.0596,  0.0052,  0.0633],
         [ 0.0126, -0.0127,  0.0336,  ..., -0.0941,  0.0138,  0.0018],
         [-0.0151, -0.0243,  0.0075,  ...,  0.0183, -0.0211, -0.0049],
         ...,
         [-0.0162,  0.0233, -0.0164,  ...,  0.0485,  0.0029,  0.0357],
         [-0.0262,  0.0108,  0.0273,  ...,  0.0073, -0.0678,  0.0023],
         [-0.0461, -0.0137, -0.0086,  ..., -0.0030, -0.0267,  0.0228]]])


In [15]:
normed_t = l2norm(t)
res = einsum('h i d, h j d ->h i j', normed_t,normed_t)

In [16]:
res

tensor([[[ 1.0000e+00,  3.8754e-02, -2.2172e-02,  ...,  1.9489e-01,
          -1.0454e-01,  1.5987e-04],
         [ 3.8754e-02,  1.0000e+00, -2.6583e-01,  ..., -2.1387e-01,
           8.9059e-02,  9.0151e-03],
         [-2.2172e-02, -2.6583e-01,  1.0000e+00,  ..., -1.9104e-01,
           7.0439e-02,  1.6044e-01],
         ...,
         [ 1.9489e-01, -2.1387e-01, -1.9104e-01,  ...,  1.0000e+00,
           2.4896e-01, -1.2304e-01],
         [-1.0454e-01,  8.9059e-02,  7.0439e-02,  ...,  2.4896e-01,
           1.0000e+00,  2.1640e-01],
         [ 1.5987e-04,  9.0151e-03,  1.6044e-01,  ..., -1.2304e-01,
           2.1640e-01,  1.0000e+00]]])

In [18]:
loss = orthogonal_loss_fn(t)

In [5]:
class PushConfig:
    def __init__(self,
                 id,
                 body_index_list:list,
                 change_interval:int,
                 force_list:list, ) -> None:
        self.id = id 
        self.body_index_list =body_index_list
        self.change_interval = change_interval
        self.force_list = force_list 
        assert len(self.body_index_list) > 0 and  len(self.force_list) > 0
        self._force = self.force_list[0]
        self._body_index = self.body_index_list[0] 
    
    def _change(self):
        self._force = np.random.choice(self.force_list)
        self._body_index = np.random.choice(self.body_index_list) 


In [6]:
test_config = PushConfig(id = 0, 
                         body_index_list=[1,2,3,4],
                         change_interval=10,
                         force_list=[1,2,3,4])



In [3]:
tmp = np.load('logs/Eval/VQ-STG_Forward-stationary_push-DebugEval.npy',allow_pickle=True).item()

In [5]:
for k,v in tmp.items():
    print(k,v.shape)    

force (100, 1011, 17, 3)
done (100, 1011)
tracking_error (100, 1011)
base_vel (100, 1011, 3)
first_done (100,)
Fall (100,)


AttributeError: 'int' object has no attribute 'shape'