In [1]:
from pathlib import Path
import torch
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from matplotlib.ticker import NullFormatter
import matplotlib as mpl
from tqdm.auto import tqdm
from scipy.cluster.hierarchy import dendrogram, linkage
from sklearn.preprocessing import StandardScaler
import umap.umap_ as umap
from sklearn.cluster import KMeans
import utils
import sys
import os
import numpy as np
from environment import CONTEXTS_LABELS
#from agent import neural_network
import seaborn as sns
import torch.nn as nn
from agent import DQN

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Loading Data

In [3]:
save_path = Path("save")
save_path.exists()

True

In [4]:
data_dir = save_path / "6-24-EW"
data_dir.exists()

True

In [5]:
data_path = data_dir / "data.tar"
data_path.exists()

True

In [6]:
data_dict = torch.load(data_path, weights_only=False, map_location=DEVICE)
data_dict.keys()

dict_keys(['rewards', 'steps', 'episodes', 'all_states', 'all_actions', 'all_qvalues', 'losses', 'p', 'epsilons', 'weights_val_stats', 'biases_val_stats', 'weights_grad_stats', 'biases_grad_stats', 'net', 'env', 'weights', 'biases'])

### Loading Model

In [18]:
model_path = data_dir / 'trained-agent-state-7.pt'
model_path.exists()

True

In [19]:
parameters = data_dict['p']
n_observations = parameters.n_observations
n_actions = parameters.n_actions
n_units = parameters.n_hidden_units

In [20]:
model = DQN(n_observations, n_actions, n_units)
model.load_state_dict(torch.load(model_path, weights_only=True, map_location=torch.device('cpu')))
model.eval()

DQN(
  (mlp): Sequential(
    (0): Linear(in_features=21, out_features=512, bias=True)
    (1): Linear(in_features=512, out_features=512, bias=True)
    (2): ReLU()
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): ReLU()
    (5): Linear(in_features=512, out_features=512, bias=True)
    (6): ReLU()
    (7): Linear(in_features=512, out_features=3, bias=True)
  )
)

In [21]:
state_dict = model.state_dict()

layer0_weights = state_dict['mlp.0.weight']
layer1_weights = state_dict['mlp.1.weight']
layer2_weights = state_dict['mlp.3.weight']
layer3_weights = state_dict['mlp.5.weight']
layer4_weights = state_dict['mlp.7.weight']

weights = [layer0_weights,layer1_weights,layer2_weights,layer3_weights,layer4_weights]
print(len(weights))

5


In [22]:
print(len(weights))
print(weights[0].shape)
print(weights[1].shape)
print(weights[4].shape)

5
torch.Size([512, 21])
torch.Size([512, 512])
torch.Size([3, 512])


In [23]:
weights[0].detach().numpy()

array([[-0.06129713,  0.1042373 , -0.23667413, ..., -0.21718992,
        -0.0724961 ,  0.17855798],
       [-0.20761935, -0.11450304,  0.23524542, ...,  0.20971468,
        -0.11361881,  0.05589928],
       [ 0.13861041,  0.00610179,  0.16196385, ..., -0.10598647,
         0.05638567,  0.19389659],
       ...,
       [-0.2346766 ,  0.19501929, -0.02208532, ..., -0.00497762,
         0.12264702, -0.18566002],
       [-0.17513673, -0.22842345, -0.12635069, ...,  0.15744969,
        -0.04630079,  0.16374682],
       [-0.19398311, -0.06367713,  0.15361938, ...,  0.0387392 ,
         0.04802285, -0.10757945]], shape=(512, 21), dtype=float32)

## Model Architecture

In [24]:
"""
DQN(
  (mlp): Sequential(
    (0): Linear(in_features=21, out_features=512, bias=True)
    (1): Linear(in_features=512, out_features=512, bias=True)
    (2): ReLU()
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): ReLU()
    (5): Linear(in_features=512, out_features=512, bias=True)
    (6): ReLU()
    (7): Linear(in_features=512, out_features=3, bias=True)
  )
)
"""

'\nDQN(\n  (mlp): Sequential(\n    (0): Linear(in_features=21, out_features=512, bias=True)\n    (1): Linear(in_features=512, out_features=512, bias=True)\n    (2): ReLU()\n    (3): Linear(in_features=512, out_features=512, bias=True)\n    (4): ReLU()\n    (5): Linear(in_features=512, out_features=512, bias=True)\n    (6): ReLU()\n    (7): Linear(in_features=512, out_features=3, bias=True)\n  )\n)\n'

## Strongest Weight Path

In [25]:
"""
weights: list of NumPy arrays of shape (n_l, n_{l+1}) for each layer l
            where weights[l][i][j] is the weight from node i in layer l to node j in layer l+1
Returns: list of (layer, node_index) representing the strongest path
"""

n_layers = len(weights) + 1 # includes input and output
layer_sizes = [21, 512, 512, 512, 512, 3]

# Initialize DP table: max log-product score for each node in each layer
scores = [np.full(size, -np.inf) for size in layer_sizes]
parents = [np.full(size, -1, dtype=int) for size in layer_sizes]

# Start with input layer: log-product is 0 (neutral for multiplication)
scores[0][:] = 0.0

# Forward pass: DP to fill in scores and parents
for l in range(len(weights)):
    W = weights[l].detach().numpy().T
    for i in range(W.shape[0]):       # node in layer l
        for j in range(W.shape[1]):   # node in layer l+1
            w_ij = W[i, j]
            if w_ij == 0:
                continue  # skip zero weights to avoid log(0)
            log_weight = np.log(np.abs(w_ij)) # we use log space to avoid underflow and keep numerical stability
            new_score = scores[l][i] + log_weight
            if new_score > scores[l + 1][j]:
                scores[l + 1][j] = new_score
                parents[l + 1][j] = i

# Backtrack from output layer
output_layer = n_layers - 1
end_node = np.argmax(scores[output_layer])
path = [(output_layer, end_node)]
current = end_node

for l in range(output_layer, 0, -1):
    current = parents[l][current]
    path.append((l - 1, current))

path.reverse()
print(path)

[(0, np.int64(1)), (1, np.int64(172)), (2, np.int64(74)), (3, np.int64(307)), (4, np.int64(177)), (5, np.int64(2))]


## Strongest Weight Paths -- Top k

In [26]:
import numpy as np
import heapq


num_layers = len(weights) + 1
layer_sizes = [21, 512, 512, 512, 512, 3]
k = 5

# Initialize: top-k paths to each node as (score, path)
paths = [ [ [] for _ in range(size) ] for size in layer_sizes ]
for i in range(layer_sizes[0]):
    paths[0][i] = [(0.0, [(0, i)])]  # log-product = 0 at input

# Forward pass
for l in range(len(weights)):
    W = weights[l].detach().numpy().T
    next_paths = [ [] for _ in range(layer_sizes[l+1]) ]
    
    for i in range(layer_sizes[l]):
        for score, path in paths[l][i]:
            for j in range(layer_sizes[l+1]):
                w_ij = W[i, j]
                if w_ij == 0:
                    continue
                log_w = np.log(np.abs(w_ij))
                new_score = score + log_w
                new_path = path + [(l+1, j)]
                next_paths[j].append((new_score, new_path))

    # Keep only top-k paths per node in next layer
    for j in range(layer_sizes[l+1]):
        next_paths[j] = heapq.nlargest(k, next_paths[j], key=lambda x: x[0])
    paths[l+1] = next_paths

# Collect all paths in output layer and return top-k overall
all_output_paths = []
for j in range(layer_sizes[-1]):
    all_output_paths.extend(paths[-1][j])

top_k = heapq.nlargest(k, all_output_paths, key=lambda x: x[0])
print(top_k)

[(np.float32(-11.385213), [(0, 1), (1, 172), (2, 74), (3, 307), (4, 177), (5, 2)]), (np.float32(-11.410194), [(0, 5), (1, 87), (2, 5), (3, 36), (4, 188), (5, 1)]), (np.float32(-11.423874), [(0, 4), (1, 87), (2, 5), (3, 36), (4, 188), (5, 1)]), (np.float32(-11.461575), [(0, 14), (1, 87), (2, 5), (3, 36), (4, 188), (5, 1)]), (np.float32(-11.462954), [(0, 0), (1, 120), (2, 5), (3, 36), (4, 188), (5, 1)])]


In [27]:
weights[0][493][2]

tensor(-0.0498)

In [28]:
for path in top_k:
    nodes = path[1]
    for i in range(5):
        start_node = nodes[i][1]
        end_node = nodes[i+1][1]

        layer = nodes[i][0]
        edge_weight = weights[layer][end_node][start_node]
        print(f'NODE: {start_node} to NODE: {end_node} WITH WEIGHT {edge_weight}')
    print('\n\n--------------------------\n\n')

NODE: 1 to NODE: 172 WITH WEIGHT -0.2471952736377716
NODE: 172 to NODE: 74 WITH WEIGHT -0.07884187251329422
NODE: 74 to NODE: 307 WITH WEIGHT -0.10903805494308472
NODE: 307 to NODE: 177 WITH WEIGHT 0.0652514398097992
NODE: 177 to NODE: 2 WITH WEIGHT -0.0819406732916832


--------------------------


NODE: 5 to NODE: 87 WITH WEIGHT 0.20870856940746307
NODE: 87 to NODE: 5 WITH WEIGHT 0.08097855001688004
NODE: 5 to NODE: 36 WITH WEIGHT -0.07919877022504807
NODE: 36 to NODE: 188 WITH WEIGHT -0.11058726161718369
NODE: 188 to NODE: 1 WITH WEIGHT -0.07486546784639359


--------------------------


NODE: 4 to NODE: 87 WITH WEIGHT 0.20587307214736938
NODE: 87 to NODE: 5 WITH WEIGHT 0.08097855001688004
NODE: 5 to NODE: 36 WITH WEIGHT -0.07919877022504807
NODE: 36 to NODE: 188 WITH WEIGHT -0.11058726161718369
NODE: 188 to NODE: 1 WITH WEIGHT -0.07486546784639359


--------------------------


NODE: 14 to NODE: 87 WITH WEIGHT 0.1982559859752655
NODE: 87 to NODE: 5 WITH WEIGHT 0.08097855001688004
N