In [47]:
import torch
import torch.nn as nn
from torch.distributions import Normal
import os

class CNN1d_o1_Student(nn.Module):
    def __init__(
        self,
        num_obs: int,
        num_actions: int,
        student_hidden_dims: list[int],
        student_cnn_kernel_sizes: list[int],
        student_cnn_strides: list[int],
        student_cnn_filters: list[int],
        student_cnn_paddings: list[int],
        student_cnn_dilations: list[int],
        len_latent: int = 64,
        len_o1: int = 48,
        enc_activation: bool = True
    ):
        super().__init__()
        activation = nn.ELU()

        self.num_obs = num_obs
        self.len_o1 = len_o1
        self.len_latent = len_latent

        # student
        s_out_channels = student_cnn_filters
        s_in_channels = [self.len_o1] + student_cnn_filters[:-1]
        
        cnn_student_layers = []
        s_cnn_out = self.num_obs - 1
        for in_ch, out_ch, kernel_size, stride, padding, dilation in zip(
            s_in_channels, 
            s_out_channels, 
            student_cnn_kernel_sizes, 
            student_cnn_strides, 
            student_cnn_paddings, 
            student_cnn_dilations
        ):
            cnn_student_layers.append(nn.Conv1d(
                in_channels=in_ch,
                out_channels=out_ch,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                dilation=dilation
            ))
            cnn_student_layers.append(nn.BatchNorm1d(out_ch))
            cnn_student_layers.append(nn.ReLU())
            s_cnn_out = (s_cnn_out + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1

        cnn_student_layers.append(nn.Flatten())
        cnn_student_layers.append(nn.Linear(s_cnn_out * s_out_channels[-1], self.len_latent))
        if enc_activation:
            cnn_student_layers.append(activation)
        self.cnn_student = nn.Sequential(*cnn_student_layers)
        self.cnn_student.eval()

        student_layers = []
        student_layers.append(nn.Linear(self.len_latent + self.len_o1, student_hidden_dims[0]))
        student_layers.append(activation)
        for layer_index in range(len(student_hidden_dims)):
            if layer_index == len(student_hidden_dims) - 1:
                student_layers.append(nn.Linear(student_hidden_dims[layer_index], num_actions))
            else:
                student_layers.append(nn.Linear(student_hidden_dims[layer_index], student_hidden_dims[layer_index + 1]))
                student_layers.append(activation)
        self.student = nn.Sequential(*student_layers)
        self.student.eval()

        print(f"Student CNN: {self.cnn_student}")
        print(f"Student MLP: {self.student}")
        print(f"Student parameters: {sum([p.numel() for p in self.student.parameters()]) + sum([p.numel() for p in self.cnn_student.parameters()])}")

    def forward(self, observations):
        o_t = observations[-self.len_o1:]
        h = observations[:-self.len_o1].reshape(self.num_obs - 1, self.len_o1).permute(1, 0)
        z_t = self.cnn_student(h.unsqueeze(0)).squeeze(0)
        actions_mean = self.student(torch.cat((o_t, z_t)))
        return actions_mean
    
    @torch.jit.export
    def reset(self):
        pass

    def export(self, path, filename):
        os.makedirs(path, exist_ok=True)
        path = os.path.join(path, filename)
        self.to("cpu")
        traced_script_module = torch.jit.script(self)
        traced_script_module.save(path)

In [48]:
model = CNN1d_o1_Student(
    num_obs=65,
    num_actions=12,
    student_hidden_dims=[256, 128, 64],
    student_cnn_kernel_sizes=[5, 5, 5, 5, 5, 5],
    student_cnn_strides=[1, 2, 1, 2, 1, 2],
    student_cnn_filters=[64] * 6,
    student_cnn_paddings=[2, 2, 4, 2, 8, 2],
    student_cnn_dilations=[1, 1, 2, 1, 4, 1],
    len_latent=64,
    len_o1=48,
    # enc_activation=False
)

Student CNN: Sequential(
  (0): Conv1d(48, 64, kernel_size=(5,), stride=(1,), padding=(2,))
  (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
  (3): Conv1d(64, 64, kernel_size=(5,), stride=(2,), padding=(2,))
  (4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): ReLU()
  (6): Conv1d(64, 64, kernel_size=(5,), stride=(1,), padding=(4,), dilation=(2,))
  (7): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (8): ReLU()
  (9): Conv1d(64, 64, kernel_size=(5,), stride=(2,), padding=(2,))
  (10): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (11): ReLU()
  (12): Conv1d(64, 64, kernel_size=(5,), stride=(1,), padding=(8,), dilation=(4,))
  (13): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (14): ReLU()
  (15): Conv1d(64, 64, kernel_size=(5,), stride=(2,), padding=(2,))
  (16): BatchNorm1d(64, eps=1e-0

In [52]:
model_path = '/home/tema/IsaacLab/logs/rsl_rl/go2_velocity_rma_v2_1_rough/2025-09-04_12-25-00'
model_name = 'model_3000.pt'

In [53]:
load_state = torch.load(model_path + '/' + model_name, weights_only=True, map_location=torch.device('cpu'))
model_state_dict = {key: value for key, value in model.state_dict().items()}

for key, value in load_state['model_state_dict'].items():
    if "student." in key:
        model_state_dict[key] = value

model.load_state_dict(model_state_dict)

<All keys matched successfully>

In [54]:
model.export(model_path, model_name.replace('.pt', '_jit.pt')) 

In [39]:
from __future__ import annotations

import math
import torch
import torch.nn as nn
from torch.distributions import Normal

from rsl_rl.utils import resolve_nn_activation


class ResidualBlock(nn.Module):
    def __init__(
        self, 
        kernel_sizes: list[int],
        strides: list[int],
        in_channels: list[int],
        out_channels: list[int],
        paddings: list[int],
        dilations: list[int]
    ):
        super().__init__()
        
        self.activation = nn.ReLU()
        
        cnn_layers = []
        
        for in_ch, out_ch, kernel_size, stride, padding, dilation in zip(
            in_channels, 
            out_channels, 
            kernel_sizes, 
            strides, 
            paddings, 
            dilations
        ):
            cnn_layers.append(nn.Conv1d(
                in_channels=in_ch,
                out_channels=out_ch,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                dilation=dilation
            ))
            cnn_layers.append(nn.BatchNorm1d(out_ch))
            cnn_layers.append(self.activation)
            
        self.cnn = nn.Sequential(*cnn_layers)

        self.downsample = nn.Sequential(
            nn.Conv1d(
                in_channels=in_channels[0], 
                out_channels=out_channels[-1],
                kernel_size=1,
                stride=math.prod(strides),
            ),
            nn.BatchNorm1d(out_channels[-1])
        )

    def forward(self, x):
        identity = x
        out = self.cnn(x)
        identity = self.downsample(x)
        out += identity 
        out = self.activation(out)
        return out


class CNN1d_o1_res_Student(nn.Module):
    is_recurrent = False

    def __init__(
        self,
        num_obs,
        num_actions,
        student_hidden_dims=[256, 256, 256],
        activation="elu",
        student_cnn_blocks_kernel_sizes: list[list[int]] = [[5, 5, 5], [5, 5, 5]],
        student_cnn_blocks_strides: list[list[int]] = [[1, 2, 1], [2, 1, 2]],
        student_cnn_blocks_in_channels: list[list[int]] = [[48, 32, 32], [32, 32, 32]],
        student_cnn_blocks_out_channels: list[list[int]] = [[32, 32, 32], [32, 32, 32]],
        student_cnn_blocks_paddings: list[list[int]] = [[2, 2, 4], [2, 8, 2]],
        student_cnn_blocks_dilations: list[list[int]] = [[1, 1, 2], [1, 4, 1]],
        len_latent: int = 64,
        len_o1: int = 48,
        enc_activation: bool = True,
        **kwargs,
    ):
        super().__init__()
        activation = resolve_nn_activation(activation)

        self.num_obs = num_obs
        self.len_o1 = len_o1
        s_cnn_out = 8

        # student
        cnn_student_layers = [
            ResidualBlock(
                kernel_sizes, 
                strides, 
                in_channels,
                out_channels, 
                paddings, 
                dilations
            ) for kernel_sizes, strides, in_channels, out_channels, paddings, dilations in zip(
                student_cnn_blocks_kernel_sizes, 
                student_cnn_blocks_strides, 
                student_cnn_blocks_in_channels, 
                student_cnn_blocks_out_channels, 
                student_cnn_blocks_paddings, 
                student_cnn_blocks_dilations
            )
        ]

        cnn_student_layers.append(nn.Flatten())
        cnn_student_layers.append(nn.Linear(s_cnn_out * student_cnn_blocks_out_channels[-1][-1], len_latent))
        if enc_activation:
            cnn_student_layers.append(activation)
        self.cnn_student = nn.Sequential(*cnn_student_layers)

        student_layers = []
        student_layers.append(nn.Linear(len_latent + self.len_o1, student_hidden_dims[0]))
        student_layers.append(activation)
        for layer_index in range(len(student_hidden_dims)):
            if layer_index == len(student_hidden_dims) - 1:
                student_layers.append(nn.Linear(student_hidden_dims[layer_index], num_actions))
            else:
                student_layers.append(nn.Linear(student_hidden_dims[layer_index], student_hidden_dims[layer_index + 1]))
                student_layers.append(activation)
        self.student = nn.Sequential(*student_layers)

        print(f"Student CNN: {self.cnn_student}")
        print(f"Student MLP: {self.student}")
        print(f"Student parameters: {sum([p.numel() for p in self.cnn_student.parameters()]) + sum([p.numel() for p in self.student.parameters()])}\n")
        
    def forward(self, observations):
        o_t = observations[-self.len_o1:]
        h = observations[:-self.len_o1].reshape(self.num_obs - 1, self.len_o1).permute(1, 0)
        z_t = self.cnn_student(h.unsqueeze(0)).squeeze(0)
        actions_mean = self.student(torch.cat((o_t, z_t)))
        return actions_mean
    
    @torch.jit.export
    def reset(self):
        pass

    def export(self, path, filename):
        os.makedirs(path, exist_ok=True)
        path = os.path.join(path, filename)
        self.to("cpu")
        traced_script_module = torch.jit.script(self)
        traced_script_module.save(path)
        
    

In [41]:
model = CNN1d_o1_res_Student(
    num_obs=65,
    num_actions=12,
    student_hidden_dims=[256, 128, 64], 
    activation="elu",
    student_cnn_blocks_kernel_sizes=[[5, 5, 5], [5, 5, 5]],
    student_cnn_blocks_strides=[[1, 2, 1], [2, 1, 2]],
    student_cnn_blocks_in_channels=[[48, 32, 32], [32, 32, 32]],
    student_cnn_blocks_out_channels=[[32, 32, 32], [32, 32, 32]],
    student_cnn_blocks_paddings=[[2, 2, 4], [2, 8, 2]],
    student_cnn_blocks_dilations=[[1, 1, 2], [1, 4, 1]],
    len_latent=64,
    len_o1=48,
    enc_activation=True
)

Student CNN: Sequential(
  (0): ResidualBlock(
    (activation): ReLU()
    (cnn): Sequential(
      (0): Conv1d(48, 32, kernel_size=(5,), stride=(1,), padding=(2,))
      (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv1d(32, 32, kernel_size=(5,), stride=(2,), padding=(2,))
      (4): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): Conv1d(32, 32, kernel_size=(5,), stride=(1,), padding=(4,), dilation=(2,))
      (7): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ReLU()
    )
    (downsample): Sequential(
      (0): Conv1d(48, 32, kernel_size=(1,), stride=(2,))
      (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (1): ResidualBlock(
    (activation): ReLU()
    (cnn): Sequential(
      (0): Conv1d(32, 32, kernel_size=(5,), stride=(2,), padding=(2,))
      (1): Batc

In [44]:
model_path = '/home/tema/IsaacLab/logs/rsl_rl/go2_velocity_rma_v2_1_rough/2025-09-04_11-11-55'
model_name = 'model_7000.pt'

In [45]:
load_state = torch.load(model_path + '/' + model_name, weights_only=True, map_location=torch.device('cpu'))
model_state_dict = {key: value for key, value in model.state_dict().items()}

for key, value in load_state['model_state_dict'].items():
    if "student." in key:
        model_state_dict[key] = value

model.load_state_dict(model_state_dict)

model.export(model_path, model_name.replace('.pt', '_jit.pt')) 