Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Predict2D over batches #6

Open
timsainb opened this issue Mar 28, 2023 · 3 comments
Open

Predict2D over batches #6

timsainb opened this issue Mar 28, 2023 · 3 comments

Comments

@timsainb
Copy link

Hey Timo,

Do you have code to run predict2D over batches?

The only code I can find loops over frames.

https://github.com/JARVIS-MoCap/JARVIS-HybridNet/blob/master/jarvis/prediction/predict2D.py#L91

Thanks!

@timsainb
Copy link
Author

timsainb commented Mar 28, 2023

If there isn't already something I wrote this

def get_2D_keypoints_batch(frames):
    """
    Given a batch of images (frames), return the 2D keypoints and confidence scores for each image.
    Uses the jarvisPredictor object to perform keypoint detection on the images.

    Args:
        frames (numpy.ndarray): A batch of images with shape (batch_size, height, width, channels).

    Returns:
        tuple: A tuple containing the 2D keypoints and confidence scores for each image in the batch.
            The keypoints have shape (batch_size, num_keypoints, 2) and the confidence scores have shape
            (batch_size, num_keypoints).

    """
    # Convert the input frames to a PyTorch tensor and perform some pre-processing
    imgs = torch.from_numpy(frames).cuda().float().permute(0, 3, 1, 2) / 255.0

    # Get the size of the input images
    img_size = torch.tensor([imgs.shape[3], imgs.shape[2]], device=torch.device("cuda"))

    # Compute the downsampling scale for the center detection
    downsampling_scale = torch.tensor(
        [
            imgs.shape[3] / float(jarvisPredictor.center_detect_img_size),
            imgs.shape[2] / float(jarvisPredictor.center_detect_img_size),
        ],
        device=torch.device("cuda"),
    ).float()

    # Resize the input images to the size expected by the center detection network
    imgs_resized = transforms.functional.resize(
        imgs,
        [
            jarvisPredictor.center_detect_img_size,
            jarvisPredictor.center_detect_img_size,
        ],
    )

    # Normalize the resized images
    imgs_resized = (
        imgs_resized - jarvisPredictor.transform_mean
    ) / jarvisPredictor.transform_std

    # Run the center detection network on the resized images
    outputs = jarvisPredictor.centerDetect(imgs_resized)

    # Get the heatmaps from the center detection network and convert them to keypoints
    heatmaps_gpu = outputs[1].view(outputs[1].shape[0], outputs[1].shape[1], -1)
    m = heatmaps_gpu.argmax(2).view(heatmaps_gpu.shape[0], heatmaps_gpu.shape[1], 1)
    preds = torch.cat((m % outputs[1].shape[2], m // outputs[1].shape[3]), dim=2)
    maxvals = heatmaps_gpu.gather(2, m)
    num_cams_detect = torch.numel(maxvals[maxvals > 50])
    maxvals = maxvals / 255.0

    # Convert the keypoints from center detection to image coordinates
    centerHMs = (
        torch.cat((m % outputs[1].shape[2], m // outputs[1].shape[3]), dim=2).squeeze()
        * downsampling_scale
        * 2
    )
    centerHMs[:, 0] = torch.clamp(
        centerHMs[:, 0], jarvisPredictor.bbox_hw, img_size[0] - jarvisPredictor.bbox_hw
    )
    centerHMs[:, 1] = torch.clamp(
        centerHMs[:, 1], jarvisPredictor.bbox_hw, img_size[1] - jarvisPredictor.bbox_hw
    )

     # Crop the input images to the bounding boxes around the keypoints
    imgs_cropped = torch.zeros(
        (
            batch_size,
            3,
            jarvisPredictor.bounding_box_size,
            jarvisPredictor.bounding_box_size,
        ),
        device=torch.device("cuda"),
    )
    centerHMs = centerHMs.int().cpu().numpy()
    for i in range(batch_size):
        imgs_cropped[i] = imgs[
            i,
            :,
            centerHMs[i, 1]
            - jarvisPredictor.bbox_hw : centerHMs[i, 1]
            + jarvisPredictor.bbox_hw,
            centerHMs[i, 0]
            - jarvisPredictor.bbox_hw : centerHMs[i, 0]
            + jarvisPredictor.bbox_hw,
        ]

    # Normalize the cropped images
    imgs_cropped = (
        imgs_cropped - jarvisPredictor.transform_mean
    ) / jarvisPredictor.transform_std

    # Run the keypoint detection network on the cropped images
    outputs = jarvisPredictor.keypointDetect(imgs_cropped)

    # Get the heatmaps from the keypoint detection network and convert them to keypoints
    heatmaps = outputs[1].view(outputs[1].shape[0], outputs[1].shape[1], -1)
    m = heatmaps.argmax(2).view(heatmaps.shape[0], heatmaps.shape[1], 1)
    points2D = (
        torch.cat((m % outputs[1].shape[2], m // outputs[1].shape[3]), dim=2).squeeze()
        * 2
    )
    confidences = heatmaps.gather(2, m).squeeze()
    confidences = torch.clamp(confidences, max=255.0) / 255.0

    points2D = points2D.cpu().numpy() +np.expand_dims(centerHMs,1)-jarvisPredictor.bbox_hw
    
    # Convert the PyTorch tensors to numpy arrays and return as a tuple
    return points2D, confidences.cpu().numpy()

@JARVIS-MoCap
Copy link
Owner

JARVIS-MoCap commented Apr 4, 2023

Hi Tim, you're right there currently is no code to perform batched predictions, so thank you very much for sharing your implementation! I'll add your implementation to the repo if you don't mind. Alternatively you can also open a pull request, that way you get some credit for your contribution, just let me know what you prefer :)

@NirvikNU
Copy link

Hi,
I tested this batch processing for 2D video prediction and it generates the csv files correctly. However, the visualization module doesn't work since the info.yaml now contains only the path to the recordings instead of the recordings themselves. Can the visualize2D be modified to account for this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants