# Tutorial: Tracking NTK and Plasticity

This notebook provides a tutorial on how to use the new features for tracking the Neural Tangent Kernel (NTK) and its relationship with network plasticity.

## 1. Introduction

We have added two new experiments to study the empirical properties of the Neural Tangent Kernel (NTK) during training:

1.  **`basic_ntk_track`**: A simple experiment that trains a model on a static dataset and tracks the NTK eigenvalue spectrum over time.
2.  **`tracking_plasticity_and_ntk`**: A more advanced experiment that trains a model on a shifting data distribution and tracks both plasticity metrics (like feature rank) and NTK eigenvalues simultaneously.

These experiments leverage the `neural-tangents` library to compute the NTK and our existing tools for plasticity analysis.

## 2. Setup

Before running the experiments, make sure you have installed the necessary dependencies. If you have already run the installation step in the main README, you should be all set. If not, you can install the required packages using pip:

In [None]:
!pip install jax jaxlib neural-tangents

## 3. The `basic_ntk_track` Experiment

This experiment is designed to provide a baseline understanding of how the NTK evolves during standard training. It trains a model on a dataset like MNIST and periodically computes and logs the NTK eigenvalues.

### 3.1. Running the Experiment

You can run the experiment from the command line using the following command:

In [None]:
!python ../experiments/basic_ntk_track/train_basic_ntk.py

### 3.2. Configuration

The experiment's configuration is located in `experiments/basic_ntk_track/cfg/config.yaml`. You can modify this file to change the experiment's parameters. Key options include:

- `epochs`: The total number of training epochs.
- `track_ntk`: A boolean to enable or disable NTK tracking.
- `ntk_measure_freq_epoch`: How often (in epochs) to compute the NTK.
- `ntk_batch_size`: The number of samples to use for the NTK computation.

## 4. The `tracking_plasticity_and_ntk` Experiment

This experiment investigates the relationship between plasticity and the NTK in a continual learning setting. It trains a model on a data distribution that shifts over time and logs both plasticity metrics and NTK eigenvalues at each task switch.

### 4.1. Running the Experiment

You can run the experiment with the following command:

In [None]:
!python ../experiments/tracking_plasticity_and_ntk/train_plasticity_ntk.py

### 4.2. Configuration

The configuration for this experiment is in `experiments/tracking_plasticity_and_ntk/cfg/config.yaml`. In addition to the NTK parameters, you can also control the task shifting and plasticity tracking:

- `task_shift_mode`: The type of data shift to apply (e.g., `continuous_input_deformation`).
- `track_rank`: A boolean to enable or disable rank tracking.
- `track_rank_drop`: A boolean to enable or disable rank drop dynamics tracking.

## 5. Analyzing the Results

The results of the experiments are logged to `wandb`. You can view the logged metrics, including the NTK eigenvalue distributions, in the `wandb` dashboard. The eigenvalues are logged as histograms, allowing you to see how the spectrum changes over time.

## 6. Under the Hood: `ntk_tools.py`

The NTK computation is handled by the functions in `src/utils/ntk_tools.py`. The core of this module is the `get_ntk_fn` function, which uses `neural-tangents` to compute the empirical NTK of a PyTorch model.

A key challenge in this implementation is bridging the gap between PyTorch (used in this project) and JAX (used by `neural-tangents`). We accomplish this using `jax.experimental.host_callback.call`, which allows us to call the PyTorch model's forward pass from within a JAX-jitted function. This provides a clean and efficient way to compute the NTK for our PyTorch models.