In [2]:
import torch
import torch.nn as nn
from torch.distributions import Normal
import os

In [3]:
class CNN1d_o1_Student(nn.Module):
    def __init__(
        self,
        num_obs: int,
        num_actions: int,
        student_hidden_dims: list[int],
        student_cnn_kernel_sizes: list[int],
        student_cnn_strides: list[int],
        student_cnn_filters: list[int],
        student_cnn_paddings: list[int],
        student_cnn_dilations: list[int],
        len_latent: int = 64,
        len_o1: int = 48,
        enc_activation: bool = True
    ):
        super().__init__()
        activation = nn.ELU()

        self.num_obs = num_obs
        self.len_o1 = len_o1
        self.len_latent = len_latent

        # student
        s_out_channels = student_cnn_filters
        s_in_channels = [self.len_o1] + student_cnn_filters[:-1]
        
        cnn_student_layers = []
        s_cnn_out = self.num_obs - 1
        for in_ch, out_ch, kernel_size, stride, padding, dilation in zip(
            s_in_channels, 
            s_out_channels, 
            student_cnn_kernel_sizes, 
            student_cnn_strides, 
            student_cnn_paddings, 
            student_cnn_dilations
        ):
            cnn_student_layers.append(nn.Conv1d(
                in_channels=in_ch,
                out_channels=out_ch,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                dilation=dilation
            ))
            cnn_student_layers.append(nn.BatchNorm1d(out_ch))
            cnn_student_layers.append(activation)
            s_cnn_out = (s_cnn_out + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1

        cnn_student_layers.append(nn.Flatten())
        cnn_student_layers.append(nn.Linear(s_cnn_out * s_out_channels[-1], self.len_latent))
        if enc_activation:
            cnn_student_layers.append(activation)
        self.cnn_student = nn.Sequential(*cnn_student_layers)
        self.cnn_student.eval()

        student_layers = []
        student_layers.append(nn.Linear(self.len_latent + self.len_o1, student_hidden_dims[0]))
        student_layers.append(activation)
        for layer_index in range(len(student_hidden_dims)):
            if layer_index == len(student_hidden_dims) - 1:
                student_layers.append(nn.Linear(student_hidden_dims[layer_index], num_actions))
            else:
                student_layers.append(nn.Linear(student_hidden_dims[layer_index], student_hidden_dims[layer_index + 1]))
                student_layers.append(activation)
        self.student = nn.Sequential(*student_layers)
        self.student.eval()

        print(f"Student CNN: {self.cnn_student}")
        print(f"Student MLP: {self.student}")
        print(f"Student parameters: {sum([p.numel() for p in self.student.parameters()]) + sum([p.numel() for p in self.cnn_student.parameters()])}")

    def forward(self, observations):
        o_t = observations[-self.len_o1:]
        h = observations[:-self.len_o1].reshape(self.num_obs - 1, self.len_o1).permute(1, 0)
        z_t = self.cnn_student(h.unsqueeze(0)).squeeze(0)
        actions_mean = self.student(torch.cat((o_t, z_t)))
        return actions_mean
    
    @torch.jit.export
    def reset(self):
        pass

    def export(self, path, filename):
        os.makedirs(path, exist_ok=True)
        path = os.path.join(path, filename)
        self.to("cpu")
        traced_script_module = torch.jit.script(self)
        traced_script_module.save(path)

In [4]:
model = CNN1d_o1_Student(
    num_obs=65,
    num_actions=12,
    student_hidden_dims=[256, 128, 64],
    student_cnn_kernel_sizes=[5, 5, 5, 5, 5, 5],
    student_cnn_strides=[1, 2, 1, 2, 1, 2],
    student_cnn_filters=[32] * 6,
    student_cnn_paddings=[2, 2, 4, 2, 8, 2],
    student_cnn_dilations=[1, 1, 2, 1, 4, 1],
    len_latent=64,
    len_o1=48,
    enc_activation=False
)

Student CNN: Sequential(
  (0): Conv1d(48, 32, kernel_size=(5,), stride=(1,), padding=(2,))
  (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ELU(alpha=1.0)
  (3): Conv1d(32, 32, kernel_size=(5,), stride=(2,), padding=(2,))
  (4): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): ELU(alpha=1.0)
  (6): Conv1d(32, 32, kernel_size=(5,), stride=(1,), padding=(4,), dilation=(2,))
  (7): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (8): ELU(alpha=1.0)
  (9): Conv1d(32, 32, kernel_size=(5,), stride=(2,), padding=(2,))
  (10): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (11): ELU(alpha=1.0)
  (12): Conv1d(32, 32, kernel_size=(5,), stride=(1,), padding=(8,), dilation=(4,))
  (13): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (14): ELU(alpha=1.0)
  (15): Conv1d(32, 32, kernel_size=(5,), stride=(2,), paddin

In [5]:
model_path = '/home/aivizw/IsaacLab_5.0.0/logs/rsl_rl/go2_velocity_rma_v3_rough/2025-08-29_10-22-54'
model_name = 'model_5000.pt'

In [6]:
load_state = torch.load(model_path + '/' + model_name, weights_only=True, map_location=torch.device('cpu'))
model_state_dict = {key: value for key, value in model.state_dict().items()}

for key, value in load_state['model_state_dict'].items():
    if "student." in key:
        model_state_dict[key] = value

model.load_state_dict(model_state_dict)

<All keys matched successfully>

In [7]:
model.export(model_path, model_name.replace('.pt', '_jit.pt')) 

: 