In [1]:
from __future__ import annotations

import torch
import torch.nn as nn
from torch.distributions import Normal

from rsl_rl.utils import resolve_nn_activation


class ActorCritic_o1(nn.Module):
    is_recurrent = False

    def __init__(
        self,
        num_actor_obs,
        num_critic_obs,
        num_actions,
        actor_hidden_dims,
        critic_hidden_dims,
        activation="elu",
        init_noise_std=1.0,
        noise_std_type: str = "scalar",
        enc_dims=[128, 64],
        len_o1=48,
        enc_activation=True,
        **kwargs,
    ):
        if kwargs:
            print(
                "ActorCritic.__init__ got unexpected arguments, which will be ignored: "
                + str([key for key in kwargs.keys()])
            )
        super().__init__()
        activation = resolve_nn_activation(activation)

        self.len_obs = num_actor_obs
        self.len_o1 = len_o1

        # Policy
        actor_enc_layers = []
        actor_enc_layers.append(nn.Linear(self.len_obs - self.len_o1, enc_dims[0]))
        actor_enc_layers.append(activation)
        for layer_index in range(len(enc_dims) - 1):
            actor_enc_layers.append(nn.Linear(enc_dims[layer_index], enc_dims[layer_index + 1]))
            if layer_index != len(enc_dims) - 2:
                actor_enc_layers.append(activation)
            elif enc_activation:
                actor_enc_layers.append(activation)
                
        self.actor_enc = nn.Sequential(*actor_enc_layers)

        actor_layers = []
        actor_layers.append(nn.Linear(enc_dims[-1] + self.len_o1, actor_hidden_dims[0]))
        actor_layers.append(activation)
        for layer_index in range(len(actor_hidden_dims)):
            if layer_index == len(actor_hidden_dims) - 1:
                actor_layers.append(nn.Linear(actor_hidden_dims[layer_index], num_actions))
            else:
                actor_layers.append(nn.Linear(actor_hidden_dims[layer_index], actor_hidden_dims[layer_index + 1]))
                actor_layers.append(activation)
        self.actor = nn.Sequential(*actor_layers)

        # Value function
        critic_enc_layers = []
        critic_enc_layers.append(nn.Linear(self.len_obs - self.len_o1, enc_dims[0]))
        critic_enc_layers.append(activation)
        for layer_index in range(len(enc_dims) - 1):
            critic_enc_layers.append(nn.Linear(enc_dims[layer_index], enc_dims[layer_index + 1]))
            if layer_index != len(enc_dims) - 2:
                critic_enc_layers.append(activation)
            elif enc_activation:
                critic_enc_layers.append(activation)

        self.critic_enc = nn.Sequential(*critic_enc_layers)

        critic_layers = []
        critic_layers.append(nn.Linear(enc_dims[-1] + self.len_o1, critic_hidden_dims[0]))
        critic_layers.append(activation)
        for layer_index in range(len(critic_hidden_dims)):
            if layer_index == len(critic_hidden_dims) - 1:
                critic_layers.append(nn.Linear(critic_hidden_dims[layer_index], 1))
            else:
                critic_layers.append(nn.Linear(critic_hidden_dims[layer_index], critic_hidden_dims[layer_index + 1]))
                critic_layers.append(activation)
        self.critic = nn.Sequential(*critic_layers)

        print(f"Actor Encoder: {self.actor_enc}")
        print(f"Actor MLP: {self.actor}")
        print(f"Actor parameters: {sum([p.numel() for p in self.actor.parameters()]) + sum([p.numel() for p in self.actor_enc.parameters()])}\n")
        print(f"Critic Encoder: {self.critic_enc}")
        print(f"Critic MLP: {self.critic}")
        print(f"Critic parameters: {sum([p.numel() for p in self.critic.parameters()]) + sum([p.numel() for p in self.critic_enc.parameters()])}")

        # Action noise
        self.noise_std_type = noise_std_type
        if self.noise_std_type == "scalar":
            self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
        elif self.noise_std_type == "log":
            self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions)))
        else:
            raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")

        # Action distribution (populated in update_distribution)
        self.distribution = None
        # disable args validation for speedup
        Normal.set_default_validate_args(False)

    @staticmethod
    # not used at the moment
    def init_weights(sequential, scales):
        [
            torch.nn.init.orthogonal_(module.weight, gain=scales[idx])
            for idx, module in enumerate(mod for mod in sequential if isinstance(mod, nn.Linear))
        ]

    def reset(self, dones=None):
        pass

    def forward(self):
        raise NotImplementedError

    @property
    def action_mean(self):
        return self.distribution.mean

    @property
    def action_std(self):
        return self.distribution.stddev

    @property
    def entropy(self):
        return self.distribution.entropy().sum(dim=-1)

    def update_distribution(self, observations):
        # compute mean
        o_t = observations[:, :self.len_o1]
        x_t = observations[:, self.len_o1:]
        l_t = self.actor_enc(x_t)
        mean = self.actor(torch.cat((o_t, l_t), dim=1))
        # compute standard deviation
        if self.noise_std_type == "scalar":
            std = self.std.expand_as(mean)
        elif self.noise_std_type == "log":
            std = torch.exp(self.log_std).expand_as(mean)
        else:
            raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
        # create distribution
        self.distribution = Normal(mean, std)

    def act(self, observations, **kwargs):
        self.update_distribution(observations)
        return self.distribution.sample()

    def get_actions_log_prob(self, actions):
        return self.distribution.log_prob(actions).sum(dim=-1)

    def act_inference(self, observations):
        o_t = observations[:, :self.len_o1]
        x_t = observations[:, self.len_o1:]
        l_t = self.actor_enc(x_t)
        actions_mean = self.actor(torch.cat((o_t, l_t), dim=1))
        return actions_mean

    def evaluate(self, critic_observations, **kwargs):
        o_t = critic_observations[:, :self.len_o1]
        x_t = critic_observations[:, self.len_o1:]
        l_t = self.critic_enc(x_t)
        value = self.critic(torch.cat((o_t, l_t), dim=1))
        return value

    def load_state_dict(self, state_dict, strict=True):
        super().load_state_dict(state_dict, strict=strict)
        return True


In [2]:
model = ActorCritic_o1(
    num_actor_obs=280,
    num_critic_obs=280,
    num_actions=12, 
    actor_hidden_dims=[256, 128, 64],
    critic_hidden_dims=[256, 128, 64],
    enc_dims=[256, 128, 64],
    len_o1=45,
    enc_activation=False
)

Actor Encoder: Sequential(
  (0): Linear(in_features=235, out_features=256, bias=True)
  (1): ELU(alpha=1.0)
  (2): Linear(in_features=256, out_features=128, bias=True)
  (3): ELU(alpha=1.0)
  (4): Linear(in_features=128, out_features=64, bias=True)
)
Actor MLP: Sequential(
  (0): Linear(in_features=109, out_features=256, bias=True)
  (1): ELU(alpha=1.0)
  (2): Linear(in_features=256, out_features=128, bias=True)
  (3): ELU(alpha=1.0)
  (4): Linear(in_features=128, out_features=64, bias=True)
  (5): ELU(alpha=1.0)
  (6): Linear(in_features=64, out_features=12, bias=True)
)
Actor parameters: 171660

Critic Encoder: Sequential(
  (0): Linear(in_features=235, out_features=256, bias=True)
  (1): ELU(alpha=1.0)
  (2): Linear(in_features=256, out_features=128, bias=True)
  (3): ELU(alpha=1.0)
  (4): Linear(in_features=128, out_features=64, bias=True)
)
Critic MLP: Sequential(
  (0): Linear(in_features=109, out_features=256, bias=True)
  (1): ELU(alpha=1.0)
  (2): Linear(in_features=256, out_

In [3]:
# load_state = torch.load('/home/aivizw/IsaacLab_5.0.0/logs/rsl_rl/go2_velocity_rma_v3_flat_v2/2025-09-11_11-46-58_teacher/model_39000.pt')  
# load_state['model_state_dict']["actor.0.weight"] = torch.cat((load_state['model_state_dict']["actor.0.weight"][:, :3], load_state['model_state_dict']["actor.0.weight"][:, 6:]), dim=1)
# load_state['model_state_dict']["critic.0.weight"] = torch.cat((load_state['model_state_dict']["critic.0.weight"][:, :3], load_state['model_state_dict']["critic.0.weight"][:, 6:]), dim=1)
# model.load_state_dict(load_state['model_state_dict'])
# load_state.pop("optimizer_state_dict")

load_state = torch.load('/home/aivizw/tema_lab/logs/rsl_rl/go2_velocity_rma_v3_flat/2025-09-23_13-29-14_teacher/model_47000.pt')  
load_state.pop("optimizer_state_dict")

{'state': {0: {'step': tensor(320080.),
   'exp_avg': tensor([ 1.9521e-03, -2.2447e-03, -8.8397e-04,  2.3337e-04, -3.0605e-04,
            4.8897e-05, -1.5588e-04,  4.8077e-03, -2.0507e-03, -4.2651e-03,
           -1.6100e-04,  3.6352e-04], device='cuda:0'),
   'exp_avg_sq': tensor([3.4074e-04, 1.0350e-04, 1.1261e-04, 2.8901e-04, 1.0013e-04, 1.2299e-04,
           3.2847e-04, 1.5093e-04, 9.6700e-05, 3.1919e-04, 1.4540e-04, 1.0989e-04],
          device='cuda:0')},
  1: {'step': tensor(320080.),
   'exp_avg': tensor([[ 2.1984e-05, -1.1238e-05, -3.9152e-06,  ...,  1.6795e-05,
             9.5483e-06,  1.1016e-06],
           [ 1.3241e-05,  9.0679e-06, -1.5492e-06,  ..., -2.4441e-05,
            -1.9891e-05,  5.3387e-07],
           [-1.4483e-05,  1.2523e-06,  2.4870e-07,  ..., -1.0342e-04,
            -3.4190e-05, -7.4608e-07],
           ...,
           [-1.6481e-06,  6.2552e-06, -1.4977e-06,  ...,  1.8388e-05,
             6.2045e-06, -2.8542e-07],
           [ 4.2133e-05, -8.6492e-06,

In [26]:
load_state.keys()

dict_keys(['model_state_dict', 'iter', 'infos'])

In [4]:
torch.save(load_state, "/home/aivizw/tema_lab/logs/rsl_rl/go2_velocity_rma_v3_flat/2025-09-23_13-29-14_teacher/model_47000_1.pt")