- Verify that the outputs by the DSM model is correct - as in compare it to a random model of same shape
- Look at representations of learnt SAC Pendulum model? 
We want to understand Internal representations of the model, i.e, plotting internal activity of model acc to external states
For one neuron in the model (eg-dense layer 2  kernel) - 3D rate map - average activity based on input dimensions
Hypothesis - Different parts of model track diff things - one neuron may model velocity, one may track certain xy regions - Pendulum env


In [1]:
import contextlib
import einops
import inspect
import logging
import operator
import os
import sys
import typing
from typing import Any

import fancyflags as ff
import fiddle as fdl
import fiddle.extensions.jax
import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint
import tqdm.rich as tqdm
from absl import app, flags
from absl import logging as absl_logging
from clu import metric_writers
from dm_env import specs
from etils import epath
from fiddle import absl_flags as fdl_flags
from fiddle import printing
from fiddle.codegen import codegen
from fiddle.experimental import serialization
from flax import traverse_util

from dsm import configs, console, datasets, envs, metrics, plotting, stade, train
from dsm.state import State

_WORKDIR = epath.DEFINE_path("workdir", "logdir", "Working directory.")
_CHECKPOINT_FROM = flags.DEFINE_string(
    "checkpoint_from",
    None,
    "Checkpoint to load from, we'll only restore from this checkpoint if "
    "the checkpoint step is greater than the current step."
    "If not specified, will load from the latest checkpoint in the working directory.",
)
_PROFILE = flags.DEFINE_bool("profile", False, "Enable profiling.")


def _maybe_restore_state(checkpoint_manager: orbax.checkpoint.CheckpointManager, state: State) -> State:
    latest_step = checkpoint_manager.latest_step()

    def _restore_state(step: int, directory: os.PathLike[str] | None = None) -> State:
        print('Debug directory', directory)
        logging.info(f"Restoring checkpoint from {directory or checkpoint_manager.directory} at step {step}.")
        restored = checkpoint_manager.restore(
            step,
            {"generator": state.generator, "discriminator": state.discriminator},
            directory=os.path.abspath(directory or checkpoint_manager.directory),
        )
        [g_state, d_state] = operator.itemgetter("generator", "discriminator")(restored)
        return State(step=jnp.int32(step), generator=g_state, discriminator=d_state)

    if _CHECKPOINT_FROM.value and (checkpoint_steps := orbax.checkpoint.utils.checkpoint_steps(_CHECKPOINT_FROM.value)):
        logging.info(f"Found checkpoint directory {_CHECKPOINT_FROM.value} with steps {checkpoint_steps}.")
        latest_checkpoint_step = max(checkpoint_steps)
        if not latest_step or latest_checkpoint_step > latest_step:
            return _restore_state(latest_checkpoint_step, _CHECKPOINT_FROM.value)
    if latest_step:
        return _restore_state(latest_step)

    logging.info("No checkpoint found.")
    return state

jax.config.parse_flags_with_absl()


def _maybe_remove_absl_logger() -> None:
    if (absl_handler := absl_logging.get_absl_handler()) in logging.root.handlers:
        logging.root.removeHandler(absl_handler)


from dsm.state import FittedValueTrainState
import numpy.typing as npt
from dsm import datasets, plotting, rewards
from dsm.configs import Config
from dsm.plotting import utils as plot_utils


In [2]:
# buildable from config
fiddle.extensions.jax.enable()

logging.getLogger("jax").setLevel(logging.INFO)
jax.config.update("jax_numpy_rank_promotion", "raise")

# from absl.flags import FLAGS, define_string
# define_string('fdl_config', None, 'The Fiddle configuration to use.')
# define_string('fdl_config_file', None, 'Path to the Fiddle configuration file.')

import fiddle.absl_flags as fdl_flags
FLAGS = flags.FLAGS
# flags.DEFINE_string('fdl_config', 'base', 'The Fiddle configuration to use.')
if not FLAGS.fdl_config and not FLAGS.fdl_config_file:
        FLAGS.fdl_config = 'base'
# _maybe_remove_absl_logger()
buildable = fdl_flags.create_buildable_from_flags(configs)


In [3]:
# checkpoint manager
logging.info(printing.as_str_flattened(buildable))
config: configs.Config = fdl.build(buildable)

workdir: epath.Path = _WORKDIR.value
workdir.mkdir(parents=True, exist_ok=True)

jax.debug.print("DEBUG directory {bar}", bar=os.path.abspath(workdir))
checkpoint_manager = orbax.checkpoint.CheckpointManager(
os.path.abspath(workdir),
checkpointers={
        "generator": orbax.checkpoint.PyTreeCheckpointer(),
        "discriminator": orbax.checkpoint.PyTreeCheckpointer(),
},
options=orbax.checkpoint.CheckpointManagerOptions( max_to_keep=2, enable_async_checkpointing=False, async_options=None,create=True, ),
)

DEBUG directory c:\Users\sruth\Documents\UCL\A_Thesis_main\codes\distributional-sr\logdir


In [4]:
env = envs.make(config.env)
env = stade.GymEnvWrapper(env, with_infos=False, seed=None)
rng = np.random.default_rng(config.seed)

data = datasets.make_dataset(config.env)

rng_key = jax.random.PRNGKey(rng.integers(np.iinfo(np.int64).min, np.iinfo(np.int64).max))
rng_key, state_rng_key = jax.random.split(rng_key)

# for checkpointing
state = train.make_state(state_rng_key, typing.cast(specs.DiscreteArray, env.observation_spec()), config)
state = _maybe_restore_state(checkpoint_manager, state)
print('Saved model state: ')

def print_params_shapes(params: Any, prefix: str = ""):
        if isinstance(params, dict):
                for key, value in params.items():
                        print_params_shapes(value, f"{prefix}.{key}" if prefix else key)
        elif isinstance(params, (jax.Array, jnp.ndarray)):
                print(f"{prefix}: {params.shape}")
        else:
                print(f"{prefix}: {type(params)}")
print("Generator params shapes:")
print_params_shapes(state.generator.params)
print("\nDiscriminator params shapes:")
print_params_shapes(state.discriminator.params)

def print_main_keys(obj):
        keys = vars(obj).keys()
        print("\nMain keys of state:")
        for key in keys:
                print(key)
print_main_keys(state)

Debug directory None




Saved model state: 
Generator params shapes:
params.model.Dense_0.bias: (16, 256)
params.model.Dense_0.kernel: (16, 5, 256)
params.model.Dense_1.bias: (16, 256)
params.model.Dense_1.kernel: (16, 256, 256)
params.model.Dense_2.bias: (16, 256)
params.model.Dense_2.kernel: (16, 256, 256)
params.model.Dense_3.bias: (16, 3)
params.model.Dense_3.kernel: (16, 256, 3)

Discriminator params shapes:
model.layer_instance.Dense_0.bias: (256,)
model.layer_instance.Dense_0.kernel: (3, 256)
model.layer_instance.Dense_1.bias: (2,)
model.layer_instance.Dense_1.kernel: (256, 2)
model.layer_instance.MLP_0.Dense_0.bias: (256,)
model.layer_instance.MLP_0.Dense_0.kernel: (256, 256)
model.layer_instance.MLP_0.Dense_1.bias: (256,)
model.layer_instance.MLP_0.Dense_1.kernel: (256, 256)
model.layer_instance.MLP_0.Dense_2.bias: (256,)
model.layer_instance.MLP_0.Dense_2.kernel: (256, 256)
model.layer_instance.MLP_1.Dense_0.bias: (256,)
model.layer_instance.MLP_1.Dense_0.kernel: (256, 256)
model.layer_instance.MLP_

In [5]:
state.generator.apply_fn

<bound method Module.apply of DistributionalSRGenerator(
    # attributes
    model = MLP(
        # attributes
        num_layers = 3
        num_hidden_units = 256
        num_outputs = None
        module = Dense
        activation = leaky_relu
        dtype = float32
        param_dtype = float32
    )
    num_atoms = 16
    num_state_dims = 3
)>

In [6]:
# def get_activations(apply_fn, params, x, layer_indices, model):
#     activations = {}

#     def hook(module, inputs, outputs, name):
#         activations[name] = outputs

#     hooks = []
    
#     # Attach hooks to the specified layers
#     for i in layer_indices:
#         layer = model.layers[i]
#         layer_name = f'layer_{i}'
#         hooks.append((layer, layer_name))
    
#     # Apply the model and capture activations
#     def forward_fn(params, x):
#         for layer, name in hooks:
#             x = layer(x)
#             hook(layer, x, x, name)
#         return x
    
#     output = forward_fn(params, x)

#     return output, activations

# # input_image = jnp.ones((1, 784))  # Adjust the shape based on your input
# # # Initialize parameters
# # rng = jax.random.PRNGKey(0)
# # params = state.generator.params  # Use the params from your state
# # # Specify which layers to capture
# # layer_indices = [0, 1, 2]  # For example, capture activations from layers 0, 1, and 2
# # # Get the activations
# # output, activations = get_activations(state.generator.apply_fn, params, input_image, layer_indices, model)


In [7]:
# from flax import linen as nn
# def get_activations(model, params, x, layer_indices):
#     activations = {}

#     def capture_activations(mod, inputs, outputs):
#         layer_name = mod.name
#         if layer_name in [f'Dense_{i}' for i in layer_indices]:
#             activations[layer_name] = outputs
#         return outputs

#     hooks = []

#     def hook_fn(layer):
#         def inner_hook(mod, inputs, outputs):
#             return capture_activations(mod, inputs, outputs)
#         return inner_hook

#     for i in layer_indices:
#         layer = model.layers[i]
#         hooks.append(nn.Module._add_forward_hook(layer, hook_fn(layer)))

#     output = model.apply(params, x)

#     for hook in hooks:
#         hook.remove()

#     return output, activations

# def compute_DSM_samples(
#     state, rng,    *,  config, source_state,  capture_activations,   layer_indices):
#     import einops
#     zs = jax.random.normal(rng, (config.plot_num_samples, config.num_outer, config.latent_dims))
#     context = einops.repeat(source_state, "s -> i o s", i=config.plot_num_samples, o=config.num_outer)
#     xs = jnp.concatenate((zs, context), axis=-1)
    
#     if capture_activations and layer_indices:
#         activations = {}
#         for i in range(xs.shape[0]):
            
#             sample_activations = get_activations(state.apply_fn, state.params, xs[i], layer_indices)
#             for layer_name, activation in sample_activations.items():
#                 if layer_name not in activations:
#                     activations[layer_name] = []
#                 activations[layer_name].append(activation)
        
#         # Convert lists to arrays
#         for layer_name in activations.keys():
#             activations[layer_name] = jnp.array(activations[layer_name])
        
#         return source_state, None, activations
    
#     ys = jax.vmap(state.apply_fn, in_axes=(None, 0))(state.params, xs)
#     samples = einops.rearrange(ys, "i o s -> o i s")
#     return source_state, samples, {}

# saved_source_states = plotting.source_states(config.env)
# sources_all = saved_source_states[1]
# source_state_current = sources_all[0]

# source, samples, activations = compute_DSM_samples(
#     state.generator, jax.random.PRNGKey(0), config=config, source_state=source_state_current,
#     capture_activations=True, layer_indices=[0, 1, 2]  # Specify layers to visualize
# )
# print(activations)
# # if activations:
# #     visualize_activations(activations)
# # else:
# #     print("No activations captured.")


In [8]:
# print_main_keys(state.generator)  #step, apply_fn, params, tx,opt_state, target_params, target_params_update, metrics 
# target_params_update=SoftTargetParamsUpdate(step_size=0.01), 
# metrics=_InlineCollection(
#       _reduction_counter=_ReductionCounter(value=Array(501, dtype=int32)), 
#       mmd=Metric.from_output.<locals>.FromOutput(total=Array(25.586227, dtype=float32), count=Array(500, dtype=int32)), 
#       observation=Metric.from_output.<locals>.FromOutput(total=Array(154.22746, dtype=float32), count=Array(500, dtype=int32)), 
#       embedding=Metric.from_output.<locals>.FromOutput(total=Array(507.3972, dtype=float32), count=Array(500, dtype=int32))))

print(state.generator.params['params']['model'].keys())
# output ['Dense_0', 'Dense_1', 'Dense_2', 'Dense_3'])

print(state.generator.params['params']['model']['Dense_3']['kernel'].shape)

dict_keys(['Dense_0', 'Dense_1', 'Dense_2', 'Dense_3'])
(16, 256, 3)


In [9]:
# from flax import linen as nn
# def get_activations(model, params, x, layer_indices):
#     activations = {}

#     def capture_activations(mod, inputs, outputs):
#         layer_name = mod.name
#         if layer_name in [f'Dense_{i}' for i in layer_indices]:
#             activations[layer_name] = outputs
#         return outputs

#     hooks = []

#     def hook_fn(layer):
#         def inner_hook(mod, inputs, outputs):
#             return capture_activations(mod, inputs, outputs)
#         return inner_hook

#     for layer in params['params']['model']:
#         hooks.append(nn.Module._add_forward_hook(layer, hook_fn(layer)))

#     output = model.apply(params, x)

#     for hook in hooks:
#         hook.remove()

#     return output, activations

# def get_activations(model,params,x, layer_indices):
#     activations = {}
#     def hook(module, input, output, layer_name):
#         activations[layer_name] = output
#     hooks = []
#     # Attach hooks
#     num_layers = len(params['params']['model'].keys()) 
#     for i in range(num_layers):
#         if i in layer_indices:
#             layer = model.layers[i]
#             hook_fn = lambda module, input, output, layer_name=f'Dense_{i}': hook(module, input, output, layer_name)
#             hooks.append(layer.register_forward_hook(hook_fn))

#     # Run the model
#     output = model.apply(params, x)

#     # Remove hooks
#     for hook in hooks:
#         hook.remove()

#     return output, activations

# def get_activations(model, params, x, layer_indices):
#     _, activations = model({"params": params["params"]}, x, capture_activations=True)
#     return {f'layer_{i}': activations[f'layer_{i}'] for i in layer_indices if f'layer_{i}' in activations}


# def fetch_activations(params, x):
#     kernel_atom = params['params']['model']['Dense_0']['kernel']#[0]  # 1st atom of 16
#     bias_atom = params['params']['model']['Dense_0']['bias']#[0]  # 1st atom of 16
#     activation = x
#     activation = jax.nn.relu(jnp.dot(kernel_atom, activation) + bias_atom)
#     return activation

# def get_activations(model,params,x, layer_indices):
#     activations = {}
#     def hook(module, input, output, layer_name):
#         activations[layer_name] = output
#     hooks = []
#     # Attach hooks
#     for i in range(model.num_layers):
#         if i in layer_indices:
#             layer = model.layers[i]
#             hook_fn = lambda module, input, output, layer_name=f'layer_{i}': hook(module, input, output, layer_name)
#             hooks.append(layer.register_forward_hook(hook_fn))

#     # Run the model
#     output = model.apply(params, x)

#     # Remove hooks
#     for hook in hooks:
#         hook.remove()

#     return output, activations

# # Specify which layers to capture
# layer_indices = [0, 1, 2]  # For example, capture activations from layers 0, 1, and 2
# # Get the activations
# output, activations = get_activations(model, params, input_image, layer_indices)

# # Visualize activations
# import matplotlib.pyplot as plt
# def visualize_activations(activations: Dict[str, jax.Array]):
#     for layer_name, activation in activations.items():
#         print(f'Visualizing {layer_name} with shape {activation.shape}')
#         rows = activation.shape[-1] // 16
#         if activation.shape[-1] % 16 != 0:
#             rows += 1
#         fig, axarr = plt.subplots(rows, 16, figsize=(15, rows * 2))
#         for idx in range(activation.shape[-1]):
#             ax = axarr[idx // 16, idx % 16]
#             ax.imshow(activation[0, :, idx], cmap='gray')
#             ax.axis('off')
#         plt.show()

# visualize_activations(activations)


In [10]:
# model.params['params']['model']
model_generator = state.generator
kernel_atom = model_generator.params['params']['model']['Dense_0']['kernel']#[0]  # 1st atom of 16
print(kernel_atom.shape)

print(len(model_generator.params['params']['model']['Dense_3'])) #kernel and bias
hidden_layers = model_generator.params['params']['model']['Dense_3']
print(hidden_layers['kernel'][0].shape) # (1 of 16) shape (256,3)
print(hidden_layers['bias'][0].shape) # shape(3)

(16, 5, 256)
2
(256, 3)
(3,)


In [14]:
model1 = model_generator.apply_fn.__self__.model
print(model1.num_layers)
print(model1.num_hidden_units)
# model.module.apply

print('applyfn_self',model_generator.apply_fn.__self__)  #.model.module

3
256
applyfn_self DistributionalSRGenerator(
    # attributes
    model = MLP(
        # attributes
        num_layers = 3
        num_hidden_units = 256
        num_outputs = None
        module = Dense
        activation = leaky_relu
        dtype = float32
        param_dtype = float32
    )
    num_atoms = 16
    num_state_dims = 3
)


In [22]:
model1.module
# model1.num_layers
# i=0
# model_generator.params['params']['model'][f'Dense_{i}']

flax.linen.linear.Dense

In [None]:
model_generator.params['params']['model']['Dense_0']['kernel'].shape

(16, 5, 256)

In [None]:
model1.module(model1.num_hidden_units, dtype=model1.dtype, param_dtype=model1.param_dtype).apply({'params': model_generator.params['params']['model'][f'Dense_{i}']}, y)

In [34]:
x = jnp.ones((16, 5, 256)) 
params = model_generator.init(rng, x)


In [47]:
model_generator.num_state_dims

In [None]:
# def get_activations(model, params, x: jax.Array, layer_indices):
#     activations = {}

#     def capture_hook(layer_idx):
#         def hook(x):
#             activations[f'layer_{layer_idx}'] = x
#             return x
#         return hook

#     def modified_apply_fn(params, x):
#         # y = jnp.expand_dims(x,axis=0)
#         # print('debug input',y.shape)
#         y=x
#         for i in range(model.num_layers):
#             layer = model.module(model.num_hidden_units, dtype=model.dtype, param_dtype=model.param_dtype)
#             # print('DEBUG layer:', layer)
#             # print(params['model'][f'Dense_{i}']['kernel'].shape)
#             # print(y.shape)
#             y = layer.apply({'params': params['model'][f'Dense_{i}']}, y)
#             y = model.activation(y)
#             if i in layer_indices:
#                 y = capture_hook(i)(y)
#         y = model.module(model.num_outputs, dtype=model.dtype, param_dtype=model.param_dtype).apply({'params': params['model'][f'Dense_{model.num_layers}']}, y)
#         return y

#     # Run the modified apply function to capture activations
#     modified_apply_fn(params, x)
#     return activations

In [46]:
# compute_DSM_samples in model_viz.ipynb  #adapted from compute_return_distribution code
saved_source_states = plotting.source_states(config.env) 
# has 2 lists of 9 elements each - each element is a list of 2 elements  - states:[theta, thetadot] observations: [sin(theta), cos(theta), thetadot]
# print('DEBUG source states 0 - polar coords, angular velocity',saved_source_states[0]) 
# print('DEBUG source states 1 - cartesian coords, angular velocity',saved_source_states[1])
sources_all = saved_source_states[1]
source_state = sources_all[0]
rng = jax.random.PRNGKey(0)
model_generator = state.generator # FittedValueTrainState,
jax.debug.print("Selected Source {bar}", bar=source_state)
num_samples=config.plot_num_samples # Number of state samples 
num_outer=config.num_outer # Number of model atoms   
num_latent_dims=config.latent_dims # Dimension of input noise 
print('num_samples', num_samples,' num_outer:', num_outer, ' num_latent_dims ',num_latent_dims)
'Simulating trajectories in an MDP'
#Code from plot_utils.sample_from_sr # samples = plot_utils.sample_from_sr(...) 
# Generates samples from the model using the provided source state and configuration settings
# source_state is used to create a context for sampling by repeating it across the 
# number of samples and outer dimensions
zs = jax.random.normal(rng, (num_samples, num_outer, num_latent_dims))
context = einops.repeat(source_state, "s -> i o s", i=num_samples, o=num_outer)
xs = jnp.concatenate((zs, context), axis=-1)
print('model input: ',xs.shape)


def get_activations(model, params, x: jax.Array, layer_indices):
    activations = {}

    def capture_hook(layer_idx):
        def hook(x):
            activations[f'layer_{layer_idx}'] = x
            return x
        return hook

    def modified_apply_fn(params, x):
        # y = jnp.expand_dims(x,axis=0)
        # print('debug input',y.shape)
        y=x
        for i in range(model.num_layers):
            layer = model.module(model.num_hidden_units, dtype=model.dtype, param_dtype=model.param_dtype)
            # print('DEBUG layer:', layer)
            # print(params['model'][f'Dense_{i}']['kernel'].shape)
            # print(y.shape)
            y = layer.apply({'params': params['model'][f'Dense_{i}']}, y)
            y = model.activation(y)
            if i in layer_indices:
                y = capture_hook(i)(y)
        y = model.module(model.num_outputs, dtype=model.dtype, param_dtype=model.param_dtype).apply({'params': params['model'][f'Dense_{model.num_layers}']}, y)
        return y

    # Run the modified apply function to capture activations
    ys = jax.vmap(modified_apply_fn, in_axes=(None, 0))(params, x)
    return activations

capture_activations=1
layer_indices=[0, 1, 2]
if capture_activations and layer_indices:
        activations = {}
        
        outputs=[]
        for i in range(xs.shape[0]): #32 samples
            # single_input = xs[i]  # Extract the i-th input (shape: (16, 5))
            # output = model_generator.apply_fn(model_generator.params, single_input)  # Apply the model
            # print(output.shape)
            # outputs.append(output)  # Collect the output
            sample_activations = get_activations(model_generator.apply_fn.__self__.model, model_generator.params['params'], xs[i], layer_indices)
            for layer_name, activation in sample_activations.items():
                if layer_name not in activations:
                    activations[layer_name] = []
                else:
                    activations[layer_name].append(activation)
        # Convert lists to arrays
        for layer_name in activations.keys():
            activations[layer_name] = jnp.array(activations[layer_name])
        
        # print(len(outputs))
        # ys = jnp.stack(outputs)
else:
    ys = jax.vmap(model_generator.apply_fn, in_axes=(None, 0))(model_generator.params, xs)
    # ys_list = []
    # for x in xs:
    #     y = model.apply_fn(model.params, x)
    #     ys_list.append(y)
    # ys = jnp.stack(ys_list)

    print('ys shape',ys.shape)
    samples =  einops.rearrange(ys, "i o s -> o i s")
    print('samples',samples.shape) #num_outer, num_samples, 3
    # # print(samples[-1])

# print('ys shape',ys.shape)  #(32, 16, 3)
# #thetas = np.arctan2(samples[i, :, 1], samples[i, :, 0]) % (2 * np.pi)
# # velocities = samples[i, :, -1]
# return source_state_current, samples

Selected Source [ 6.123234e-17  1.000000e+00 -4.000000e+00]
num_samples 32  num_outer: 16  num_latent_dims  2
model input:  (32, 16, 5)


In [None]:
# plotting generated samples
import matplotlib.pyplot as plt
from dsm.plotting import utils as plotting_utils

def plot_samples(samples):
    # fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(6 * 2, 6))
    # # Left scatter plot
    # # converts Cartesian coordinates to polar coordinates (thetas) and extracts velocities
    # ENVIRONMENT = "Pendulum-v1"
    # from dsm import datasets
    # dataset = datasets.make_dataset(ENVIRONMENT)
    # thetas = np.arctan2(dataset.observation[:, 1], dataset.observation[:, 0]) % (2 * np.pi)
    # velocities = dataset.observation[:, -1]
    # axs[0].scatter(thetas, velocities, alpha=0.1, s=1.0, color="grey")

    fig, axs = plt.subplots(nrows=1, ncols=1, figsize=(6, 6))
    # Plot atom scatter & kde
    # plots generated samples - each colour represents atom?
    print('DEBUG plot_samples() - atoms? - samples.shape[0]',samples.shape[0])
    cmap = plt.get_cmap("Dark2")  # pyright: ignore
    for i in range(samples.shape[0]):
        thetas = np.arctan2(samples[i, :, 1], samples[i, :, 0]) % (2 * np.pi)
        velocities = samples[i, :, -1]
        try:
            axs[1].scatter(thetas, velocities, color=cmap(i), s=2.0, alpha=0.25)
        except:
            plt.scatter(thetas, velocities, color=cmap(i), s=2.0, alpha=0.25)


    # Plot source state
    theta = np.arctan2(source[1], source[0]) % (2 * np.pi)
    try:
        for ax in axs:
            ax.scatter(theta, source[-1], marker="x", s=64, alpha=0.8, color="red")
    except:
        plt.scatter(theta, source[-1], marker="x", s=64, alpha=0.8, color="red")

    # set bounds
    try:
        for ax in axs:
            # ax.set_ylim(-8.5, 8.5)
            ax.set_aspect("auto")
            ax.set_xticks([0, np.pi / 2, np.pi, 3 * np.pi / 2, 2 * np.pi], ["0", "π/2", "π", "3π/2", "2π"])
    except:
        # plt.set_ylim(-8.5, 8.5)
        plt.xticks([0, np.pi / 2, np.pi, 3 * np.pi / 2, 2 * np.pi], ["0", "π/2", "π", "3π/2", "2π"])

    image = plotting_utils.fig_to_ndarray(fig)
    plt.show(fig)

In [None]:
def visualize_feature_maps(model, input_image, layer_index):
    print('Visualizing layer',layer_index,' on passing image through model')
    # Function to fetch the output of a specific layer
    # understanding what features of the input image are being highlighted or detected by the specified layer of the model.
    activation = {}
    # The nested function get_activation is defined to create a hook. This hook is a 
    # callback that will be called during the forward pass of the model, specifically when the desired layer has completed its computation.
    def get_activation(name): # to store the activations of the specified layer.
        def hook(model, input, output):
            activation[name] = output.detach()
        return hook

    # Attach the hook to the desired layer
    layer = model.encoder.layers[layer_index]
    handle = layer.register_forward_hook(get_activation('feature_map'))

    # Run the model
    model(input_image)
    handle.remove()

    # Visualize the feature maps
    act = activation['feature_map']
    print('Activations shape',act.shape)
    # Calculate the number of rows needed: one row for every 16 feature maps
    rows = act.size(1) // 16
    # Add an additional row if there are leftover feature maps after filling the rows with 16 feature maps each
    if act.size(1) % 16 != 0:
        rows += 1
    fig, axarr = plt.subplots(rows, 16)

    for idx in range(act.size(1)):
        ax = axarr[idx // 16, idx % 16]
        ax.imshow(act[0, idx].cpu().numpy(), cmap='gray')
        ax.axis('off')
    plt.show()
# plt.imshow(act[0, 0].cpu().numpy())