In [2]:
from __future__ import annotations

import torch
import torch.nn as nn
from torch.distributions import Normal

from rsl_rl.utils import resolve_nn_activation
from math import sqrt


class ActorCritic_o1(nn.Module):
    is_recurrent = False

    def __init__(
        self,
        num_actor_obs,
        num_critic_obs,
        num_actions,
        actor_hidden_dims,
        critic_hidden_dims,
        activation="elu",
        init_noise_std=1.0,
        noise_std_type: str = "scalar",
        enc_dims=[128, 64],
        len_o1=48,
        enc_activation=True,
        **kwargs,
    ):
        if kwargs:
            print(
                "ActorCritic.__init__ got unexpected arguments, which will be ignored: "
                + str([key for key in kwargs.keys()])
            )
        super().__init__()
        activation = resolve_nn_activation(activation)

        self.len_obs = num_actor_obs
        self.len_o1 = len_o1

        # Policy
        actor_enc_layers = []
        actor_enc_layers.append(self.layer_init(nn.Linear(self.len_obs - self.len_o1, enc_dims[0])))
        actor_enc_layers.append(activation)
        for layer_index in range(len(enc_dims) - 1):
            actor_enc_layers.append(self.layer_init(nn.Linear(enc_dims[layer_index], enc_dims[layer_index + 1])))
            if layer_index != len(enc_dims) - 2:
                actor_enc_layers.append(activation)
            elif enc_activation:
                actor_enc_layers.append(activation)
                
        self.actor_enc = nn.Sequential(*actor_enc_layers)

        actor_layers = []
        actor_layers.append(self.layer_init(nn.Linear(enc_dims[-1] + self.len_o1, actor_hidden_dims[0])))
        actor_layers.append(activation)
        for layer_index in range(len(actor_hidden_dims)):
            if layer_index == len(actor_hidden_dims) - 1:
                actor_layers.append(self.layer_init(nn.Linear(actor_hidden_dims[layer_index], num_actions), std=1.0))
            else:
                actor_layers.append(self.layer_init(nn.Linear(actor_hidden_dims[layer_index], actor_hidden_dims[layer_index + 1])))
                actor_layers.append(activation)
        self.actor = nn.Sequential(*actor_layers)

        # Value function
        critic_enc_layers = []
        critic_enc_layers.append(self.layer_init(nn.Linear(self.len_obs - self.len_o1, enc_dims[0])))
        critic_enc_layers.append(activation)
        for layer_index in range(len(enc_dims) - 1):
            critic_enc_layers.append(self.layer_init(nn.Linear(enc_dims[layer_index], enc_dims[layer_index + 1])))
            if layer_index != len(enc_dims) - 2:
                critic_enc_layers.append(activation)
            elif enc_activation:
                critic_enc_layers.append(activation)

        self.critic_enc = nn.Sequential(*critic_enc_layers)

        critic_layers = []
        critic_layers.append(self.layer_init(nn.Linear(enc_dims[-1] + self.len_o1, critic_hidden_dims[0])))
        critic_layers.append(activation)
        for layer_index in range(len(critic_hidden_dims)):
            if layer_index == len(critic_hidden_dims) - 1:
                critic_layers.append(self.layer_init(nn.Linear(critic_hidden_dims[layer_index], 1), std=1.0))
            else:
                critic_layers.append(self.layer_init(nn.Linear(critic_hidden_dims[layer_index], critic_hidden_dims[layer_index + 1])))
                critic_layers.append(activation)
        self.critic = nn.Sequential(*critic_layers)

        print(f"Actor Encoder: {self.actor_enc}")
        print(f"Actor MLP: {self.actor}")
        print(f"Actor parameters: {sum([p.numel() for p in self.actor.parameters()]) + sum([p.numel() for p in self.actor_enc.parameters()])}\n")
        print(f"Critic Encoder: {self.critic_enc}")
        print(f"Critic MLP: {self.critic}")
        print(f"Critic parameters: {sum([p.numel() for p in self.critic.parameters()]) + sum([p.numel() for p in self.critic_enc.parameters()])}")

        # Action noise
        self.noise_std_type = noise_std_type
        if self.noise_std_type == "scalar":
            self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
        elif self.noise_std_type == "log":
            self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions)))
        else:
            raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")

        # Action distribution (populated in update_distribution)
        self.distribution = None
        # disable args validation for speedup
        Normal.set_default_validate_args(False)

    @staticmethod
    def layer_init(layer, std=sqrt(2), bias_const=0.0):
        torch.nn.init.orthogonal_(layer.weight, std)
        torch.nn.init.constant_(layer.bias, bias_const)
        return layer

    def reset(self, dones=None):
        pass

    def forward(self):
        raise NotImplementedError

    @property
    def action_mean(self):
        return self.distribution.mean

    @property
    def action_std(self):
        return self.distribution.stddev

    @property
    def entropy(self):
        return self.distribution.entropy().sum(dim=-1)

    def update_distribution(self, observations):
        # compute mean
        o_t = observations[:, :self.len_o1]
        x_t = observations[:, self.len_o1:]
        l_t = self.actor_enc(x_t)
        mean = self.actor(torch.cat((o_t, l_t), dim=1))
        # compute standard deviation
        if self.noise_std_type == "scalar":
            std = self.std.expand_as(mean)
        elif self.noise_std_type == "log":
            std = torch.exp(self.log_std).expand_as(mean)
        else:
            raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
        # create distribution
        self.distribution = Normal(mean, std)

    def act(self, observations, **kwargs):
        self.update_distribution(observations)
        return self.distribution.sample()

    def get_actions_log_prob(self, actions):
        return self.distribution.log_prob(actions).sum(dim=-1)

    def act_inference(self, observations):
        o_t = observations[:, :self.len_o1]
        x_t = observations[:, self.len_o1:]
        l_t = self.actor_enc(x_t)
        actions_mean = self.actor(torch.cat((o_t, l_t), dim=1))
        return actions_mean

    def evaluate(self, critic_observations, **kwargs):
        o_t = critic_observations[:, :self.len_o1]
        x_t = critic_observations[:, self.len_o1:]
        l_t = self.critic_enc(x_t)
        value = self.critic(torch.cat((o_t, l_t), dim=1))
        return value

    def load_state_dict(self, state_dict, strict=True):
        super().load_state_dict(state_dict, strict=strict)
        return True


In [3]:
model = ActorCritic_o1(
    num_actor_obs=208,
    num_critic_obs=208,
    num_actions=12, 
    actor_hidden_dims=[256, 128, 64],
    critic_hidden_dims=[256, 128, 64],
    enc_dims=[256, 128, 64],
    len_o1=45,
    enc_activation=False
)

Actor Encoder: Sequential(
  (0): Linear(in_features=163, out_features=256, bias=True)
  (1): ELU(alpha=1.0)
  (2): Linear(in_features=256, out_features=128, bias=True)
  (3): ELU(alpha=1.0)
  (4): Linear(in_features=128, out_features=64, bias=True)
)
Actor MLP: Sequential(
  (0): Linear(in_features=109, out_features=256, bias=True)
  (1): ELU(alpha=1.0)
  (2): Linear(in_features=256, out_features=128, bias=True)
  (3): ELU(alpha=1.0)
  (4): Linear(in_features=128, out_features=64, bias=True)
  (5): ELU(alpha=1.0)
  (6): Linear(in_features=64, out_features=12, bias=True)
)
Actor parameters: 153228

Critic Encoder: Sequential(
  (0): Linear(in_features=163, out_features=256, bias=True)
  (1): ELU(alpha=1.0)
  (2): Linear(in_features=256, out_features=128, bias=True)
  (3): ELU(alpha=1.0)
  (4): Linear(in_features=128, out_features=64, bias=True)
)
Critic MLP: Sequential(
  (0): Linear(in_features=109, out_features=256, bias=True)
  (1): ELU(alpha=1.0)
  (2): Linear(in_features=256, out_

In [12]:
load_state = torch.load('/home/tema/tema_lab/logs/rsl_rl/go2_velocity_rma_v3_rough/2025-09-25_17-41-36_teacher/model_250000.pt')  
load_state['model_state_dict']["actor_enc.0.weight"] = load_state['model_state_dict']["actor_enc.0.weight"][:, :-72]
load_state['model_state_dict']["critic_enc.0.weight"] = load_state['model_state_dict']["critic_enc.0.weight"][:, :-72]
model.load_state_dict(load_state['model_state_dict'])
load_state.pop("optimizer_state_dict")

{'state': {0: {'step': tensor(5260160.),
   'exp_avg': tensor([ 0.0013,  0.0005, -0.0012, -0.0010, -0.0014,  0.0006, -0.0005,  0.0014,
           -0.0006,  0.0007,  0.0020, -0.0008], device='cuda:0'),
   'exp_avg_sq': tensor([1.5564e-04, 4.7150e-05, 5.8223e-05, 1.4310e-04, 4.5348e-05, 6.3191e-05,
           1.7995e-04, 6.8655e-05, 4.3499e-05, 1.4753e-04, 6.5041e-05, 6.2744e-05],
          device='cuda:0')},
  1: {'step': tensor(5260160.),
   'exp_avg': tensor([[-7.4939e-07,  1.1228e-06,  7.9082e-08,  ..., -1.0546e-06,
             1.3124e-06,  8.7981e-08],
           [-5.3547e-06,  1.5503e-06, -3.9382e-07,  ...,  1.2440e-05,
            -1.9981e-06,  1.2944e-07],
           [-2.7471e-06,  2.3899e-06, -6.6121e-07,  ..., -1.7493e-05,
            -1.3125e-05, -4.8412e-07],
           ...,
           [ 1.1975e-06, -1.3009e-06,  5.9522e-07,  ..., -5.3752e-06,
            -3.2332e-06, -2.5089e-07],
           [ 2.0615e-06,  2.8316e-07,  2.7339e-08,  ..., -1.1719e-05,
            -1.6224e-06,

In [13]:
torch.save(load_state, "/home/tema/tema_lab/logs/rsl_rl/go2_velocity_rma_v3_rough/2025-09-25_17-41-36_teacher/model_250000_1.pt")

In [6]:
list(model.actor_enc.parameters()) + list(model.actor.parameters())

[Parameter containing:
 tensor([[ 0.1281, -0.1141,  0.0232,  ...,  0.0280,  0.1828, -0.0274],
         [ 0.0634,  0.0262, -0.0042,  ...,  0.0053, -0.0101,  0.0944],
         [ 0.0772,  0.0512,  0.2252,  ...,  0.0072, -0.0374, -0.0072],
         ...,
         [ 0.0308,  0.1775,  0.0351,  ...,  0.0609,  0.0321,  0.0171],
         [ 0.0019,  0.0226, -0.0826,  ..., -0.0429,  0.0054,  0.0077],
         [-0.2465,  0.0017,  0.1208,  ...,  0.0340,  0.0708, -0.0299]],
        requires_grad=True),
 Parameter containing:
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,

In [7]:
[p.numel() for p in list(model.actor_enc.parameters()) + list(model.actor.parameters())]

[41728, 256, 32768, 128, 8192, 64, 27904, 256, 32768, 128, 8192, 64, 768, 12]

In [14]:
from isaaclab.utils import configclass

In [29]:
@configclass
class Data:
    a: list[int] = [1, 2]

In [41]:
@configclass
class Data2(Data):
    def __post_init__(self):
        self.a[0] = 0

In [42]:
cfg = Data2()

In [43]:
cfg.a

[0, 2]