<a href="https://colab.research.google.com/github/Carmenwang0724/Drosophila-Addiction-Tracking/blob/main/docs/notebooks/Interactive_and_resumable_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Interactive and resumable training

Most of the time, you will be training models through the GUI or using the [`sleap-train` CLI](https://sleap.ai/guides/cli.html#sleap-train).

If you'd like to customize the training process, however, you can use SLEAP's low-level training functionality interactively. This allows you to define scripts that train models according to your own workflow, for example, to **resume training** on an already trained model. Another possible application would be to train a model using **transfer learning**, where a pretrained model can be used to initialize the weights of the new model.

In this notebook we will explore how to set up a training job and train a model for multiple rounds without the GUI or CLI.

## 1. Setup SLEAP

Run this cell first to install SLEAP. If you get a dependency error in subsequent cells, just click **Runtime** → **Restart runtime** to reload the packages.

Don't forget to set **Runtime** → **Change runtime type** → **GPU** as the accelerator.

Import SLEAP to make sure it installed correctly and print out some information about the system:

In [None]:
# 1. Clean and install in one go
!pip uninstall -y numpy opencv-python opencv-python-headless sleap
!pip install "numpy<2.0" "sleap[nn]==1.5.2" "opencv-python-headless"

# 2. Restart to apply changes
import os
os.kill(os.getpid(), 9)

Found existing installation: numpy 2.0.2
Uninstalling numpy-2.0.2:
  Successfully uninstalled numpy-2.0.2
Found existing installation: opencv-python 4.13.0.90
Uninstalling opencv-python-4.13.0.90:
  Successfully uninstalled opencv-python-4.13.0.90
Found existing installation: opencv-python-headless 4.13.0.90
Uninstalling opencv-python-headless-4.13.0.90:
  Successfully uninstalled opencv-python-headless-4.13.0.90
[0mCollecting numpy<2.0
  Downloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting sleap==1.5.2 (from sleap[nn]==1.5.2)
  Downloading sleap-1.5.2-py3-none-any.whl.metadata (11 kB)
Collecting opencv-python-headless
  Downloading opencv_python_headless-4.13.0.90-cp37-abi3-manylinux_2_28_x86_64.whl.metadata (19 kB)
Collecting imgstore (from sleap==1.5.2->sleap[nn]==1.5.2)
  Downloading imgstore-0.3.7-

## 2. Setup training data

Here we will download an existing training dataset package. This is an `.slp` file that contains both the labeled poses, as well as the image data for labeled frames.

If running on Google Colab, you'll want to replace this with mounting your Google Drive folder containing your own data, or if running locally, simply change the path to your labels below in `TRAINING_SLP_FILE`.

In [2]:
from google.colab import drive
import os
drive.mount('/content/drive')

# UPDATE THESE PATHS TO YOUR ACTUAL FILES
TRAINING_DATA = "/content/drive/MyDrive/SLEAP/addictionrun_final.pkg.slp"
VIDEO_DIR = "/content/drive/MyDrive/SLEAP/"

# List of your 6 videos
video_files = [
    "10mins baseline.mp4",
    "Group 3-2.mp4",
    "Group 4-1.mp4",
    "Group 4-2.mp4",
    "Group3 1.2.mp4",
    "Group3 1.1.mp4"
]

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## 3. Setup training job

A SLEAP `TrainingJobConfig` is a structure that contains all of the hyperparameters needed to train a SLEAP model. This is typically saved out to `initial_config.json` and `training_config.json` in the model folder so that training runs can be reproduced if needed, as well as to store metadata necessary for inference.

Normally, these are generated interactively by the GUI, or manually by editing an existing JSON file in a text editor. Here, we will define a configuration interactively entirely in Python.

In [None]:
from omegaconf import OmegaConf

centroid_cfg = {
    "data_config": {
        "train_labels_path": [TRAINING_DATA],
        "validation_fraction": 0.1
    },
    "trainer_config": {
        "max_epochs": 50,
        "run_name": "fly_centroids",
        "ckpt_dir": "/content/models",
        "save_ckpt": True
    },
    "model_config": {
        "backbone_config": {"unet": {"filters": 16, "output_stride": 4}},
        "head_configs": {
            "centroid": {
                "confmaps": {
                    "anchor_part": "thorax",
                    "sigma": 5.0,
                    "output_stride": 4
                }
            }
        }
    }
}
OmegaConf.save(centroid_cfg, "centroid_config.yaml")

print("--- Starting Centroid Training ---")
!sleap-nn train --config-dir . --config-name centroid_config

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
                                                               [3m0.000            [0m
                                                               [3mval_loss_epoch:  [0m
                                                               [3m0.001            [0m
                                                               [3mtrain_loss_epoch:[0m
[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2KEpoch 8/49 [35m━━━━━━━━━━━━━━━[0m[35m╸[0m 197/200 [2m0:00:22 • 0:00:01[0m [2;4m8.71it/s[0m [3mtrain_loss_step: [0m
                                                               [3m0.001            [0m
                                                               [3mlearning_rate_st…[0m
                                                               [3m0.000            [0m
                                                               [3mval_loss_step:   [0m
      

Existing configs can also be loaded from a `.json` file with:

```python
cfg = sleap.load_config("training_config.json")
```

## 4. Training
Next we will create a SLEAP `Trainer` from the configuration we just specified. This handles all the nitty gritty mechanics necessary to setup training in the backend.

In [5]:

from sleap_nn.training.model_trainer import ModelTrainer

trainer = ModelTrainer(config=cfg)

ModuleNotFoundError: No module named 'sleap_nn'

In [None]:
parts_cfg = {
    "data_config": {
        "train_labels_path": [TRAINING_DATA],
        "validation_fraction": 0.1
    },
    "trainer_config": {
        "max_epochs": 70,
        "run_name": "fly_parts",
        "ckpt_dir": "/content/models",
        "save_ckpt": True
    },
    "model_config": {
        "backbone_config": {"unet": {"filters": 32, "output_stride": 4}},
        "head_configs": {
            "centered_instance": {
                "confmaps": {
                    "anchor_part": "thorax",
                    "sigma": 1.5,
                    "output_stride": 4
                }
            }
        }
    }
}
OmegaConf.save(parts_cfg, "parts_config.yaml")

print("--- Starting Parts Training ---")
!sleap-nn train --config-dir . --config-name parts_config

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
                                                               [3mtrain_loss_epoch:[0m
[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2KEpoch 13/69 [35m━━━━━━━━━━━━━━[0m[35m╸[0m[90m━[0m 186/200 [2m0:00:15 •      [0m [2;4m11.90it/s[0m [3mhead_step: 0.003 [0m
                                     [2m0:00:02        [0m           [3mthorax_step:     [0m
                                                               [3m0.000            [0m
                                                               [3mabdomen_step:    [0m
                                                               [3m0.003            [0m
                                                               [3mtrain_loss_step: [0m
                                                               [3m0.002            [0

In [3]:
from omegaconf import OmegaConf

# 1. CREATE THE CENTROID CONFIG (High Accuracy Version)
centroid_cfg = {
    "data_config": {
        "train_labels_path": ["/content/drive/MyDrive/SLEAP/addictionrun_final.pkg.slp"]
    },
    "trainer_config": {
        "max_epochs": 50,                 # Increased for stable detection of 6 flies
        "run_name": "fly_centroids",
        "ckpt_dir": "/content",
        "save_ckpt": True
    },
    "model_config": {
        "backbone_config": {"unet": {"filters": 16, "output_stride": 4}}, # Increased filters to 16
        "head_configs": {
            "centroid": {
                "confmaps": {
                    "anchor_part": "thorax",
                    "sigma": 5.0,
                    "output_stride": 4
                }
            }
        }
    }
}
OmegaConf.save(centroid_cfg, "/content/centroid_config.yaml")

# 2. START CENTROID TRAINING
print("Training the Centroid Finder (High Accuracy Mode)...")
!sleap-nn train --config-dir /content --config-name centroid_config

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
                                                               [3m0.000            [0m
                                                               [3mtrain_loss_epoch:[0m
[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2KEpoch 48/49 [35m━━━━━━━━━━━━━━[0m[35m╸[0m[90m━[0m 187/200 [2m0:00:20 •       [0m [2;4m8.95it/s[0m [3mtrain_loss_step: [0m
                                     [2m0:00:02         [0m          [3m0.000            [0m
                                                               [3mlearning_rate_st…[0m
                                                               [3m0.000            [0m
                                                               [3mval_loss_step:   [0m
                                                               [3m0.000            [0m
                                                               [3mlearning_ra

In [7]:
print("🔍 Checking for final model files...")
!find /content -name "best.ckpt"

🔍 Checking for final model files...
/content/fly_addiction_v1/best.ckpt
/content/fly_centroids/best.ckpt


In [None]:
# List your 7 videos here
video_list = [
    "/content/drive/MyDrive/SLEAP/10mins baseline.mp4",
    "/content/drive/MyDrive/SLEAP/Group 3-2.mp4",
    "/content/drive/MyDrive/SLEAP/Group 4-1.mp4",
    "/content/drive/MyDrive/SLEAP/Group 4-2.mp4",
    "/content/drive/MyDrive/SLEAP/Group3 1.2.mp4",
    "/content/drive/MyDrive/SLEAP/Group3 1.1.mp4"

    # ... add the rest
]

for i, video in enumerate(video_list):
    output = f"final_results_video_{i}.slp"
    print(f"Tracking Video {i+1} ...")

    # target_instance_count 6 is vital here!
    !sleap-track "$video" \
        -m "/content/fly_centroids" \
        -m "/content/fly_addiction_v1" \
        --tracking.tracker simple \
        --tracking.target_instance_count 6 \
        --batch_size 4 \
        -o "$output"

    # Convert each to CSV for your final spreadsheet analysis
    !sleap-convert "$output" --format analysis.csv

Tracking Video 1 ...
INFO:numexpr.utils:NumExpr defaulting to 12 threads.
2026-02-05 01:11:15 | INFO | sleap_nn.predict:run_inference:349 | Started inference at: 2026-02-05 01:11:15.836341
2026-02-05 01:11:15 | INFO | sleap_nn.predict:run_inference:365 | Using device: cuda
[2KPredicting... [90m━━━━━━━━━━━[0m [35m100%[0m [32m72181/721…[0m ETA: [36m0:00:00[0m Elapsed: [33m0:25:10[0m [31m46.6 FPS[0m
[?25h

In [24]:
import pandas as pd
import sleap_io as sio
import numpy as np

# 1. Load the results file
print("📂 Opening AI tracking results...")
labels = sio.load_slp("fly_study_results.slp")

# 2. Get node names to create CSV columns
skeleton = labels.skeletons[0]
node_names = [node.name for node in skeleton.nodes]
print(f"🦴 Skeleton nodes: {node_names}")

# 3. Extract coordinates using the Matrix method
print(f"🧪 Extracting coordinates for {len(labels.labeled_frames)} frames...")
rows = []

for lf in labels.labeled_frames:
    for inst in lf.instances:
        # Get Fly ID (Track)
        track_id = inst.track.name if inst.track else "untracked"

        # Start the row
        row = {"frame": lf.frame_idx, "fly_id": track_id}

        # THE FIX: Convert the whole fly instance to a simple numpy array
        # This returns an array of shape (nodes, 2) -> (3, 2)
        points_matrix = inst.numpy()

        for i, name in enumerate(node_names):
            # i=0 is head, i=1 is thorax, i=2 is abdomen
            x_val = points_matrix[i, 0]
            y_val = points_matrix[i, 1]

            # Only add if the value is a real number (not NaN)
            row[f"{name}_x"] = x_val if not np.isnan(x_val) else ""
            row[f"{name}_y"] = y_val if not np.isnan(y_val) else ""

        rows.append(row)

# 4. Final Processing and Save
if rows:
    df = pd.DataFrame(rows).sort_values(["frame", "fly_id"])
    df.to_csv("FLY_COORDINATES_FINAL.csv", index=False)
    print("-" * 30)
    print("🎉 SUCCESS! DATA EXTRACTION COMPLETE.")
    print(f"Total entries: {len(df)}")
    print("📁 DOWNLOAD 'FLY_COORDINATES_FINAL.csv' FROM THE SIDEBAR.")
else:
    print("❌ No data points found to extract.")

# Preview the data for your HMK report
print("\nTop 5 rows of your research data:")
print(df.head())

📂 Opening AI tracking results...
🦴 Skeleton nodes: ['head', 'thorax', 'abdomen']
🧪 Extracting coordinates for 27466 frames...
------------------------------
🎉 SUCCESS! DATA EXTRACTION COMPLETE.
Total entries: 53565
📁 DOWNLOAD 'FLY_COORDINATES_FINAL.csv' FROM THE SIDEBAR.

Top 5 rows of your research data:
   frame   fly_id      head_x      head_y    thorax_x    thorax_y   abdomen_x  \
0   1611  track_0  892.603088  551.390930  896.929321  549.751526  889.077393   
1   1613  track_0  886.168091  544.648193  896.516296  547.871399  887.650879   
2   1666  track_0  892.512634  551.503357  896.847046  549.836365  888.922424   
3   1690  track_0  892.728699  551.772217  897.057251  550.119812  889.209534   
4   1711  track_0  892.668030  551.728149  896.978699  550.069885  889.079407   

    abdomen_y  
0  541.262268  
1  539.089417  
2  541.547363  
3  541.891235  
4  541.824402  


In [21]:
import pandas as pd
import sleap_io as sio

# 1. Load the results
print("📂 Opening file...")
labels = sio.load_slp("fly_study_results.slp")

print(f"📊 Total frames in file: {len(labels.labeled_frames)}")

# 2. The Extraction Loop
rows = []
for lf in labels.labeled_frames:
    # We check for BOTH regular instances and predicted instances
    for inst in lf.instances:
        # Get Fly ID (Track)
        track_id = inst.track.name if inst.track else "untracked"

        row = {
            "frame_idx": lf.frame_idx,
            "fly_id": track_id,
            "type": "prediction" # labeling the data type
        }

        # Extract points using the most direct dictionary access
        for node, pt in inst.points.items():
            row[f"{node.name}_x"] = pt.x
            row[f"{node.name}_y"] = pt.y

        rows.append(row)

# 3. Check and Save
if len(rows) == 0:
    print("❌ Still 0 rows found. Let's try one more deep-access method...")
    # Deep-access for certain SLEAP-NN versions
    for lf in labels:
        for inst in lf:
            # (Repeating extraction logic inside deep loop)
            track_id = inst.track.name if inst.track else "untracked"
            row = {"frame_idx": lf.frame_idx, "fly_id": track_id}
            for node, pt in inst.points.items():
                row[f"{node.name}_x"] = pt.x
                row[f"{node.name}_y"] = pt.y
            rows.append(row)

if len(rows) > 0:
    df = pd.DataFrame(rows)
    df.to_csv("FLY_TRACKING_DATA_FINAL.csv", index=False)
    print("-" * 30)
    print(f"🎉 SUCCESS! Found {len(df)} fly positions.")
    print(f"Detected Nodes: {[c for c in df.columns if '_x' in c]}")
    print("📁 Download 'FLY_TRACKING_DATA_FINAL.csv' now!")
else:
    print("‼️ ERROR: The file is physically empty. This means the AI 'Zeroed Out'.")
    print("Check: Did you label BOTH the Centroid and the Head in your 30 frames?")

📂 Opening file...
📊 Total frames in file: 27466


AttributeError: 'PredictedPointsArray' object has no attribute 'items'

In [17]:
# Force-track with a very low threshold (0.05) so it doesn't delete your dots
!sleap-track "/content/drive/MyDrive/SLEAP/phase1allpatterns.mov" \
    -m "/content/fly_centroids" \
    -m "/content/fly_addiction_v1" \
    --tracking.tracker simple \
    --peak_threshold 0.05 \
    -o "low_threshold_results.slp"

INFO:numexpr.utils:NumExpr defaulting to 2 threads.
2026-02-01 07:55:14 | INFO | sleap_nn.predict:run_inference:349 | Started inference at: 2026-02-01 07:55:14.511258
2026-02-01 07:55:14 | INFO | sleap_nn.predict:run_inference:365 | Using device: cuda
[2K2026-02-01 08:13:25 | ERROR | sleap_nn.data.providers:run:200 | Error when 
reading video frame. Stopping video reader.
Failed to read frame index 30771.
[2KPredicting... [91m━━━━━━━━━━[0m[91m╸[0m [35m100%[0m [32m30771/307…[0m ETA: [36m0:00:01[0m Elapsed: [33m0:18:11[0m [31m26.4 FPS[0m
[?25h
Aborted!


In [2]:
# 1. THE NUMPY PROTECTION (Mandatory)
import sys, numpy as np
if np.version.version.startswith("2."):
    import numpy._core as core
    sys.modules['numpy.core'] = core

# 2. THE STABLE ENGINE (Bypasses all Merge/Type errors)
# Put your file path in the first line below
!sleap-nn train \
    data_config.train_labels_path=["/content/drive/MyDrive/SLEAP/fly_trial.v001.pkg.slp"] \
    trainer_config.max_epochs=50 \
    trainer_config.run_name="fly_addiction_v1" \
    model_config.backbone_config.unet.filters=16 \
    model_config.head_configs.centered_instance.confmaps.anchor_part="thorax" \
    model_config.head_configs.centered_instance.confmaps.output_stride=4


sleap-nn train — Train SLEAP models from a config YAML file.

Usage:
  sleap-nn train --config-dir <dir> --config-name <name> [overrides]

Common overrides:
  trainer_config.max_epochs=100
  trainer_config.batch_size=32

Examples:
  Start new run:
    sleap-nn train --config-dir /path/to/config_dir/ --config-name myrun
  Resume 20 more epochs:
    sleap-nn train --config-dir /path/to/config_dir/ --config-name myrun \
      trainer_config.resume_ckpt_path=<path/to/ckpt> \
      trainer_config.max_epochs=20

Tips:
  - Use -m/--multirun for sweeps; outputs go under hydra.sweep.dir.
  - For Hydra flags and completion, use --hydra-help.

For a detailed list of all available config options, please refer to https://nn.sleap.ai/config/.



Great, now we're ready to do the first round of training. This is when the model will actually start to improve over time:

In [None]:
trainer.train()

2026-01-29 10:27:10 | INFO | sleap_nn.training.model_trainer:train:849 | Setting up for training...
2026-01-29 10:27:10 | INFO | sleap_nn.training.model_trainer:_setup_train_val_labels:216 | Creating train-val split...


TypeError: 'TrainingJobConfig' object is not subscriptable

## 5. Continuing training

If we still have the trainer in memory, we can continue training by simply calling `trainer.train()` again with a potentially different number of epochs:

In [None]:
trainer.config.optimization.epochs = 3
trainer.train()

AttributeError: 'TrainingJobConfig' object has no attribute 'optimization'

As you can see, the loss and accuracy pick up from where it left off in the previous training.


Usually, however, if you're continuing training it's likely because you're starting off from an already trained model.

In this case, all you need to do to continue training is to create a new `Trainer` from the existing model configuration and load up the weights before continuing training:

In [None]:
# Load config.
cfg = sleap.load_config("models/baseline_model.topdown")
# cfg.outputs.run_name = "new_folder"  # Set the run_name to a new value if you want the model to be saved to a different folder.

# Create and initialize the trainer.
trainer = sleap.nn.training.Trainer.from_config(cfg)
trainer.setup()

# Replace the randomly initialized weights with the saved weights.
trainer.keras_model.load_weights("models/baseline_model.topdown/best_model.h5")

INFO:sleap.nn.training:Loading training labels from: labels.pkg.slp
INFO:sleap.nn.training:Creating training and validation splits from validation fraction: 0.1
INFO:sleap.nn.training:  Splits: Training = 1440 / Validation = 160.
INFO:sleap.nn.training:Setting up for training...
INFO:sleap.nn.training:Setting up pipeline builders...
INFO:sleap.nn.training:Setting up model...
INFO:sleap.nn.training:Building test pipeline...
INFO:sleap.nn.training:Loaded test example. [0.925s]
INFO:sleap.nn.training:  Input shape: (160, 160, 1)
INFO:sleap.nn.training:Created Keras model.
INFO:sleap.nn.training:  Backbone: UNet(stacks=1, filters=16, filters_rate=2.0, kernel_size=3, stem_kernel_size=7, convs_per_block=2, stem_blocks=0, down_blocks=4, middle_block=True, up_blocks=2, up_interpolate=False, block_contraction=False)
INFO:sleap.nn.training:  Max stride: 16
INFO:sleap.nn.training:  Parameters: 2,101,501
INFO:sleap.nn.training:  Heads: 
INFO:sleap.nn.training:    [0] = CenteredInstanceConfmapsHead

In [None]:
trainer.config.optimization.epochs = 3
trainer.train()

INFO:sleap.nn.training:Creating tf.data.Datasets for training data generation...
INFO:sleap.nn.training:Finished creating training datasets. [17.7s]
INFO:sleap.nn.training:Starting training loop...
Epoch 1/3
360/360 - 9s - loss: 8.3664e-04 - head: 3.5190e-04 - thorax: 1.7037e-04 - abdomen: 9.8467e-04 - wingL: 7.9929e-04 - wingR: 8.0385e-04 - forelegL4: 0.0012 - forelegR4: 0.0012 - midlegL4: 9.5228e-04 - midlegR4: 9.8510e-04 - hindlegL4: 0.0013 - hindlegR4: 0.0013 - eyeL: 4.0772e-04 - eyeR: 3.9413e-04 - val_loss: 8.7351e-04 - val_head: 4.0943e-04 - val_thorax: 1.7453e-04 - val_abdomen: 9.4413e-04 - val_wingL: 8.3617e-04 - val_wingR: 8.4860e-04 - val_forelegL4: 0.0012 - val_forelegR4: 0.0012 - val_midlegL4: 9.4441e-04 - val_midlegR4: 0.0011 - val_hindlegL4: 0.0014 - val_hindlegR4: 0.0014 - val_eyeL: 4.4847e-04 - val_eyeR: 4.4179e-04 - lr: 1.0000e-04 - 9s/epoch - 24ms/step
Epoch 2/3
360/360 - 7s - loss: 8.0541e-04 - head: 3.4627e-04 - thorax: 1.6070e-04 - abdomen: 9.4325e-04 - wingL: 7.72

Output()

INFO:sleap.nn.evals:Saved predictions: models/baseline_model.topdown/labels_pr.train.slp
INFO:sleap.nn.evals:Saved metrics: models/baseline_model.topdown/metrics.train.npz
INFO:sleap.nn.evals:OKS mAP: 0.585451


Output()

INFO:sleap.nn.evals:Saved predictions: models/baseline_model.topdown/labels_pr.val.slp
INFO:sleap.nn.evals:Saved metrics: models/baseline_model.topdown/metrics.val.npz
INFO:sleap.nn.evals:OKS mAP: 0.574921


Again, the loss and accuracy pick up from where they left off prior to this round of training.

The resulting model can be used as usual for inference on new data.