# Animation

In [None]:
import ray
from ray.rllib.algorithms.algorithm import Algorithm
import numpy as np
from custom_env import CustomEnvironment
from config import run_config

def process_observations(observation, agent_ids, truncation=None):
    loc_x = [observation[key][2] if key in observation else 0 for key in agent_ids]
    loc_y = [observation[key][3] if key in observation else 0 for key in agent_ids]
    heading = [observation[key][4] if key in observation else 0 for key in agent_ids]
    if truncation:
        still_in_the_game = [1 if not truncation[key] else 0 for key in agent_ids]
    else:
        still_in_the_game = [1 for _ in agent_ids]
    observations["loc_x"].append(np.array(loc_x))
    observations["loc_y"].append(np.array(loc_y))
    observations["heading"].append(np.array(heading))
    observations["still_in_the_game"].append(np.array(still_in_the_game))

    return observations


path_to_checkpoint = "/Users/tanguy/ray_results/PPO_2023-12-07_15-33-17/PPO_CustomEnvironment_8f9d7_00000_0_2023-12-07_15-33-17/checkpoint_000025"

# This does ray.init()
algo = Algorithm.from_checkpoint(path_to_checkpoint)
# After loading the algorithm
available_policy_ids = list(algo.workers.local_worker().policy_map.keys())
print("Available Policy IDs:", available_policy_ids)

env = CustomEnvironment(run_config["env"])

observations = {"loc_x": [], "loc_y": [], "heading": [], "still_in_the_game": []}
observation, _ = env.reset()
agent_ids = env._agent_ids
loc_x, loc_y, heading, still_in_the_game = process_observations(observation, agent_ids)
step_count = 1

while step_count < 500:
    actions = {
        key: algo.compute_single_action(
            value, policy_id="prey" if env.agents[key].agent_type == 0 else "predator"
        ) for key, value in observation.items()
    }

    observation, _, termination, truncation, _ = env.step(actions)
    observations = process_observations(observation, agent_ids, truncation)
    step_count += 1

grid_diagonal = env.grid_diagonal
observations["loc_x"] = np.array(observations["loc_x"]) * grid_diagonal
observations["loc_y"] = np.array(observations["loc_y"]) * grid_diagonal
observations["heading"] = np.array(observations["heading"])
observations["still_in_the_game"] = np.array(observations["still_in_the_game"])

env.close()
ray.shutdown()

In [None]:
from IPython.display import HTML
# reload animation module
from  importlib import reload
import animation
reload(animation)
from animation import generate_animation_3d

ani = generate_animation_3d(observations, env, fps=10)

HTML(ani.to_html5_video())

# Network visualization

In [None]:
algo.get_policy(available_policy_ids[0]).get_weights()['_hidden_layers.1._model.0.weight'].shape
# from algo.get_policy(available_policy_ids[0]).get_weights() get the weights which have _value_branch in their key
nn_weights = {}
for key, value in algo.get_policy(available_policy_ids[0]).get_weights().items():
    if "_value_branch" not in key:
        nn_weights[key] = value

nn_weights.keys()

In [None]:
import numpy as np

from graph_tool.all import *

def plot_mlp(neural_network):
    g = Graph(directed=True)
    
    # Create property maps for vertex and edge labels and edge width
    v_label = g.new_vertex_property("string")
    e_width = g.new_edge_property("double")
    pos = g.new_vertex_property("vector<double>")
    
    max_neurons = max(len(neural_network[key]) for key in neural_network if 'weight' in key)
    
    ## VERTEX ##
    # Add vertices for input layer and set their positions
    input_neurons = [g.add_vertex() for _ in neural_network['_hidden_layers.0._model.0.weight'].T]
    starting_y = (max_neurons - len(input_neurons)) / 2
    for i, v in enumerate(input_neurons):
        pos[v] = (0, starting_y + len(input_neurons) - 1 - i)
    
    # Get all the keys that contain "_bias"
    biases_keys = [key for key in neural_network if ".bias" in key and "_hidden" in key]
    # Initial x position
    x_position = 20
    current_neurons = []
    for i, biases_key in enumerate(biases_keys):
        # Add vertices for the current layer and set their positions
        current_neurons.append([g.add_vertex() for _ in neural_network[biases_key]])
        starting_y = (max_neurons - len(current_neurons[i])) / 2
        for j, v in enumerate(current_neurons[i]):
            pos[v] = (x_position, starting_y + len(current_neurons[i]) - 1 - j)
    
        # Increment x position for the next layer
        x_position += 20

    # Add vertex for output layer and set its position
    output_neurons = [g.add_vertex() for _ in neural_network['_logits._model.0.bias']]
    starting_y = (max_neurons - len(output_neurons)) / 2
    for i, v in enumerate(output_neurons):
        pos[v] = (x_position, starting_y + len(output_neurons) - 1 - i)

    ## LABELS ##
    # Set labels and add edges for input-hidden layer
    for i, input_neuron in enumerate(input_neurons):
        for j, hidden_neuron in enumerate(current_neurons[0]):
            e = g.add_edge(input_neuron, hidden_neuron)
            weight = neural_network['_hidden_layers.0._model.0.weight'].T[i][j]
            e_width[e] = weight
            
    # Set labels and add edges for hidden-hidden layer
    weights_keys = [key for key in neural_network if ".weight" in key and "_hidden" in key and not "layers.0" in key]
    for k, weights_key in enumerate(weights_keys):
        for i, hidden_neuron in enumerate(current_neurons[k]):
            for j, next_hidden_neuron in enumerate(current_neurons[k+1]):
                e = g.add_edge(hidden_neuron, next_hidden_neuron)
                weight = neural_network[weights_key].T[i][j]
                e_width[e] = weight

    # Set labels and add edges for hidden-output layer
    for j, output_neuron in enumerate(output_neurons):
        for i, hidden_neuron in enumerate(current_neurons[-1]):
            e = g.add_edge(hidden_neuron, output_neuron)
            weight = neural_network['_logits._model.0.weight'].T[i][j]
            e_width[e] = weight
        
    
    # Set neuron labels (optional, for clarity)
    for v in input_neurons:
        v_label[v] = "I"
    for k, hidden_neurons in enumerate(current_neurons):
        for v in hidden_neurons:
            v_label[v] = "H"
    for v in output_neurons:
        v_label[v] = "O"
    
    # Draw the graph
    graph_draw(g, pos=pos, vertex_text=v_label, edge_text=None, edge_pen_width=e_width, vertex_size=15, vertex_font_size=10, edge_font_size=10, output_size=(800, 800))
    
# Example usage with the same nn_wandb
plot_mlp(nn_weights)