# Bee Orientation Project Experiment Notebook

This notebook provides a comprehensive script to run the entire experiment for the Bee Orientation project. 
It covers dataset preparation, model definition, training, evaluation, visualization of predictions, and model analysis.

## 1. Setup and Imports

This section imports all necessary libraries and modules from the `src` directory of the `CV-BeeOrientation` repository. 

In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# Import project-specific modules
from src.dataset import BeeSegmentationDataset
from src.models.base import UNet3
from src.models.resunet import ResUNet18
from src.utils.split_dataset import split_dataset_two_stage
from src.utils.training import train_model
from src.utils.evaluation import evaluate_on_test, collect_evaluation_data, load_checkpoint
from src.utils import plots

## 2. Dataset Setup

This section checks if the dataset is already available.
If not, it downloads the original dataset archives, extracts them, and runs the preparation script to produce cropped images and masks ready for training.

In [None]:
!python scripts/prepare_dataset.py

## 3. Configure and Load Dataset

Here, we configure parameters for the dataset and data loaders, including batch size and shuffling. 
The `split_dataset_two_stage` utility splits the dataset into training, validation, and test sets.

In [None]:
# Dataset Configuration
DATA_IMAGES_PATH = "data/processed/images"
DATA_MASKS_PATH = "data/processed/masks"
CSV_PATH = "data/processed/labels.csv"
BATCH_SIZE = 16
SHUFFLE_TRAIN_DATA = True
RANDOM_SEED = None
# Set device for training and inference
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

print(f"Loading dataset from images: {DATA_IMAGES_PATH} and masks: {DATA_MASKS_PATH}")
dataset = BeeSegmentationDataset(DATA_IMAGES_PATH, DATA_MASKS_PATH)

train_dataset, val_dataset, test_dataset, seed = split_dataset_two_stage(dataset, seed=RANDOM_SEED)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=SHUFFLE_TRAIN_DATA)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

## 4. Global Experiment Parameters

In [None]:
GLOBAL_LEARNING_RATE = 1e-3
GLOBAL_NUM_EPOCHS = 20
GLOBAL_LOSS_WEIGHTS = torch.tensor([0.1, 1.0, 1.0]).to(device)

## 5. Experiment for UNet3 Model

This section runs the complete experiment pipeline specifically for the `UNet3` model.

### 5.1. UNet3 Model Definition

In [None]:
print(f"\n{'='*80}")
print("Starting Experiment for Model: UNet3")
print(f"{'='*80}\n")

model_unet3 = UNet3().to(device)

### 5.2. UNet3 Training Setup

In [None]:
optimizer_unet3 = torch.optim.Adam(model_unet3.parameters(), lr=GLOBAL_LEARNING_RATE)
criterion_unet3 = nn.CrossEntropyLoss(weight=GLOBAL_LOSS_WEIGHTS)

### 5.3. UNet3 Model Training

In [None]:
train_losses_unet3, val_losses_unet3, best_path_unet3 = train_model(
    model_unet3,
    train_loader,
    val_loader,
    optimizer_unet3,
    criterion_unet3,
    device,
    num_epochs=GLOBAL_NUM_EPOCHS,
    checkpoint_filename=None
)

plots.plot_training_curves(train_losses_unet3, val_losses_unet3, model_unet3, save_path="results/UNet3 Loss.png")

### 5.4. UNet3 Model Evaluation

In [None]:
model_unet3 = load_checkpoint(model_unet3, best_path_unet3, device)

In [None]:
df = pd.read_csv(CSV_PATH)
gt_csv = {row["mask_filename"]: row["angle"] for _, row in df.iterrows()}
data_unet3, avg_loss_unet3 = collect_evaluation_data(model_unet3, test_loader, criterion_unet3, device, gt_csv)

In [None]:
results_unet3 = evaluate_on_test(data_unet3, avg_loss_unet3)
plots.plot_orientation_error_distribution(results_unet3["orientation"]["all_errors_deg"], model_unet3, save_path="results/UNet3 Orientation Error.png")
plots.plot_miou_vs_orientation_error(data_unet3, model_unet3)
plots.plot_miou_vs_orientation_error_hexbin(data_unet3, model_unet3, save_path="results/UNet3 mIoU vs Orientation Error.png")
plots.plot_signed_orientation_error_distribution(data_unet3, model_unet3, save_path="results/UNet3 Signed Orientation Error Distribution.png")
plots.plot_orientation_error_vs_gt_angle_hexbin(data_unet3, model_unet3, save_path="results/UNet3 Orientation Error vs GT-Angle.png")

### 5.5. UNet3 Model Size

In [None]:
total_params_unet3 = sum(p.numel() for p in model_unet3.parameters() if p.requires_grad)
print(f"\nNumber of trainable parameters in UNet3: {total_params_unet3}")

### 5.6. UNet3 Example Prediction Masks

In [None]:
plots.plot_predictions(data_unet3, model_unet3, save_path="results/UNet3 Prediction Masks.png")
plots.plot_worst_orientation_errors(data_unet3, model_unet3, save_path="results/UNet3 Worst Orientation Errors.png")
plots.plot_orientation_errors_in_range(data_unet3, model_unet3, save_path="results/UNet3 Orientation Errors 90deg Range.png")

## 6. Experiment for ResUNet18 Model

This section runs the complete experiment pipeline specifically for the `ResUNet18` model.

### 6.1. ResUNet18 Model Definition

In [None]:
print(f"\n{'='*80}")
print("Starting Experiment for Model: ResUNet18")
print(f"{'='*80}\n")

model_resunet18 = ResUNet18().to(device)

### 6.2. ResUNet18 Training Setup

In [None]:
criterion_resunet18 = nn.CrossEntropyLoss(weight=GLOBAL_LOSS_WEIGHTS)
optimizer_resunet18 = torch.optim.Adam(model_resunet18.parameters(), lr=GLOBAL_LEARNING_RATE)

### 6.3. ResUNet18 Model Training

In [None]:
train_losses_resunet18, val_losses_resunet18, best_path_resunet18 = train_model(
    model_resunet18,
    train_loader,
    val_loader,
    optimizer_resunet18,
    criterion_resunet18,
    device,
    num_epochs=GLOBAL_NUM_EPOCHS,
    checkpoint_filename=None
)

plots.plot_training_curves(train_losses_resunet18, val_losses_resunet18, model_resunet18, save_path="results/ResUNet18 Loss.png")

### 6.4. ResUNet18 Model Evaluation

In [None]:
model_resunet18 = load_checkpoint(model_resunet18, best_path_resunet18, device)

In [None]:
df = pd.read_csv(CSV_PATH)
gt_csv = {row["mask_filename"]: row["angle"] for _, row in df.iterrows()}
data_resunet18, avg_loss_resunet18 = collect_evaluation_data(model_resunet18, test_loader, criterion_resunet18, device, gt_csv)

In [None]:
results_resunet18 = evaluate_on_test(data_resunet18, avg_loss_resunet18)
plots.plot_orientation_error_distribution(results_resunet18["orientation"]["all_errors_deg"], model_resunet18, save_path="results/ResUNet18 Orientation Error.png")
plots.plot_miou_vs_orientation_error(data_resunet18, model_resunet18)
plots.plot_miou_vs_orientation_error_hexbin(data_resunet18, model_resunet18, save_path="results/ResUNet18 mIoU vs Orientation Error.png")
plots.plot_signed_orientation_error_distribution(data_resunet18, model_resunet18, save_path="results/ResUNet18 Signed Orientation Error Distribution.png")
plots.plot_orientation_error_vs_gt_angle_hexbin(data_resunet18, model_resunet18, save_path="results/ResUNet18 Orientation Error vs GT-Angle.png")

### 6.5. ResUNet18 Model Size

In [None]:
total_params_resunet18 = sum(p.numel() for p in model_resunet18.parameters() if p.requires_grad)
print(f"\nNumber of trainable parameters in ResUNet18: {total_params_resunet18}")

### 6.6. ResUNet18 Example Prediction Masks

In [None]:
plots.plot_predictions(data_resunet18, model_resunet18, save_path="results/ResUNet18 Prediction Masks.png")
plots.plot_worst_orientation_errors(data_resunet18, model_resunet18, save_path="results/ResUNet18 Worst Orientation Errors.png")
plots.plot_orientation_errors_in_range(data_resunet18, model_resunet18, save_path="results/ResUNet18 Orientation Errors 90deg Range.png")