In [None]:
import torch
import torch.nn as nn
import numpy as np

class Agent(nn.Module):
    def __init__(self, envs, gaussian=False):
        super().__init__()

        self.is_gaussian = gaussian

        self.critic = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 256)),
            nn.Tanh(),
            layer_init(nn.Linear(256, 256)),
            nn.Tanh(),
            layer_init(nn.Linear(256, 1), std=1.0),
        )

        self.actor = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 256)),
            nn.Tanh(),
            layer_init(nn.Linear(256, 256)),
            nn.Tanh(),
            BetaHead(256, np.prod(envs.single_action_space.shape)),
        )

    def expand_layer(self, network, layer_idx, new_out_features):
        """ Expands a given layer and updates the next layer's input features """
        old_layer = network[layer_idx]
        next_layer = network[layer_idx + 2]  # Skip activation function (Tanh)

        if not isinstance(old_layer, nn.Linear) or not isinstance(next_layer, nn.Linear):
            raise ValueError(f"Layers at indices {layer_idx} and {layer_idx+2} must be Linear layers.")

        in_features = old_layer.in_features
        old_out_features = old_layer.out_features

        # Expand first layer
        new_layer = nn.Linear(in_features, new_out_features)
        new_layer.weight.data[:old_out_features, :] = old_layer.weight.data
        new_layer.bias.data[:old_out_features] = old_layer.bias.data
        nn.init.xavier_uniform_(new_layer.weight.data[old_out_features:, :])
        nn.init.zeros_(new_layer.bias.data[old_out_features:])

        # Expand second layer (update input features)
        old_next_weights = next_layer.weight.data.clone()
        old_next_bias = next_layer.bias.data.clone()

        new_next_layer = nn.Linear(new_out_features, next_layer.out_features)
        new_next_layer.weight.data[:, :old_out_features] = old_next_weights
        new_next_layer.bias.data = old_next_bias
        nn.init.xavier_uniform_(new_next_layer.weight.data[:, old_out_features:])
        
        # Replace layers in network
        network[layer_idx] = new_layer
        network[layer_idx + 2] = new_next_layer  # Skip activation function

    def expand_actor(self, layer_idx, new_out_features):
        self.expand_layer(self.actor, layer_idx, new_out_features)

    def expand_critic(self, layer_idx, new_out_features):
        self.expand_layer(self.critic, layer_idx, new_out_features)

ModuleNotFoundError: No module named 'torch._C'