In [15]:
from __future__ import annotations
import warnings

import torch
import torch.nn as nn
from torch.distributions import Normal
from rsl_rl.utils import resolve_nn_activation
import os

class Memory(torch.nn.Module):
    def __init__(
        self, 
        input_size, 
        type="gru", 
        num_layers=1, 
        hidden_size=256,
        mlp_dim=64,
        enc_activation=True
    ):
        super().__init__()
        # RNN
        rnn_cls = nn.GRU if type.lower() == "gru" else nn.LSTM
        self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
        
        mlp_layers = []
        mlp_layers.append(nn.Linear(hidden_size, mlp_dim))
        if enc_activation:
            mlp_layers.append(nn.ELU())
        self.mlp_enc = nn.Sequential(*mlp_layers)
        
        self.reset()

    def forward(self, input):
        hid_out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states)
        out = self.mlp_enc(hid_out).squeeze(0)
        return out

    def reset(self):
        self.hidden_states = torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size)


class Recurrent_o1_Student(nn.Module):
    is_recurrent = True

    def __init__(
        self,
        num_obs,
        num_actions,
        student_hidden_dims=[256, 256, 256],
        activation="elu",
        rnn_type="gru",
        rnn_hidden_dim=128,
        rnn_num_layers=1,
        init_noise_std=0.1,
        num_latent=64,
        enc_activation=True,
        **kwargs,
    ):
        super().__init__()
        if "rnn_hidden_size" in kwargs:
            warnings.warn(
                "The argument `rnn_hidden_size` is deprecated and will be removed in a future version. "
                "Please use `rnn_hidden_dim` instead.",
                DeprecationWarning,
            )
            if rnn_hidden_dim == 256:  # Only override if the new argument is at its default
                rnn_hidden_dim = kwargs.pop("rnn_hidden_size")
        if kwargs:
            print(
                "StudentTeacherRecurrent.__init__ got unexpected arguments, which will be ignored: "
                + str(kwargs.keys()),
            )

        activation = resolve_nn_activation(activation)

        self.num_obs = num_obs

        # student
        self.student_enc = Memory(
            num_obs, 
            type=rnn_type, 
            num_layers=rnn_num_layers, 
            hidden_size=rnn_hidden_dim, 
            mlp_dim=num_latent,
            enc_activation=enc_activation
        )
        student_layers = []
        student_layers.append(nn.Linear(num_latent + self.num_obs, student_hidden_dims[0]))
        student_layers.append(activation)
        for layer_index in range(len(student_hidden_dims)):
            if layer_index == len(student_hidden_dims) - 1:
                student_layers.append(nn.Linear(student_hidden_dims[layer_index], num_actions))
            else:
                student_layers.append(nn.Linear(student_hidden_dims[layer_index], student_hidden_dims[layer_index + 1]))
                student_layers.append(activation)
        self.student = nn.Sequential(*student_layers)

        print(f"Student CNN: {self.student_enc}")
        print(f"Student MLP: {self.student}")
        print(f"Student parameters: {sum([p.numel() for p in self.student_enc.parameters()]) + sum([p.numel() for p in self.student.parameters()])}\n")

    def forward(self, x):
        l_t = self.student_enc(x)
        actions_mean = self.student(torch.cat((x, l_t), dim=1))
        return actions_mean
    
    @torch.jit.export
    def reset(self):
        self.student_enc.reset()

    def export(self, path, filename):
        os.makedirs(path, exist_ok=True)
        path = os.path.join(path, filename)
        self.to("cpu")
        traced_script_module = torch.jit.script(self)
        traced_script_module.save(path)

In [29]:
# model = Recurrent_o1_Student(
#     num_obs=45,
#     num_actions=12,
#     student_hidden_dims=[256, 128, 64],
#     rnn_type='gru',
#     rnn_hidden_dim=256,
#     rnn_num_layers=1,
#     len_latent=64,
#     enc_activation=True
# )

# model = Recurrent_o1_Student(
#     num_obs=45,
#     num_actions=12,
#     student_hidden_dims=[256, 128, 64],
#     rnn_type='gru',
#     rnn_hidden_dim=256,
#     rnn_num_layers=1,
#     len_latent=64,
#     enc_activation=False
# )

model = Recurrent_o1_Student(
    num_obs=59,
    num_actions=16,
    student_hidden_dims=[256, 128, 64],
    rnn_type='gru',
    rnn_hidden_dim=256,
    rnn_num_layers=1,
    len_latent=64,
    enc_activation=False
)

StudentTeacherRecurrent.__init__ got unexpected arguments, which will be ignored: dict_keys(['len_latent'])
Student CNN: Memory(
  (rnn): GRU(59, 256)
  (mlp_enc): Sequential(
    (0): Linear(in_features=256, out_features=64, bias=True)
  )
)
Student MLP: Sequential(
  (0): Linear(in_features=123, out_features=256, bias=True)
  (1): ELU(alpha=1.0)
  (2): Linear(in_features=256, out_features=128, bias=True)
  (3): ELU(alpha=1.0)
  (4): Linear(in_features=128, out_features=64, bias=True)
  (5): ELU(alpha=1.0)
  (6): Linear(in_features=64, out_features=16, bias=True)
)
Student parameters: 333840



In [36]:
model_path = '/home/aivizw/tema_lab/logs/rsl_rl/go2_velocity_rma_v3_ftg_flat/2025-09-25_17-02-23'
model_name = 'model_5000.pt'

In [37]:
load_state = torch.load(model_path + '/' + model_name, weights_only=True, map_location=torch.device('cpu'))
model_state_dict = {key: value for key, value in model.state_dict().items()}

for key, value in load_state['model_state_dict'].items():
    if "student" in key:
        model_state_dict[key] = value

model.load_state_dict(model_state_dict)
model.eval()

Recurrent_o1_Student(
  (student_enc): Memory(
    (rnn): GRU(59, 256)
    (mlp_enc): Sequential(
      (0): Linear(in_features=256, out_features=64, bias=True)
    )
  )
  (student): Sequential(
    (0): Linear(in_features=123, out_features=256, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features=256, out_features=128, bias=True)
    (3): ELU(alpha=1.0)
    (4): Linear(in_features=128, out_features=64, bias=True)
    (5): ELU(alpha=1.0)
    (6): Linear(in_features=64, out_features=16, bias=True)
  )
)

In [38]:
model.export(model_path, model_name.replace('.pt', '_jit.pt')) 