#### PyTorch Dataset

   A dataset that reads coco_like JSON or reads per-image files directly and returns image and target. Use torchvision transforms for training.

In [1]:
# dataset_saffron.py
import torch
from torch.utils.data import Dataset
from PIL import Image
import json
import os

class SaffronKeypointDataset(Dataset):
    def __init__(self, images_dir, annotation_json, transforms=None):
        """
        annotation_json: path to the coco_like list saved earlier
        """
        with open(annotation_json, "r") as f:
            self.records = json.load(f)
        self.images_dir = images_dir
        self.transforms = transforms

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        rec = self.records[idx]
        img_path = os.path.join(self.images_dir, rec['file_name'])
        img = Image.open(img_path).convert("RGB")
        w, h = img.size

        boxes = torch.as_tensor(rec['boxes'], dtype=torch.float32)
        labels = torch.as_tensor(rec['labels'], dtype=torch.int64)

        # keypoints expected shape (N, K, 3) -> K=1
        kps = rec.get('keypoints', [])
        if len(kps) == 0:
            keypoints = torch.zeros((boxes.shape[0], 1, 3), dtype=torch.float32)
        else:
            # convert list [x,y,v] to (N,1,3)
            kps_tensor = torch.tensor(kps, dtype=torch.float32)
            keypoints = kps_tensor.view(-1, 1, 3)

        image_id = torch.tensor([rec['image_id']])
        area = torch.as_tensor(rec.get('area', [ (b[2]-b[0])*(b[3]-b[1]) for b in rec['boxes'] ]), dtype=torch.float32)
        iscrowd = torch.as_tensor(rec.get('iscrowd', [0]*len(rec['boxes'])), dtype=torch.int64)

        target = {
            "boxes": boxes,
            "labels": labels,
            "keypoints": keypoints,
            "image_id": image_id,
            "area": area,
            "iscrowd": iscrowd
        }

        if self.transforms:
            img, target = self.transforms(img, target)

        return img, target

dataset = SaffronKeypointDataset(
    images_dir="Dataset_txt/image",
    annotation_json="annotations_coco_like.json"
)


#### Transforms (using torchvision utilities)

    Keypoint models require transforms that also transform keypoints. Here's a simple set modeled after torchvision references/detection transforms:

In [2]:
# transforms.py
import torchvision
from PIL import Image
import random
import torch
import torchvision.transforms.functional as F

class Compose:
    def __init__(self, transforms):
        self.transforms = transforms
    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

class ToTensor:
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target

class RandomHorizontalFlip:
    def __init__(self, prob=0.5):
        self.prob = prob
    def __call__(self, image, target):
        if random.random() < self.prob:
            image = F.hflip(image)
            w, h = image.shape[2], image.shape[1]
            # flip boxes
            boxes = target['boxes']
            boxes = boxes.clone()
            boxes[:, [0,2]] = w - boxes[:, [2,0]]
            target['boxes'] = boxes
            # flip keypoints: x -> w - x if visible
            kps = target['keypoints']
            if kps is not None:
                kps = kps.clone()
                kps[:, :, 0] = w - kps[:, :, 0]
                target['keypoints'] = kps
        return image, target

def get_train_transforms():
    return Compose([
        ToTensor(),
        RandomHorizontalFlip(0.5),
    ])

def get_test_transforms():
    return Compose([
        ToTensor(),
    ])


#### Model definition (Keypoint R-CNN)

    Use torchvision's keypointrcnn_resnet50_fpn. Important: set num_keypoints=1 (you have a single plucking point). The num_classes is num_background + classes → if only 1 foreground class then num_classes=2.

In [3]:
# model.py
import torchvision
from torchvision.models.detection.keypoint_rcnn import KeypointRCNN
from torchvision.models.detection import keypointrcnn_resnet50_fpn

def get_model(num_classes=2, num_keypoints=1, pretrained_backbone=True):
    # load model with pre-trained weights on COCO backbone
    model = keypointrcnn_resnet50_fpn(pretrained=True, progress=True,
                                      num_classes=num_classes,
                                      pretrained_backbone=pretrained_backbone,
                                      num_keypoints=num_keypoints)
    return model


#### Training loop
   A typical training loop adapted from torchvision references. Save checkpoints.

In [4]:
# train.py
import torch
from torch.utils.data import DataLoader, random_split
import torch.optim as optim
from dataset_saffron import SaffronKeypointDataset
from transforms import get_train_transforms, get_test_transforms
from model import get_model
import utils  # helper functions from torchvision references or simple collate

def collate_fn(batch):
    return tuple(zip(*batch))

def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=50):
    model.train()
    for i, (images, targets) in enumerate(data_loader):
        images = list(img.to(device) for img in images)
        targets = [{k: v.to(device) for k,v in t.items()} for t in targets]
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        if i % print_freq == 0:
            print(f"Epoch {epoch} Iter {i} Loss {losses.item():.4f}")

def main():
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    images_dir = "Dataset_txt/image"
    ann_json = "annotations_coco_like.json"

    dataset_full = SaffronKeypointDataset(images_dir, ann_json, transforms=get_train_transforms())
    # split
    n = len(dataset_full)
    val_size = int(n*0.2)
    train_size = n - val_size
    train_dataset, val_dataset = random_split(dataset_full, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=2, collate_fn=collate_fn)

    model = get_model(num_classes=2, num_keypoints=1)
    model.to(device)

    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

    num_epochs = 20
    for epoch in range(num_epochs):
        train_one_epoch(model, optimizer, train_loader, device, epoch, print_freq=20)
        lr_scheduler.step()

        # save checkpoint
        torch.save(model.state_dict(), f"checkpoint_epoch_{epoch}.pth")

    # Save final
    torch.save(model.state_dict(), "keypointrcnn_saffron_final.pth")

if __name__ == "__main__":
    main()


ModuleNotFoundError: No module named 'dataset_saffron'

In [None]:
from torch.utils.data import DataLoader
import torch

def collate_fn(batch):
    return tuple(zip(*batch))

dataset = SaffronKeypointDataset("Dataset_txt/image", "annotations_coco_like.json")
data_loader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = get_model().to(device)

optimizer = torch.optim.SGD([p for p in model.parameters() if p.requires_grad],
                            lr=0.005, momentum=0.9, weight_decay=0.0005)

for epoch in range(10):
    model.train()
    for images, targets in data_loader:
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k,v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        loss = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")


#### Evaluation

    For full detection+keypoint evaluation use COCO keypoints metrics (AP for detection and keypoint AP). Pycocotools and proper annotation format are required. Alternatively compute simpler metrics:

    Detection: mAP via torchvision.ops.boxes / external evaluator.

    Keypoint correctness: PCK (Percentage of Correct Keypoints) — consider a keypoint correct if distance between predicted and GT is <= alpha * max(box_width, box_height) (alpha e.g. 0.1).

In [None]:
# evaluation_utils.py
import numpy as np

def pck_single(gt_kp, pred_kp, bbox, alpha=0.1):
    # gt_kp, pred_kp: [x,y,v]
    if gt_kp[2] == 0:
        return None  # not annotated
    gt = np.array(gt_kp[:2])
    pred = np.array(pred_kp[:2])
    w = bbox[2] - bbox[0]
    h = bbox[3] - bbox[1]
    thresh = alpha * max(w, h)
    dist = np.linalg.norm(pred - gt)
    return dist <= thresh


#### Inference & visualization

    Load model, run on an image, filter by score, and draw box + keypoint.

In [None]:
# inference.py
import torch
from PIL import Image, ImageDraw
import torchvision.transforms.functional as F
from model import get_model

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = get_model(num_classes=2, num_keypoints=1)
model.load_state_dict(torch.load("keypointrcnn_saffron_final.pth", map_location=device))
model.to(device)
model.eval()

def predict_and_visualize(image_path, score_thresh=0.6, save_path=None):
    img = Image.open(image_path).convert("RGB")
    img_tensor = F.to_tensor(img).to(device)
    with torch.no_grad():
        outputs = model([img_tensor])
    out = outputs[0]
    boxes = out['boxes'].cpu()
    scores = out['scores'].cpu()
    keypoints = out['keypoints'].cpu()  # shape (num_dets, num_kpts, 3)
    draw = ImageDraw.Draw(img)

    for i, s in enumerate(scores):
        if s < score_thresh:
            continue
        box = boxes[i].numpy().tolist()
        draw.rectangle(box, outline="red", width=2)
        kpt = keypoints[i][0]  # only one keypoint
        x, y, v = float(kpt[0]), float(kpt[1]), float(kpt[2])
        if v > 0:
            r = 3
            draw.ellipse((x-r, y-r, x+r, y+r), fill="blue")

    if save_path:
        img.save(save_path)
    return img

if __name__ == "__main__":
    img = predict_and_visualize("images/example.jpg", save_path="out_example.jpg")
    img.show()
