In [1]:

import sys
import os

sys.path.append(os.path.dirname(os.getcwd()))

In [2]:
from sp_utils import update_config, pose_estimation, save_model

In [3]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from PIL import Image
from rich import print
import os

In [4]:
import torch
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader
from torchvision import models
from torchvision import transforms
from torchvision.ops import roi_align
import torch.nn.functional as F

In [5]:
# Device-Agnostic
DeviceLikeType = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(DeviceLikeType)
print(f"""
Device: {device}
Device CUDNN enabled: {torch.backends.cudnn.enabled}
""")

In [6]:
IMG_WIDTH = 320
IMG_HEIGHT = 240
NUM_KEYPOINTS = 7
NUM_BATCH = 16

CONFIG_PATH = "config.json"
MODEL_PATH = "models/pose_estimation"
DATASET_ROOT = "datasets"
TRAIN_DATASET_FILE = DATASET_ROOT + "/preprocessed_dataset.csv"
TEST_DATASET_FILE = DATASET_ROOT + "/test_dataset.csv"


train_df = pd.read_csv(TRAIN_DATASET_FILE)
test_df = pd.read_csv(TEST_DATASET_FILE)

In [7]:
train_df.head()

Unnamed: 0,behavior,image_id,image_file,head_x,head_y,beak_base_x,beak_base_y,beak_tip_x,beak_tip_y,neck_x,neck_y,body1_x,body1_y,body2_x,body2_y,tail_base_x,tail_base_y
0,0,n_001,59-20151230231705-00.jpg,19.234443,92.112384,41.246921,134.089668,39.711167,149.447212,61.211727,86.993203,79.640779,136.137341,123.153818,131.530078,176.393301,7.133978
1,0,n_001,59-20151230231706-00.jpg,12.579507,49.111263,43.806512,116.684453,55.580628,136.137341,70.426253,83.921694,83.736124,139.20885,145.166297,131.01816,164.619185,5.086305
2,0,n_001,59-20151230231714-00.jpg,24.865542,28.634538,38.175412,83.921694,35.615822,99.279237,61.211727,54.742362,86.295715,137.161177,139.535198,132.553914,193.286599,5.086305
3,0,n_001,59-20151230231720-00.jpg,23.841705,105.934174,38.175412,120.267882,42.270757,127.946653,49.949529,112.58911,73.497762,132.553916,111.379702,121.291718,126.225327,-0.544792
4,0,n_001,59-20151230231721-00.jpg,26.913214,101.83883,37.151576,117.708291,44.31843,125.387062,51.997201,109.005683,66.842826,139.720769,104.212848,131.018162,141.070952,3.550553


In [8]:
def soft_argmax(heatmaps: torch.Tensor, num_keypoints: int = 7) -> torch.Tensor:
    """
    Decode heatmaps into keypoint coordinates using soft-argmax.

    Args:
        heatmaps (torch.Tensor): Heatmaps of shape (batch_size, num_keypoints * 2, H, W).
        num_keypoints (int): The number of keypoints.

    Returns:
        torch.Tensor: Keypoint coordinates (batch_size, num_keypoints, 2).
    """
    batch_size, _, H, W = heatmaps.size()

    # Reshape the heatmaps to flatten the spatial dimensions (H, W)
    heatmaps = heatmaps.view(batch_size, num_keypoints * 2, -1)

    # Apply softmax along the flattened spatial dimensions
    probabilities = F.softmax(heatmaps, dim=-1)

    # Create index tensor for weighted sum
    indices = torch.arange(H * W, device=heatmaps.device).float().view(1, 1, -1)

    # Compute x and y coordinates by weighted sum
    x_coords = (indices % W) * probabilities  # Weighted x
    y_coords = (indices // W) * probabilities  # Weighted y

    # Sum over the spatial dimensions to get the weighted average coordinates
    x_coords = x_coords.sum(dim=-1)
    y_coords = y_coords.sum(dim=-1)

    # Return coordinates in (batch_size, num_keypoints, 2)
    return torch.stack([x_coords, y_coords], dim=-1)


In [None]:
class BirdPoseModel(nn.Module):
    def __init__(self, num_keypoints: int, num_classes: int = 1, name="resnet50_ssdlite_topdown"):
        super(BirdPoseModel, self).__init__()
        self.name = name

        # Load pretrained SSDLite for bounding box detection
        ssdlite = models.detection.ssdlite320_mobilenet_v3_large(weights='DEFAULT')
        self.detector = ssdlite
        self.detector.head.classification_head.num_classes = num_classes + 1  # +1: Default background class

        # Keypoint heatmap generation branch (try to predict the local `num_keypoints`)
        self.keypoint_head = nn.Sequential(
            nn.Conv2d(1280, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),  # BatchNorm to prevent neuron co-adaptation
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_keypoints * 2, kernel_size=1)  # Output heatmaps for keypoints
        )

        # Refinement branch with upsampling [num_keypoints -> 128] for better resolution
        self.refinement_head = nn.Sequential(
            nn.ConvTranspose2d(num_keypoints * 2, 128, kernel_size=4, stride=2, padding=1),  # Upsampling
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, num_keypoints * 2, kernel_size=1)
        )

    def forward(self, x: torch.Tensor, keypoints) -> dict:
        """
        Forward pass for SSDLite-based bounding box detection and keypoint prediction.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).

        Returns:
            dict: Dictionary containing bounding boxes and keypoint coordinates.
        """
        # SSDLite for bounding box detection (without ground truth)
        detections = self.detector(x)  # List[Dict] with bounding boxes and scores
        
        # Extract bounding boxes and corresponding features
        feature_maps = self.detector.backbone(x)  # Intermediate features from backbone
        detected_bboxes = [detection["boxes"] for detection in detections]  # Extract predicted bounding boxes

        keypoint_coords = []
        for i, bboxes in enumerate(detected_bboxes):
            if len(bboxes) > 0:
                # Crop regions using ROI Align
                cropped_features = roi_align(feature_maps[i:i+1], [bboxes], output_size=(7, 7))
                cropped_features = cropped_features.mean(dim=1, keepdim=True)  # Reduce channel dimension

                # Generate keypoint heatmaps
                heatmaps = self.keypoint_head(cropped_features)
                refined_heatmaps = self.refinement_head(heatmaps)

                # Decode heatmaps into keypoint coordinates
                coords = soft_argmax(refined_heatmaps, keypoints, num_keypoints=7).squeeze() 
                keypoint_coords.append(coords)
            else:
                keypoint_coords.append(torch.empty(0, device=x.device))  # No detections

        return {
            "detections": detections,  # SSDLite bounding boxes, and scores 
            "keypoints": keypoint_coords  # Decoded keypoint coordinates
        }


In [10]:
# Image Transformations Defination
img_transform = transforms.Compose([
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
    transforms.ToTensor(),
])

kp_transform = pose_estimation.NormalizeKeypoints(IMG_WIDTH, IMG_HEIGHT)

# Create datasets
train_dataset = pose_estimation.PoseDataset(
    dataframe=train_df, 
    dataset_root_folder=DATASET_ROOT, 
    img_transform=img_transform, 
    kp_transform=kp_transform
)

test_dataset = pose_estimation.PoseDataset(
    dataframe=test_df, 
    dataset_root_folder=DATASET_ROOT, 
    img_transform=img_transform, 
    kp_transform=kp_transform
)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=NUM_BATCH, shuffle=True, num_workers=0, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=NUM_BATCH, shuffle=False, num_workers=0, pin_memory=True)

In [11]:
model = BirdPoseModel(NUM_KEYPOINTS).to(device)

In [12]:
EPOCHS = 10
LEARNING_RATE = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.MSELoss()

In [13]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam
from tqdm import tqdm  # Optional: For progress bar

# Loss functions for bounding boxes and keypoints
def compute_keypoint_loss(pred_heatmaps, target_keypoints, batch_size):
    """
    Computes the loss for keypoints by comparing predicted heatmaps to ground truth keypoints.
    
    Args:
        pred_heatmaps (torch.Tensor): Predicted heatmaps, shape (batch_size, num_keypoints, H, W).
        target_keypoints (torch.Tensor): Ground truth keypoints, shape (batch_size, num_keypoints, 2).
        batch_size (int): Batch size.
    
    Returns:
        torch.Tensor: Computed keypoint loss.
    """
    # Using mean squared error (MSE) for keypoints
    keypoint_loss = 0
    for i in range(batch_size):
        for j in range(target_keypoints.shape[1]):
            # For each keypoint, calculate MSE between predicted heatmap and ground truth
            pred_coords = soft_argmax(pred_heatmaps[i, j:j+1])  # Predicted keypoint from soft-argmax
            true_coords = target_keypoints[i, j]
            keypoint_loss += F.mse_loss(pred_coords, true_coords)
    
    return keypoint_loss / (batch_size * target_keypoints.shape[1])


In [None]:

# Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

num_epochs = 10  # Number of epochs for training
for epoch in range(num_epochs):
    model.train()  # Set model to training mode
    running_loss = 0.0

    for images, keypoints in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        images = images.to(device)
        keypoints = keypoints.to(device)

        # Forward pass
        outputs = model(images, keypoints)

        detections = outputs["detections"]
        predicted_keypoints = outputs["keypoints"]

        # Flatten keypoints
        keypoints = keypoints.view(-1, NUM_KEYPOINTS, 2)
        predicted_keypoints = torch.stack(predicted_keypoints).view(-1, NUM_KEYPOINTS, 2)

        # Compute loss (L2 loss for keypoint predictions)
        loss = F.mse_loss(predicted_keypoints, keypoints)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")



Epoch 1/10:   0%|          | 0/73 [00:00<?, ?it/s]


AssertionError: targets should not be none when in training mode