In [1]:
import pickle
import jax
import jax.numpy as jnp
from dataclasses import asdict
import jax.tree_util as tree
import numpy as np
import torch.onnx

In [2]:
with open("/home/rajath/Documents/go2_mjx/go2_params-2.pkl", "rb") as f:
    params_loaded = pickle.load(f)

# Convert back to JAX arrays if needed
params_jax = jax.tree_map(jnp.array, params_loaded)

print("Params successfully loaded")

  params_jax = jax.tree_map(jnp.array, params_loaded)


Params successfully loaded


In [3]:
params_jax

(RunningStatisticsState(mean={'privileged_state': Array([ 1.8289642e-02,  1.3090867e-03,  1.2097746e-02,  2.2260242e-03,
        -8.2428427e-03,  6.3151391e-03,  6.8136035e-03, -3.1715224e-04,
        -9.9226946e-01,  3.9039836e-03,  2.3533672e-02,  4.5742471e-02,
         2.9967295e-02,  4.1372482e-02,  6.2746361e-02, -1.7049771e-02,
         2.8441222e-02,  7.2308421e-02, -9.9016363e-03,  1.4072988e-02,
         2.2956943e-02, -1.1215338e-02,  8.5668741e-03, -2.3886548e-02,
         1.0460826e-02,  3.2212827e-02, -5.0802894e-02, -1.0000665e-02,
         4.5650452e-03, -1.1325739e-02,  6.8191229e-03,  2.1560378e-02,
        -7.5513460e-02, -2.1749906e-01,  8.8398727e-03,  3.9951625e-01,
         3.5934207e-01,  2.0643292e-02,  2.5817654e-01, -3.8763985e-01,
         3.9937560e-02,  1.7850511e-01,  2.5764763e-01,  5.5519387e-02,
         4.0031371e-01, -1.8792097e-03, -4.4011694e-04,  8.3973340e-04,
         2.2218851e-03, -8.2435533e-03,  6.3129584e-03, -9.7282737e-02,
        -2.0234

In [4]:
mean_std_all = asdict(params_loaded[0])

In [7]:
mean_std_all.keys()

dict_keys(['mean', 'std', 'count', 'summed_variance'])

In [8]:
weights_bias = params_loaded[1]

In [10]:
weights_bias.keys()

dict_keys(['params'])

In [11]:
weights_bias["params"].keys()

dict_keys(['hidden_0', 'hidden_1', 'hidden_2', 'hidden_3'])

In [8]:
weights_bias["params"]['hidden_0'].keys()

dict_keys(['bias', 'kernel'])

In [23]:
weights_bias["params"]['hidden_3']['bias'].shape

(24,)

In [24]:
weights_bias["params"]['hidden_0']['kernel'].shape

(48, 512)

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class MLP(nn.Module):
    def __init__(
        self,
        layer_sizes,
        activation=nn.ReLU(),
        kernel_init="lecun_uniform",
        activate_final=False,
        bias=True,
        layer_norm=False,
        mean_std=None,
    ):
        super().__init__()

        self.layer_sizes = layer_sizes
        self.activation = activation
        self.kernel_init = kernel_init
        self.activate_final = activate_final
        self.bias = bias
        self.layer_norm = layer_norm

        # Register mean and std as buffers (non-trainable parameters)
        if mean_std is not None:
            self.register_buffer("mean", torch.tensor(mean_std[0], dtype=torch.float32))
            self.register_buffer("std", torch.tensor(mean_std[1], dtype=torch.float32))
        else:
            self.mean = None
            self.std = None

        # Build the MLP block
        self.mlp_block = nn.Sequential()
        for i in range(len(self.layer_sizes) - 1):
            in_features = self.layer_sizes[i]
            out_features = self.layer_sizes[i + 1]

            # Add linear layer
            dense_layer = nn.Linear(in_features, out_features, bias=self.bias)
            self.mlp_block.add_module(f"hidden_{i}", dense_layer)

            # Initialize weights (e.g., Lecun uniform initialization)
            if self.kernel_init == "lecun_uniform":
                nn.init.kaiming_uniform_(dense_layer.weight, mode='fan_in', nonlinearity='relu')

            # Add layer normalization if enabled
            if self.layer_norm and i < len(self.layer_sizes) - 2:  # No layer norm after the last layer
                self.mlp_block.add_module(f"layer_norm_{i}", nn.LayerNorm(out_features))

            # Add activation function, except for the final layer if `activate_final` is False
            if i < len(self.layer_sizes) - 2 or self.activate_final:  # Add activation for all but the last layer
                self.mlp_block.add_module(f"activation_{i}", self.activation)

    def forward(self, inputs):
        # Handle list inputs
        if isinstance(inputs, list):
            inputs = inputs[0]

        # Normalize inputs if mean and std are provided
        if self.mean is not None and self.std is not None:
            inputs = (inputs - self.mean) / self.std

        # Pass through the MLP block
        logits = self.mlp_block(inputs)

        # Split the output into two parts and apply tanh to the first half
        loc, _ = torch.split(logits, logits.size(-1) // 2, dim=-1)
        return torch.tanh(loc)

def make_policy_network(
    observation_size,
    action_size,
    mean_std,
    hidden_layer_sizes=[256, 256],
    activation=nn.SiLU(),
    kernel_init="lecun_uniform",
    layer_norm=False,
):
    layers = hidden_layer_sizes + [action_size]
    print(layers)
    policy_network = MLP(
        layer_sizes= layers,
        activation=activation,
        kernel_init=kernel_init,
        layer_norm=layer_norm,
        mean_std=mean_std,
    )
    return policy_network

In [19]:
mean_std = (torch.tensor(mean_std_all['mean']['state']), torch.tensor(mean_std_all['std']['state']))

In [25]:
th_policy_network = make_policy_network(
    observation_size = 48,
    action_size = 12*2,
    mean_std=mean_std,
    hidden_layer_sizes=[48, 512, 256, 128])

[48, 512, 256, 128, 24]


  self.register_buffer("mean", torch.tensor(mean_std[0], dtype=torch.float32))
  self.register_buffer("std", torch.tensor(mean_std[1], dtype=torch.float32))


In [26]:
th_policy_network

MLP(
  (activation): SiLU()
  (mlp_block): Sequential(
    (hidden_0): Linear(in_features=48, out_features=512, bias=True)
    (activation_0): SiLU()
    (hidden_1): Linear(in_features=512, out_features=256, bias=True)
    (activation_1): SiLU()
    (hidden_2): Linear(in_features=256, out_features=128, bias=True)
    (activation_2): SiLU()
    (hidden_3): Linear(in_features=128, out_features=24, bias=True)
  )
)

In [27]:
# Access and modify weights and biases
for i, layer in enumerate(th_policy_network.mlp_block):
    if isinstance(layer, nn.Linear):
        print(f"Layer {i}:")
        print("Original weights:", layer.weight.shape)
        print("Original biases:", layer.bias.shape)

Layer 0:
Original weights: torch.Size([512, 48])
Original biases: torch.Size([512])
Layer 2:
Original weights: torch.Size([256, 512])
Original biases: torch.Size([256])
Layer 4:
Original weights: torch.Size([128, 256])
Original biases: torch.Size([128])
Layer 6:
Original weights: torch.Size([24, 128])
Original biases: torch.Size([24])


In [16]:
type(weights_bias["params"]['hidden_0'])

dict

In [17]:
for key, values in weights_bias["params"].items():
    print(f"layer:{key}------------------")
    for key, kernel_bias in values.items():
        print(f"{key} size: {kernel_bias.shape}")

layer:hidden_0------------------
bias size: (512,)
kernel size: (48, 512)
layer:hidden_1------------------
bias size: (256,)
kernel size: (512, 256)
layer:hidden_2------------------
bias size: (128,)
kernel size: (256, 128)
layer:hidden_3------------------
bias size: (24,)
kernel size: (128, 24)


In [18]:
for i, layer in enumerate(th_policy_network.mlp_block):
    print(i, layer)

0 Linear(in_features=48, out_features=512, bias=True)
1 SiLU()
2 Linear(in_features=512, out_features=256, bias=True)
3 SiLU()
4 Linear(in_features=256, out_features=128, bias=True)
5 SiLU()
6 Linear(in_features=128, out_features=24, bias=True)


In [28]:
# Assuming th_policy_network is already defined
values = [(key,value) for key, value in weights_bias["params"].items()]
j = 0
for i, layer in enumerate(th_policy_network.mlp_block):
    if isinstance(layer, nn.Linear):  # Check if the layer is a Linear layer
        #print(f"Layer {values[i][0]}:")
        #print("Weights shape:", layer.weight.shape)
        #print("Biases shape:", layer.bias.shape)
        #print(f"before: {layer.weight.data}, size: {layer.weight.data.shape}")
        #print(f"BEFOREvalues: {values[i][1]['kernel']}, size: {values[i][1]['kernel'].shape}")

        #print(f"size: {values[i][1]['bias'].shape}")
        #print(f"size: {values[i][1]['kernel'].shape}")

        if (i%2==0):
            
            print(layer)
            transpose_tensor_kernel = torch.tensor(values[j][1]['kernel']).t()
            transpose_tensor_bias = torch.tensor(values[j][1]['bias']).t()

            #print(f"ker: {transpose_tensor_kernel.shape}")
            #print(f"bia: {transpose_tensor_bias.shape}")
                # Assign new weights and biases
            layer.weight.data = transpose_tensor_kernel
            layer.bias.data = transpose_tensor_bias
            j = j + 1
     

Linear(in_features=48, out_features=512, bias=True)
Linear(in_features=512, out_features=256, bias=True)
Linear(in_features=256, out_features=128, bias=True)
Linear(in_features=128, out_features=24, bias=True)


In [29]:
batch_size = 1
input = torch.randn(batch_size, 48)  # For a batch of inputs

In [30]:
th_policy_network

MLP(
  (activation): SiLU()
  (mlp_block): Sequential(
    (hidden_0): Linear(in_features=48, out_features=512, bias=True)
    (activation_0): SiLU()
    (hidden_1): Linear(in_features=512, out_features=256, bias=True)
    (activation_1): SiLU()
    (hidden_2): Linear(in_features=256, out_features=128, bias=True)
    (activation_2): SiLU()
    (hidden_3): Linear(in_features=128, out_features=24, bias=True)
  )
)

In [31]:
input.shape

torch.Size([1, 48])

In [32]:
th_policy_network.forward(input)

tensor([[-0.6950, -0.0239, -0.3071, -0.1466, -0.7421,  0.9410, -0.6703,  0.2743,
          0.3922,  0.2304, -0.5998, -0.5471]], grad_fn=<TanhBackward0>)

In [23]:
mean_std

(tensor([-5.3753e-03, -3.7227e-03,  8.1214e-03,  2.5784e-03, -6.0854e-03,
         -1.6575e-03,  6.0498e-03, -1.8035e-03, -9.9622e-01,  9.2987e-05,
          2.4611e-02,  1.0910e-01,  1.6130e-02,  1.2811e-02,  9.1156e-02,
         -1.9285e-02,  9.3518e-03,  1.0188e-01,  3.8124e-02,  2.0091e-02,
          8.4244e-02, -4.6906e-03,  8.9083e-03, -2.5855e-02,  5.9413e-03,
          9.4363e-03, -2.6273e-02, -6.1056e-03,  4.4214e-03, -2.1119e-02,
          6.9455e-03, -1.0700e-03, -6.9315e-03, -2.3088e-01, -3.0316e-02,
          5.0119e-01,  2.5405e-01, -1.0623e-02,  5.2193e-01, -2.1492e-01,
          5.4053e-02,  4.6472e-01,  2.5415e-01,  3.8651e-02,  4.8790e-01,
         -2.9012e-03, -7.5254e-04, -1.1815e-04]),
 tensor([0.3682, 0.1376, 0.1472, 0.8580, 0.4785, 0.5599, 0.0573, 0.0689, 0.0445,
         0.1033, 0.1173, 0.1269, 0.1088, 0.1202, 0.1283, 0.0904, 0.1269, 0.1364,
         0.0931, 0.1213, 0.1330, 1.9139, 2.7987, 3.6095, 1.9457, 2.8429, 3.8450,
         1.9158, 3.0908, 3.5546, 1.9811, 

In [33]:

# Define the output ONNX file path
onnx_file_path = "mysai.onnx"

# Export the model
torch.onnx.export(
    th_policy_network,                  # Model to export
    input,            # Dummy input
    onnx_file_path,         # Output file path
    export_params=True,     # Export model parameters (weights)
    opset_version=11,       # ONNX opset version (e.g., 11 is widely supported)
    do_constant_folding=True,  # Optimize the model by folding constants
    input_names=["state"],  # Input tensor name
    output_names=["actions"],  # Output tensor name
)

print(f"Model exported to {onnx_file_path}")

Model exported to mysai.onnx
