Detailed explanation:
https://medium.com/@alexppppp/how-to-train-a-custom-keypoint-detection-model-with-pytorch-d9af90e111da

GitHub repo:
https://github.com/alexppppp/keypoint_rcnn_training_pytorch


# CUHK Jockey Club AI for the Future Project

# Step 1:Set up Environment

Libraries and Packages:

In [None]:
import os
import json
import cv2
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.transforms import functional as F

import albumentations as A  # Library for augmentations

# https://github.com/pytorch/vision/tree/main/references/detection
import transforms, utils, engine, train
from utils import collate_fn
from engine import train_one_epoch, evaluate

In [None]:
def rename_and_copy_files(source_directory, train_directory, test_directory, train_count=15, test_count=5):
    """
    Rename and copy files from the source directory to train and test directories.

    Parameters:
    source_directory (str): Path to the source directory containing the raw photos.
    train_directory (str): Path to the directory where the training images will be copied.
    test_directory (str): Path to the directory where the testing images will be copied.
    train_count (int): Number of files to be used for training. Default is 15.
    test_count (int): Number of files to be used for testing. Default is 5.
    """
    
    # List all files in the source directory
    files = os.listdir(source_directory)
    
    # Check if there are enough files in the source directory
    if len(files) < train_count + test_count:
        raise ValueError("Not enough files in the source directory to allocate to train and test sets")
      
    # Select files for training and testing
    test_files = files[:test_count]
    train_files = files[test_count:train_count + test_count]
    
    # Clear all files in the train directory
    for filename in os.listdir(train_directory):
        file_path = os.path.join(train_directory, filename)
        if os.path.isfile(file_path):
            os.remove(file_path)
            print(f"Deleted: {file_path}")
    
    # Clear all files in the test directory
    for filename in os.listdir(test_directory):
        file_path = os.path.join(test_directory, filename)
        if os.path.isfile(file_path):
            os.remove(file_path)
            print(f"Deleted: {file_path}")
    
    # Rename and copy training files to the train directory
    for i, filename in enumerate(train_files):
        new_name = f"train_{i+1}{os.path.splitext(filename)[1]}"
        source_file = os.path.join(source_directory, filename)
        destination_file = os.path.join(train_directory, new_name)
        shutil.copy2(source_file, destination_file)
        print(f"Copied: {source_file} to {destination_file}")
    
    # Rename and copy testing files to the test directory
    for i, filename in enumerate(test_files):
        new_name = f"test_{i+1}{os.path.splitext(filename)[1]}"
        source_file = os.path.join(source_directory, filename)
        destination_file = os.path.join(test_directory, new_name)
        shutil.copy2(source_file, destination_file)
        print(f"Copied: {source_file} to {destination_file}")

def clear_directory(directory):
    """
    Delete all files and subdirectories in the specified directory, but keep the directory itself.

    Parameters:
    directory (str): Path to the directory to be cleared.
    """
    if os.path.exists(directory):
        for filename in os.listdir(directory):
            file_path = os.path.join(directory, filename)
            try:
                if os.path.isfile(file_path) or os.path.islink(file_path):
                    os.unlink(file_path)
                    print(f"Deleted file: {file_path}")
                elif os.path.isdir(file_path):
                    shutil.rmtree(file_path)
                    print(f"Deleted directory: {file_path}")
            except Exception as e:
                print(f"Error deleting {file_path}: {e}")
    else:
        print(f"Directory does not exist: {directory}")

        
def train_transform():
    return A.Compose([
        A.Sequential([
            A.RandomRotate90(p=1), # Random rotation of an image by 90 degrees zero or more times
            A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, brightness_by_max=True, always_apply=False, p=1), # Random change of brightness & contrast
        ], p=1)
    ],
    keypoint_params=A.KeypointParams(format='xy'), # More about keypoint formats used in albumentations library read at https://albumentations.ai/docs/getting_started/keypoints_augmentation/
    bbox_params=A.BboxParams(format='pascal_voc', label_fields=['bboxes_labels']) # Bboxes should have labels, read more here https://albumentations.ai/docs/getting_started/bounding_boxes_augmentation/
    )

In [None]:

!git clone https://github.com/Camellia-k27/Simple_Keypoint_Detect.git
%cd Simple_Keypoint_Detect

# Step 2: Upload images of objects for detection as training data (Optional)



In [None]:
# Specify the paths to the source and destination directories
source_directory = "my_set/raw_photos"
train_directory = "my_set/train/images"
test_directory = "my_set/test/images"

# Call the function to rename and copy files
rename_and_copy_files(source_directory, train_directory, test_directory)

# Create zip archives of the train and test directories
shutil.make_archive('train', 'zip', train_directory)
shutil.make_archive('test', 'zip', test_directory)


# Step 3: Delete the original train and test datasets (Optional)


In [None]:
# Specify the paths to the directories to be cleared
train_labels_directory = "my_set/train/labels"
test_labels_directory = "my_set/test/labels"

# Call the function to clear the directories
clear_directory(train_labels_directory)
clear_directory(test_labels_directory)


# Step 4: Hand-Label Training Data



In [None]:
from google.colab import output
output.serve_kernel_port_as_window(5000)
!python /content/Simple_Keypoint_Detect/labeller/labeller.py

# Step 5: Categorize and Process Dataset


In [None]:
class ClassDataset(Dataset):
    def __init__(self, root, transform=None, demo=False):
        self.root = root
        self.transform = transform
        self.demo = demo # Use demo=True if you need transformed and original images (for example, for visualization purposes)
        self.imgs_files = sorted(os.listdir(os.path.join(root, "images")))
        self.annotations_files = sorted(os.listdir(os.path.join(root, "labels")))

    def __getitem__(self, idx):
        img_path = os.path.join(self.root, "images", self.imgs_files[idx])
        annotations_path = os.path.join(self.root, "labels", self.annotations_files[idx])

        img_original = cv2.imread(img_path)
        img_original = cv2.cvtColor(img_original, cv2.COLOR_BGR2RGB)

        with open(annotations_path) as f:
            data = json.load(f)
            bboxes_original = data['bboxes']
            keypoints_original = data['keypoints']
            #check if the bboxes is out of the image
            for i in range(len(bboxes_original)):
                bboxes_original[i][0] = max(0, bboxes_original[i][0])
                bboxes_original[i][1] = max(0, bboxes_original[i][1])
                bboxes_original[i][2] = min(img_original.shape[1], bboxes_original[i][2])
                bboxes_original[i][3] = min(img_original.shape[0], bboxes_original[i][3])
            # All objects are glue tubes
            bboxes_labels_original = ['Glue tube' for _ in bboxes_original]

        if self.transform:
            # Converting keypoints from [x,y,visibility]-format to [x, y]-format + Flattening nested list of keypoints
            # For example, if we have the following list of keypoints for three objects (each object has two keypoints):
            # [[obj1_kp1, obj1_kp2], [obj2_kp1, obj2_kp2], [obj3_kp1, obj3_kp2]], where each keypoint is in [x, y]-format
            # Then we need to convert it to the following list:
            # [obj1_kp1, obj1_kp2, obj2_kp1, obj2_kp2, obj3_kp1, obj3_kp2]
            keypoints_original_flattened = [el[0:2] for kp in keypoints_original for el in kp]

            # Apply augmentations
            transformed = self.transform(image=img_original, bboxes=bboxes_original, bboxes_labels=bboxes_labels_original, keypoints=keypoints_original_flattened)
            img = transformed['image']
            bboxes = transformed['bboxes']

            # Unflattening list transformed['keypoints']
            # For example, if we have the following list of keypoints for three objects (each object has two keypoints):
            # [obj1_kp1, obj1_kp2, obj2_kp1, obj2_kp2, obj3_kp1, obj3_kp2], where each keypoint is in [x, y]-format
            # Then we need to convert it to the following list:
            # [[obj1_kp1, obj1_kp2], [obj2_kp1, obj2_kp2], [obj3_kp1, obj3_kp2]]
            keypoints_transformed_unflattened = np.reshape(np.array(transformed['keypoints']), (-1,2,2)).tolist()

            # Converting transformed keypoints from [x, y]-format to [x,y,visibility]-format by appending original visibilities to transformed coordinates of keypoints
            keypoints = []
            for o_idx, obj in enumerate(keypoints_transformed_unflattened): # Iterating over objects
                obj_keypoints = []
                for k_idx, kp in enumerate(obj): # Iterating over keypoints in each object
                    # kp - coordinates of keypoint
                    # keypoints_original[o_idx][k_idx][2] - original visibility of keypoint
                    obj_keypoints.append(kp + [keypoints_original[o_idx][k_idx][2]])
                keypoints.append(obj_keypoints)

        else:
            img, bboxes, keypoints = img_original, bboxes_original, keypoints_original

        # Convert everything into a torch tensor
        bboxes = torch.as_tensor(bboxes, dtype=torch.float32)
        target = {}
        target["boxes"] = bboxes
        target["labels"] = torch.as_tensor([1 for _ in bboxes], dtype=torch.int64) # all objects are glue tubes
        target["image_id"] = torch.tensor([idx])
        target["area"] = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
        target["iscrowd"] = torch.zeros(len(bboxes), dtype=torch.int64)
        target["keypoints"] = torch.as_tensor(keypoints, dtype=torch.float32)
        img = F.to_tensor(img)

        bboxes_original = torch.as_tensor(bboxes_original, dtype=torch.float32)
        target_original = {}
        target_original["boxes"] = bboxes_original
        target_original["labels"] = torch.as_tensor([1 for _ in bboxes_original], dtype=torch.int64) # all objects are glue tubes
        target_original["image_id"] = torch.tensor([idx])
        target_original["area"] = (bboxes_original[:, 3] - bboxes_original[:, 1]) * (bboxes_original[:, 2] - bboxes_original[:, 0])
        target_original["iscrowd"] = torch.zeros(len(bboxes_original), dtype=torch.int64)
        target_original["keypoints"] = torch.as_tensor(keypoints_original, dtype=torch.float32)
        img_original = F.to_tensor(img_original)

        if self.demo:
            return img, target, img_original, target_original
        else:
            return img, target

    def __len__(self):
        return len(self.imgs_files)

# Step 6: Visualize Random Labeled Sample 



In [None]:
KEYPOINTS_FOLDER_TRAIN = 'my_set/test'
#KEYPOINTS_FOLDER_TRAIN = 'my_set/train_blocked'
#KEYPOINTS_FOLDER_TRAIN = 'my_set/test_blurred'
dataset = ClassDataset(KEYPOINTS_FOLDER_TRAIN, transform=train_transform(), demo=True)
data_loader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)

iterator = iter(data_loader)
batch = next(iterator)

print("Original targets:\n", batch[3], "\n\n")
print("Transformed targets:\n", batch[1])

keypoints_classes_ids2names = {0: 'mouse', 1: 'pen'} 
# keypoints_classes_ids2names = {0: 'cat', 1: 'dog'} according to the dataset

def visualize(image, bboxes, keypoints, image_original=None, bboxes_original=None, keypoints_original=None):
    fontsize = 13

    for bbox in bboxes:
        start_point = (bbox[0], bbox[1])
        end_point = (bbox[2], bbox[3])
        image = cv2.rectangle(image.copy(), start_point, end_point, (0,255,0), 1)

    for kps in keypoints:
        n=0
        for idx, kp in enumerate(kps):
          if n ==0:
            image = cv2.circle(image.copy(), tuple(kp), 2, (255,0,0), 2)
            image = cv2.putText(image.copy(), " " + keypoints_classes_ids2names[idx], tuple(kp), cv2.FONT_HERSHEY_PLAIN, 1, (255,0,0), 1, cv2.LINE_AA)
          else:
            image = cv2.circle(image.copy(), tuple(kp), 2, (0,0,255), 2)
            image = cv2.putText(image.copy(), " " + keypoints_classes_ids2names[idx], tuple(kp), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,255), 1, cv2.LINE_AA)
          n+=1

    if image_original is None and keypoints_original is None:
        plt.figure(figsize=(10,10))
        plt.imshow(image)

    else:
        for bbox in bboxes_original:
            start_point = (bbox[0], bbox[1])
            end_point = (bbox[2], bbox[3])
            image_original = cv2.rectangle(image_original.copy(), start_point, end_point, (0,255,0), 2)

        for kps in keypoints_original:
            n = 0
            for idx, kp in enumerate(kps):
              if n == 0:
                image_original = cv2.circle(image_original, tuple(kp), 2, (255,0,0), 2)
                image_original = cv2.putText(image_original, " " + keypoints_classes_ids2names[idx], tuple(kp), cv2.FONT_HERSHEY_PLAIN, 1, (255,0,0), 1, cv2.LINE_AA)
              else:
                image_original = cv2.circle(image_original, tuple(kp), 2, (0,0,255), 2)
                image_original = cv2.putText(image_original, " " + keypoints_classes_ids2names[idx], tuple(kp), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,255), 1, cv2.LINE_AA)
              n+=1

        f, ax = plt.subplots(1, 2, figsize=(40, 20))

        ax[0].imshow(image_original)
        ax[0].set_title('Original image', fontsize=fontsize)

        ax[1].imshow(image)
        ax[1].set_title('Transformed image', fontsize=fontsize)

image = (batch[0][0].permute(1,2,0).numpy() * 255).astype(np.uint8)
bboxes = batch[1][0]['boxes'].detach().cpu().numpy().astype(np.int32).tolist()

keypoints = []
for kps in batch[1][0]['keypoints'].detach().cpu().numpy().astype(np.int32).tolist():
    keypoints.append([kp[:2] for kp in kps])

image_original = (batch[2][0].permute(1,2,0).numpy() * 255).astype(np.uint8)
bboxes_original = batch[3][0]['boxes'].detach().cpu().numpy().astype(np.int32).tolist()

keypoints_original = []
for kps in batch[3][0]['keypoints'].detach().cpu().numpy().astype(np.int32).tolist():
    keypoints_original.append([kp[:2] for kp in kps])

visualize(image, bboxes, keypoints, image_original, bboxes_original, keypoints_original)

# Step 7: Fine-Tune Model


In [None]:
def get_model(num_keypoints, weights_path=None):

    anchor_generator = AnchorGenerator(sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0))
    model = torchvision.models.detection.keypointrcnn_resnet50_fpn(pretrained=False,
                                                                   pretrained_backbone=True,
                                                                   num_keypoints=num_keypoints,
                                                                   num_classes = 2, # Background is the first class, object is the second class
                                                                   rpn_anchor_generator=anchor_generator)

    if weights_path:
        state_dict = torch.load(weights_path)
        model.load_state_dict(state_dict)

    return model


##########cells


device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

KEYPOINTS_FOLDER_TRAIN = 'my_set/train'
 # Change the dataset to others (train_blocked/ train_blurred) to see the changes!
KEYPOINTS_FOLDER_TEST = 'my_set/test'
 # Change the dataset to others (test_blocked/ test_blurred) to see the changes!
dataset_train = ClassDataset(KEYPOINTS_FOLDER_TRAIN, transform=train_transform(), demo=False)
dataset_test = ClassDataset(KEYPOINTS_FOLDER_TEST, transform=None, demo=False)

data_loader_train = DataLoader(dataset_train, batch_size=1, shuffle=True, collate_fn=collate_fn)
data_loader_test = DataLoader(dataset_test, batch_size=1, shuffle=False, collate_fn=collate_fn)

model = get_model(num_keypoints = 2)

model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(params, lr=0.0001)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.3)
num_epochs = 50

for epoch in range(num_epochs):
    train_one_epoch(model, optimizer, data_loader_train, device, epoch, print_freq=1000)
    lr_scheduler.step()
    evaluate(model, data_loader_test, device)

# Save model weights after training
torch.save(model.state_dict(), 'keypointsrcnn_weights.pth')

# Step 8: Download Pre-trained Model(Optional)


If you have problem with fine-tuning the model, you can download an existing model and still make a prediction

In [None]:
!wget https://huggingface.co/XBLUECATX/KeypointRCNN_Pretrained/resolve/main/pretrained_weights.pth

def load_model(num_keypoints, weights_path=None):

    anchor_generator = AnchorGenerator(sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0))
    model = torchvision.models.detection.keypointrcnn_resnet50_fpn(
                                      pretrained=False,
                                      pretrained_backbone=True,
                                      num_keypoints=num_keypoints,
                                      num_classes = 2, # Background is the first class, object is the second class
                                      rpn_anchor_generator=anchor_generator)

    if weights_path:
        state_dict = torch.load(weights_path,map_location='cpu')
        model.load_state_dict(state_dict)

    return model

model = load_model(num_keypoints = 2,weights_path='pretrained_weights.pth')




# Step 9: Visualize Model Predictions


In [None]:
KEYPOINTS_FOLDER_TEST = 'my_set/test'
 # Change the dataset to others (test_blocked/ test_blurred) to see the changes!
dataset_test = ClassDataset(KEYPOINTS_FOLDER_TEST, transform=None, demo=False)
data_loader_test = DataLoader(dataset_test, batch_size=1, shuffle=False, collate_fn=collate_fn)


iterator = iter(data_loader_test)

In [None]:
images, targets = next(iterator)
images = list(image.to(device) for image in images)

with torch.no_grad():
    model.to(device)
    model.eval()
    output = model(images)

#print("Predictions: \n", output)

############cells

image = (images[0].permute(1,2,0).detach().cpu().numpy() * 255).astype(np.uint8)
scores = output[0]['scores'].detach().cpu().numpy()
print(scores)

high_scores_idxs = np.where(scores > 0.5)[0].tolist() # Indexes of boxes with scores > 0.5
post_nms_idxs = torchvision.ops.nms(output[0]['boxes'][high_scores_idxs], output[0]['scores'][high_scores_idxs], 0.3).cpu().numpy() # Indexes of boxes left after applying NMS (iou_threshold=0.3)

# Below, in output[0]['keypoints'][high_scores_idxs][post_nms_idxs] and output[0]['boxes'][high_scores_idxs][post_nms_idxs]
# Firstly, we choose only those objects, which have score above predefined threshold. This is done with choosing elements with [high_scores_idxs] indexes
# Secondly, we choose only those objects, which are left after NMS is applied. This is done with choosing elements with [post_nms_idxs] indexes

keypoints = []
for kps in output[0]['keypoints'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy():
    keypoints.append([list(map(int, kp[:2])) for kp in kps])
    if len(keypoints) != 0: #shows only best option
      break

bboxes = []
for bbox in output[0]['boxes'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy():
    bboxes.append(list(map(int, bbox.tolist())))
    if len(bboxes) != 0: #shows only best option
      break

print(bboxes)

visualize(image, bboxes, keypoints)