# GPU Use

## Overview


_Developer Note:_ if you may make a PR in the future, be sure to copy this
notebook, and use the `gitignore` prefix `temp` to avoid future conflicts.

This is one notebook in a multi-part series on decoding in Spyglass. To set up
your Spyglass environment and database, see 
[the Setup notebook](./00_Setup.ipynb).

In this tutorial, we'll set up GPU access for subsequent decoding analyses. While this notebook doesn't have any direct prerequisites, you will need 
[Spike Sorting](./02_Spike_Sorting.ipynb) data for the next step.


## GPU Clusters


### Connecting
 
Members of the Frank Lab have access to two GPU cluster, `breeze` and `zephyr`.
To access them, specify the cluster when you `ssh`, with the default port:

> `ssh username@{breeze or zephyr}.cin.ucsf.edu`

There are currently 10 available GPUs, each with 80 GB RAM, each referred to by their IDs (0 - 9).

<!-- TODO: Use the position pipeline code for selecting GPU -->

### Selecting a GPU

For decoding, we first install `cupy`. By doing so with conda, we're sure to
install the correct cuda-toolkit:

```bash
conda install cupy
```

Next, we'll select a single GPU for decoding, using `cp.cuda.Device(GPU_ID)` in a context manager (i.e., `with`). Below, we'll select GPU #6 (ID = 5).

_Warning:_ Omitting the context manager will cause cupy to default to using GPU 0.

### Which GPU?

You can see which GPUs are occupied by running the command `nvidia-smi` in
a terminal (or `!nvidia-smi` in a notebook). Pick a GPU with low memory usage. 

In the output below, GPUs 1, 4, 6, and 7 have low memory use and power draw (~42W), are probably not in use.

In [None]:
!nvidia-smi

Sat Jun  4 09:37:07 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  On   | 00000000:4F:00.0 Off |                    0 |
| N/A   30C    P0    42W / 300W |     38MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100 80G...  On   | 00000000:52:00.0 Off |                    0 |
| N/A   32C    P0    43W / 300W |     38MiB / 81920MiB |      0%      Default |
|       

We can monitor GPU use with the terminal command `watch -n 0.1 nvidia-smi`, will
update `nvidia-smi` every 100 ms. This won't work in a notebook, as it won't
display the updates.

Other ways to monitor GPU usage are:

- A 
  [jupyter widget by nvidia](https://github.com/rapidsai/jupyterlab-nvdashboard)
  to monitor GPU usage in the notebook
- A [terminal program](https://github.com/peci1/nvidia-htop) like nvidia-smi
  with more information about  which GPUs are being utilized and by whom.

## Imports


In [None]:
import os
import datajoint as dj

import cupy as cp
import numpy as np

import dask
import dask_cuda

import replay_trajectory_classification as rtc
from replay_trajectory_classification import (
    sorted_spikes_simulation as rtc_spike_sim,
    environment as rtc_env,
    continuous_state_transitions as rts,
)


# change to the upper level folder to detect dj_local_conf.json
if os.path.basename(os.getcwd()) == "notebooks":
    os.chdir("..")
dj.config.load("dj_local_conf.json")  # load config for database connection info

import logging

# Set up logging message formatting
logging.basicConfig(
    level="INFO", format="%(asctime)s %(message)s", datefmt="%d-%b-%y %H:%M:%S"
)

[2023-08-01 11:16:06,882][INFO]: Connecting root@localhost:3306
[2023-08-01 11:16:06,909][INFO]: Connected root@localhost:3306


## Simulated data


First, we'll simulate some data.

In [None]:
(
    time,
    position,
    sampling_frequency,
    spikes,
    place_fields,
) = rtc_spike_sim.make_simulated_run_data()

replay_time, test_spikes = rtc_spike_sim.make_continuous_replay()

## Set up classifier

In [None]:
movement_var = rts.estimate_movement_var(position, sampling_frequency)

environment = rtc_env.Environment(place_bin_size=np.sqrt(movement_var))

continuous_transition_types = [
    [
        rts.RandomWalk(movement_var=movement_var * 120),
        rts.Uniform(),
        rts.Identity(),
    ],
    [rts.Uniform(), rts.Uniform(), rts.Uniform()],
    [
        rts.RandomWalk(movement_var=movement_var * 120),
        rts.Uniform(),
        rts.Identity(),
    ],
]

classifier = rtc.SortedSpikesClassifier(
    environments=environment,
    continuous_transition_types=continuous_transition_types,
    # specify GPU enabled algorithm for the likelihood
    sorted_spikes_algorithm="spiking_likelihood_kde_gpu",
    sorted_spikes_algorithm_params={"position_std": 3.0},
)
state_names = ["continuous", "fragmented", "stationary"]

We can use a context manager to specify which GPU (device)


In [None]:
GPU_ID = 5  # Use GPU #6

with cp.cuda.Device(GPU_ID):
    # Fit the model place fields
    classifier.fit(position, spikes)

    # Run the model on the simulated replay
    results = classifier.predict(
        test_spikes,
        time=replay_time,
        state_names=state_names,
        use_gpu=True,  # Use GPU for computation of causal/acausal posterior
    )

## Multiple GPUs

Using multiple GPUs requires the `dask_cuda`:

```bash
conda install -c rapidsai -c nvidia -c conda-forge dask-cuda
```

We will set up a client to select GPUs. By default, this is all available 
GPUs. Below, we select a subset using the `CUDA_VISIBLE_DEVICES`.

In [3]:
cluster = dask_cuda.LocalCUDACluster(CUDA_VISIBLE_DEVICES=[4, 5, 6])
client = dask.distributed.Client(cluster)

client

2022-05-18 13:50:10,288 - distributed.diskutils - INFO - Found stale lock file and directory '/stelmo/edeno/nwb_datajoint/notebooks/dask-worker-space/worker-ly7bpyy1', purging
2022-05-18 13:50:10,296 - distributed.diskutils - INFO - Found stale lock file and directory '/stelmo/edeno/nwb_datajoint/notebooks/dask-worker-space/worker-n6nteep3', purging
2022-05-18 13:50:10,302 - distributed.diskutils - INFO - Found stale lock file and directory '/stelmo/edeno/nwb_datajoint/notebooks/dask-worker-space/worker-okcse855', purging
2022-05-18 13:50:10,305 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
2022-05-18 13:50:10,313 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
2022-05-18 13:50:10,319 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize


0,1
Connection method: Cluster object,Cluster type: dask_cuda.LocalCUDACluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 3
Total threads: 3,Total memory: 3.94 TiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:37725,Workers: 3
Dashboard: http://127.0.0.1:8787/status,Total threads: 3
Started: Just now,Total memory: 3.94 TiB

0,1
Comm: tcp://127.0.0.1:35731,Total threads: 1
Dashboard: http://127.0.0.1:46067/status,Memory: 1.31 TiB
Nanny: tcp://127.0.0.1:34151,
Local directory: /stelmo/edeno/nwb_datajoint/notebooks/dask-worker-space/worker-n17o941r,Local directory: /stelmo/edeno/nwb_datajoint/notebooks/dask-worker-space/worker-n17o941r
GPU: NVIDIA A100 80GB PCIe,GPU memory: 80.00 GiB

0,1
Comm: tcp://127.0.0.1:39549,Total threads: 1
Dashboard: http://127.0.0.1:42725/status,Memory: 1.31 TiB
Nanny: tcp://127.0.0.1:41769,
Local directory: /stelmo/edeno/nwb_datajoint/notebooks/dask-worker-space/worker-ppy7ls9e,Local directory: /stelmo/edeno/nwb_datajoint/notebooks/dask-worker-space/worker-ppy7ls9e
GPU: NVIDIA A100 80GB PCIe,GPU memory: 80.00 GiB

0,1
Comm: tcp://127.0.0.1:39335,Total threads: 1
Dashboard: http://127.0.0.1:34189/status,Memory: 1.31 TiB
Nanny: tcp://127.0.0.1:46373,
Local directory: /stelmo/edeno/nwb_datajoint/notebooks/dask-worker-space/worker-m7thw3ne,Local directory: /stelmo/edeno/nwb_datajoint/notebooks/dask-worker-space/worker-m7thw3ne
GPU: NVIDIA A100 80GB PCIe,GPU memory: 80.00 GiB


To use this client, we declare a function we want to run on each GPU with the
`dask.delayed` decorator.

In the example below, we run `test_gpu` on each item of `data` where each item is processed on a different GPU.

In [4]:
def setup_logger(name_logfile, path_logfile):
    """Sets up a logger for each function that outputs
    to the console and to a file"""
    logger = logging.getLogger(name_logfile)
    formatter = logging.Formatter(
        "%(asctime)s %(message)s", datefmt="%d-%b-%y %H:%M:%S"
    )
    fileHandler = logging.FileHandler(path_logfile, mode="w")
    fileHandler.setFormatter(formatter)
    streamHandler = logging.StreamHandler()
    streamHandler.setFormatter(formatter)

    logger.setLevel(logging.INFO)
    logger.addHandler(fileHandler)
    logger.addHandler(streamHandler)

    return logger


# This uses the dask.delayed decorator on the test_gpu function
@dask.delayed
def test_gpu(x, ind):
    # Create a log file for this run of the function
    logger = setup_logger(
        name_logfile=f"test_{ind}", path_logfile=f"test_{ind}.log"
    )

    # Test to see if these go into different log files
    logger.info(f"This is a test of {ind}")
    logger.info("This should be in a unique file")

    # Run a GPU computation
    return cp.asnumpy(cp.mean(x[:, None] @ x[:, None].T, axis=0))


# Make up 10 fake datasets
x = cp.random.normal(size=10_000, dtype=cp.float32)
data = [x + i for i in range(10)]

# Append the result of the computation into a results list
results = [test_gpu(x, ind) for ind, x in enumerate(data)]

# Run `dask.compute` on the results list for the code to run
dask.compute(*results)

18-May-22 13:50:12 This is a test of 4
18-May-22 13:50:12 This should be in a unique file
18-May-22 13:50:12 This is a test of 3
18-May-22 13:50:12 This should be in a unique file
18-May-22 13:50:12 This is a test of 1
18-May-22 13:50:12 This should be in a unique file
18-May-22 13:50:13 This is a test of 9
18-May-22 13:50:13 This should be in a unique file
18-May-22 13:50:13 This is a test of 6
18-May-22 13:50:13 This should be in a unique file
18-May-22 13:50:13 This is a test of 5
18-May-22 13:50:13 This should be in a unique file
18-May-22 13:50:13 This is a test of 0
18-May-22 13:50:13 This should be in a unique file
18-May-22 13:50:13 This is a test of 2
18-May-22 13:50:13 This should be in a unique file
18-May-22 13:50:13 This is a test of 7
18-May-22 13:50:13 This should be in a unique file
18-May-22 13:50:13 This is a test of 8
18-May-22 13:50:13 This should be in a unique file


(array([ 0.00106875, -0.0032616 , -0.00759213, ..., -0.00612375,
         0.0059738 , -0.01329288], dtype=float32),
 array([ 1.1191896 ,  0.6762629 ,  0.23331782, ...,  0.38350907,
         1.6208985 , -0.34978032], dtype=float32),
 array([4.23731  , 3.3557875, 2.4742277, ..., 2.7731416, 5.235823 ,
        1.3137323], dtype=float32),
 array([ 9.3554325,  8.035313 ,  6.7151375, ...,  7.1627736, 10.850748 ,
         4.9772453], dtype=float32),
 array([16.47355 , 14.714837, 12.956048, ..., 13.552408, 18.465672,
        10.640757], dtype=float32),
 array([25.591675, 23.39436 , 21.196953, ..., 21.942041, 28.0806  ,
        18.304268], dtype=float32),
 array([36.7098  , 34.073883, 31.437868, ..., 32.331676, 39.695522,
        27.967781], dtype=float32),
 array([49.827915, 46.753407, 43.67878 , ..., 44.721313, 53.31045 ,
        39.631294], dtype=float32),
 array([64.946045, 61.432938, 57.9197  , ..., 59.11094 , 68.92538 ,
        53.2948  ], dtype=float32),
 array([82.064156, 78.11245 , 74.1

This example also shows how to create a log file for each item in data with the `setup_logger` function.