In [7]:
import torch
import torch.nn as nn
import plotly.graph_objects as go

In [8]:
# saved pytorch model architecture
class NeuralNetwork(nn.Module):
    def __init__(self, input_dim, hidden_sizes, output_dim, learning_rate=0.001):
        super(NeuralNetwork, self).__init__()
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        layers = []
        prev_size = input_dim
        for hidden_size in hidden_sizes:
            layers.append(nn.Linear(prev_size, hidden_size))
            layers.append(nn.ReLU())
            prev_size = hidden_size
        layers.append(nn.Linear(prev_size, output_dim))
        self.model = nn.Sequential(*layers).to(self.device)

        self.loss_fn = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)

    def forward(self, x):
        x = x.to(self.device)
        return self.model(x)

In [12]:
# Load state_dict to infer output_dim
model_path = "../agents/pt_model/pro_agent_0802_1903.pt"
state_dict = torch.load(model_path, map_location="cpu")

input_dim = 15
hidden_sizes = [64, 64, 64]
output_dim = state_dict["model.6.weight"].shape[0]

# Initialize model with correct shape
model = NeuralNetwork(input_dim, hidden_sizes, output_dim)
model.load_state_dict(state_dict)
model.eval()

NeuralNetwork(
  (model): Sequential(
    (0): Linear(in_features=15, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=64, bias=True)
    (5): ReLU()
    (6): Linear(in_features=64, out_features=3, bias=True)
  )
  (loss_fn): MSELoss()
)

In [23]:
# Visualize weights
for i, layer in enumerate(model.model):
    if isinstance(layer, nn.Linear):
        weights = layer.weight.data.cpu().numpy()

        fig = go.Figure(data=[go.Surface(z=weights)])
        fig.update_layout(
            title=f'3D Plot - Layer {i} Weight Matrix',
            width=800,
            height=800,
            scene=dict(
                xaxis_title='Input Features',
                yaxis_title='Neurons',
                zaxis_title='Weight Value'
            )
        )
        fig.show()
