In [5]:
import math
import numpy as np
import torch
import torch.nn as nn   
from functorch import jacrev, vmap
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv,ResGatedGraphConv,SAGEConv, GatedGraphConv
import torch.nn as nn 
from OnlineAdaptation.modules.vq_torch.vq_quantize import VectorQuantize

In [6]:
def get_activation(act_name):
    if act_name == "elu":
        return nn.ELU()
    elif act_name == "selu":
        return nn.SELU()
    elif act_name == "relu":
        return nn.ReLU()
    elif act_name == "crelu":
        return nn.ReLU()
    elif act_name == "lrelu":
        return nn.LeakyReLU()
    elif act_name == "tanh":
        return nn.Tanh()
    elif act_name == "sigmoid":
        return nn.Sigmoid()
    elif act_name == "identity":
        return nn.Identity()
    else:
        print("invalid activation function!")
        return None

In [7]:
body_index = {
    '0': [0,1,2,3,4,5,6,7,8,9,10,11,48,49,50,51],
}

hip_index = {
    '1': [12,24,36,48],
    '2': [15,27,39,49],
    '3': [18,30,42,50],
    '4': [21,33,45,51],
}
thigh_index = {
    '5': [13,25,37,48],
    '6': [16,28,40,49],
    '7': [19,31,43,50],
    '8': [22,34,46,51],
}
calf_index = {
    '9': [14,26,38,48],
    '10': [17,29,41,49],
    '11': [20,32,44,50],
    '12': [23,35,47,51],
}

edge_index = [
    [0,1],[0,2],[0,3],[0,4],[0,5],[0,6],[0,7],[0,8],[0,9],[0,10],[0,11],[0,12],
    [1,5],[5,9],
    [2,6],[6,10],
    [3,7],[7,11],
    [4,8],[8,12],
]

def mlp(input_dim, out_dim, hidden_sizes, activations):
    layers = []
    prev_h = input_dim
    for h in hidden_sizes:
        layers.append(nn.Linear(prev_h, h))
        layers.append(activations)
        prev_h = h
    layers.append(nn.Linear(prev_h, out_dim))
    return nn.Sequential(*layers)


In [8]:
class GraphEncoder(nn.Module):
    def __init__(self,
                 num_obs,
                 num_history,
                 num_latent,
                 activation = 'relu',):
        super(GraphEncoder, self).__init__()
        self.num_obs = num_obs
        self.num_latent = num_latent
        activation_fn = get_activation(activation)
        # graph info
        node_base = torch.tensor(list(body_index.values()),dtype=torch.long).squeeze()
        node_hip = torch.stack([torch.tensor(list(hip_index.values()),dtype=torch.long)],dim=0).squeeze()
        node_thigh = torch.stack([torch.tensor(list(thigh_index.values()),dtype=torch.long)],dim=0).squeeze()
        node_calf = torch.stack([torch.tensor(list(calf_index.values()),dtype=torch.long)],dim=0).squeeze()

        self.node_base = nn.Parameter(node_base, requires_grad=False)
        self.node_hip = nn.Parameter(node_hip, requires_grad=False)
        self.node_thigh = nn.Parameter(node_thigh, requires_grad=False)
        self.node_calf = nn.Parameter(node_calf, requires_grad=False) 

        self.edge = torch.as_tensor(edge_index, dtype=torch.long).contiguous().t()

        # build feature extractor for base, hip, thigh and calf
        base_input_size = num_history * len(list(body_index.values())[0])
        hip_input_size = num_history * len(list(hip_index.values())[0])
        thigh_input_size = num_history * len(list(thigh_index.values())[0])
        calf_input_size = num_history * len(list(calf_index.values())[0])
        self.base_net = mlp(base_input_size, 2 * num_latent, [128, 64], activation_fn)
        self.hip_net = mlp(hip_input_size, 2 * num_latent, [128, 64], activation_fn)
        self.thigh_net = mlp(thigh_input_size, 2 * num_latent, [128, 64], activation_fn)
        self.calf_net = mlp(calf_input_size, 2* num_latent, [128, 64], activation_fn)

        # build graph net 
        self.gn = ResGatedGraphConv(in_channels= 2* num_latent,out_channels=2* num_latent)
        self.gn2 = ResGatedGraphConv(in_channels= 2* num_latent,out_channels= num_latent)
        self.act = activation_fn
    
    def _history2node(self,obs_history):
        # obs_history.shape = (bz, n_histroy, n_obs)
        base = obs_history[:,:,self.node_base].unsqueeze(1) # (bz, n_history, n_base)
        hip = obs_history[:,:,self.node_hip].permute(0,2,1,3) # (bz, n_history, 4, 4)
        thigh = obs_history[:,:,self.node_thigh].permute(0,2,1,3)
        calf = obs_history[:,:,self.node_calf].permute(0,2,1,3) 
        base = self.base_net(base.flatten(-2,-1))
        hip = self.hip_net(hip.flatten(-2,-1))
        thigh = self.thigh_net(thigh.flatten(-2,-1))
        calf = self.calf_net(calf.flatten(-2,-1))
        return torch.cat([base,hip,thigh,calf],dim=1) # (bz, n_node,num_latent)
    
    def forward(self,obs_history):
        nodes = self._history2node(obs_history)
        nodes = self.gn(nodes,self.edge)
        nodes = self.act(nodes)
        nodes = self.gn2(nodes,self.edge)
        return nodes

    
class GraphActor(nn.Module):
    def __init__(self, 
                 num_obs,
                 num_latent,
                 num_actions,
                 activation = 'elu',
                 actor_hidden_dims = [512, 256, 128]):
        super().__init__()
        self.num_latent = num_latent
        # graph info
        node_base = torch.tensor(list(body_index.values()),dtype=torch.long).squeeze()
        node_hip = torch.stack([torch.tensor(list(hip_index.values()),dtype=torch.long)],dim=0).squeeze()
        node_thigh = torch.stack([torch.tensor(list(thigh_index.values()),dtype=torch.long)],dim=0).squeeze()
        node_calf = torch.stack([torch.tensor(list(calf_index.values()),dtype=torch.long)],dim=0).squeeze()
        activation_fn = get_activation(activation)
        self.node_base = nn.Parameter(node_base, requires_grad=False)
        self.node_hip = nn.Parameter(node_hip, requires_grad=False)
        self.node_thigh = nn.Parameter(node_thigh, requires_grad=False)
        self.node_calf = nn.Parameter(node_calf, requires_grad=False) 
        self.edge = torch.as_tensor(edge_index, dtype=torch.long).contiguous().t()
        # Pipeline
        # obs-> node, concat with latent -> node latent
        # node latent -> node level policy -> node action
        base_input_size =  len(list(body_index.values())[0])
        hip_input_size = len(list(hip_index.values())[0])
        thigh_input_size =len(list(thigh_index.values())[0])
        calf_input_size = len(list(calf_index.values())[0])
        self.base_net = mlp(base_input_size, num_latent, [128], activation_fn)
        self.hip_net = mlp(hip_input_size, num_latent, [128], activation_fn)
        self.thigh_net = mlp(thigh_input_size, num_latent, [128], activation_fn)
        self.calf_net = mlp(calf_input_size, num_latent, [128], activation_fn)

        # graph neural network 
        self.gn = ResGatedGraphConv(in_channels=2 * num_latent,
                                    out_channels=2 * num_latent)
        self.act = activation_fn
        self.gn2 = ResGatedGraphConv(in_channels=2 * num_latent,
                                    out_channels=num_latent)
        
        # graph policy : (base, hip, thigh, calf) -> leg_control 
        self.leg_policy = mlp(4*num_latent, 3, [256,128], activation_fn)
        self.FL_Leg = nn.Parameter(torch.tensor([0,1,5,9],dtype=torch.long), requires_grad=False)
        self.FR_Leg = nn.Parameter(torch.tensor([0,2,6,10],dtype=torch.long), requires_grad=False)
        self.RL_Leg = nn.Parameter(torch.tensor([0,3,7,11],dtype=torch.long), requires_grad=False)
        self.RR_Leg = nn.Parameter(torch.tensor([0,4,8,12],dtype=torch.long), requires_grad=False)

    def _obs2node(self, obs):
        # obs.shape = (bz, n_obs)
        base = obs[:,self.node_base].unsqueeze(1) # (bz, 1, n_base)
        hip = obs[:,self.node_hip]# (bz, 4, 4)
        thigh = obs[:,self.node_thigh]
        calf = obs[:,self.node_calf]
        base = self.base_net(base)
        hip = self.hip_net(hip)
        thigh = self.thigh_net(thigh)
        calf = self.calf_net(calf)
        return torch.cat([base,hip,thigh,calf],dim=1) # (bz, n_node,num_latent)
    def forward(self, obs, latent):
        # obs.shape = (bz, n_obs)
        # latent.shape = (bz,n_node, num_latent)
        obs_nodes = self._obs2node(obs) # (bz, n_node, num_latent) 
        nodes_latent = torch.cat([obs_nodes,latent],dim=-1) # (bz, n_node, 2*num_latent)
        nodes_latent = self.gn(nodes_latent, self.edge) # (bz, n_node, 2*num_latent)
        nodes_latent = self.act(nodes_latent)
        nodes_latent = self.gn2(nodes_latent, self.edge) # (bz, n_node, num_latent) 
        FL_Leg_latent = nodes_latent[:,self.FL_Leg,:].reshape(-1,4*self.num_latent)
        FR_Leg_latent = nodes_latent[:,self.FR_Leg,:].reshape(-1,4*self.num_latent)
        RL_Leg_latent = nodes_latent[:,self.RL_Leg,:].reshape(-1,4*self.num_latent)
        RR_Leg_latent = nodes_latent[:,self.RR_Leg,:].reshape(-1,4*self.num_latent)
        FL_Leg_action = self.leg_policy(FL_Leg_latent)
        FR_Leg_action = self.leg_policy(FR_Leg_latent)
        RL_Leg_action = self.leg_policy(RL_Leg_latent)
        RR_Leg_action = self.leg_policy(RR_Leg_latent)
        return torch.cat([FL_Leg_action,FR_Leg_action,RL_Leg_action,RR_Leg_action],dim=1) # (bz, 12)

class GraphForward(nn.Module):
    def __init__(self, 
                 num_obs,
                 num_latent,
                 num_actions,
                 activation = 'elu',
                 actor_hidden_dims = [512, 256, 128]):
        super().__init__()
        """
        obs + action + latent -> next_obs 
        obs -> node 
        action -> node
        """
        self.num_latent = num_latent
        # graph info
        node_base = torch.tensor(list(body_index.values()),dtype=torch.long).squeeze()
        node_hip = torch.stack([torch.tensor(list(hip_index.values()),dtype=torch.long)],dim=0).squeeze()
        node_thigh = torch.stack([torch.tensor(list(thigh_index.values()),dtype=torch.long)],dim=0).squeeze()
        node_calf = torch.stack([torch.tensor(list(calf_index.values()),dtype=torch.long)],dim=0).squeeze()
        
        node_hip_action = torch.tensor([0, 3, 6, 9],dtype=torch.long)
        node_thigh_action = torch.tensor([1, 4, 7, 10],dtype=torch.long)
        node_calf_action = torch.tensor([2, 5, 8, 11],dtype=torch.long)
        self.node_hip_action = nn.Parameter(node_hip_action, requires_grad=False)
        self.node_thigh_action = nn.Parameter(node_thigh_action, requires_grad=False)
        self.node_calf_action = nn.Parameter(node_calf_action, requires_grad=False)

        activation_fn = get_activation(activation)
        self.node_base = nn.Parameter(node_base, requires_grad=False)
        self.node_hip = nn.Parameter(node_hip, requires_grad=False)
        self.node_thigh = nn.Parameter(node_thigh, requires_grad=False)
        self.node_calf = nn.Parameter(node_calf, requires_grad=False) 
        self.edge = nn.Parameter(torch.as_tensor(edge_index, dtype=torch.long).contiguous().t(),requires_grad=False)
        # pipeline 
        base_input_size =  len(list(body_index.values())[0])
        hip_input_size = len(list(hip_index.values())[0]) + 1
        thigh_input_size =len(list(thigh_index.values())[0]) + 1 
        calf_input_size = len(list(calf_index.values())[0]) + 1 
        self.base_net = mlp(base_input_size, num_latent, [128], activation_fn)
        self.hip_net = mlp(hip_input_size, num_latent, [128], activation_fn)
        self.thigh_net = mlp(thigh_input_size, num_latent, [128], activation_fn)
        self.calf_net = mlp(calf_input_size, num_latent, [128], activation_fn)
        # graph neural network 
        self.gn = ResGatedGraphConv(in_channels=2 * num_latent,
                                    out_channels=2 * num_latent)
        self.act = activation_fn
        self.gn2 = ResGatedGraphConv(in_channels=2 * num_latent,
                                    out_channels=num_latent)
        # decoder 
        ## base_vel, base_ang, project_gravity, cmd
        ## dof_pos,vel,actions, contact
        self.base_decoder = mlp(num_latent, 3 + 3 + 3 + 3, [128], activation_fn)
        self.leg_decoder = mlp(num_latent * 4 , 3 + 3 + 3 + 1, [128], activation_fn)
        self.FL_Leg = nn.Parameter(torch.tensor([0,1,5,9],dtype=torch.long), requires_grad=False)
        self.FR_Leg = nn.Parameter(torch.tensor([0,2,6,10],dtype=torch.long), requires_grad=False)
        self.RL_Leg = nn.Parameter(torch.tensor([0,3,7,11],dtype=torch.long), requires_grad=False)
        self.RR_Leg = nn.Parameter(torch.tensor([0,4,8,12],dtype=torch.long), requires_grad=False)


    def _obsaction2node(self,obs,action):
        base = obs[:,self.node_base].unsqueeze(1) # (bz, 1, n_base)
        hip = obs[:,self.node_hip]# (bz, 4, 4)
        thigh = obs[:,self.node_thigh]
        calf = obs[:,self.node_calf]
        hip_action = action[:,self.node_hip_action].unsqueeze(-1)
        thigh_action = action[:,self.node_thigh_action].unsqueeze(-1)
        calf_action = action[:,self.node_calf_action].unsqueeze(-1)
        base = self.base_net(base)
        hip = self.hip_net(torch.cat([hip,hip_action],dim=-1))
        thigh = self.hip_net(torch.cat([thigh,thigh_action],dim=-1))
        calf = self.calf_net(torch.cat([calf,calf_action],dim=-1))
        node = torch.cat([base,hip,thigh,calf],dim=1)
        return node # shape (bz, 13, 4)

    def forward(self,obs,action,latent):
        obsaction_node = self._obsaction2node(obs,action) 
        nodes_latent = torch.cat([obsaction_node,latent],dim=-1) # (bz, n_node, 2*num_latent)
        nodes_latent = self.gn(nodes_latent, self.edge) # (bz, n_node, 2*num_latent)
        nodes_latent = self.act(nodes_latent)
        nodes_latent = self.gn2(nodes_latent, self.edge) # (bz, n_node, num_latent) 

        Base_latent = nodes_latent[:,0:1,:].reshape(-1,self.num_latent)
        FL_Leg_latent = nodes_latent[:,self.FL_Leg,:].reshape(-1,4*self.num_latent)
        FR_Leg_latent = nodes_latent[:,self.FR_Leg,:].reshape(-1,4*self.num_latent)
        RL_Leg_latent = nodes_latent[:,self.RL_Leg,:].reshape(-1,4*self.num_latent)
        RR_Leg_latent = nodes_latent[:,self.RR_Leg,:].reshape(-1,4*self.num_latent)

        base_decoded = self.base_decoder(Base_latent) # (bz,12)
        FL_Leg_decoded = self.leg_decoder(FL_Leg_latent) # (bz, 4) 
        FR_Leg_decoded = self.leg_decoder(FR_Leg_latent)
        RL_Leg_decoded = self.leg_decoder(RL_Leg_latent)
        RR_Leg_decoded = self.leg_decoder(RR_Leg_latent)
        decoded_pos = torch.cat([FL_Leg_decoded[:,0:3],FR_Leg_decoded[:,0:3],RL_Leg_decoded[:,0:3],RR_Leg_decoded[:,0:3]],dim=-1) # (bz,12)
        decoded_vel = torch.cat([FL_Leg_decoded[:,3:6],FR_Leg_decoded[:,3:6],RL_Leg_decoded[:,3:6],RR_Leg_decoded[:,3:6]],dim=-1)
        decoded_act = torch.cat([FL_Leg_decoded[:,6:9],FR_Leg_decoded[:,6:9],RL_Leg_decoded[:,6:9],RR_Leg_decoded[:,6:9]],dim=-1)
        decoded_contact = torch.cat([FL_Leg_decoded[:,9:10],FR_Leg_decoded[:,9:10],RL_Leg_decoded[:,9:10],RR_Leg_decoded[:,9:10]],dim=-1)
        decoded = torch.cat((base_decoded,decoded_pos,decoded_vel,decoded_act,decoded_contact),dim=-1)
        return decoded

In [10]:
obs_history = torch.randn(2,10,52)
obs = torch.randn(2,52)

obs_history_dim = 10 * 52 

vq = VectorQuantize(
    dim = 256,
    codebook_dim = 32,                  # a number of papers have shown smaller codebook dimension to be acceptable
    heads = 8,                          # number of heads to vector quantize, codebook shared across all heads
    separate_codebook_per_head = True,  # whether to have a separate codebook per head. False would mean 1 shared codebook
    codebook_size = 32,
)

In [13]:
net = nn.Sequential(
    nn.Linear(obs_history_dim, 256),
    vq,
)

In [14]:
res = net(obs_history.reshape(2,-1))

In [18]:
res[2].shape

torch.Size([1])