In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as func
import math
import numpy as np
from torch.autograd import Variable
import torch_geometric.nn as pyg
from torch_geometric.data import Data

_EPS = 1e-10

class MLP(nn.Module):
    """Two-layer fully-connected ELU net with batch norm."""
    def __init__(self, n_in, n_hid, n_out, do_prob=0.):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(n_in, n_hid)
        self.fc2 = nn.Linear(n_hid, n_out)
        self.bn = nn.BatchNorm1d(n_out)
        self.dropout_prob = do_prob
        self.init_weights()
    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data)
                m.bias.data.fill_(0.1)
            elif isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
    def batch_norm(self, inputs):
        x = inputs.view(inputs.size(0) * inputs.size(1), -1)
        x = self.bn(x)
        return x.view(inputs.size(0), inputs.size(1), -1)
    def forward(self, inputs):
        # Flatten the last two dimensions for the linear layer input
        x = func.elu(self.fc1(inputs))
        x = func.dropout(x, self.dropout_prob, training=self.training)
        x = func.elu(self.fc2(x))        
        return self.batch_norm(x)
class GraphEncoder(nn.Module):
    def __init__(self, n_in, n_hid, n_out=4, do_prob=0., factor=True):
        super(GraphEncoder, self).__init__()
        self.factor = factor
        self.mlp1 = MLP(n_in, n_hid, n_hid, do_prob)
        self.mlp2 = MLP(n_hid * 2, n_hid, n_hid, do_prob)
        self.mlp3 = MLP(n_hid, n_hid, n_hid, do_prob)
        if self.factor:
            self.mlp4 = MLP(n_hid * 3, n_hid, n_hid, do_prob)
            print("Using factor graph MLP encoder.")
        else:
            self.mlp4 = MLP(n_hid * 2, n_hid, n_hid, do_prob)
            print("mlp4", self.mlp4)
            print("Using MLP graph encoder.")
        self.fc_out = nn.Linear(n_hid, n_out)
        self.init_weights()
    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data)
                m.bias.data.fill_(0.1)
    def edge2node(self, x, rel_rec, rel_send):
        # NOTE: Assumes that we have the same graph across all samples.
        incoming = torch.matmul(rel_rec.t(), x)
        return incoming / incoming.size(1)
    def node2edge(self, x, rel_rec, rel_send):
        # NOTE: Assumes that we have the same graph across all samples.
        receivers = torch.matmul(rel_rec, x)
        senders = torch.matmul(rel_send, x)
        edges = torch.cat([receivers, senders], dim=2)
        return edges
    def forward(self, inputs, rel_rec, rel_send):
        # Input shape: [num_sims, num_atoms, num_timesteps, num_dims]
        x = inputs.view(inputs.size(0), inputs.size(1), -1)
        # New shape: [num_sims, num_atoms, num_timesteps*num_dims]
        x = self.mlp1(x)  # 2-layer ELU net per node
        x = self.node2edge(x, rel_rec, rel_send)
        x = self.mlp2(x)
        x_skip = x            
        if self.factor:
            x = self.edge2node(x, rel_rec, rel_send)
            x = self.mlp3(x)
            x = self.node2edge(x, rel_rec, rel_send)
            x = torch.cat((x, x_skip), dim=2)  # Skip connection
            x = self.mlp4(x)
        else:
            x = self.mlp3(x)
            x = torch.cat((x, x_skip), dim=2)  # Skip connection
            x = self.mlp4(x)
        return self.fc_out(x)    
Part2 of the Code:  
class GraphDecoder(nn.Module):
    def __init__(self, n_in_node, edge_types, msg_hid, msg_out, n_hid,
                 do_prob=0., skip_first=False):
        super(GraphDecoder, self).__init__()
        self.msg_fc1 = nn.ModuleList(
            [nn.Linear(2 * n_in_node, msg_hid) for _ in range(edge_types)])
        self.msg_fc2 = nn.ModuleList(
            [nn.Linear(msg_hid, msg_out) for _ in range(edge_types)])
        self.msg_out_shape = msg_out
        self.skip_first_edge_type = skip_first

        self.out_fc1 = nn.Linear(n_in_node + msg_out, n_hid)
        self.out_fc2 = nn.Linear(n_hid, n_hid)
        self.out_fc3 = nn.Linear(n_hid, n_in_node)
        print('Using learned graph decoder.')
        self.dropout_prob = do_prob
    def single_step_forward(self, single_timestep_inputs, rel_rec, rel_send,
                            single_timestep_rel_type):
        # Node2edge
        receivers = torch.matmul(rel_rec, single_timestep_inputs)
        senders = torch.matmul(rel_send, single_timestep_inputs)
        pre_msg = torch.cat([receivers, senders], dim=-1)

        all_msgs = Variable(torch.zeros(pre_msg.size(0), pre_msg.size(1),self.msg_out_shape))
        if single_timestep_inputs.is_cuda:
            all_msgs = all_msgs.cuda()

        if self.skip_first_edge_type:
            start_idx = 1
        else:
            start_idx = 0
        # Run separate MLP for every edge type
        # NOTE: To exlude one edge type, simply offset range by 1
        for i in range(start_idx, len(self.msg_fc2)):
            msg = func.relu(self.msg_fc1[i](pre_msg))
            msg = func.dropout(msg, p=self.dropout_prob)
            msg = func.relu(self.msg_fc2[i](msg))
            msg = msg * single_timestep_rel_type[:, :, i:i + 1]
            all_msgs += msg
        # Aggregate all msgs to receiver
        agg_msgs = all_msgs.transpose(-2, -1).matmul(rel_rec).transpose(-2, -1)
        agg_msgs = agg_msgs.contiguous()
        # Skip connection
        aug_inputs = torch.cat([single_timestep_inputs, agg_msgs], dim=-1)
        # Output MLP
        pred = func.dropout(func.relu(self.out_fc1(aug_inputs)), p=self.dropout_prob)
        pred = func.dropout(func.relu(self.out_fc2(pred)), p=self.dropout_prob)
        pred = self.out_fc3(pred)
        # Predict position/velocity difference
        return single_timestep_inputs + pred
    def forward(self, inputs, rel_type, rel_rec, rel_send, pred_steps=1):
        # NOTE: Assumes that we have the same graph across all samples.
        # Only take n-th timesteps as starting points (n: pred_steps)
        last_pred = inputs[:, :, :]
        curr_rel_type = rel_type[:, :, :]
        preds=[]
        # NOTE: Assumes rel_type is constant (i.e. same across all time steps).
        # Run n prediction steps
        #for step in range(0, pred_steps):
        last_pred = self.single_step_forward(last_pred, rel_rec, rel_send,
                                                 curr_rel_type)
        preds.append(last_pred)
        sizes = [preds[0].size(0), preds[0].size(1),
                 preds[0].size(2)]
        output = Variable(torch.zeros(sizes))
        if inputs.is_cuda:
            output = output.cuda()
        # Re-assemble correct timeline
        for i in range(len(preds)):
            output[:, :, :] = preds[i]
        pred_all = output[:, :, :]
        # NOTE: We potentially over-predicted (stored in future_pred). Unused.
        # future_pred = output[:, (inputs.size(1) - 1):, :, :]
        return pred_all#.transpose(1, 2).contiguous()  
    
def my_softmax(input, axis=1):
    trans_input = input.transpose(axis, 0).contiguous()
    soft_max_1d = func.softmax(trans_input,dim=0)
    return soft_max_1d.transpose(axis, 0)

Part3 of the code:
class KeypointPipeline(nn.Module):
    def __init__(self, weights_path):
        super(KeypointPipeline, self).__init__()  
        self.keypoint_model = torch.load(weights_path).to(device)
        self.encoder = GraphEncoder(4,512,4,0.5,False)
        self.decoder = GraphDecoder(n_in_node=4,
                                 edge_types=2,
                                 msg_hid=512,
                                 msg_out=512,
                                 n_hid=512,
                                 do_prob=0.5,
                                 skip_first=False)
        
        self.off_diag = np.ones([6,6]) - np.eye(6)
        self.rel_rec = np.array(encode_onehot(np.where(self.off_diag)[1]), dtype=np.float32)
        self.rel_send = np.array(encode_onehot(np.where(self.off_diag)[0]), dtype=np.float32)
        self.rel_rec = torch.FloatTensor(self.rel_rec)
        self.rel_send = torch.FloatTensor(self.rel_send)
self.rel_rec = np.array(encode_onehot(np.where(self.off_diag)[1]), dtype=np.float32)
        self.rel_send = np.array(encode_onehot(np.where(self.off_diag)[0]), dtype=np.float32)
        self.rel_rec = torch.FloatTensor(self.rel_rec).to(device)
        self.rel_send = torch.FloatTensor(self.rel_send).to(device)
        self.encoder= self.encoder.cuda()
        self.decoder = self.decoder.cuda()
        self.rel_rec = self.rel_rec.cuda()
        self.rel_send = self.rel_send.cuda()    
    def process_model_output(self, output):
        scores = output[0]['scores'].detach().cpu().numpy()
        high_scores_idxs = np.where(scores > 0.7)[0].tolist()
        post_nms_idxs = torchvision.ops.nms(output[0]['boxes'][high_scores_idxs], 
                                            output[0]['scores'][high_scores_idxs], 0.3).cpu().numpy()
        confidence = output[0]['scores'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy()
        labels = output[0]['labels'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy()        keypoints = []
        for idx, kps in enumerate(output[0]['keypoints'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy()):
            keypoints.append(list(map(int, kps[0,0:2])) + [confidence[idx]] + [labels[idx]])        
        # Sort keypoints based on label
        keypoints.sort(key=lambda x: x[-1])
        return keypoints    
    def keypoints_to_graph(self, keypoints, image_width, image_height):
        # keypoints is expected to be a tensor with shape (num_keypoints, 4),
        # where each keypoint is (x, y, score, label).
        # Convert all elements in keypoints to tensors if they are not already
        keypoints = [torch.tensor(kp, dtype=torch.float32).to(device) if not isinstance(kp, torch.Tensor) else kp for kp in keypoints]
        # Then stack them
        keypoints = torch.stack(keypoints).to(device)       
        # Remove duplicates: Only keep the keypoint with the highest score for each label
        unique_labels, best_keypoint_indices = torch.unique(keypoints[:, 3], return_inverse=True)
        best_scores, best_indices = torch.max(keypoints[:, 2].unsqueeze(0) * (best_keypoint_indices == torch.arange(len(unique_labels)).unsqueeze(1).cuda()), dim=1)
        keypoints = keypoints[best_indices]        
        print("init keypoints in graph features", keypoints)
        # Normalize x and y to be in the range [-1, 1]
        keypoints[:, 0] = (keypoints[:, 0] - image_width / 2) / (image_width / 2)
        keypoints[:, 1] = (keypoints[:, 1] - image_height / 2) / (image_height / 2)
        # Use only x, y, and score for the graph features
        graph_features = keypoints[:, :4]  # Now shape is (num_keypoints, 3)        
        # Ensure the shape is [num_keypoints, 3] before returning
        graph_features = graph_features.view(-1, 4)  # Reshape to ensure it's [num_keypoints, 3]
#         print("graph features", graph_features)
        print("graph features shape", graph_features.shape)
        return graph_features   
Part4 Continued from KeypointPipeline
    def forward(self, imgs):
        # Temporarily set the keypoint model to evaluation mode
        keypoint_model_training = self.keypoint_model.training
        self.keypoint_model.eval()
        # Process each image in the batch
        with torch.no_grad():
            batch_outputs = [self.keypoint_model(img.unsqueeze(0).to(device)) for img in imgs]
        # Set the keypoint model back to its original training mode
        self.keypoint_model.train(mode=keypoint_model_training)
        # Process model outputs to get labeled keypoints
        batch_labeled_keypoints = [self.process_model_output(output) for output in batch_outputs]
        # Generate graph input tensor for each image and handle varying number of keypoints
        batch_x = []
        for labeled_keypoints in batch_labeled_keypoints:
            keypoints = self.keypoints_to_graph(labeled_keypoints, 640, 480)
            # Initialize x with zeros for 6 nodes with 4 features each
            x = torch.zeros(1, 6, 4, device=device)
            # Ensure that keypoints are on the correct device and fill in x
            num_keypoints_detected = keypoints.size(0)
            if num_keypoints_detected <= 6:
                x[0, :num_keypoints_detected, :] = keypoints
            else:
                raise ValueError("Number of keypoints detected exceeds the maximum of 6.")
            batch_x.append(x)
        # Stack the batch of x tensors for batch processing
        batch_x = torch.cat(batch_x, dim=0)
        # Forward pass through the encoder and decoder
        logits = self.encoder(batch_x, self.rel_rec, self.rel_send)
        edges = my_softmax(logits, -1)
        KGNN2D = self.decoder(batch_x, edges, self.rel_rec, self.rel_send)
        return logits, KGNN2D, batch_labeled_keypoints
    
def loss_edges(valid_points, edges):
    off_diag = np.ones([6, 6]) - np.eye(6)
    idx =  torch.LongTensor(np.where(off_diag)[1].reshape(6,5)).cuda()
    if valid_points.ndim == 1:
        valid_points = valid_points.unsqueeze(0)  # Reshape to 2D if necessary
    relations = torch.zeros(valid_points.shape[0],valid_points.shape[1]*(valid_points.shape[1]-1)).cuda()
    for count,vis in enumerate(valid_points):
        vis = vis.view(-1,1) 
        vis = vis*vis.t()
        vis = torch.gather(vis,1,idx)
        relations[count] = vis.view(-1)
    relations = relations.type(torch.LongTensor).cuda() 
    relations_expanded = relations.repeat_interleave(2)
    loss_edges = func.cross_entropy(edges.view(-1, 2), relations_expanded.view(-1))
    return loss_edges
def nll_gaussian(preds, target, variance, add_const=False):
    neg_log_p = ((preds - target) ** 2 / (2 * variance))
    if add_const:
        const = 0.5 * np.log(2 * np.pi * variance)
        neg_log_p += const
    return neg_log_p.sum() / (target.size(0) * target.size(1))
def kgnn2d_loss(keypoints_gt, valid_points, keypoints_logits):
    # Ensure data types are consistent and move tensors to the appropriate device
    keypoints_gt = keypoints_gt.type(torch.FloatTensor).cuda()
    keypoints_logits = keypoints_logits.type(torch.FloatTensor).cuda()
    valid_points = valid_points.type(torch.FloatTensor).cuda()
    keypoints_gt = keypoints_gt.type(torch.FloatTensor)*valid_points.unsqueeze(2).type(torch.FloatTensor)
    keypoints_logits = keypoints_logits.type(torch.FloatTensor)*valid_points.unsqueeze(2).type(torch.FloatTensor)
    keypoints_gt = keypoints_gt.cuda()
    keypoints_logits = keypoints_logits.cuda()
    loss_occ = nll_gaussian(keypoints_gt[:,:,0:2], keypoints_logits[:,:,0:2] , 0.1)
    return loss_occ
def encode_onehot(labels):
    classes = set(labels)
    classes_dict = {c: np.identity(len(classes))[i, :] for i, c in enumerate(classes)}
    labels_onehot = np.array(list(map(classes_dict.get, labels)),dtype=np.int32)
    return labels_onehot

Part5 of the code
def process_batch_keypoints(target_dicts):
    # This function now expects target_dicts, a list of dictionaries containing keypoints information
    batch_size = len(target_dicts)
    # Initialize lists to store keypoints and visibilities for each image in the batch
    keypoints_list = []
    visibilities_list = []
    for dict_ in target_dicts:
        # Each keypoints tensor in the dict is expected to have a shape [num_keypoints, 3]
        keypoints = dict_['keypoints'].squeeze(1).to(device)
        # Extract x, y coordinates and visibility flags
        xy_coords = keypoints[:, :2]  # Keep only x, y coordinates
        visibilities = keypoints[:, 2]  # Extract visibility flags
        keypoints_list.append(xy_coords)
        visibilities_list.append(visibilities)
    # Concatenate keypoints and visibilities for the entire batch
    # The final shape of keypoints_gt should be [batch_size, num_keypoints, 2]
    keypoints_gt = torch.stack(keypoints_list).float().cuda()
    visibilities = torch.stack(visibilities_list).cuda()
    # Create valid visibility masks
    valid_vis_all = (visibilities == 1).long().cuda()
    valid_invis_all = (visibilities == 0).long().cuda()
    return keypoints_gt, valid_vis_all, valid_invis_all
def reorder_batch_keypoints(batch_keypoints):
    # Assuming batch_keypoints is a tensor of shape [batch_size, num_keypoints, num_features]
    batch_size, num_keypoints, num_features = batch_keypoints.shape
    reordered_keypoints_batch = []
    for i in range(batch_size):
        # Directly use the normalized keypoints
        normalized_keypoints = batch_keypoints[i]
        # Initialize a tensor for reordered keypoints with only x, y coordinates
        reordered_normalized_keypoints = torch.zeros(num_keypoints, 2, device=batch_keypoints.device)
        # Reordering logic
        rounded_labels = torch.round(normalized_keypoints[:, -1]).int()
        used_indices = []
        for label in range(1, 7):
            valid_idx = (rounded_labels == label).nonzero(as_tuple=True)[0]
            if valid_idx.numel() > 0:
                reordered_normalized_keypoints[label - 1] = normalized_keypoints[valid_idx[0], :2]
            else:
                invalid_idx = ((rounded_labels < 1) | (rounded_labels > 6)).nonzero(as_tuple=True)[0]
                invalid_idx = [idx for idx in invalid_idx if idx not in used_indices]
                if invalid_idx:
                    reordered_normalized_keypoints[label - 1] = normalized_keypoints[invalid_idx[0], :2]
                    used_indices.append(invalid_idx[0])
        reordered_keypoints_batch.append(reordered_normalized_keypoints)
    return torch.stack(reordered_keypoints_batch)
def denormalize_keypoints(batch_keypoints, width=640, height=480):
    # Assuming batch_keypoints is a batch of normalized keypoints tensors
    # Denormalize each keypoint in the batch
    denormalized_keypoints = []
    for kp in batch_keypoints:
        denormalized_x = (kp[:, 0] * (width / 2)) + (width / 2)
        denormalized_y = (kp[:, 1] * (height / 2)) + (height / 2)
        denormalized_kp = torch.stack((denormalized_x, denormalized_y), dim=1)
        denormalized_keypoints.append(denormalized_kp)        
    denormalized_keypoints = torch.stack(denormalized_keypoints)
    return denormalized_keypoints

Part 6 of the code:
# Define the model
model = KeypointPipeline(weights_path)
model = model.to(device)
# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 2 # Define your number of epochs
batch_size = 2
split_folder_path = train_test_split(root_dir)
KEYPOINTS_FOLDER_TRAIN = split_folder_path +"/train" #train_test_split(root_dir) +"/train"
KEYPOINTS_FOLDER_VAL = split_folder_path +"/val"
KEYPOINTS_FOLDER_TEST = split_folder_path +"/test"
dataset_train = KPDataset(KEYPOINTS_FOLDER_TRAIN, transform=None, demo=False)
dataset_val = KPDataset(KEYPOINTS_FOLDER_VAL, transform=None, demo=False)
dataset_test = KPDataset(KEYPOINTS_FOLDER_TEST, transform=None, demo=False)
data_loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
data_loader_val = DataLoader(dataset_val, batch_size=1, shuffle=False, collate_fn=collate_fn)
data_loader_test = DataLoader(dataset_test, batch_size=1, shuffle=False, collate_fn=collate_fn)
v = 1
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for imgs, target_dicts, _ in data_loader_train:
        imgs = [img.to(device) for img in imgs]
        optimizer.zero_grad()
        # Forward pass for batch
        logits, KGNN2D, batch_labeled_keypoints = model(imgs)
        # Process keypoints for the entire batch
        keypoints_gt, valid_vis_all, valid_invis_all = process_batch_keypoints(target_dicts)           
        # Normalize and reorder keypoints as per your existing logic
        # Ensure this logic works on the batch level        
        reordered_normalized_keypoints = reorder_batch_keypoints(KGNN2D)
        # Denormalize the reordered keypoints for the entire batch
        denormalized_keypoints = denormalize_keypoints(reordered_normalized_keypoints)        
        loss_kgnn2d = kgnn2d_loss(keypoints_gt, valid_invis_all, denormalized_keypoints)
        # Compute batch losses
        edge_loss = loss_edges(valid_vis_all, logits)
        # Combine the losses
        total_batch_loss = edge_loss + loss_kgnn2d
        total_batch_loss.backward()
        optimizer.step()
        total_loss += total_batch_loss.item()
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(data_loader_train)}')

In [None]:
class KeypointPipeline(nn.Module):
    def __init__(self, weights_path):
        super(KeypointPipeline, self).__init__()  
        self.keypoint_model = torch.load(weights_path).to(device)
        # probably the autoencoder model will go in here 
       
    def process_model_output(self, output):
        scores = output[0]['scores'].detach().cpu().numpy()
        high_scores_idxs = np.where(scores > 0.7)[0].tolist()
        post_nms_idxs = torchvision.ops.nms(output[0]['boxes'][high_scores_idxs], 
                                            output[0]['scores'][high_scores_idxs], 0.3).cpu().numpy()
        confidence = output[0]['scores'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy()
        labels = output[0]['labels'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy()        keypoints = []
        for idx, kps in enumerate(output[0]['keypoints'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy()):
            keypoints.append(list(map(int, kps[0,0:2])) + [confidence[idx]] + [labels[idx]])        
        # Sort keypoints based on label
        keypoints.sort(key=lambda x: x[-1])
        return keypoints    
    def keypoints_to_graph(self, keypoints, image_width, image_height):
        # keypoints is expected to be a tensor with shape (num_keypoints, 4),
        # where each keypoint is (x, y, score, label).
        # Convert all elements in keypoints to tensors if they are not already
        keypoints = [torch.tensor(kp, dtype=torch.float32).to(device) if not isinstance(kp, torch.Tensor) else kp for kp in keypoints]
        # Then stack them
        keypoints = torch.stack(keypoints).to(device)       
        # Remove duplicates: Only keep the keypoint with the highest score for each label
        unique_labels, best_keypoint_indices = torch.unique(keypoints[:, 3], return_inverse=True)
        best_scores, best_indices = torch.max(keypoints[:, 2].unsqueeze(0) * (best_keypoint_indices == torch.arange(len(unique_labels)).unsqueeze(1).cuda()), dim=1)
        keypoints = keypoints[best_indices]        
        print("init keypoints in graph features", keypoints)
        # Normalize x and y to be in the range [-1, 1]
        keypoints[:, 0] = (keypoints[:, 0] - image_width / 2) / (image_width / 2)
        keypoints[:, 1] = (keypoints[:, 1] - image_height / 2) / (image_height / 2)
        # Use only x, y, and score for the graph features
        graph_features = keypoints[:, :4]  # Now shape is (num_keypoints, 3)        
        # Ensure the shape is [num_keypoints, 3] before returning
        graph_features = graph_features.view(-1, 4)  # Reshape to ensure it's [num_keypoints, 3]
#         print("graph features", graph_features)
        print("graph features shape", graph_features.shape)
        return graph_features   
    def forward(self, imgs):
        # Temporarily set the keypoint model to evaluation mode
        keypoint_model_training = self.keypoint_model.training
        self.keypoint_model.eval()
        # Process each image in the batch
        with torch.no_grad():
            batch_outputs = [self.keypoint_model(img.unsqueeze(0).to(device)) for img in imgs]
        # Set the keypoint model back to its original training mode
        self.keypoint_model.train(mode=keypoint_model_training)
        # Process model outputs to get labeled keypoints
        batch_labeled_keypoints = [self.process_model_output(output) for output in batch_outputs]
        # Generate graph input tensor for each image and handle varying number of keypoints
        batch_x = []
        for labeled_keypoints in batch_labeled_keypoints:
            keypoints = self.keypoints_to_graph(labeled_keypoints, 640, 480)
            # Initialize x with zeros for 6 nodes with 4 features each
            x = torch.zeros(1, 6, 4, device=device)
            # Ensure that keypoints are on the correct device and fill in x
            num_keypoints_detected = keypoints.size(0)
            if num_keypoints_detected <= 6:
                x[0, :num_keypoints_detected, :] = keypoints
            else:
                raise ValueError("Number of keypoints detected exceeds the maximum of 6.")
            batch_x.append(x)
        # Stack the batch of x tensors for batch processing
        batch_x = torch.cat(batch_x, dim=0)
        # Forward pass through the encoder and decoder
        logits = self.encoder(batch_x, self.rel_rec, self.rel_send)
        edges = my_softmax(logits, -1)
        KGNN2D = self.decoder(batch_x, edges, self.rel_rec, self.rel_send)
        return logits, KGNN2D, batch_labeled_keypoints