# Quantification of animal behavior

- practice notebook for first steps in visualizing tracked centroid positions
- original notebook by Lukas Breitzler 2024
- updated 2026 by Johannes Larsch

### In this exercise we will be using numpy arrays (a type of data structure in python) to visualize the movement of a fish.

- we will be using tracked data from the fighting zebrafish videos that you hand annotated in the first session
- we have tracked these videos for you with idTracker.ai, a high-performance animal tracker.

### Select the cell below and press CTRL + Enter to import the libraries that we will need in for this exercise 

In [None]:
import numpy as np                  # This imports the library numpy, which is used for numerical operations and mathematical computations. It allows you to handle multi-dimensional arrays and matrices in python and perform several of mathematical functions.
import matplotlib.pyplot as plt      # This line imports Matplotlib, which is a widely used  library for creating visualizations in python.
import pandas as pd                    # This line imports the pandas library, which is used for data manipulation and analysis. It provides data structures and functions needed to manipulate structured data.
from scipy.signal import find_peaks  # This line imports the find_peaks function from the scipy.signal module, which is used to identify peaks in a signal or data set. It can be used to find local maxima in a 1D array, which is useful for analyzing data such as time series or signals.

<br>

### 1: let's load some fish trajectories using pandas. 
- Select the cell below and press CTRL + Enter to import the data we will be analyzing and save it as an object named 'fish_coordinates'.
- You can choose which of the trajectories you want to load: gt1, gt2, or gt3 by changing the file name.
- We wrote the code for you for this one.

In [None]:
fish_coordinates = pd.read_csv('gt1_trajectories.csv') # This line loads the fish coordinates from a CSV file named 'gt1_trajectories.csv' using the pandas function read_csv. It reads the CSV file and creates a DataFrame object that contains the data from the file, which can be easily manipulated and analyzed using pandas functions.


<br>

### 2: Take a look at the data.
- In the cell below, type in 'fish_coordinates' and execute the cell. 
- Do you get an output? What do you think this array and the numbers in it represent?

##### Enter your code here:

##### Enter your answer here:

<br>

### 3: How much data do you have?
- You might have noticed that you only see a few lines of your fish_coordinates dataframe. 
- That means the array is too big to be visualized. 
- Try adding the command '.shape' after fish_coordinates to see the dimensions of your array. 
- How many rows and how many columns does the array have?


##### Code:

##### Answer:

### 4: Selecting data of one fish
- You can extract parts of your dataframe by its ['column name']
- e.g. fish_coordinates['x1'] returns the value stored in the first row and first column. 
- alternatively, you can index individual elements via .iloc.
- fish_coordinates.iloc[0,0] returns the value stored in the first row/first column. 
- fish_coordinates.iloc[0,:] returns all values stored in the first row. 
- As you can see, indexing starts at 0.

- In the two cells below, write lines of code that returns all values in the first and second columns of your fish_coordinates array. (Bonus. Try adding .shape after the new arrays and check whether / how the dimensions changed)

##### Code to return all values in first colum:

##### Code to return all values in second colum:

<br>

### 5: Plotting trajectories

- You can create various plots using the matplotlib library. 
- E.g. plt.plot(x, y) creates a line plot of whatever x and y values you enter. 
- Generate a scatter plot that displays all of the values stored in your fish_coordinates array.
- Use the command plt.scatter(x, y) for this and replace the x and y values with those stored in your fish_coordinates array.

What do you see? (Tipp: if you add 's = 0.01' in the plt.scatter brackets after your values, things might become more clear)

##### Code:

<br>

b) Now let's visualize time in this scatter plot. 
- Add the parameter c = np.arange(0, fish_coordinates.shape[0]) in plt.scatter and look at the output.
- add a colorbar using plt.colorbar() - what does it show?

<br>

c) Lastly, let's visualize the same data, but as a heatmap. 
- You can do this using plt.hist2d(). Insert your x and y coordinates and add the parameter 'bins = 50' inside the function to increase histogram resolution.
- function calls are more organized if you create clearly labled variables, e.g. x1=fish_coordinates['x1'].
- depending on your data set, you may get an error with hist2d, which is related to NaN (not a number) - missing data for a few frames.
- you can remove those with a command like x1_clean=x1[~np.isnan(x1)]. ~np.isnan(x1) filters out any NaN values from the variable, ensuring that only valid coordinates are included in the histogram.
- you can check how many values are affected like so: np.isnan(x1).sum()

<br>

<br>

# You got to this point? Congratulations! 
- Please upload your time-trajectory and the heatmap to Moodle.
- The following steps are slightly more advanced but still conceptually very do-able.
- Try to get as far as you can and post your plots to moodle.
- Ask us, your neighbor, google or chatGPT for help and enjoy!
 

# Optional steps below:

### 6: How to calculate speed?

- Letâ€™s say we want to learn more about the behavior of this fish. 
- For this we want to look at its movement more closely and determine how the fish moves over time.  
- To determine speed we need to know two variables. Distance and time. 
- First, think about how you can determine distance mathematically using the data from your fish_coordinates file. 
- And in high-level terms how you want to calculate speed.


<br>

### 7: Implement a speed formula

- Using whatever mathmatical formula you decided on, enter some code below which calculates the distance between the first and second time point of the fish_coordinates array.       
- it may help to save the x and y coordinates of both time points as objects and perform your calculation on them
- you can also write a function if you want.
- you can think about this via pythagoras theorem or skip ahead use a ready made function for the norm of a vector, there are many good solutions to this.

**---------------------------------------------------------------------------------------------------------------------**
#### Tipp: Python and Numpy support various arithmetic operations on arrays. Here are some examples:

**Addition:**   
array1 = np.array([1, 2, 3]    
array2 = np.array([4, 5, 6])   
array1 + array2  
Output: [5 7 9]

**Subtraction:**    
array2 - array1    
Output: [3, 3, 3]

**Multiplication:**    
array1 * array2   
Output: [4, 10, 18]

**Power:**  
2 ** 3    
Output: 8

**Square root:**    
np.sqrt(9)    
Output: 3    

**---------------------------------------------------------------------------------------------------------------------**


- Save the first and second time points as objects:
- ( this is just to help you conceptualize how you want to calculate this for the entire time series)

In [None]:
x1 = 
y1 = 
x2 = 
y2 = 

Code to calculate the distance between the first and second time point:

<br>

### 8: Write code to determine the distances between each of the time points for the entire fish_coordinates array. 

- You might need several lines of code for this.
- Save the output in an object called "speed"

**---------------------------------------------------------------------------------------------------------------------**
#### Tipp: Numpy support various vector operations on arrays. Here are some examples:

**Calculate all differences along a given array:**   
np.diff([1, 2, 3, 5])    
Output: [1 1 2]

**Calculate the length or magnitude of a vector:**    
np.linalg.norm([1,1])    
Output: 1.4142135623730951

**---------------------------------------------------------------------------------------------------------------------**




Code to calculate the distance between all time points:

<br>

### 9: Plot the speed time series
- you can use plt.plot() for this.
- Tipp: Zoom in a little using the plt.xlim(), or numpy '[ ]' subsetting commands to get a better idea of what's going on. 
- You should be able to see the speed of the fish and identify individual swim bouts.

Code:

<br>

### 10: Speed-annotated trajectory
- Use the speed values from the previous question to color code the coordinate scatter plot you generated in question 5   
For this, use the 'c' parameter in plt.scatter() and assign your speed values to it. 
- The speed traces and coordinates have to be the same shape. 
- You can prepend a 0 using np.insert(speed, 0, 0), or skip the first entry of the trajectory.   
- Tipp: further parameters to add: cmap = 'Reds', s=5, or make the size also dependent on speed!
- Can you see isolated 'high-speed events'? What might those correspond to?

Code:

<br>

# Still hanging on? Congratulations!

- below, questions will ge a bit harder


### 11: Let's segment the movement of this fish further. 
- a) Use the find_peaks() function from scipy on the array you generated in Question 8 to identify individual bouts. 
- Read the documentation here https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.find_peaks.html and select appropriate parameters, e.g. distance and prominence.      
- b) Check how many bouts / peaks you detected     
- c) Plot the detected peaks on top of the speed traces. Similar to your plot from Question 9. Use plt.scatter in combination with plt.plot for this in two separate lines of code. 

**Tipp for part c**: You can subset numpy arrays by specifying indices using another array. For example:      
indices = np.array([0, 2, 5, 9])      
data = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])     
data[indices]    
Output: [0, 2, 5, 9]

Code:

<br>

<br>

### 12: Plot an average of all bouts - an event triggered average!
Use the peak indices from the previous question to generate an average bout plot. For this, plot all speed traces centered on each bout peak and include all speed values from 15 frames before until 15 frames after


Code to plot an average of all bouts:

In [None]:
# %load ./snippets/solution07.py
#uncomment above to load a solution


<br>

<br>

### 13: Bonus task:
You can try to load your manual fight annotations into this notebook 
- overlay your annotations with the trajectories
- overlay your annotations with the speed trace
- do you see any relationships emerging?

Code to plot annotation overlays