In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from PIL import Image
from rich import print
import os

In [2]:
import torch
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader
from torchvision import models
from torchvision import transforms

In [3]:
# Device-Agnostic
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"""
Device: {device}
Device CUDNN enabled: {torch.backends.cudnn.enabled}
""")

In [4]:
IMG_WIDTH = 320
IMG_HEIGHT = 240
NUM_KEYPOINTS = 7
NUM_BATCH = 16

MODEL_PATH = "../models/pose_estimation"
DATASET_ROOT = "../datasets"
DATASET_FILE = DATASET_ROOT + "/preprocessed_dataset.csv"

dataset = pd.read_csv(DATASET_FILE)

# SKELETON = []

In [5]:
dataset.head()

Unnamed: 0,behavior,image_id,image_file,head_x,head_y,beak_base_x,beak_base_y,beak_tip_x,beak_tip_y,neck_x,neck_y,body1_x,body1_y,body2_x,body2_y,tail_base_x,tail_base_y
0,nesting,n_001,59-20151230231705-00.jpg,19.234443,92.112384,41.246921,134.089668,39.711167,149.447212,61.211727,86.993203,79.640779,136.137341,123.153818,131.530078,176.393301,7.133978
1,nesting,n_001,59-20151230231706-00.jpg,12.579507,49.111263,43.806512,116.684453,55.580628,136.137341,70.426253,83.921694,83.736124,139.20885,145.166297,131.01816,164.619185,5.086305
2,nesting,n_001,59-20151230231714-00.jpg,24.865542,28.634538,38.175412,83.921694,35.615822,99.279237,61.211727,54.742362,86.295715,137.161177,139.535198,132.553914,193.286599,5.086305
3,nesting,n_001,59-20151230231720-00.jpg,23.841705,105.934174,38.175412,120.267882,42.270757,127.946653,49.949529,112.58911,73.497762,132.553916,111.379702,121.291718,126.225327,-0.544792
4,nesting,n_001,59-20151230231721-00.jpg,26.913214,101.83883,37.151576,117.708291,44.31843,125.387062,51.997201,109.005683,66.842826,139.720769,104.212848,131.018162,141.070952,3.550553


In [6]:
class PoseDataset(Dataset):
    def __init__(self, dataframe, dataset_root_folder, img_transform=None, kp_transform=None):
        self.annotations = dataframe  # Load the pandas DataFrame directly
        self.dataset_root_folder = dataset_root_folder  # Root folder for the dataset
        self.img_transform = img_transform
        self.kp_transform = kp_transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        # Construct the image path from behavior, image_id, and image_file columns
        behavior = self.annotations.iloc[idx]['behavior']
        image_id = self.annotations.iloc[idx]['image_id']
        image_file = self.annotations.iloc[idx]['image_file']
        
        # Create the full image path
        img_path = os.path.join(self.dataset_root_folder, behavior, image_id, image_file)
        
        # Load and process the image
        image = Image.open(img_path).convert("RGB")
        
        # Extract the keypoints (head_x, head_y, ..., body2_x, body2_y) as numpy array
        keypoints = self.annotations.iloc[idx, 3:].values.astype('float32')
        
        if self.img_transform:
            image = self.img_transform(image)

        if self.kp_transform:
            keypoints = self.kp_transform(keypoints)
        

        return image, keypoints


In [7]:
## Q. Why nn.SiLU (Swish Activation) [ f(x) = x * sigmoid(x) ]
## A. Allows smooth, non-monotonic behavior, enhancing gradient flow

# ARCHITECTURE = "efficientnet_b0_batch_norm2d_swish"

# class BirdPoseModel(nn.Module):
#     def __init__(self, num_keypoints: int):
#         super(BirdPoseModel, self).__init__()
        
#         # Efficient backbone (EfficientNet-B0) for feature extraction
#         efficientnet = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
#         self.backbone = nn.Sequential(*list(efficientnet.features))
        
#         # Reduce feature map channel sizes
#         self.conv_layers = nn.Sequential(
#             nn.Conv2d(1280, 128, kernel_size=3, padding=1),  # From EfficientNet-B0 last layer (1280 channels)
#             nn.BatchNorm2d(128),
#             nn.SiLU(),  # Swish Activation
#             nn.Conv2d(128, 64, kernel_size=3, padding=1),
#             nn.BatchNorm2d(64),
#             nn.SiLU()
#         )
        
#         # Global average pooling
#         self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        
#         # Fully connected layer for keypoints prediction
#         self.fc = nn.Sequential(
#             nn.Dropout(0.3),  # Add dropout for regularization
#             nn.Linear(64, num_keypoints * 2),
#         )

#     def forward(self, x: torch.Tensor) -> torch.Tensor:
#         # Feature extraction
#         x = self.backbone(x)
        
#         # Channel size reduction
#         x = self.conv_layers(x)
        
#         # Global average pooling
#         x = self.global_avg_pool(x)  # Shape: (batch_size, 64, 1, 1)
#         x = torch.flatten(x, 1)  # Shape: (batch_size, 64)
        
#         # Fully connected layer for keypoints
#         x = self.fc(x)  # Shape: (batch_size, num_keypoints * 2)
#         return x


In [8]:
# ARCHITECTURE = "first_resnet50"

# class BirdPoseModel(nn.Module):
#     def __init__(self, num_keypoints: int):
#         super(BirdPoseModel, self).__init__()
#         resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        
#         # Use all layers except the last two
#         self.backbone = nn.Sequential(*list(resnet.children())[:-2])
        
#         # Reduce the output channel size
#         self.conv = nn.Conv2d(2048, 512, kernel_size=3, padding=1)
        
#         # Global average pooling for spatial dimensions
#         self.global_avg_pool = nn.AdaptiveAvgPool2d(1)

#         # Fully connected layer for final predictions
#         self.fc = nn.Linear(512, num_keypoints * 2)

#     def forward(self, x: torch.Tensor) -> torch.Tensor:
#         # Pass through the backbone
#         x = self.backbone(x)
        
#         # Apply convolutional layer
#         x = self.conv(x)
#         x = nn.ReLU()(x)
        
#         # Global average pooling
#         x = self.global_avg_pool(x)  # Shape: (batch_size, 512, 1, 1)
#         x = torch.flatten(x, 1)  # Shape: (batch_size, 512)
        
#         # Fully connected layer for keypoints
#         x = self.fc(x)  # Shape: (batch_size, num_keypoints * 2)
        
#         # Reshape to (batch_size, num_keypoints * 2)
#         return x.view(-1, NUM_KEYPOINTS * 2)

In [9]:
ARCHITECTURE = "resnet50_batch_norm2d_relu"

class BirdPoseModel(nn.Module):
    def __init__(self, num_keypoints: int):
        super(BirdPoseModel, self).__init__()
        
        # Load ResNet-50 backbone and remove the last two layers
        resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        self.backbone = nn.Sequential(*list(resnet.children())[:-2])
        
        # Reduce the channel size progressively
        self.conv_layers = nn.Sequential(
            nn.Conv2d(2048, 1024, kernel_size=3, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.Conv2d(1024, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        
        # Global average pooling
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        
        # Fully connected layer for keypoints
        self.fc = nn.Linear(64, num_keypoints * 2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Pass input through the ResNet backbone
        x = self.backbone(x)
        
        # Apply the convolutional layers
        x = self.conv_layers(x)
        
        # Global average pooling
        x = self.global_avg_pool(x)  # Shape: (batch_size, 64, 1, 1)
        x = torch.flatten(x, 1)  # Shape: (batch_size, 64)
        
        # Fully connected layer for keypoint prediction
        x = self.fc(x)  # Shape: (batch_size, num_keypoints * 2)
        
        return x


In [10]:
# ARCHITECTURE = "resnet50_batch_norm2d_swish"

# class BirdPoseModel(nn.Module):
#     def __init__(self, num_keypoints: int):
#         super(BirdPoseModel, self).__init__()
        
#         # Load ResNet-50 backbone and remove the last two layers
#         resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
#         self.backbone = nn.Sequential(*list(resnet.children())[:-2])
        
#         # Reduce the channel size progressively with BatchNorm and SiLU
#         self.conv_layers = nn.Sequential(
#             nn.Conv2d(2048, 1024, kernel_size=3, padding=1),
#             nn.BatchNorm2d(1024),
#             nn.SiLU(),
            
#             nn.Conv2d(1024, 512, kernel_size=3, padding=1),
#             nn.BatchNorm2d(512),
#             nn.SiLU(),
            
#             nn.Conv2d(512, 256, kernel_size=3, padding=1),
#             nn.BatchNorm2d(256),
#             nn.SiLU(),
            
#             nn.Conv2d(256, 64, kernel_size=3, padding=1),
#             nn.BatchNorm2d(64),
#             nn.SiLU()
#         )
        
#         # Global average pooling
#         self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        
#         # Fully connected layer for keypoints
#         self.fc = nn.Linear(64, num_keypoints * 2)

#     def forward(self, x: torch.Tensor) -> torch.Tensor:
#         # Pass input through the ResNet backbone
#         x = self.backbone(x)
        
#         # Apply the convolutional layers
#         x = self.conv_layers(x)
        
#         # Global average pooling
#         x = self.global_avg_pool(x)  # Shape: (batch_size, 64, 1, 1)
#         x = torch.flatten(x, 1)  # Shape: (batch_size, 64)
        
#         # Fully connected layer for keypoint prediction
#         x = self.fc(x)  # Shape: (batch_size, num_keypoints * 2)
        
#         return x


In [11]:
class NormalizeKeypoints:
    def __init__(self, image_width: int, image_height: int):
        self.image_width = image_width
        self.image_height = image_height

    def __call__(self, keypoints):
        # Convert to tensor if not already a tensor
        keypoints = torch.tensor(keypoints, dtype=torch.float32) if not isinstance(keypoints, torch.Tensor) else keypoints
        
        keypoints[0::2] /= self.image_width  # Normalize x-coordinates
        keypoints[1::2] /= self.image_height  # Normalize y-coordinates
        return keypoints


class DenormalizeKeypoints:
    def __init__(self, image_width: int, image_height: int):
        self.image_width = image_width
        self.image_height = image_height

    def __call__(self, keypoints):
        # Convert to tensor if not already a tensor
        keypoints = torch.tensor(keypoints, dtype=torch.float32) if not isinstance(keypoints, torch.Tensor) else keypoints.clone()
        
        # Perform non-in-place operations
        denormalized = keypoints.clone()
        denormalized[0::2] = denormalized[0::2] * self.image_width  # Denormalize x-coordinates
        denormalized[1::2] = denormalized[1::2] * self.image_height  # Denormalize y-coordinates
        return denormalized


In [12]:

# Image Transformations Defination
img_transform = transforms.Compose([
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
    transforms.ToTensor(),
])

kp_transform = NormalizeKeypoints(IMG_WIDTH, IMG_HEIGHT)

# Initialize the dataset and dataloader
pose_dataset = PoseDataset(dataframe=dataset, dataset_root_folder=DATASET_ROOT, img_transform=img_transform, kp_transform=kp_transform)
dataloader = DataLoader(pose_dataset, batch_size=NUM_BATCH, shuffle=True, num_workers=0, pin_memory=True)

In [13]:
model = BirdPoseModel(NUM_KEYPOINTS).to(device)

In [14]:
# Why scaling?: 
# If you want to accumulate the loss over multiple batches,
# you need to ensure that each batch contributes the correct amount to the total loss,
# regardless of the batch size. Multiplying by images.size(0) gives the total loss for that batch instead of just the average.

In [15]:
from tqdm import tqdm  # Import tqdm for the progress bar

EPOCHS = 100
LEARNING_RATE = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.MSELoss()

# Use torch.inference_mode() for validation
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0

    with tqdm(dataloader, desc=f'Epoch {epoch + 1}/{EPOCHS}', unit='batch') as pbar:
        for batch in dataloader:
            images, keypoints = batch[0].to(device), batch[1].to(device)

            optimizer.zero_grad()  # Clear gradients
            outputs = model(images)
            loss = criterion(outputs, keypoints)
            
            loss.backward()  # Backpropagation
            optimizer.step()  # Update parameters
            
            batch_loss = loss.item() * images.size(0)  # Accumulate batch loss (scaled)
            running_loss += batch_loss
            
            pbar.update(1)
            pbar.set_postfix(batch_loss=batch_loss) 

    # Normalize and print epoch loss
    epoch_loss = running_loss / len(dataloader.dataset)
    print(f"Epoch [{epoch + 1}/{EPOCHS}], Epoch Loss: {epoch_loss:.4f}")

Epoch 1/100:   0%|          | 0/73 [00:00<?, ?batch/s]

Epoch 1/100: 100%|██████████| 73/73 [00:38<00:00,  1.92batch/s, batch_loss=0.32] 


Epoch 2/100: 100%|██████████| 73/73 [00:40<00:00,  1.79batch/s, batch_loss=0.091] 


Epoch 3/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.0424]


Epoch 4/100: 100%|██████████| 73/73 [00:40<00:00,  1.82batch/s, batch_loss=0.0364]


Epoch 5/100: 100%|██████████| 73/73 [00:40<00:00,  1.82batch/s, batch_loss=0.045] 


Epoch 6/100: 100%|██████████| 73/73 [00:40<00:00,  1.82batch/s, batch_loss=0.0536]


Epoch 7/100: 100%|██████████| 73/73 [00:40<00:00,  1.79batch/s, batch_loss=0.112] 


Epoch 8/100: 100%|██████████| 73/73 [00:40<00:00,  1.82batch/s, batch_loss=0.0506]


Epoch 9/100: 100%|██████████| 73/73 [00:40<00:00,  1.78batch/s, batch_loss=0.0383]


Epoch 10/100: 100%|██████████| 73/73 [00:40<00:00,  1.78batch/s, batch_loss=0.036] 


Epoch 11/100: 100%|██████████| 73/73 [00:41<00:00,  1.77batch/s, batch_loss=0.0472]


Epoch 12/100: 100%|██████████| 73/73 [00:41<00:00,  1.77batch/s, batch_loss=0.0288]


Epoch 13/100: 100%|██████████| 73/73 [00:40<00:00,  1.82batch/s, batch_loss=0.0235]


Epoch 14/100: 100%|██████████| 73/73 [00:40<00:00,  1.81batch/s, batch_loss=0.0165]


Epoch 15/100: 100%|██████████| 73/73 [00:39<00:00,  1.84batch/s, batch_loss=0.0498]


Epoch 16/100: 100%|██████████| 73/73 [00:40<00:00,  1.81batch/s, batch_loss=0.0128]


Epoch 17/100: 100%|██████████| 73/73 [00:40<00:00,  1.82batch/s, batch_loss=0.0268]


Epoch 18/100: 100%|██████████| 73/73 [00:40<00:00,  1.79batch/s, batch_loss=0.0461]


Epoch 19/100: 100%|██████████| 73/73 [00:41<00:00,  1.77batch/s, batch_loss=0.0149]


Epoch 20/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.0215]


Epoch 21/100: 100%|██████████| 73/73 [00:39<00:00,  1.86batch/s, batch_loss=0.0217]


Epoch 22/100: 100%|██████████| 73/73 [00:39<00:00,  1.86batch/s, batch_loss=0.0288]


Epoch 23/100: 100%|██████████| 73/73 [00:39<00:00,  1.85batch/s, batch_loss=0.00989]


Epoch 24/100: 100%|██████████| 73/73 [00:39<00:00,  1.84batch/s, batch_loss=0.0216] 


Epoch 25/100: 100%|██████████| 73/73 [00:39<00:00,  1.84batch/s, batch_loss=0.0179]


Epoch 26/100: 100%|██████████| 73/73 [00:40<00:00,  1.81batch/s, batch_loss=0.0145]


Epoch 27/100: 100%|██████████| 73/73 [00:39<00:00,  1.85batch/s, batch_loss=0.0528] 


Epoch 28/100: 100%|██████████| 73/73 [00:39<00:00,  1.85batch/s, batch_loss=0.0151] 


Epoch 29/100: 100%|██████████| 73/73 [00:40<00:00,  1.79batch/s, batch_loss=0.0109] 


Epoch 30/100: 100%|██████████| 73/73 [00:41<00:00,  1.78batch/s, batch_loss=0.012]  


Epoch 31/100: 100%|██████████| 73/73 [00:41<00:00,  1.77batch/s, batch_loss=0.0136] 


Epoch 32/100: 100%|██████████| 73/73 [00:41<00:00,  1.77batch/s, batch_loss=0.0166] 


Epoch 33/100: 100%|██████████| 73/73 [00:41<00:00,  1.76batch/s, batch_loss=0.01]   


Epoch 34/100: 100%|██████████| 73/73 [00:40<00:00,  1.81batch/s, batch_loss=0.011]  


Epoch 35/100: 100%|██████████| 73/73 [00:39<00:00,  1.85batch/s, batch_loss=0.0108] 


Epoch 36/100: 100%|██████████| 73/73 [00:39<00:00,  1.86batch/s, batch_loss=0.0204] 


Epoch 37/100: 100%|██████████| 73/73 [00:39<00:00,  1.85batch/s, batch_loss=0.0143] 


Epoch 38/100: 100%|██████████| 73/73 [00:39<00:00,  1.85batch/s, batch_loss=0.0075] 


Epoch 39/100: 100%|██████████| 73/73 [00:39<00:00,  1.85batch/s, batch_loss=0.00817]


Epoch 40/100: 100%|██████████| 73/73 [00:39<00:00,  1.85batch/s, batch_loss=0.00819]


Epoch 41/100: 100%|██████████| 73/73 [00:39<00:00,  1.85batch/s, batch_loss=0.00576]


Epoch 42/100: 100%|██████████| 73/73 [00:39<00:00,  1.84batch/s, batch_loss=0.00731]


Epoch 43/100: 100%|██████████| 73/73 [00:40<00:00,  1.81batch/s, batch_loss=0.0108] 


Epoch 44/100: 100%|██████████| 73/73 [00:40<00:00,  1.82batch/s, batch_loss=0.00869]


Epoch 45/100: 100%|██████████| 73/73 [00:37<00:00,  1.97batch/s, batch_loss=0.00859]


Epoch 46/100: 100%|██████████| 73/73 [00:39<00:00,  1.85batch/s, batch_loss=0.00632]


Epoch 47/100: 100%|██████████| 73/73 [00:40<00:00,  1.82batch/s, batch_loss=0.00708]


Epoch 48/100: 100%|██████████| 73/73 [00:37<00:00,  1.95batch/s, batch_loss=0.00732]


Epoch 49/100: 100%|██████████| 73/73 [00:40<00:00,  1.81batch/s, batch_loss=0.0077] 


Epoch 50/100: 100%|██████████| 73/73 [00:32<00:00,  2.24batch/s, batch_loss=0.00542]


Epoch 51/100: 100%|██████████| 73/73 [00:23<00:00,  3.08batch/s, batch_loss=0.00802]


Epoch 52/100: 100%|██████████| 73/73 [00:23<00:00,  3.06batch/s, batch_loss=0.00642]


Epoch 53/100: 100%|██████████| 73/73 [00:23<00:00,  3.07batch/s, batch_loss=0.00801]


Epoch 54/100: 100%|██████████| 73/73 [00:27<00:00,  2.70batch/s, batch_loss=0.0123] 


Epoch 55/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.0974]


Epoch 56/100: 100%|██████████| 73/73 [00:35<00:00,  2.03batch/s, batch_loss=0.142]


Epoch 57/100: 100%|██████████| 73/73 [00:40<00:00,  1.79batch/s, batch_loss=0.167] 


Epoch 58/100: 100%|██████████| 73/73 [00:40<00:00,  1.81batch/s, batch_loss=0.119] 


Epoch 59/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.0616]


Epoch 60/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.0673]


Epoch 61/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.054] 


Epoch 62/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.0616]


Epoch 63/100: 100%|██████████| 73/73 [00:40<00:00,  1.82batch/s, batch_loss=0.0302]


Epoch 64/100: 100%|██████████| 73/73 [00:40<00:00,  1.82batch/s, batch_loss=0.0532]


Epoch 65/100: 100%|██████████| 73/73 [00:40<00:00,  1.82batch/s, batch_loss=0.0375]


Epoch 66/100: 100%|██████████| 73/73 [00:40<00:00,  1.82batch/s, batch_loss=0.0361]


Epoch 67/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.0284]


Epoch 68/100: 100%|██████████| 73/73 [00:40<00:00,  1.82batch/s, batch_loss=0.0236]


Epoch 69/100: 100%|██████████| 73/73 [00:40<00:00,  1.82batch/s, batch_loss=0.0345]


Epoch 70/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.0331]


Epoch 71/100: 100%|██████████| 73/73 [00:40<00:00,  1.82batch/s, batch_loss=0.0342]


Epoch 72/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.0201]


Epoch 73/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.0169]


Epoch 74/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.0162]


Epoch 75/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.0217]


Epoch 76/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.0304] 


Epoch 77/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.0251]


Epoch 78/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.0349] 


Epoch 79/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.00753]


Epoch 80/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.0202] 


Epoch 81/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.0078] 


Epoch 82/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.0185] 


Epoch 83/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.026]  


Epoch 84/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.0252] 


Epoch 85/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.0278] 


Epoch 86/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.0185] 


Epoch 87/100: 100%|██████████| 73/73 [00:40<00:00,  1.82batch/s, batch_loss=0.0133] 


Epoch 88/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.0093] 


Epoch 89/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.0122] 


Epoch 90/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.0141] 


Epoch 91/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.0337] 


Epoch 92/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.0195] 


Epoch 93/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.0146] 


Epoch 94/100: 100%|██████████| 73/73 [00:40<00:00,  1.82batch/s, batch_loss=0.0124] 


Epoch 95/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.00882]


Epoch 96/100: 100%|██████████| 73/73 [00:40<00:00,  1.82batch/s, batch_loss=0.00479]


Epoch 97/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.00999]


Epoch 98/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.00807]


Epoch 99/100: 100%|██████████| 73/73 [00:40<00:00,  1.82batch/s, batch_loss=0.0118] 


Epoch 100/100: 100%|██████████| 73/73 [00:39<00:00,  1.83batch/s, batch_loss=0.00503]


In [16]:
model.eval()
with torch.inference_mode():

    for images, keypoints in dataloader:
        predictions = model(images.to(device))
        print(keypoints.shape)
        break # Single batch of dataset

kp_denormalize = DenormalizeKeypoints(IMG_HEIGHT, IMG_WIDTH)

for pred, exp in zip(predictions, keypoints):
    print(f"Prediction: {kp_denormalize(pred)}")
    print(f"Expected: {kp_denormalize(exp)}")
    print(f"{'-'* 100}")
    break # Single Dataset

In [17]:
from datetime import datetime, timezone

# Ensure the directory exists
os.makedirs(MODEL_PATH, exist_ok=True)

# Get the current UTC date and time in ISO format
current_time_utc = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M")

# Combine the directory and file name
# sp: snow_petrel, pe: pose_estimation
file_name = f"sp_pe_{ARCHITECTURE}_{current_time_utc}.pth"
full_path = os.path.join(MODEL_PATH, file_name)

# Save the model state_dict
torch.save(model.state_dict(), full_path)

print(f"Model saved as {full_path}")

In [64]:
def calculate_head_size(keypoints):
    """
    Calculate head size for a batch of flattened keypoints.

    Args:
        keypoints (torch.Tensor): A tensor of shape (batch_size, num_keypoints * 2),
                                  where each row contains flattened 2D coordinates.
                                  Keypoints are arranged as:
    Returns:
        torch.Tensor: A tensor of shape (batch_size,) containing the head size for each sample.
    """
    # Extract batch size and number of keypoints
    batch_size = keypoints.size(0)  # First dimension is the batch size
    num_keypoints = keypoints.size(1) // 2  # Number of keypoints

    # Reshape to (batch_size, num_keypoints, 2)
    keypoints = keypoints.view(batch_size, num_keypoints, 2)

    # Extract head and beak_tip keypoints
    head = keypoints[:, 0, :]       # Shape: (batch_size, 2)
    beak_tip = keypoints[:, 2, :]   # Shape: (batch_size, 2)

    # Calculate Euclidean distance between head and beak_tip
    head_size = torch.norm(head - beak_tip, p=2, dim=1)  # Shape: (batch_size,)

    return head_size


In [65]:
# When the threshold is 0.2 in the PCKh (Percentage of Correct Keypoints with Head Normalization) calculation,
# it means that a predicted keypoint is considered correct
# if the Euclidean distance between the predicted and ground truth keypoints is less than 20% of the head size.

def pckh(predictions, ground_truth, head_size, threshold=0.2):
    """
    Calculate PCKh (Percentage of Correct Keypoints with Head Normalization) for a batch of predictions.

    Args:
        predictions (Tensor): Predicted keypoints, shape (batch_size, num_keypoints * 2)
        ground_truth (Tensor): Ground truth keypoints, shape (batch_size, num_keypoints * 2)
        head_size (Tensor): Normalizing head size for each sample, shape (batch_size,)
        threshold (float): Normalized distance threshold (percentage of head size)

    Returns:
        float: PCKh metric as a percentage of correct keypoints
    """
    batch_size, num_flattened = predictions.size()
    num_keypoints = num_flattened // 2  # Derive number of keypoints
    
    # Reshape flattened predictions and ground truth to (batch_size, num_keypoints, 2)
    predictions = predictions.view(batch_size, num_keypoints, 2)
    ground_truth = ground_truth.view(batch_size, num_keypoints, 2)
    
    # Calculate Euclidean distance between predicted and ground truth keypoints
    distance = torch.norm(predictions - ground_truth, p=2, dim=2)  # shape: (batch_size, num_keypoints)
    
    # Normalize by head size for PCKh
    normalized_distance = distance / head_size.unsqueeze(1)  # shape: (batch_size, num_keypoints)
    
    # Calculate PCKh: Count keypoints that are within the threshold
    correct_keypoints = (normalized_distance < threshold).float()  # shape: (batch_size, num_keypoints)
    
    # Compute the percentage of correct keypoints
    pckh = correct_keypoints.sum() / (batch_size * num_keypoints) * 100
    
    return pckh.item()


def pe_accuracy(model, dataloader, device):
    """
    Calculate PCKh accuracy for the entire dataset.

    Args:
        model (nn.Module): The pose estimation model
        dataloader (DataLoader): DataLoader providing the dataset
        device (str): Device to run the model on (either 'cuda' or 'cpu')

    Returns:
        float: Average PCKh for the dataset
    """
    model.eval()  # Set model to evaluation mode
    total_pckh = 0.0
    total_samples = 0

    with torch.inference_mode():  # Disable gradient calculation for evaluation
        for images, keypoints in dataloader:
            images = images.to(device)
            keypoints = keypoints.to(device)
            head_sizes = calculate_head_size(keypoints).to(device)
            
            # Predict keypoints
            outputs = model(images)
            
            # Calculate PCKh for the current batch
            batch_pckh = pckh(outputs, keypoints, head_sizes)
            
            total_pckh += batch_pckh * images.size(0)
            total_samples += images.size(0)

    # Calculate the average PCKh for the dataset
    average_pckh = total_pckh / total_samples
    return average_pckh

In [None]:
# List all model files
model_files = [f for f in os.listdir(MODEL_PATH) if f.endswith('.pth')]

if model_files:
    # Sort the files by timestamp in descending order (newest first)
    latest_model_file = sorted(model_files, reverse=True)[0]

    # Get the full path of the latest model
    latest_model_path = os.path.join(MODEL_PATH, latest_model_file)

    # Load the model
    model = BirdPoseModel(NUM_KEYPOINTS)
    model.load_state_dict(torch.load(full_path, map_location=device, weights_only=True)) 
    model.to(device) 

    average_pckh = pe_accuracy(model, dataloader, device)
    print(f"Average PCKh: {average_pckh:.2f}%")
else:
    print("No model files found in the directory.")

In [None]:
# The Average PCKh (Percentage of Correct Keypoints with Head Normalization) being 94% means that on average,
# only 94% of the predicted keypoints are within the specified "threshold=0.2" (e.g., 20% of the head size) across the dataset.