In [34]:
import os
import time
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
from os import listdir
import pandas as pd
import numpy as np
import glob
import cv2
import json
from os.path import expanduser
import splitfolders
import shutil
from define_path import Def_Path
from datetime import datetime

from tqdm import tqdm

import torch 
import torchvision
from torchvision import models
from torchvision.models.detection.rpn import AnchorGenerator
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn 
import torchvision.transforms as T
from torchvision.transforms import functional as F
from torchsummary import summary
from sklearn.model_selection import train_test_split

import albumentations as A # Library for augmentations

import matplotlib.pyplot as plt 
from PIL import Image

import transforms, utils, engine, train
from utils import collate_fn
from engine import train_one_epoch, evaluate

t = torch.cuda.get_device_properties(0).total_memory
print(t)
torch.cuda.empty_cache()

r = torch.cuda.memory_reserved(0)
print(r)
a = torch.cuda.memory_allocated(0)
print(a)
# f = r-a  # free inside reserved

weights_path = '/home/jc-merlab/Pictures/Data/trained_models/keypointsrcnn_weights_sim_b1_e25_v0.pth'

16908615680
354418688
257395712


In [35]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# torch.cuda.set_per_process_memory_fraction(0.9, 0)
print(device)

cuda


In [36]:
import torch
import torch.nn as nn
import torch.nn.functional as func
import math
import numpy as np
from torch.autograd import Variable
import torch_geometric.nn as pyg
from torch_geometric.data import Data

_EPS = 1e-10

class MLP(nn.Module):
    """Two-layer fully-connected ELU net with batch norm."""

    def __init__(self, n_in, n_hid, n_out, do_prob=0.):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(n_in, n_hid)
        self.fc2 = nn.Linear(n_hid, n_out)
        self.bn = nn.BatchNorm1d(n_out)
        self.dropout_prob = do_prob

        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data)
                m.bias.data.fill_(0.1)
            elif isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def batch_norm(self, inputs):
        x = inputs.view(inputs.size(0) * inputs.size(1), -1)
        x = self.bn(x)
        return x.view(inputs.size(0), inputs.size(1), -1)

    def forward(self, inputs):
#         print("Input shape before any operations: ", inputs.shape)

        # Flatten the last two dimensions for the linear layer input
#         x = inputs.view(inputs.size(0), -1)
        x = func.elu(self.fc1(inputs))
        x = func.dropout(x, self.dropout_prob, training=self.training)
        x = func.elu(self.fc2(x))
        
        return self.batch_norm(x)

        # Assuming you want to maintain the second dimension for some reason
        # (like temporal sequence in a RNN), you would reshape the output
        # back to the desired shape. If not, this step is unnecessary.
        # output = x.view(inputs.size(0), inputs.size(1), -1)
        # print("Output shape after forward pass: ", output.shape)

#         return x


class GraphEncoder(nn.Module):
    def __init__(self, n_in, n_hid, n_out=4, do_prob=0., factor=True):
        super(GraphEncoder, self).__init__()

        self.factor = factor

        self.mlp1 = MLP(n_in, n_hid, n_hid, do_prob)
        self.mlp2 = MLP(n_hid * 2, n_hid, n_hid, do_prob)
        self.mlp3 = MLP(n_hid, n_hid, n_hid, do_prob)
        if self.factor:
            self.mlp4 = MLP(n_hid * 3, n_hid, n_hid, do_prob)
            print("Using factor graph MLP encoder.")
        else:
            self.mlp4 = MLP(n_hid * 2, n_hid, n_hid, do_prob)
            print("mlp4", self.mlp4)
            print("Using MLP graph encoder.")
        self.fc_out = nn.Linear(n_hid, n_out)
        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data)
                m.bias.data.fill_(0.1)

    def edge2node(self, x, rel_rec, rel_send):
        # NOTE: Assumes that we have the same graph across all samples.
        incoming = torch.matmul(rel_rec.t(), x)
        return incoming / incoming.size(1)

    def node2edge(self, x, rel_rec, rel_send):
        # NOTE: Assumes that we have the same graph across all samples.
        receivers = torch.matmul(rel_rec, x)
        senders = torch.matmul(rel_send, x)
        edges = torch.cat([receivers, senders], dim=2)
        return edges

    def forward(self, inputs, rel_rec, rel_send):
        # Input shape: [num_sims, num_atoms, num_timesteps, num_dims]
        x = inputs.view(inputs.size(0), inputs.size(1), -1)
#         print("x shape:", x.shape)
#         print("rel_rec shape:", rel_rec.shape)
#         print("rel_send shape:", rel_send.shape)

        # New shape: [num_sims, num_atoms, num_timesteps*num_dims]
        x = self.mlp1(x)  # 2-layer ELU net per node

        x = self.node2edge(x, rel_rec, rel_send)
        x = self.mlp2(x)
        x_skip = x    
        
        if self.factor:
            x = self.edge2node(x, rel_rec, rel_send)
            x = self.mlp3(x)
            x = self.node2edge(x, rel_rec, rel_send)
            x = torch.cat((x, x_skip), dim=2)  # Skip connection
            x = self.mlp4(x)
        else:
            x = self.mlp3(x)
            x = torch.cat((x, x_skip), dim=2)  # Skip connection
            x = self.mlp4(x)

        return self.fc_out(x)    
    
class GraphDecoder(nn.Module):

    def __init__(self, n_in_node, edge_types, msg_hid, msg_out, n_hid,
                 do_prob=0., skip_first=False):
        super(GraphDecoder, self).__init__()
        self.msg_fc1 = nn.ModuleList(
            [nn.Linear(2 * n_in_node, msg_hid) for _ in range(edge_types)])
        self.msg_fc2 = nn.ModuleList(
            [nn.Linear(msg_hid, msg_out) for _ in range(edge_types)])
        self.msg_out_shape = msg_out
        self.skip_first_edge_type = skip_first

        self.out_fc1 = nn.Linear(n_in_node + msg_out, n_hid)
        self.out_fc2 = nn.Linear(n_hid, n_hid)
        self.out_fc3 = nn.Linear(n_hid, n_in_node)

        print('Using learned graph decoder.')

        self.dropout_prob = do_prob

    def single_step_forward(self, single_timestep_inputs, rel_rec, rel_send,
                            single_timestep_rel_type):

        # single_timestep_inputs has shape
        # [batch_size, num_timesteps, num_atoms, num_dims]

        # single_timestep_rel_type has shape:
        # [batch_size, num_timesteps, num_atoms*(num_atoms-1), num_edge_types]

        # Node2edge
        receivers = torch.matmul(rel_rec, single_timestep_inputs)
        senders = torch.matmul(rel_send, single_timestep_inputs)
        pre_msg = torch.cat([receivers, senders], dim=-1)

        all_msgs = Variable(torch.zeros(pre_msg.size(0), pre_msg.size(1),self.msg_out_shape))
        if single_timestep_inputs.is_cuda:
            all_msgs = all_msgs.cuda()

        if self.skip_first_edge_type:
            start_idx = 1
        else:
            start_idx = 0

        # Run separate MLP for every edge type
        # NOTE: To exlude one edge type, simply offset range by 1
        for i in range(start_idx, len(self.msg_fc2)):
            msg = func.relu(self.msg_fc1[i](pre_msg))
            msg = func.dropout(msg, p=self.dropout_prob)
            msg = func.relu(self.msg_fc2[i](msg))
            msg = msg * single_timestep_rel_type[:, :, i:i + 1]
            all_msgs += msg

        # Aggregate all msgs to receiver
        agg_msgs = all_msgs.transpose(-2, -1).matmul(rel_rec).transpose(-2, -1)
        agg_msgs = agg_msgs.contiguous()

        # Skip connection
        aug_inputs = torch.cat([single_timestep_inputs, agg_msgs], dim=-1)

        # Output MLP
        pred = func.dropout(func.relu(self.out_fc1(aug_inputs)), p=self.dropout_prob)
        pred = func.dropout(func.relu(self.out_fc2(pred)), p=self.dropout_prob)
        pred = self.out_fc3(pred)
#        print(pred.shape,single_timestep_inputs.shape)

        # Predict position/velocity difference
        return single_timestep_inputs + pred

    def forward(self, inputs, rel_type, rel_rec, rel_send, pred_steps=1):
        # NOTE: Assumes that we have the same graph across all samples.


        # Only take n-th timesteps as starting points (n: pred_steps)
        last_pred = inputs[:, :, :]
        #asa
        curr_rel_type = rel_type[:, :, :]
        preds=[]
        #print(curr_rel_type.shape)
        # NOTE: Assumes rel_type is constant (i.e. same across all time steps).

        # Run n prediction steps
        #for step in range(0, pred_steps):
        last_pred = self.single_step_forward(last_pred, rel_rec, rel_send,
                                                 curr_rel_type)
        preds.append(last_pred)

        sizes = [preds[0].size(0), preds[0].size(1),
                 preds[0].size(2)]

        output = Variable(torch.zeros(sizes))
        if inputs.is_cuda:
            output = output.cuda()

        # Re-assemble correct timeline
        for i in range(len(preds)):
            output[:, :, :] = preds[i]

        pred_all = output[:, :, :]

        # NOTE: We potentially over-predicted (stored in future_pred). Unused.
        # future_pred = output[:, (inputs.size(1) - 1):, :, :]

        return pred_all#.transpose(1, 2).contiguous()   

In [37]:
def my_softmax(input, axis=1):
    trans_input = input.transpose(axis, 0).contiguous()
    soft_max_1d = func.softmax(trans_input,dim=0)
    return soft_max_1d.transpose(axis, 0)


In [38]:
class KeypointPipeline(nn.Module):
    def __init__(self, weights_path):
        super(KeypointPipeline, self).__init__()  
        self.keypoint_model = torch.load(weights_path).to(device)
        self.encoder = GraphEncoder(4,512,4,0.5,True)
        self.decoder = GraphDecoder(n_in_node=4,
                                 edge_types=2,
                                 msg_hid=512,
                                 msg_out=512,
                                 n_hid=512,
                                 do_prob=0.5,
                                 skip_first=False)
        
#         self.off_diag = np.ones([6,6]) - np.eye(6)

#         self.rel_rec = np.array(encode_onehot(np.where(self.off_diag)[1]), dtype=np.float32)
#         self.rel_send = np.array(encode_onehot(np.where(self.off_diag)[0]), dtype=np.float32)
#         self.rel_rec = torch.FloatTensor(self.rel_rec)
#         self.rel_send = torch.FloatTensor(self.rel_send)

        # Define a unidirectional cyclical graph
        num_nodes = 6
        self.off_diag = np.zeros([num_nodes, num_nodes])
        
        # Creating a cycle: 1->2, 2->3, ..., 6->1
        for i in range(num_nodes):
            self.off_diag[i, (i + 1) % num_nodes] = 1

        # Update rel_rec and rel_send based on the new off_diag
        self.rel_rec = np.array(encode_onehot(np.where(self.off_diag)[1]), dtype=np.float32)
        self.rel_send = np.array(encode_onehot(np.where(self.off_diag)[0]), dtype=np.float32)
        self.rel_rec = torch.FloatTensor(self.rel_rec).to(device)
        self.rel_send = torch.FloatTensor(self.rel_send).to(device)

        self.encoder= self.encoder.cuda()
        self.decoder = self.decoder.cuda()
        self.rel_rec = self.rel_rec.cuda()
        self.rel_send = self.rel_send.cuda()
    
    def process_model_output(self, output):
        scores = output[0]['scores'].detach().cpu().numpy()
        high_scores_idxs = np.where(scores > 0.7)[0].tolist()

        post_nms_idxs = torchvision.ops.nms(output[0]['boxes'][high_scores_idxs], 
                                            output[0]['scores'][high_scores_idxs], 0.3).cpu().numpy()

        confidence = output[0]['scores'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy()
        labels = output[0]['labels'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy()
        keypoints = []
        for idx, kps in enumerate(output[0]['keypoints'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy()):
            keypoints.append(list(map(int, kps[0,0:2])) + [confidence[idx]] + [labels[idx]])
        
        # Sort keypoints based on label
        keypoints.sort(key=lambda x: x[-1])
        return keypoints
    
    def keypoints_to_graph(self, keypoints, image_width, image_height):
        # keypoints is expected to be a tensor with shape (num_keypoints, 4),
        # where each keypoint is (x, y, score, label).
        # Convert all elements in keypoints to tensors if they are not already
        keypoints = [torch.tensor(kp, dtype=torch.float32).to(device) if not isinstance(kp, torch.Tensor) else kp for kp in keypoints]

        # Then stack them
        keypoints = torch.stack(keypoints).to(device)        
        
        # Remove duplicates: Only keep the keypoint with the highest score for each label
        unique_labels, best_keypoint_indices = torch.unique(keypoints[:, 3], return_inverse=True)
        best_scores, best_indices = torch.max(keypoints[:, 2].unsqueeze(0) * (best_keypoint_indices == torch.arange(len(unique_labels)).unsqueeze(1).cuda()), dim=1)
        keypoints = keypoints[best_indices]
        
#         print("init keypoints in graph features", keypoints)

        # Normalize x and y to be in the range [-1, 1]
        keypoints[:, 0] = (keypoints[:, 0] - image_width / 2) / (image_width / 2)
        keypoints[:, 1] = (keypoints[:, 1] - image_height / 2) / (image_height / 2)

        # Use only x, y, and score for the graph features
        graph_features = keypoints[:, :4]  # Now shape is (num_keypoints, 3)
        
        # Ensure the shape is [num_keypoints, 3] before returning
        graph_features = graph_features.view(-1, 4)  # Reshape to ensure it's [num_keypoints, 3]
#         print("graph features", graph_features)
        print("graph features shape", graph_features.shape)

        return graph_features
        
    def forward(self, imgs):
        # Temporarily set the keypoint model to evaluation mode
        keypoint_model_training = self.keypoint_model.training
        self.keypoint_model.eval()

        # Process each image in the batch
        with torch.no_grad():
            batch_outputs = [self.keypoint_model(img.unsqueeze(0).to(device)) for img in imgs]

        # Set the keypoint model back to its original training mode
        self.keypoint_model.train(mode=keypoint_model_training)

        # Process model outputs to get labeled keypoints
        batch_labeled_keypoints = [self.process_model_output(output) for output in batch_outputs]
        # Generate graph input tensor for each image and handle varying number of keypoints
        batch_x = []
        for labeled_keypoints in batch_labeled_keypoints:
            keypoints = self.keypoints_to_graph(labeled_keypoints, 640, 480)

            # Initialize x with zeros for 6 nodes with 4 features each
            x = torch.zeros(1, 6, 4, device=device)

            # Ensure that keypoints are on the correct device and fill in x
            num_keypoints_detected = keypoints.size(0)
            if num_keypoints_detected <= 6:
                x[0, :num_keypoints_detected, :] = keypoints
            else:
                raise ValueError("Number of keypoints detected exceeds the maximum of 6.")

            batch_x.append(x)

        # Stack the batch of x tensors for batch processing
        batch_x = torch.cat(batch_x, dim=0)

        # Forward pass through the encoder and decoder
        logits = self.encoder(batch_x, self.rel_rec, self.rel_send)
        edges = my_softmax(logits, -1)
        KGNN2D = self.decoder(batch_x, edges, self.rel_rec, self.rel_send)

        return logits, KGNN2D, batch_labeled_keypoints


In [39]:
# def loss_edges(valid_points, edges):
#     off_diag = np.ones([6, 6]) - np.eye(6)
#     idx =  torch.LongTensor(np.where(off_diag)[1].reshape(6,5)).cuda()
#     if valid_points.ndim == 1:
#         valid_points = valid_points.unsqueeze(0)  # Reshape to 2D if necessary

#     relations = torch.zeros(valid_points.shape[0],valid_points.shape[1]*(valid_points.shape[1]-1)).cuda()
#     for count,vis in enumerate(valid_points):
#         vis = vis.view(-1,1) 
#         vis = vis*vis.t()
#         vis = torch.gather(vis,1,idx)
#         relations[count] = vis.view(-1)
#     relations = relations.type(torch.LongTensor).cuda() 
#     loss_edges = func.cross_entropy(edges.view(-1, 2), relations.view(-1))
#     return loss_edges

# def loss_kp(gt_keypoints, pred_keypoints):
#     # Convert pred_keypoints to tensor if it's a list
    
#     if isinstance(pred_keypoints, list):
#         pred_keypoints = torch.stack([torch.tensor(kp, device=gt_keypoints.device, dtype=torch.float32) if isinstance(kp, list) else kp for kp in pred_keypoints])

#     # Ensure gt_keypoints is a tensor
#     if not isinstance(gt_keypoints, torch.Tensor):
#         gt_keypoints = torch.tensor(gt_keypoints, dtype=torch.float32, device=pred_keypoints.device)

#     # Check if the shape of gt_keypoints is as expected, it should be [N, M] where N is the number of keypoints and M is the properties of each keypoint (like x, y, visibility, etc.)
#     if gt_keypoints.dim() != 2 or gt_keypoints.size(-1) < 3:
#         raise ValueError("gt_keypoints must be a 2D tensor with shape [N, M] where M >= 3.")

#     # Initialize a mask for selecting valid keypoints in gt_keypoints
#     valid_gt_mask = (gt_keypoints[:, -1] == 1)

#     # Ensure the mask is one-dimensional
#     if valid_gt_mask.dim() != 1:
#         raise ValueError("The mask must be one-dimensional")

#     # Filter the gt_keypoints and pred_keypoints based on the mask
#     filtered_gt_keypoints = gt_keypoints[valid_gt_mask][:, :2]  # x, y columns
#     filtered_pred_keypoints = pred_keypoints[valid_gt_mask][:, :2]  # x, y columns

#     # Compute the loss using Smooth L1 Loss on the filtered keypoints
#     loss = func.smooth_l1_loss(filtered_pred_keypoints, filtered_gt_keypoints, reduction='none')

#     # Apply the mask to the loss to consider only valid keypoints
#     valid_loss = loss[valid_gt_mask]
#     return valid_loss.sum() / valid_gt_mask.float().sum()  
    
# def kgnn2d_loss(gt_keypoints, pred_keypoints):
#     loss = func.mse_loss(pred_keypoints, gt_keypoints)
    
#     return loss


In [40]:
def encode_onehot(labels):
    classes = set(labels)
    classes_dict = {c: np.identity(len(classes))[i, :] for i, c in
                    enumerate(classes)}
    labels_onehot = np.array(list(map(classes_dict.get, labels)),
                             dtype=np.int32)
    return labels_onehot
def process_keypoints(keypoints):
    # Assuming keypoints is a list of Nx3 tensors where N is the number of keypoints
    # and each keypoint is represented as [x, y, visibility]
    # Remove the unnecessary middle dimension
    keypoints = [kp.squeeze(1) for kp in keypoints]
    visibilities = [kp[:, 2] for kp in keypoints]  # Extract visibility flags
    valid_vis_all = torch.cat([v == 1 for v in visibilities]).long().cuda()
    valid_invis_all = torch.cat([v == 0 for v in visibilities]).long().cuda()

    keypoints_gt = torch.cat([kp[:, :2] for kp in keypoints]).float().cuda()  # Gather all keypoints and discard visibility flags
    keypoints_gt = keypoints_gt.view(-1, 2).unsqueeze(0)  # Add an extra dimension to match expected shape for loss_edges

    return keypoints_gt, valid_vis_all, valid_invis_all

In [41]:
# import os
# import torch
# import torchvision
# import numpy as np
# from torchvision.transforms import functional as F
# from PIL import Image
# import cv2

# model = KeypointPipeline(weights_path)
# model = model.to(device)

# new_weights_path = '/home/jc-merlab/Pictures/Data/trained_models/krcnn_occ_b32_e200_v0.pth'
# model = torch.load(weights_path).to(device)

# # image = Image.open("/home/jc-merlab/Pictures/Data/split_folder_output-2023-10-22/test/images/002510.rgb.jpg")
# image = Image.open("/home/jc-merlab/Pictures/Data/2023-08-14-Occluded/000207.rgb.jpg")
# print(type(image))

# image = F.to_tensor(image).to(device)
# # image.unsqueeze_(0)
# print(image.shape)
# # image = list(image)

# with torch.no_grad():
#     model.to(device)
#     model.eval()
#     output = model(image)
    
# print(output[1])
    
# pred_kp = output[1]

# # Replace these with your actual image dimensions
# image_width = 640
# image_height = 480


# # Denormalize keypoints
# # denormalized_x = (pred_kp[:, :, 0] * (image_width / 2)) + (image_width / 2)
# # denormalized_y = (pred_kp[:, :, 1] * (image_height / 2)) + (image_height / 2)

# # Denormalize keypoints
# denormalized_x = (pred_kp[:, :, 0] + 1) * (image_width / 2)
# denormalized_y = (pred_kp[:, :, 1] + 1) * (image_height / 2)

# # Stack the denormalized x and y coordinates together to form [n, 2] tensor
# denormalized_keypoints = torch.stack((denormalized_x, denormalized_y), dim=2)

# print("Denormalized Keypoints:", denormalized_keypoints)

# kp_numpy = denormalized_keypoints.cpu().numpy()
# kp_flat_list = kp_numpy.reshape(-1).tolist()  # Convert to a flat list


In [42]:
# def visualize_keypoints(image_path, all_keypoints, point_radius=5, keypoint_color=(255, 0, 0)):
#     """
#     Visualize keypoints on the given image. Expects keypoints in the format of a list with a flat structure [x1,y1,x2,y2,...]
#     for each keypoint set associated with an object.
#     """
#     image = cv2.imread(image_path)
    
# #     if bboxes:
# #         for bbox in bboxes:
# #             cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), bbox_color, 2)
    
#     for keypoints in all_keypoints:
#         for i in range(0, len(keypoints), 2):
#             x, y = int(keypoints[i]), int(keypoints[i+1])
#             cv2.circle(image, (x, y), point_radius, keypoint_color, -1)
    
#     return image

# image = "/home/jc-merlab/Pictures/Data/2023-08-14-Occluded/000207.rgb.jpg"

# output_image = visualize_keypoints(image,[kp_flat_list])

# # Display or save the output image as needed
# cv2.imshow("Keypoints", output_image)
# cv2.waitKey(0)
# cv2.destroyAllWindows()

In [43]:
import os
import json
import torch
import torchvision
import numpy as np
from torchvision.transforms import functional as F
from PIL import Image
import cv2

def calculate_accuracy(pred_kps, gt_kps, threshold=10):
    """Calculate accuracy based on distance threshold."""
    distances = torch.norm(pred_kps - gt_kps, dim=2)
    correct = torch.le(distances, threshold).all(dim=1)
    accuracy = torch.mean(correct.float())
    return accuracy.item()

def visualize_keypoints(image, pred_keypoints, gt_keypoints, pred_color=(0, 255, 0), gt_color=(255, 0, 0), pred_radius=7, gt_radius=5):
    """Visualize predicted and ground truth keypoints."""
    for x, y in pred_keypoints:
        cv2.circle(image, (int(x), int(y)), pred_radius, pred_color, -1)
    for x, y in gt_keypoints:
        cv2.circle(image, (int(x), int(y)), gt_radius, gt_color, -1)
    return image

In [44]:
import torch
import os
import json
from PIL import Image
import torchvision.transforms.functional as F
import numpy as np
import cv2

# Load your model
new_weights_path = '/home/jc-merlab/Pictures/Data/trained_models/krcnn_occ_b128_e200_v0.pth'
model = torch.load(new_weights_path).to(device)
model.eval()

# source_folder = '/home/jc-merlab/Pictures/Data/only_occ_data/'
source_folder = '/home/jc-merlab/Pictures/Data/split_folder_output-2023-11-17/val/folder'
output_folder = '/home/jc-merlab/Pictures/Data/occ_results/13/'

if not os.path.exists(output_folder):
    os.makedirs(output_folder)

total_accuracy = 0
num_images = 0
image_width = 640
image_height = 480

for filename in os.listdir(source_folder):
    if filename.endswith(".rgb.jpg"):
        image_path = os.path.join(source_folder, filename)
        annotation_path = os.path.join(source_folder, filename.replace('.rgb.jpg','.json'))

        # Load image and annotation
        image = Image.open(image_path).convert('RGB')
        with open(annotation_path, 'r') as f:
            annotation = json.load(f)

        # Preprocess image
        input_tensor = F.to_tensor(image).unsqueeze(0).to(device)

        # Predict keypoints
        with torch.no_grad():
            _, pred_kps, _ = model([input_tensor.squeeze(0)])
            
        print(pred_kps[0].shape)

        # pred_kps is a batch. Since batch size is 1, take the first element.
        pred_kps = pred_kps[0]

        # Denormalize keypoints
        denormalized_pred_kps = torch.stack(((pred_kps[:,0] * (image_width / 2)) + (image_width / 2), 
                                             (pred_kps[:,1] * (image_height / 2)) + (image_height / 2)), dim=1)

        # Extract ground truth keypoints and format them correctly
        gt_kps = torch.tensor([kp[0][:2] for kp in annotation['keypoints']], dtype=torch.float32)

        # Calculate accuracy (Implement calculate_accuracy according to your requirements)
        accuracy = calculate_accuracy(denormalized_pred_kps.to(device), gt_kps.unsqueeze(0).to(device))
        total_accuracy += accuracy
        num_images += 1

        # Visualize keypoints (Implement visualize_keypoints according to your requirements)
        image_np = np.array(image)
        output_image = visualize_keypoints(image_np, denormalized_pred_kps.cpu().numpy(), gt_kps.numpy())

        print(f'Accuracy for {filename}: {accuracy * 100:.2f}%')
        
        # Save the output image
        output_image_path = os.path.join(output_folder, f'visualized_{filename}')
        cv2.imwrite(output_image_path, cv2.cvtColor(output_image, cv2.COLOR_RGB2BGR))

# Calculate mean accuracy
if num_images > 0:
    mean_accuracy = total_accuracy / num_images
    print(f'Mean Accuracy across all images: {mean_accuracy * 100:.2f}%')
else:
    print('No images were processed.')

graph features shape torch.Size([3, 4])
torch.Size([6, 4])
Accuracy for 004211.rgb.jpg: 0.00%
graph features shape torch.Size([2, 4])
torch.Size([6, 4])
Accuracy for 006465.rgb.jpg: 0.00%
graph features shape torch.Size([3, 4])
torch.Size([6, 4])
Accuracy for 004503.rgb.jpg: 0.00%
graph features shape torch.Size([3, 4])
torch.Size([6, 4])
Accuracy for 004906.rgb.jpg: 0.00%
graph features shape torch.Size([6, 4])
torch.Size([6, 4])
Accuracy for 001323.rgb.jpg: 100.00%
graph features shape torch.Size([5, 4])
torch.Size([6, 4])
Accuracy for 004819.rgb.jpg: 0.00%
graph features shape torch.Size([5, 4])
torch.Size([6, 4])
Accuracy for 003334.rgb.jpg: 0.00%
graph features shape torch.Size([4, 4])
torch.Size([6, 4])
Accuracy for 005915.rgb.jpg: 0.00%
graph features shape torch.Size([4, 4])
torch.Size([6, 4])
Accuracy for 001755.rgb.jpg: 0.00%
graph features shape torch.Size([5, 4])
torch.Size([6, 4])
Accuracy for 002030.rgb.jpg: 0.00%
graph features shape torch.Size([4, 4])
torch.Size([6, 4])

graph features shape torch.Size([6, 4])
torch.Size([6, 4])
Accuracy for 000895.rgb.jpg: 100.00%
graph features shape torch.Size([5, 4])
torch.Size([6, 4])
Accuracy for 003976.rgb.jpg: 0.00%
graph features shape torch.Size([5, 4])
torch.Size([6, 4])
Accuracy for 001735.rgb.jpg: 0.00%
graph features shape torch.Size([3, 4])
torch.Size([6, 4])
Accuracy for 002655.rgb.jpg: 0.00%
graph features shape torch.Size([4, 4])
torch.Size([6, 4])
Accuracy for 001885.rgb.jpg: 0.00%
graph features shape torch.Size([4, 4])
torch.Size([6, 4])
Accuracy for 005293.rgb.jpg: 0.00%
graph features shape torch.Size([5, 4])
torch.Size([6, 4])
Accuracy for 004417.rgb.jpg: 0.00%
graph features shape torch.Size([5, 4])
torch.Size([6, 4])
Accuracy for 003285.rgb.jpg: 0.00%
graph features shape torch.Size([4, 4])
torch.Size([6, 4])
Accuracy for 006593.rgb.jpg: 0.00%
graph features shape torch.Size([6, 4])
torch.Size([6, 4])
Accuracy for 000257.rgb.jpg: 100.00%
graph features shape torch.Size([6, 4])
torch.Size([6, 4