In [2]:
import torch
import torch.nn as nn
from rsl_rl.modules import (
    ActorCritic,
    ActorCriticRecurrent,
    EmpiricalNormalization,
    StudentTeacher,
    StudentTeacherRecurrent,
)

In [9]:
model = StudentTeacherRecurrent(
    num_student_obs=45,
    num_teacher_obs=396,
    num_actions=12,
    init_noise_std=0.1,
    student_hidden_dims=[128, 128], 
    teacher_hidden_dims=[128, 128, 128],  
    activation="elu",
    rnn_type='lstm',
    rnn_hidden_dim=128,
    rnn_num_layers=1,
    teacher_recurrent=False
)

Student MLP: Sequential(
  (0): Linear(in_features=128, out_features=128, bias=True)
  (1): ELU(alpha=1.0)
  (2): Linear(in_features=128, out_features=128, bias=True)
  (3): ELU(alpha=1.0)
  (4): Linear(in_features=128, out_features=12, bias=True)
)
Teacher MLP: Sequential(
  (0): Linear(in_features=396, out_features=128, bias=True)
  (1): ELU(alpha=1.0)
  (2): Linear(in_features=128, out_features=128, bias=True)
  (3): ELU(alpha=1.0)
  (4): Linear(in_features=128, out_features=128, bias=True)
  (5): ELU(alpha=1.0)
  (6): Linear(in_features=128, out_features=12, bias=True)
)
Student RNN: Memory(
  (rnn): LSTM(45, 128)
)


In [10]:
rnn_actor_critic = torch.load('/home/tema/IsaacLab/logs/rsl_rl/go2_velocity_async_flat_rnn/2025-05-09_16-40-23/model_6000.pt', weights_only=True)['model_state_dict']
teacher_state_dict = torch.load('/home/tema/IsaacLab/logs/rsl_rl/go2_velocity_rma_flat/2025-05-08_02-51-09_teacher/model_43000.pt', weights_only=True)['model_state_dict']

In [11]:
new_dict = {key: value for key, value in model.state_dict().items()}

for key, value in new_dict.items():
    if 'student.' in key:
        new_dict[key] = rnn_actor_critic[key.replace('student.', 'actor.')]

    if 'memory_s.' in key:
        new_dict[key] = rnn_actor_critic[key.replace('memory_s.', 'memory_a.')]

    if 'teacher.' in key:
        new_dict[key] = teacher_state_dict[key.replace('teacher.', 'actor.')]

In [12]:
model.load_state_dict(new_dict)

True

In [13]:
saved_dict = {
    "model_state_dict": model.state_dict()
}

In [14]:
torch.save(saved_dict, '/home/tema/IsaacLab/logs/rsl_rl/go2_velocity_rma_flat/2025-05-08_02-51-09_teacher/pretrained_rnn.pt')

In [10]:
torch.load('/home/aivizw/Downloads/pretrained_ts.pt', weights_only=True)['model_state_dict']

OrderedDict([('std',
              tensor([0.3512, 0.4535, 0.3805, 0.3540, 0.4551, 0.3816, 0.3438, 0.3995, 0.3711,
                      0.3449, 0.3939, 0.3768])),
             ('student.0.weight',
              tensor([[ 0.0404,  0.1344, -0.0347,  ..., -0.1187,  0.0727, -0.1436],
                      [ 0.2064, -0.0596, -0.0717,  ..., -0.0331,  0.1622, -0.1386],
                      [ 0.0929,  0.0191,  0.1609,  ..., -0.2139, -0.0409,  0.0605],
                      ...,
                      [-0.2613,  0.0813, -0.0223,  ...,  0.0268,  0.2091, -0.1658],
                      [ 0.0560,  0.1813,  0.0370,  ...,  0.0993, -0.0481,  0.0192],
                      [-0.1768,  0.2862, -0.0030,  ...,  0.0944, -0.0012, -0.0919]])),
             ('student.0.bias',
              tensor([-8.3863e-03, -2.1921e-01, -9.5453e-02, -1.7385e-02, -1.1340e-01,
                      -6.2723e-02, -1.4344e-01, -3.9621e-02, -1.2094e-01, -1.2916e-02,
                      -2.3050e-01, -1.4431e-01, -1.5297e-01,  