# Approach

1. ResNet-50 Backbone: Use a pre-trained ResNet-50 model, modifying it to output features suitable for pose estimation. ResNet-50 is ideal for extracting spatial features due to its convolutional layers.

2. SSDLite Detector: Use SSDLite as the detection network for the body part localization, where SSDLite's lightweight design is efficient for mobile or low-resource deployment. SSDLite can detect bounding boxes around bird keypoints or entire bird regions.

3. Top-Down Pose Estimation: After detecting the bird, crop the region and apply ResNet-50 to localize keypoints. Top-down approach involves first detecting the bird in the image, and then estimating the pose of the detected bird.

In [1]:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from PIL import Image
from rich import print
import os

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision import models
from torchvision.models import ResNet50_Weights
import torch.nn.functional as F


# SSDLite Detection Network
from torchvision.models.detection import ssdlite320_mobilenet_v3_large, SSDLite320_MobileNet_V3_Large_Weights

In [3]:

# Device-Agnostic
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"""
Device: {device}
Device CUDNN enabled: {torch.backends.cudnn.enabled}
""")

In [4]:
IMG_WIDTH = 320
IMG_HEIGHT = 240

DATASET_ROOT = "datasets"
DATASET_FILE = DATASET_ROOT + "/preprocessed_dataset.csv"
dataset = pd.read_csv(DATASET_FILE)

In [5]:
# DATASET CLASS

class PoseDataset(Dataset):
    def __init__(self, dataframe: pd.DataFrame, dataset_root_folder: str, img_transform=None, kp_transform=None):
        self.annotations = dataframe  # Load the pandas DataFrame directly
        self.dataset_root_folder = dataset_root_folder  # Root folder for the dataset
        self.img_transform = img_transform
        self.kp_transform = kp_transform

    def __len__(self) -> int:
        return len(self.annotations)

    def __getitem__(self, idx: int):
        # Construct the image path
        behavior = self.annotations.iloc[idx]['behavior']
        image_id = self.annotations.iloc[idx]['image_id']
        image_file = self.annotations.iloc[idx]['image_file']
        
        # Create the full image path
        img_path = os.path.join(self.dataset_root_folder, behavior, image_id, image_file)
        
        # Load and process the image
        image = Image.open(img_path).convert("RGB")

        # Extract the keypoints (head_x, head_y, ..., body2_x, body2_y)
        keypoints = self.annotations.iloc[idx, 3:].values.astype('float32')

        # Ensure keypoints are in the correct shape
        if len(keypoints) % 2 != 0:
            raise ValueError("Keypoints must contain an even number of values (x, y pairs).")
        
        # Apply transformations if provided
        if self.img_transform:
            image = self.img_transform(image)

        if self.kp_transform:
            keypoints = self.kp_transform(keypoints)

        return image, keypoints


In [6]:
# Load pretrained ResNet-50 backbone
class ResNetBackbone(nn.Module):
    def __init__(self):
        super(ResNetBackbone, self).__init__()
        resnet = models.resnet50(weights=ResNet50_Weights.DEFAULT)
        self.features = nn.Sequential(*list(resnet.children())[:-2])  # Remove the FC layers

    def forward(self, x):
        return self.features(x)


In [7]:
# SSDLite Detection Network
ssd_model = ssdlite320_mobilenet_v3_large(weights=SSDLite320_MobileNet_V3_Large_Weights.DEFAULT)

In [8]:
class PoseEstimation(nn.Module):
    def __init__(self, num_keypoints=7):  # Adjust according to the bird's body parts
        super(PoseEstimation, self).__init__()
        self.num_keypoints = num_keypoints
        self.conv = nn.Conv2d(2048, 512, kernel_size=3, padding=1)  # Tune the layers as necessary
        self.fc = nn.Linear(512, num_keypoints * 2)  # Assuming output as (x, y) for each keypoint

    def forward(self, x):
        x = self.conv(x)
        x = nn.functional.relu(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x.view(-1, self.num_keypoints * 2)  # (batch_size, num_keypoints, (x, y))


In [9]:
class BirdPoseEstimationModel(nn.Module):
    def __init__(self, num_keypoints=7):
        super(BirdPoseEstimationModel, self).__init__()
        self.backbone = ResNetBackbone()
        self.ssd_detector = ssd_model
        self.pose_head = PoseEstimation(num_keypoints=num_keypoints)

    def forward(self, x):
        # Step 1: Detect bird region with SSDLite
        with torch.inference_mode():
            self.ssd_detector.eval()
            detections = self.ssd_detector(x)
        
        # Placeholder for simplicity: select one detected bird box
        # (Assume post-processing is done to select the most confident detection)
        if len(detections[0]['boxes']) > 0:
            bird_box = detections[0]['boxes'][0]  # For simplicity, use the first detection
        else:
            return None  # If no detections, return None
        
        # Step 2: Crop and process the bird region for pose estimation
        bird_region = x[:, :, int(bird_box[1]):int(bird_box[3]), int(bird_box[0]):int(bird_box[2])]
        
        # Step 3: Extract features from ResNet-50 backbone
        features = self.backbone(bird_region)
        print(features.shape)
        
        # Step 4: Predict keypoints using pose estimation head
        keypoints = self.pose_head(features)
        return keypoints

In [10]:
# KP Regression Transformation

class NormalizeKeypoints:
    def __init__(self, image_width: int, image_height: int):
        self.image_width = image_width
        self.image_height = image_height

    def __call__(self, keypoints):
        # Convert to tensor if not already a tensor
        keypoints = torch.tensor(keypoints, dtype=torch.float32) if not isinstance(keypoints, torch.Tensor) else keypoints
        
        keypoints[0::2] /= self.image_width  # Normalize x-coordinates
        keypoints[1::2] /= self.image_height  # Normalize y-coordinates
        return keypoints


class DenormalizeKeypoints:
    def __init__(self, image_width: int, image_height: int):
        self.image_width = image_width
        self.image_height = image_height

    def __call__(self, keypoints):
        # Convert to tensor if not already a tensor
        keypoints = torch.tensor(keypoints, dtype=torch.float32) if not isinstance(keypoints, torch.Tensor) else keypoints
        
        keypoints[0::2] *= self.image_width  # Denormalize x-coordinates
        keypoints[1::2] *= self.image_height  # Denormalize y-coordinates
        return keypoints

In [11]:

# Define transformations (modify as needed)
img_transform = transforms.Compose([
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),  # HxW Resize images to the input size of SSDLite
    transforms.ToTensor(),  # Convert image to tensor
])

kp_transform = NormalizeKeypoints(IMG_WIDTH, IMG_HEIGHT)

# Initialize the dataset and dataloader
pose_dataset = PoseDataset(dataframe=dataset, dataset_root_folder=DATASET_ROOT, img_transform=img_transform, kp_transform=kp_transform)
dataloader = DataLoader(pose_dataset, batch_size=16, shuffle=False, num_workers=0, pin_memory=True)


In [12]:
from tqdm import tqdm  # Import tqdm for the progress bar

# Initialize the model
num_keypoints = 7
model = BirdPoseEstimationModel(num_keypoints=num_keypoints).to(device)

# Define loss functions
criterion_bbox = nn.SmoothL1Loss()  # For bounding box regression 
criterion_kp = nn.MSELoss()

# Define optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-3)  # Adjust learning rate as needed

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()  # Set model to training mode
    total_loss = 0

    # Wrap the dataloader with tqdm to display a progress bar
    with tqdm(total=len(dataloader), desc=f'Epoch {epoch + 1}/{num_epochs}', unit='batch') as pbar:
        for images, keypoints in dataloader:
            images = images.to(device)
            keypoints = keypoints.to(device)

            # Zero the gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(images)

            if outputs is None:
                continue  # Skip this batch if no detections

            # Keypoints loss
            loss_kp = criterion_kp(outputs, keypoints)  # Keypoints loss

            # Combine losses if you have both bbox and kp losses
            total_loss = loss_kp  # If only using keypoints, otherwise: total_loss = loss_bbox + loss_kp
            
            # Backward pass and optimization step
            total_loss.backward()
            optimizer.step()

            # Update progress bar
            pbar.set_postfix(loss=total_loss.item())  # Update the loss in the progress bar
            pbar.update(1)  # Move the progress bar one step forward

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss.item():.4f}')

Epoch 1/10:   0%|          | 0/73 [00:00<?, ?batch/s]

Epoch 1/10:   0%|          | 0/73 [00:02<?, ?batch/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (16x27648 and 512x14)