In [14]:
%pip install graphviz

Note: you may need to restart the kernel to use updated packages.




In [13]:
import torch
import torch.nn as nn
from torchviz import make_dot

# Define the neural network architecture
class ActorCriticPolicy(nn.Module):
    def __init__(self):
        super(ActorCriticPolicy, self).__init__()

        self.features_extractor = nn.Flatten()
        self.pi_features_extractor = nn.Flatten()
        self.vf_features_extractor = nn.Flatten()

        self.mlp_extractor = nn.Sequential(
            nn.Linear(20, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh()
        )

        self.action_net = nn.Linear(64, 2)  # Output for actions
        self.value_net = nn.Linear(64, 1)   # Output for value estimation

    def forward(self, x):
        x = self.features_extractor(x)
        pi_x = self.pi_features_extractor(x)
        vf_x = self.vf_features_extractor(x)

        mlp_output = self.mlp_extractor(x)

        action_output = self.action_net(mlp_output)
        value_output = self.value_net(mlp_output)

        return action_output, value_output

# Create an instance of the model
model = ActorCriticPolicy()

# Generate a visualization of the individual components
x = torch.randn(1, 20)  # Example input

# Add torch.no_grad() context manager
with torch.no_grad():
    features_extractor_output = model.features_extractor(x)
    mlp_output = model.mlp_extractor(features_extractor_output)
    action_output = model.action_net(mlp_output)
    value_output = model.value_net(mlp_output)

# Create separate visualizations for each component
dot1 = make_dot(features_extractor_output, params=dict(model.named_parameters()))
dot2 = make_dot(mlp_output, params=dict(model.named_parameters()))
dot3 = make_dot(action_output, params=dict(model.named_parameters()))
dot4 = make_dot(value_output, params=dict(model.named_parameters()))

# Save the visualizations to files or display them
for i, dot in enumerate([dot1, dot2, dot3, dot4], start=1):
    dot.format = 'png'  # You can change the format to 'svg' or other supported formats
    dot.render(f"component_{i}")
