# BIPN 145 Fly Tracker

This notebook tracks a fruit fly against a light background in a video, calculates its position and velocity over time, and produces path and velocity plots.

**Original MATLAB code** by Jeff Stafford, modified by A. Juavinett for BIPN 145.

## How to use
1. Run the **Setup** cell to install dependencies.
2. **Upload your video** (`.avi`, `.mp4`, etc.) using the upload cell.
3. **Set your parameters** (dish diameter, frame rate).
4. **Draw your ROI** on the first frame.
5. Run the remaining cells to track the fly and view results.

## Setup

In [None]:
!pip install ipympl -q

import numpy as np
import cv2
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from scipy.spatial.distance import euclidean
from google.colab import files, output
from IPython.display import display, clear_output
import ipywidgets as widgets
import os

output.enable_custom_widget_manager()

print('All packages loaded successfully!')

## Upload Video(s)

Upload one or more fly videos. Supported formats: `.avi`, `.mp4`, `.mov`, `.mkv`

In [None]:
uploaded = files.upload()
video_files = sorted(uploaded.keys())
print(f'{len(video_files)} file(s) uploaded: {video_files}')

## Set Parameters

In [None]:
#@title Parameters { run: "auto" }

# Diameter of the dish in centimeters
diameter = 4  #@param {type:"number"}

# Frame rate of the video (frames per second)
frame_rate = 30  #@param {type:"integer"}

# Search size for fly detection (in pixels) - generally leave at 20
search_size = 20  #@param {type:"integer"}

# Per-pixel intensity threshold - generally leave at 1.5
per_pixel_threshold = 1.5  #@param {type:"number"}

# Bin size for velocity calculation (in seconds)
bin_size = 1  #@param {type:"number"}

height = diameter
width = diameter

print(f'Dish diameter: {diameter} cm')
print(f'Frame rate: {frame_rate} fps')
print(f'Search size: {search_size} px')
print(f'Velocity bin size: {bin_size} s')

## Helper Functions

These replicate the MATLAB helper functions (`flyFinder`, `distFilter`, `interpolatePos`).

In [None]:
def fly_finder(roi_image, half_search, threshold, flip=True):
    """
    Find a fly (dark region) in a grayscale image.
    Locates the darkest pixel, retrieves a search area around it,
    and finds the center of pixel intensity.

    Returns (x, y) position or (NaN, NaN) if not found.
    """
    if flip:
        val = np.nanmin(roi_image)
    else:
        val = np.nanmax(roi_image)

    ys, xs = np.where(roi_image == val)
    xpos = np.mean(xs)
    ypos = np.mean(ys)

    h, w = roi_image.shape
    left = max(int(round(xpos) - half_search), 0)
    right = min(int(round(xpos) + half_search), w - 1)
    top = max(int(round(ypos) - half_search), 0)
    bottom = min(int(round(ypos) + half_search), h - 1)

    search_area = roi_image[top:bottom+1, left:right+1].astype(np.float64)

    if flip:
        search_area = 255.0 - search_area

    total = np.sum(search_area)

    if total >= threshold:
        # Center of mass
        x_indices = np.arange(search_area.shape[1])
        y_indices = np.arange(search_area.shape[0])
        x = np.sum(search_area @ x_indices) / total + left
        y = np.sum(search_area.T @ y_indices) / total + top
        return x, y
    else:
        return np.nan, np.nan


def dist_filter(array, tele_dist_threshold, num_avg=5):
    """
    Teleport filter: removes spurious points where fly position
    jumps far from the mean of surrounding frames.

    array: Nx3 array [time, x, y]
    """
    filtered = array.copy()
    tele_count = 0

    for i in range(num_avg, len(filtered) - num_avg):
        point = filtered[i, 1:3]
        if np.any(np.isnan(point)):
            continue

        last_set = filtered[i - num_avg:i, 1:3]
        last_set = last_set[~np.isnan(last_set[:, 0])]
        if len(last_set) == 0:
            continue
        last_mean = np.mean(last_set, axis=0)

        next_set = filtered[i + 1:i + 1 + num_avg, 1:3]
        next_set = next_set[~np.isnan(next_set[:, 0])]
        if len(next_set) == 0:
            continue
        next_mean = np.mean(next_set, axis=0)

        if (euclidean(point, last_mean) > tele_dist_threshold or
                euclidean(point, next_mean) > tele_dist_threshold):
            filtered[i, 1:3] = np.nan
            tele_count += 1

    # More stringent check at start and end
    for idx in list(range(0, min(5, len(filtered) - 1))) + \
               list(range(max(0, len(filtered) - 6), len(filtered) - 1)):
        if np.any(np.isnan(filtered[idx, 1:3])) or np.any(np.isnan(filtered[idx + 1, 1:3])):
            continue
        if euclidean(filtered[idx, 1:3], filtered[idx + 1, 1:3]) > tele_dist_threshold / 2:
            filtered[idx, 1:3] = np.nan
            tele_count += 1

    print(f'{tele_count} points removed by the teleportation filter.')
    return filtered


def interpolate_pos(array, inter_dist_threshold):
    """
    Linearly interpolate fly position between NaN gaps,
    as long as the gap endpoints are within inter_dist_threshold.

    array: Nx3 array [time, x, y]
    """
    result = array.copy()
    interp_count = 0

    col_pairs = [(1, 2)]  # x, y columns
    for cx, cy in col_pairs:
        i = 0
        while i < len(result):
            if np.isnan(result[i, cx]) and i > 0:
                last_idx = i - 1
                last_point = result[last_idx, cx:cy+1]
                # Find next non-NaN
                remaining = result[i:, cx]
                non_nan = np.where(~np.isnan(remaining))[0]
                if len(non_nan) == 0:
                    break
                next_idx = non_nan[0] + i
                next_point = result[next_idx, cx:cy+1]
                gap = next_idx - i

                if euclidean(last_point, next_point) <= inter_dist_threshold:
                    for j in range(1, gap + 1):
                        frac = j / (gap + 1)
                        result[last_idx + j, cx:cy+1] = last_point + (next_point - last_point) * frac
                    interp_count += gap

                i = next_idx
            elif np.isnan(result[i, cx]) and i == 0:
                non_nan = np.where(~np.isnan(result[:, cx]))[0]
                if len(non_nan) == 0:
                    break
                i = non_nan[0]
            else:
                i += 1

    print(f'{interp_count} points recovered through interpolation.')
    return result


print('Helper functions defined.')

## Select ROI (Region of Interest)

Run this cell to display the first frame. **Click and drag** to draw a rectangle around the dish. You can redraw as many times as you like — only the last rectangle counts. When you're happy with the selection, click the **Confirm ROI** button.

In [None]:
%matplotlib widget

def select_roi_interactive(video_path):
    """
    Interactive ROI selector: click and drag on the image to draw a rectangle.
    Press 'Confirm ROI' when done.
    """
    cap = cv2.VideoCapture(video_path)
    ret, frame = cap.read()
    cap.release()
    if not ret:
        raise ValueError(f'Could not read first frame of {video_path}')

    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    fig, ax = plt.subplots(figsize=(10, 8))
    ax.imshow(frame_rgb)
    ax.set_title(f'Click and drag to draw ROI, then click "Confirm ROI"')
    ax.set_xlabel('X (pixels)')
    ax.set_ylabel('Y (pixels)')

    # State for the rectangle drawing
    state = {'drawing': False, 'x0': 0, 'y0': 0, 'rect': None, 'roi': None}

    def on_press(event):
        if event.inaxes != ax:
            return
        state['drawing'] = True
        state['x0'] = event.xdata
        state['y0'] = event.ydata
        # Remove old rectangle if redrawing
        if state['rect'] is not None:
            state['rect'].remove()
            state['rect'] = None
        fig.canvas.draw_idle()

    def on_motion(event):
        if not state['drawing'] or event.inaxes != ax:
            return
        if state['rect'] is not None:
            state['rect'].remove()
        w = event.xdata - state['x0']
        h = event.ydata - state['y0']
        state['rect'] = plt.Rectangle(
            (state['x0'], state['y0']), w, h,
            linewidth=2, edgecolor='red', facecolor='none')
        ax.add_patch(state['rect'])
        fig.canvas.draw_idle()

    def on_release(event):
        if not state['drawing'] or event.inaxes != ax:
            return
        state['drawing'] = False
        x0, y0 = state['x0'], state['y0']
        x1, y1 = event.xdata, event.ydata
        # Normalize so x,y is top-left corner
        rx = int(min(x0, x1))
        ry = int(min(y0, y1))
        rw = int(abs(x1 - x0))
        rh = int(abs(y1 - y0))
        state['roi'] = (rx, ry, rw, rh)

    fig.canvas.mpl_connect('button_press_event', on_press)
    fig.canvas.mpl_connect('motion_notify_event', on_motion)
    fig.canvas.mpl_connect('button_release_event', on_release)

    plt.tight_layout()
    plt.show()

    # Confirm button
    confirm_btn = widgets.Button(description='Confirm ROI', button_style='success')
    roi_output = widgets.Output()
    display(confirm_btn, roi_output)

    def on_confirm(b):
        with roi_output:
            if state['roi'] is None:
                print('No ROI drawn yet! Please click and drag on the image first.')
            else:
                roi_result['value'] = state['roi']
                print(f"ROI confirmed: x={state['roi'][0]}, y={state['roi'][1]}, "
                      f"w={state['roi'][2]}, h={state['roi'][3]}")

    confirm_btn.on_click(on_confirm)

    # We need to store the result somewhere accessible
    return state, roi_output

# Global dict to hold the confirmed ROI
roi_result = {'value': None}

# Launch the interactive selector
roi_state, roi_out = select_roi_interactive(video_files[0])
print('\nDraw a rectangle around the dish, then click "Confirm ROI".')

In [None]:
# Run this cell AFTER you have confirmed your ROI above
%matplotlib inline

roi = roi_result['value']
if roi is None:
    # Fallback: check the state dict directly
    roi = roi_state['roi']

if roi is None:
    raise ValueError("No ROI selected! Go back and draw a rectangle on the image, then click 'Confirm ROI'.")

print(f'Using ROI: x={roi[0]}, y={roi[1]}, width={roi[2]}, height={roi[3]}')

## Track Fly in Video(s)

This processes each video frame-by-frame:
1. Creates a background image from 100 random frames
2. Subtracts the background from each frame
3. Finds the fly position using center-of-mass of dark pixels
4. Applies teleportation filter and interpolation
5. Converts pixel positions to centimeters

In [None]:
def process_video(video_path, roi, diameter, frame_rate, search_size, per_pixel_threshold):
    """
    Process a single fly video and return the corrected position array.

    Returns:
        corrected_array: Nx3 array [time_s, x_cm, y_cm]
        fps: actual frame rate from the video
    """
    height = diameter
    width = diameter
    roi_x, roi_y, roi_w, roi_h = roi

    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    nfrm = total_frames - 1

    print(f'\nProcessing: {os.path.basename(video_path)}')
    print(f'  Video FPS: {fps}, Total frames: {total_frames}')
    print(f'  Using frame_rate parameter: {frame_rate} for time conversion')

    # --- Create background from 100 random frames ---
    print('  Calculating background...')
    bg_number = min(100, nfrm)
    bg_indices = sorted(np.random.choice(nfrm, bg_number, replace=False))

    # Read first frame to get dimensions
    cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
    ret, sample = cap.read()
    gray_sample = cv2.cvtColor(sample, cv2.COLOR_BGR2GRAY)
    bg_array = np.zeros((*gray_sample.shape, bg_number), dtype=np.uint8)

    for idx, frame_num in enumerate(bg_indices):
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
        ret, frame = cap.read()
        if ret:
            bg_array[:, :, idx] = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

    background = np.mean(bg_array, axis=2).astype(np.uint8)

    # --- Process each frame ---
    print('  Tracking fly positions...')
    threshold = (search_size ** 2) * per_pixel_threshold
    half_search = round(search_size / 2)

    pos_array = np.zeros((nfrm, 3))

    for nofr in range(nfrm):
        cap.set(cv2.CAP_PROP_POS_FRAMES, nofr)
        ret, frame = cap.read()
        if not ret:
            pos_array[nofr] = [nofr, np.nan, np.nan]
            continue

        frame_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY).astype(np.float64)

        # Background subtraction using GIMP division formula
        frame_div = np.clip((256.0 * frame_gray) / (background.astype(np.float64) + 1), 0, 255).astype(np.uint8)

        # Crop to ROI
        frame_crop = frame_div[roi_y:roi_y + roi_h, roi_x:roi_x + roi_w]

        # Find fly
        fx, fy = fly_finder(frame_crop, half_search, threshold, flip=True)
        pos_array[nofr] = [nofr, fx, fy]

        # Progress update every 10%
        if (nofr + 1) % max(1, nfrm // 10) == 0:
            pct = (nofr + 1) / nfrm * 100
            print(f'    {pct:.0f}% complete ({nofr + 1}/{nfrm} frames)')

    cap.release()

    # --- Convert to real coordinates ---
    xscale = width / roi_w
    yscale = height / roi_h

    corrected_array = np.column_stack([
        pos_array[:, 0] / frame_rate,  # time in seconds
        pos_array[:, 1] * xscale,       # x in cm
        pos_array[:, 2] * yscale        # y in cm
    ])

    skipped = np.sum(np.isnan(corrected_array[:, 1]))
    print(f'  {skipped} points skipped out of {nfrm}.')

    # Apply teleport filter and interpolation
    corrected_array = dist_filter(corrected_array, 2)
    corrected_array = interpolate_pos(corrected_array, 2)

    # Manual fix for 15fps videos (matching MATLAB behavior)
    if frame_rate == 15:
        corrected_array[:, 0] = corrected_array[:, 0] / 4

    return corrected_array


# --- Process all uploaded videos ---
all_corrected = []
for vf in video_files:
    corrected = process_video(vf, roi, diameter, frame_rate,
                              search_size, per_pixel_threshold)
    all_corrected.append(corrected)

print(f'\nDone! Processed {len(all_corrected)} video(s).')

## Plot Fly Path(s)

Color-coded by time (blue = start, red = end).

In [None]:
for idx, corrected in enumerate(all_corrected):
    x = corrected[:, 1]
    y = corrected[:, 2]
    t = corrected[:, 0]

    fig, ax = plt.subplots(figsize=(6, 6))

    # Plot path colored by time
    valid = ~np.isnan(x) & ~np.isnan(y)
    points = np.array([x[valid], y[valid]]).T.reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)

    from matplotlib.collections import LineCollection
    lc = LineCollection(segments, cmap='viridis', linewidth=2)
    lc.set_array(t[valid][:-1])
    ax.add_collection(lc)

    ax.set_xlim(0, width)
    ax.set_ylim(height, 0)  # Invert y to match video
    ax.set_aspect('equal')
    ax.set_xlabel('X-coordinate (cm)', fontsize=11)
    ax.set_ylabel('Y-coordinate (cm)', fontsize=11)
    ax.set_title(f'Fly Path — {video_files[idx]}')

    cbar = fig.colorbar(lc, ax=ax, orientation='horizontal', pad=0.1)
    cbar.set_label('Time (s)')

    plt.tight_layout()
    plt.show()

## Calculate & Plot Velocity

In [None]:
all_velocity = []

for idx, corrected in enumerate(all_corrected):
    x = corrected[:, 1]
    y = corrected[:, 2]

    total_time = len(x) / frame_rate
    total_bins = int(np.floor(total_time / bin_size))

    # Calculate velocity per bin
    data_rate = round(1.0 / corrected[1, 0]) * bin_size if corrected[1, 0] > 0 else frame_rate * bin_size
    data_rate = int(data_rate)

    if data_rate < 1:
        raise ValueError('bin_size is smaller than the minimum data rate.')

    velocity = np.zeros(total_bins)
    for row in range(0, len(corrected) - data_rate, data_rate):
        bin_idx = row // data_rate
        if bin_idx >= total_bins:
            break
        p1 = corrected[row, 1:3]
        p2 = corrected[row + data_rate, 1:3]
        if np.any(np.isnan(p1)) or np.any(np.isnan(p2)):
            velocity[bin_idx] = np.nan
        else:
            # 10x converts cm to mm
            velocity[bin_idx] = 10.0 * euclidean(p1, p2)

    # Convert from mm/bin to mm/s
    velocity = velocity / bin_size

    all_velocity.append(velocity)

    # Warn about absurd velocities
    if np.nanmax(velocity) > 30:
        print(f'WARNING ({video_files[idx]}): Absurdly high velocities detected.')
        print('  Consider changing the ROI or re-recording the video.')

    mean_vel = np.nanmean(velocity)
    std_vel = np.nanstd(velocity)
    print(f'\n--- {video_files[idx]} ---')
    print(f'  Mean velocity: {mean_vel:.2f} mm/s')
    print(f'  Std deviation: {std_vel:.2f} mm/s')

    # Plot velocity over time
    time_axis = np.arange(total_bins) * bin_size
    fig, ax = plt.subplots(figsize=(10, 4))
    ax.plot(time_axis, velocity, linewidth=1.5)
    ax.set_xlim(0, time_axis[-1] + bin_size if len(time_axis) > 0 else 1)
    ax.set_ylim(0, np.nanmax(velocity) * 1.5 if np.nanmax(velocity) > 0 else 1)
    ax.set_xlabel('Time (s)', fontsize=11)
    ax.set_ylabel('Velocity (mm/s)', fontsize=11)
    ax.set_title(f'Fly Velocity — {video_files[idx]}')
    plt.tight_layout()
    plt.show()

## Summary Across All Videos

In [None]:
num_files = len(all_velocity)

if num_files > 1:
    # Pad velocity arrays to the same length for comparison
    max_len = max(len(v) for v in all_velocity)
    velocity_matrix = np.full((num_files, max_len), np.nan)
    for i, v in enumerate(all_velocity):
        velocity_matrix[i, :len(v)] = v

    # Plot all velocities together
    fig, ax = plt.subplots(figsize=(10, 5))
    time_axis = np.arange(max_len) * bin_size
    for i in range(num_files):
        ax.plot(time_axis, velocity_matrix[i], linewidth=2,
                label=f'Fly {i + 1}')
    ax.set_xlabel('Time (s)', fontsize=11)
    ax.set_ylabel('Velocity (mm/s)', fontsize=11)
    ax.set_title('All Fly Velocities')
    ax.legend()
    plt.tight_layout()
    plt.show()

    # Summary stats across videos
    per_video_means = [np.nanmean(v) for v in all_velocity]
    mean_across = np.mean(per_video_means)
    sd_across = np.std(per_video_means)
    print(f'\n=== Summary Across {num_files} Videos ===')
    print(f'Mean velocity across videos: {mean_across:.2f} mm/s')
    print(f'SD of mean velocity across videos: {sd_across:.2f} mm/s')
else:
    mean_vel = np.nanmean(all_velocity[0])
    sd_vel = np.nanstd(all_velocity[0])
    print(f'\n=== Summary (1 Video) ===')
    print(f'Mean velocity: {mean_vel:.2f} mm/s')
    print(f'SD of velocity: {sd_vel:.2f} mm/s')

## Download Results (Optional)

Download the tracking data as CSV files.

In [None]:
#@title Download tracking data? { run: "auto" }
download_data = True  #@param {type:"boolean"}

if download_data:
    for idx, corrected in enumerate(all_corrected):
        base = os.path.splitext(video_files[idx])[0]
        csv_name = f'{base}_tracking.csv'
        np.savetxt(csv_name, corrected, delimiter=',',
                   header='Time_s,X_cm,Y_cm', comments='')
        files.download(csv_name)
        print(f'Downloaded {csv_name}')

    # Also save velocity data
    for idx, vel in enumerate(all_velocity):
        base = os.path.splitext(video_files[idx])[0]
        csv_name = f'{base}_velocity.csv'
        time_axis = np.arange(len(vel)) * bin_size
        vel_data = np.column_stack([time_axis, vel])
        np.savetxt(csv_name, vel_data, delimiter=',',
                   header='Time_s,Velocity_mm_per_s', comments='')
        files.download(csv_name)
        print(f'Downloaded {csv_name}')