In [8]:
from __future__ import annotations
import warnings

import torch
import torch.nn as nn
from torch.distributions import Normal
from rsl_rl.utils import resolve_nn_activation
import numpy as np

class Memory(torch.nn.Module):
    def __init__(
        self, 
        input_size, 
        type="gru", 
        num_layers=1, 
        hidden_size=256,
        mlp_dim=64,
        enc_activation=True
    ):
        super().__init__()
        # RNN
        rnn_cls = nn.GRU if type.lower() == "gru" else nn.LSTM
        self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
        self.hidden_states = None
        
        mlp_layers = []
        mlp_layers.append(nn.Linear(hidden_size, mlp_dim))
        if enc_activation:
            mlp_layers.append(nn.ELU())
        self.mlp_enc = nn.Sequential(*mlp_layers)

    def forward(self, input):
        hid_out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states)
        out = self.mlp_enc(hid_out).squeeze(0)
        return out

    def reset(self, dones=None, hidden_states=None):
        if dones is None:  # reset all hidden states
            if hidden_states is None:
                self.hidden_states = None
            else:
                self.hidden_states = hidden_states
        elif self.hidden_states is not None:  # reset hidden states of done environments
            if hidden_states is None:
                if isinstance(self.hidden_states, tuple):  # tuple in case of LSTM
                    for hidden_state in self.hidden_states:
                        hidden_state[..., dones == 1, :] = 0.0
                else:
                    self.hidden_states[..., dones == 1, :] = 0.0
            else:
                NotImplementedError(
                    "Resetting hidden states of done environments with custom hidden states is not implemented"
                )

    def detach_hidden_states(self, dones=None):
        if self.hidden_states is not None:
            if dones is None:  # detach all hidden states
                if isinstance(self.hidden_states, tuple):  # tuple in case of LSTM
                    self.hidden_states = tuple(hidden_state.detach() for hidden_state in self.hidden_states)
                else:
                    self.hidden_states = self.hidden_states.detach()
            else:  # detach hidden states of done environments
                if isinstance(self.hidden_states, tuple):  # tuple in case of LSTM
                    for hidden_state in self.hidden_states:
                        hidden_state[..., dones == 1, :] = hidden_state[..., dones == 1, :].detach()
                else:
                    self.hidden_states[..., dones == 1, :] = self.hidden_states[..., dones == 1, :].detach()


class Recurrent_o1_StudentTeacher(nn.Module):
    is_recurrent = True

    def __init__(
        self,
        num_student_obs,
        num_teacher_obs,
        num_actions,
        student_hidden_dims=[256, 256, 256],
        teacher_hidden_dims=[256, 256, 256],
        activation="elu",
        rnn_type="gru",
        rnn_hidden_dim=128,
        rnn_num_layers=1,
        init_noise_std=0.1,
        teacher_enc_dims=[128, 64],
        enc_activation=True,
        **kwargs,
    ):
        super().__init__()
        if "rnn_hidden_size" in kwargs:
            warnings.warn(
                "The argument `rnn_hidden_size` is deprecated and will be removed in a future version. "
                "Please use `rnn_hidden_dim` instead.",
                DeprecationWarning,
            )
            if rnn_hidden_dim == 256:  # Only override if the new argument is at its default
                rnn_hidden_dim = kwargs.pop("rnn_hidden_size")
        if kwargs:
            print(
                "StudentTeacherRecurrent.__init__ got unexpected arguments, which will be ignored: "
                + str(kwargs.keys()),
            )

        activation = resolve_nn_activation(activation)

        self.num_student_obs = num_student_obs
        self.num_teacher_obs = num_teacher_obs
        self.rnn_hidden_dim = rnn_hidden_dim

        # student
        self.student_enc = Memory(
            num_student_obs, 
            type=rnn_type, 
            num_layers=rnn_num_layers, 
            hidden_size=rnn_hidden_dim, 
            mlp_dim=teacher_enc_dims[-1], 
            enc_activation=enc_activation
        )
        student_layers = []
        student_layers.append(nn.Linear(teacher_enc_dims[-1] + self.num_student_obs, student_hidden_dims[0]))
        student_layers.append(activation)
        for layer_index in range(len(student_hidden_dims)):
            if layer_index == len(student_hidden_dims) - 1:
                student_layers.append(nn.Linear(student_hidden_dims[layer_index], num_actions))
            else:
                student_layers.append(nn.Linear(student_hidden_dims[layer_index], student_hidden_dims[layer_index + 1]))
                student_layers.append(activation)
        self.student = nn.Sequential(*student_layers)

        # teacher
        teacher_enc_layers = []
        teacher_enc_layers.append(nn.Linear(self.num_teacher_obs - self.num_student_obs, teacher_enc_dims[0]))
        teacher_enc_layers.append(activation)
        for layer_index in range(len(teacher_enc_dims) - 1):
            teacher_enc_layers.append(nn.Linear(teacher_enc_dims[layer_index], teacher_enc_dims[layer_index + 1]))
            if layer_index != len(teacher_enc_dims) - 2:
                teacher_enc_layers.append(activation)
            elif enc_activation:
                teacher_enc_layers.append(activation)
        self.teacher_enc = nn.Sequential(*teacher_enc_layers)
        self.teacher_enc.eval()

        teacher_layers = []
        teacher_layers.append(nn.Linear(teacher_enc_dims[-1] + self.num_student_obs, teacher_hidden_dims[0]))
        teacher_layers.append(activation)
        for layer_index in range(len(teacher_hidden_dims)):
            if layer_index == len(teacher_hidden_dims) - 1:
                teacher_layers.append(nn.Linear(teacher_hidden_dims[layer_index], num_actions))
            else:
                teacher_layers.append(nn.Linear(teacher_hidden_dims[layer_index], teacher_hidden_dims[layer_index + 1]))
                teacher_layers.append(activation)
        self.teacher = nn.Sequential(*teacher_layers)
        self.teacher.eval()

        print(f"Student CNN: {self.student_enc}")
        print(f"Student MLP: {self.student}")
        print(f"Student parameters: {sum([p.numel() for p in self.student_enc.parameters()]) + sum([p.numel() for p in self.student.parameters()])}\n")
        print(f"Teacher Encoder: {self.teacher_enc}")
        print(f"Teacher MLP: {self.teacher}")
        print(f"Teacher parameters: {sum([p.numel() for p in self.teacher_enc.parameters()]) + sum([p.numel() for p in self.teacher.parameters()])}")

        # action noise
        self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
        self.distribution = None
        # disable args validation for speedup
        Normal.set_default_validate_args = False

In [12]:
# student_teacher = Recurrent_o1_StudentTeacher(
#     num_student_obs=45,
#     num_teacher_obs=280,
#     num_actions=12,
#     student_hidden_dims=[256, 128, 64], 
#     teacher_hidden_dims=[256, 128, 64], 
#     activation="elu",
#     rnn_type='gru',
#     rnn_hidden_dim=256,
#     teacher_enc_dims=[256, 128, 64],
#     enc_activation=False
# )

student_teacher = Recurrent_o1_StudentTeacher(
    num_student_obs=59,
    num_teacher_obs=294,
    num_actions=16,
    student_hidden_dims=[256, 128, 64], 
    teacher_hidden_dims=[256, 128, 64], 
    activation="elu",
    rnn_type='gru',
    rnn_hidden_dim=256,
    teacher_enc_dims=[256, 128, 64],
    enc_activation=False
)

Student CNN: Memory(
  (rnn): GRU(59, 256)
  (mlp_enc): Sequential(
    (0): Linear(in_features=256, out_features=64, bias=True)
  )
)
Student MLP: Sequential(
  (0): Linear(in_features=123, out_features=256, bias=True)
  (1): ELU(alpha=1.0)
  (2): Linear(in_features=256, out_features=128, bias=True)
  (3): ELU(alpha=1.0)
  (4): Linear(in_features=128, out_features=64, bias=True)
  (5): ELU(alpha=1.0)
  (6): Linear(in_features=64, out_features=16, bias=True)
)
Student parameters: 333840

Teacher Encoder: Sequential(
  (0): Linear(in_features=235, out_features=256, bias=True)
  (1): ELU(alpha=1.0)
  (2): Linear(in_features=256, out_features=128, bias=True)
  (3): ELU(alpha=1.0)
  (4): Linear(in_features=128, out_features=64, bias=True)
)
Teacher MLP: Sequential(
  (0): Linear(in_features=123, out_features=256, bias=True)
  (1): ELU(alpha=1.0)
  (2): Linear(in_features=256, out_features=128, bias=True)
  (3): ELU(alpha=1.0)
  (4): Linear(in_features=128, out_features=64, bias=True)
  (5)

In [13]:
path = '/home/aivizw/tema_lab/logs/rsl_rl/go2_velocity_rma_v3_ftg_flat/2025-09-25_15-41-46_teacher/model_14000.pt'
teacher_state_dict = torch.load(path, weights_only=True)['model_state_dict']

In [14]:
model_state_dict = {key: value for key, value in student_teacher.state_dict().items()}

for key, value in teacher_state_dict.items():
    if "actor." in key:
        model_state_dict[key.replace("actor.", "teacher.")] = value
        model_state_dict[key.replace("actor.", "student.")] = value

    if "actor_enc." in key:
        model_state_dict[key.replace("actor_enc.", "teacher_enc.")] = value

student_teacher.load_state_dict(model_state_dict)

saved_dict = {
    "model_state_dict": student_teacher.state_dict()
}

In [15]:
torch.save(saved_dict, path.replace('.pt', f'_student_rnn{student_teacher.rnn_hidden_dim}.pt'))
# torch.save(saved_dict, path.replace('.pt', '_student_101.pt'))