In [1]:
# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import torch
import torch.nn as nn
from torch.distributions import Normal

from rsl_rl.utils import resolve_nn_activation


class CNN1d_o1_StudentTeacher(nn.Module):
    is_recurrent = False

    def __init__(
        self,
        num_student_obs,
        num_teacher_obs,
        num_actions,
        student_hidden_dims=[256, 256, 256],
        teacher_hidden_dims=[256, 256, 256],
        activation="elu",
        init_noise_std=0.1,
        student_cnn_kernel_sizes: list[int] = [3, 3, 3],
        student_cnn_strides: list[int] = [3, 3, 3],
        student_cnn_filters: list[int] = [32, 16, 8],
        student_cnn_paddings: list[int] = [0, 0, 1],
        student_cnn_dilations: list[int] = [1, 1, 1],
        teacher_enc_dims: list[int] = [128, 64],
        len_o1: int = 48,
        sum_student_obs: int = 65,
        enc_activation: bool = True,
        **kwargs,
    ):
        if kwargs:
            print(
                "StudentTeacher.__init__ got unexpected arguments, which will be ignored: "
                + str([key for key in kwargs.keys()])
            )
        super().__init__()
        activation = resolve_nn_activation(activation)
        self.loaded_teacher = False  # indicates if teacher has been loaded

        self.len_student_obs = sum_student_obs
        self.num_teacher_obs = num_teacher_obs
        self.len_o1 = len_o1

        # student
        s_out_channels = student_cnn_filters
        s_in_channels = [self.len_o1] + student_cnn_filters[:-1]

        cnn_student_layers = []
        s_cnn_out = self.len_student_obs - 1
        for in_ch, out_ch, kernel_size, stride, padding, dilation in zip(
            s_in_channels, 
            s_out_channels, 
            student_cnn_kernel_sizes, 
            student_cnn_strides, 
            student_cnn_paddings, 
            student_cnn_dilations
        ):
            cnn_student_layers.append(nn.Conv1d(
                in_channels=in_ch,
                out_channels=out_ch,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                dilation=dilation
            ))
            cnn_student_layers.append(nn.BatchNorm1d(out_ch))
            cnn_student_layers.append(nn.ReLU())
            s_cnn_out = (s_cnn_out + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1

        cnn_student_layers.append(nn.Flatten())
        cnn_student_layers.append(nn.Linear(s_cnn_out * s_out_channels[-1], teacher_enc_dims[-1]))
        if enc_activation:
            cnn_student_layers.append(activation)
        self.cnn_student = nn.Sequential(*cnn_student_layers)

        student_layers = []
        student_layers.append(nn.Linear(teacher_enc_dims[-1] + self.len_o1, student_hidden_dims[0]))
        student_layers.append(activation)
        for layer_index in range(len(student_hidden_dims)):
            if layer_index == len(student_hidden_dims) - 1:
                student_layers.append(nn.Linear(student_hidden_dims[layer_index], num_actions))
            else:
                student_layers.append(nn.Linear(student_hidden_dims[layer_index], student_hidden_dims[layer_index + 1]))
                student_layers.append(activation)
                # student_layers.append(nn.Dropout(0.3))
        self.student = nn.Sequential(*student_layers)

        # teacher
        teacher_enc_layers = []
        teacher_enc_layers.append(nn.Linear(self.num_teacher_obs - self.len_o1, teacher_enc_dims[0]))
        teacher_enc_layers.append(activation)
        for layer_index in range(len(teacher_enc_dims) - 1):
            teacher_enc_layers.append(nn.Linear(teacher_enc_dims[layer_index], teacher_enc_dims[layer_index + 1]))
            teacher_enc_layers.append(activation)
        self.teacher_enc = nn.Sequential(*teacher_enc_layers)
        self.teacher_enc.eval()

        teacher_layers = []
        teacher_layers.append(nn.Linear(teacher_enc_dims[-1] + self.len_o1, teacher_hidden_dims[0]))
        teacher_layers.append(activation)
        for layer_index in range(len(teacher_hidden_dims)):
            if layer_index == len(teacher_hidden_dims) - 1:
                teacher_layers.append(nn.Linear(teacher_hidden_dims[layer_index], num_actions))
            else:
                teacher_layers.append(nn.Linear(teacher_hidden_dims[layer_index], teacher_hidden_dims[layer_index + 1]))
                teacher_layers.append(activation)
        self.teacher = nn.Sequential(*teacher_layers)
        self.teacher.eval()

        print(f"Student CNN: {self.cnn_student}")
        print(f"Student MLP: {self.student}")
        print(f"Student parameters: {sum([p.numel() for p in self.cnn_student.parameters()]) + sum([p.numel() for p in self.student.parameters()])}\n")
        print(f"Teacher Encoder: {self.teacher_enc}")
        print(f"Teacher MLP: {self.teacher}")
        print(f"Teacher parameters: {sum([p.numel() for p in self.teacher_enc.parameters()]) + sum([p.numel() for p in self.teacher.parameters()])}")

        # action noise
        self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
        self.distribution = None
        # disable args validation for speedup
        Normal.set_default_validate_args = False

In [2]:
student_teacher = CNN1d_o1_StudentTeacher(
    num_student_obs=258,
    num_teacher_obs=258,
    num_actions=12,
    student_hidden_dims=[256, 128, 64], 
    teacher_hidden_dims=[256, 128, 64], 
    activation="elu",
    student_cnn_kernel_sizes=[5, 5, 5, 5, 5, 5],
    student_cnn_strides=[1, 2, 1, 2, 1, 2],
    student_cnn_filters=[32] * 6,
    student_cnn_paddings=[2, 2, 4, 2, 8, 2],
    student_cnn_dilations=[1, 1, 2, 1, 4, 1],
    teacher_enc_dims=[128, 64],
    len_o1=48,
    sum_student_obs=65,
    enc_activation=True
)

Student CNN: Sequential(
  (0): Conv1d(48, 32, kernel_size=(5,), stride=(1,), padding=(2,))
  (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
  (3): Conv1d(32, 32, kernel_size=(5,), stride=(2,), padding=(2,))
  (4): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): ReLU()
  (6): Conv1d(32, 32, kernel_size=(5,), stride=(1,), padding=(4,), dilation=(2,))
  (7): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (8): ReLU()
  (9): Conv1d(32, 32, kernel_size=(5,), stride=(2,), padding=(2,))
  (10): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (11): ReLU()
  (12): Conv1d(32, 32, kernel_size=(5,), stride=(1,), padding=(8,), dilation=(4,))
  (13): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (14): ReLU()
  (15): Conv1d(32, 32, kernel_size=(5,), stride=(2,), padding=(2,))
  (16): BatchNorm1d(32, eps=1e-0

In [6]:
path = '/home/tema/IsaacLab/logs/rsl_rl/go2_velocity_rma_v2_1_rough/2025-09-02_18-29-31_teacher/model_87000.pt'
teacher_state_dict = torch.load(path, weights_only=True)['model_state_dict']

In [7]:
model_state_dict = {key: value for key, value in student_teacher.state_dict().items()}

for key, value in teacher_state_dict.items():
    if "actor." in key:
        model_state_dict[key.replace("actor.", "teacher.")] = value
        model_state_dict[key.replace("actor.", "student.")] = value

    if "actor_enc." in key:
        model_state_dict[key.replace("actor_enc.", "teacher_enc.")] = value

student_teacher.load_state_dict(model_state_dict)

saved_dict = {
    "model_state_dict": student_teacher.state_dict()
}

In [8]:
torch.save(saved_dict, path.replace('.pt', '_student_65_f32.pt'))
# torch.save(saved_dict, path.replace('.pt', '_student_101.pt'))