In [1]:
import torch
from torch.utils.data import Dataset
import json
import os
import torch.nn as nn
import torch.nn.functional as func
from torch.utils.data import DataLoader
import torch.optim as optim
from os.path import expanduser
import splitfolders
import shutil
import glob
import numpy as np
from sklearn.model_selection import train_test_split

device = 'cuda' if torch.cuda.is_available() else 'cpu'
# torch.cuda.set_per_process_memory_fraction(0.9, 0)
print(device)

cuda


In [2]:
class KpVelDataset(Dataset):
    def __init__(self, json_folder):
        super(KpVelDataset, self).__init__()
        self.data = []
        for json_file in sorted(os.listdir(json_folder)):
            if json_file.endswith('_combined.json'):
                with open(os.path.join(json_folder, json_file), 'r') as file:
                    data = json.load(file)
                    start_kp = data['start_kp']
                    next_kp = data['next_kp']
                    position = data['position']
                    self.data.append((start_kp, next_kp, position))

#     def __init__(self, json_folder):
#         super(KpVelDataset, self).__init__()
#         self.data = []
#         for json_file in sorted(os.listdir(json_folder)):
#             if json_file.endswith('_combined.json'):
#                 with open(os.path.join(json_folder, json_file), 'r') as file:
#                     data = json.load(file)
#                     # Ensure data contains 'start_kp', 'next_kp', and 'position'
#                     if 'start_kp' in data and 'next_kp' in data and 'position' in data:
#                         start_kp = data['start_kp']
#                         next_kp = data['next_kp']
#                         position = data['position']
#                         # Only append if start_kp and next_kp are not empty
#                         if start_kp and next_kp:
#                             self.data.append((start_kp, next_kp, position))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        start_kp, next_kp, position = self.data[idx]
        # Ensure start_kp and next_kp have consistent dimensions
#         if not start_kp or not next_kp:
#             raise ValueError(f"Empty keypoints found at index {idx}")
        start_kp_flat = torch.tensor([kp for sublist in start_kp for kp in sublist[0][:2]], dtype=torch.float)
        next_kp_flat = torch.tensor([kp for sublist in next_kp for kp in sublist[0][:2]], dtype=torch.float)
        position = torch.tensor(position, dtype=torch.float)
        return start_kp_flat, next_kp_flat, position

#     def __len__(self):
#         return len(self.data)

# #     def __getitem__(self, idx):
# #         start_kp, next_kp, velocity = self.data[idx]
# #         start_kp_flat = torch.tensor([kp for sublist in start_kp for kp in sublist[0]])
# #         next_kp_flat = torch.tensor([kp for sublist in next_kp for kp in sublist[0]])
# #         velocity = torch.tensor(velocity)
# #         return start_kp_flat, next_kp_flat, velocity
  
#     def __getitem__(self, idx):
#         start_kp, next_kp, position = self.data[idx]
#         # Extract and flatten the first two elements of each keypoint in start_kp
#         start_kp_flat = torch.tensor([kp for sublist in start_kp for kp in sublist[0][:2]], dtype=torch.float)
#         # Extract and flatten the first two elements of each keypoint in next_kp
#         next_kp_flat = torch.tensor([kp for sublist in next_kp for kp in sublist[0][:2]], dtype=torch.float)
#         position = torch.tensor(position)
#         return start_kp_flat, next_kp_flat, position
    

In [3]:
def train_test_split(src_dir):
#     dst_dir_img = src_dir + "images"
    dst_dir_anno = src_dir + "annotations"
    
    if os.path.exists(dst_dir_anno):
        print("folders exist")
    else:
        os.mkdir(dst_dir_anno)
        
#     for jpgfile in glob.iglob(os.path.join(src_dir, "*.jpg")):
#         shutil.copy(jpgfile, dst_dir_img)

    for jsonfile in glob.iglob(os.path.join(src_dir, "*_combined.json")):
        shutil.copy(jsonfile, dst_dir_anno)
        
    output = root_dir + "split_folder_reg"
    
    splitfolders.ratio(src_dir, # The location of dataset
                   output=output, # The output location
                   seed=42, # The number of seed
                   ratio=(0.8, 0.1, 0.1), # The ratio of split dataset
                   group_prefix=None, # If your dataset contains more than one file like ".jpg", ".pdf", etc
                   move=False # If you choose to move, turn this into True
                   )
    
#     shutil.rmtree(dst_dir_img)
    shutil.rmtree(dst_dir_anno)
    
    return output  

In [4]:
class PosRegModel(nn.Module):
    def __init__(self, input_size):
        super(PosRegModel, self).__init__()
        self.fc1 = nn.Linear(input_size * 2, 1024)  # Assuming start_kp and next_kp are concatenated
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512,256)
        self.fc4 = nn.Linear(256,128)
        self.fc5 = nn.Linear(128,64)
        self.fc6 = nn.Linear(64,64)
        self.fc7 = nn.Linear(64,3)  # Output size is 3 for velocity

    def forward(self, start_kp, next_kp):
        x = torch.cat((start_kp.to(device), next_kp.to(device)), dim=1)
        x = func.relu(self.fc1(x))
        x = func.relu(self.fc2(x))
        x = func.relu(self.fc3(x))
        x = func.relu(self.fc4(x))
        x = func.relu(self.fc5(x))
        x = func.relu(self.fc6(x))
        x = self.fc7(x)
        return x
    

In [17]:
# Initialize dataset and data loader
# to generalize home directory. User can change their parent path without entering their home directory
num_epochs = 300
batch_size = 64
v = 1
root_dir = '/home/jc-merlab/Pictures/panda_data/panda_sim_vel/regression_combined_test_new/'
print(root_dir)
split_folder_path = train_test_split(root_dir)
dataset = KpVelDataset(root_dir)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize model
model = PosRegModel(12)  # Adjust input_size as necessary
model = model.to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.005)

# Training loop
for epoch in range(num_epochs):
    for start_kp, next_kp, position in data_loader:
        optimizer.zero_grad()
        position = position.squeeze(1)
        print(position)
        print(start_kp.shape)
        print(position.shape)
        output = model(start_kp, next_kp)
        loss = criterion(output, position.to(device))
        loss.backward()
        optimizer.step()
        print("output", output)
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')
    
# # Save the trained model
model_save_path = f'/home/jc-merlab/Pictures/Data/trained_models/reg_pos_b{batch_size}_e{num_epochs}_v{v}.pth'
torch.save(model.state_dict(), model_save_path)

# model_save_path = f'/home/jc-merlab/Pictures/Data/trained_models/reg_nkp_b{batch_size}_e{num_epochs}_v{v}.pth'
# torch.save({
#     'model_state_dict': model.state_dict(),
#     'model_structure': KeypointRegressionNet()
# }, model_save_path)
# print(f"Model saved to {model_save_path}")


/home/jc-merlab/Pictures/panda_data/panda_sim_vel/regression_combined_test_new/


Copying files: 12428 files [00:00, 21412.15 files/s]


tensor([[ 0.0000, -0.0032,  0.0000],
        [ 0.0000,  0.0000, -0.0100],
        [-0.0098,  0.0000,  0.0000],
        [ 0.0000,  0.0709,  0.5500],
        [-0.4024, -0.1385, -0.2000],
        [ 0.0000, -0.0419, -0.3400],
        [ 0.0000,  0.0000,  0.0100],
        [ 0.0000,  0.0000,  0.0100],
        [-0.0098,  0.0000,  0.0000],
        [ 0.0000,  0.0016,  0.0000],
        [ 0.2879,  0.0236,  0.3500],
        [ 0.0000, -0.0016,  0.0000],
        [ 0.0000,  0.0000,  0.0100],
        [ 0.0000,  0.0000,  0.0100],
        [ 0.0000,  0.0000, -0.0100],
        [ 0.0000,  0.0000, -0.0100],
        [ 0.0000,  0.0000, -0.0100],
        [ 0.0000, -0.0043,  0.0000],
        [ 0.0000,  0.0000, -0.0100],
        [ 0.0000,  0.0000,  0.0100],
        [ 0.3730,  0.1514,  0.3600],
        [ 0.0000,  0.0000,  0.0100],
        [ 0.0000, -0.0021,  0.0000],
        [ 0.0065,  0.0000,  0.0000],
        [ 0.0000,  0.0032,  0.0000],
        [ 0.0000, -0.0021,  0.0000],
        [ 0.0000,  0.0000, -0.0100],
 

output tensor([[ 3.1522e-02, -1.3716e-02, -4.6788e-03],
        [ 1.1892e-02, -2.8638e-03,  3.2473e-03],
        [ 3.0291e-02, -1.4094e-02, -4.3356e-03],
        [ 3.0887e-02, -1.4706e-02, -2.6313e-03],
        [ 3.0474e-02, -1.0721e-02,  1.8433e-04],
        [ 2.5711e-02, -4.4835e-03,  4.2173e-03],
        [ 1.9965e-02, -4.9533e-03,  8.8155e-05],
        [ 2.2647e-02, -2.5438e-03,  4.4529e-03],
        [ 2.0134e-02, -3.5742e-03,  2.9258e-03],
        [ 2.0928e-02, -3.2394e-03,  2.5421e-03],
        [ 3.1788e-02, -1.4403e-02, -2.0833e-03],
        [ 1.1818e-02, -2.5361e-03,  3.8691e-03],
        [ 1.3715e-02, -4.6640e-03,  2.4978e-03],
        [ 2.3386e-02, -1.1075e-02, -1.1308e-02],
        [ 2.9966e-02, -1.4131e-02, -1.0034e-02],
        [ 1.8077e-02, -8.7129e-03, -1.2639e-03],
        [ 1.1871e-02, -2.3445e-03,  4.1836e-03],
        [ 3.1307e-02, -1.4204e-02, -8.9516e-03],
        [ 3.0267e-02, -1.4234e-02, -4.8118e-03],
        [ 3.2034e-02, -1.3497e-02, -2.6511e-03],
        [ 3.0

tensor([[ 0.0000,  0.0000,  0.0100],
        [ 0.0049,  0.0000,  0.0000],
        [ 0.0000, -0.0548, -0.3600],
        [ 0.0000,  0.0021,  0.0000],
        [ 0.0000,  0.0000, -0.0100],
        [ 0.0000,  0.1369, -0.0300],
        [ 0.0000,  0.0021,  0.0000],
        [ 0.0000,  0.0000, -0.1100],
        [ 0.0065,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0100],
        [ 0.0000,  0.0000,  0.0100],
        [ 0.0000,  0.0000,  0.0100],
        [ 0.0000,  0.0000,  0.0100],
        [ 0.0000,  0.0000,  0.0100],
        [ 0.0000,  0.0000,  0.0100],
        [-0.2846, -0.1417, -0.8800],
        [ 0.0000,  0.0000, -0.0100],
        [ 0.0000, -0.0016,  0.0000],
        [ 0.0000,  0.0000, -0.0100],
        [ 0.0000,  0.0016,  0.0000],
        [ 0.0000,  0.0000,  0.0100],
        [ 0.0000,  0.0000, -0.0100],
        [ 0.0000,  0.0000, -0.0100],
        [ 0.0000,  0.1578,  0.0400],
        [ 0.0000,  0.0215,  0.0000],
        [ 0.0000, -0.0032,  0.0000],
        [ 0.0000, -0.0016,  0.0000],
 

output tensor([[-1.4796e-02,  4.5009e-03,  1.1221e-02],
        [-1.9642e-02,  5.6453e-03,  1.1449e-02],
        [-2.4117e-03,  7.0441e-03,  6.9267e-03],
        [ 5.1371e-03,  1.0635e-02,  1.1489e-03],
        [-8.6521e-03,  8.0781e-03,  1.3554e-02],
        [-6.2103e-03,  5.1205e-03,  6.4685e-03],
        [-1.5471e-02,  5.9623e-03,  1.1854e-02],
        [-2.7092e-03,  7.0331e-03,  7.0296e-03],
        [-8.1179e-03,  8.0640e-03,  1.3759e-02],
        [-4.2774e-03,  9.2338e-03,  9.4015e-03],
        [-1.5197e-03,  7.0741e-03,  6.3767e-03],
        [ 4.0270e-03,  6.9006e-03,  7.3906e-04],
        [-5.9603e-03,  6.7949e-03,  9.1804e-03],
        [ 6.0314e-03,  8.8652e-03,  4.0562e-04],
        [-5.8194e-03,  5.3154e-03,  6.2108e-03],
        [-1.6521e-02,  5.9116e-03,  1.1769e-02],
        [ 7.2024e-03,  8.2959e-03, -5.0513e-03],
        [ 1.9736e-03,  8.2947e-03,  3.5025e-03],
        [ 2.2700e-03,  7.6901e-03,  3.9078e-03],
        [-1.0794e-02,  6.3960e-03,  1.1017e-02],
        [ 7.9

tensor([[ 0.0000,  0.0000,  0.0100],
        [-0.4122, -0.1450,  0.2500],
        [ 0.0000,  0.0000,  0.0100],
        [ 0.0000, -0.1047,  0.7300],
        [ 0.0000,  0.0000,  0.0100],
        [ 0.0000,  0.0000,  0.0100],
        [ 0.0000,  0.0000,  0.0100],
        [ 0.0000,  0.0000,  0.0100],
        [ 0.0000, -0.0016,  0.0000],
        [ 0.0000,  0.0000, -0.0100],
        [ 0.0000, -0.0789,  0.0500],
        [ 0.0000,  0.0000, -0.0100],
        [ 0.0000,  0.0000, -0.0100],
        [ 0.0098,  0.0000,  0.0000],
        [ 0.0000,  0.0725,  0.3800],
        [ 0.0000, -0.0032,  0.0000],
        [ 0.0000, -0.0923, -0.1000],
        [ 0.0000,  0.0000, -0.0100],
        [ 0.0000,  0.0000, -0.0100],
        [ 0.0000, -0.0016,  0.0000],
        [ 0.0000,  0.0021,  0.0000],
        [ 0.0000,  0.0000,  0.4500],
        [ 0.0000,  0.0000, -0.0100],
        [ 0.0000,  0.0000, -0.0100],
        [ 0.0000,  0.0016,  0.0000],
        [ 0.0000,  0.0000, -0.0100],
        [ 0.0000,  0.0000,  0.0100],
 

output tensor([[-4.0008e-03, -3.7859e-04,  9.9794e-03],
        [ 7.2759e-03, -1.4667e-03,  7.7342e-03],
        [ 3.7391e-03,  4.1958e-03,  6.3527e-03],
        [ 5.2731e-03,  1.6159e-03,  7.5877e-03],
        [-2.9774e-03, -6.5231e-03,  6.1922e-03],
        [ 2.8467e-03, -4.7119e-03,  6.7459e-03],
        [ 1.1533e-02, -1.0669e-03,  2.8955e-03],
        [-5.2021e-03, -2.1911e-05,  1.0102e-02],
        [-2.8707e-04,  2.7394e-03,  1.1999e-02],
        [-1.0861e-03,  1.4029e-04,  1.0779e-02],
        [ 7.1306e-03,  1.0896e-03,  6.7401e-03],
        [ 8.6769e-03, -1.0130e-02, -2.2085e-03],
        [ 7.0089e-03,  1.7521e-03,  6.4846e-03],
        [ 3.8641e-03, -4.4744e-03,  6.5989e-03],
        [ 7.9344e-03,  6.8203e-03,  2.3344e-03],
        [ 5.2750e-04,  5.9075e-03,  9.9119e-03],
        [ 1.5763e-03, -3.3075e-03,  7.9354e-03],
        [ 3.2486e-03, -4.4864e-03,  5.9304e-03],
        [ 8.2251e-03,  2.8840e-03,  4.6706e-03],
        [ 3.3754e-03, -4.6330e-03,  6.7601e-03],
        [ 3.4

output tensor([[ 4.6144e-02,  6.2033e-03, -2.5152e-02],
        [ 5.0372e-02,  8.3626e-03, -2.7044e-02],
        [ 4.9485e-02,  7.2727e-03, -2.6514e-02],
        [ 4.6652e-02,  1.9789e-03, -2.9586e-02],
        [ 4.6918e-02, -1.5950e-03, -2.8773e-02],
        [ 4.7713e-02,  2.3002e-05, -3.0454e-02],
        [ 4.9258e-02,  2.6369e-03, -2.8486e-02],
        [ 4.7442e-02,  7.4564e-03, -2.5192e-02],
        [ 4.9500e-02, -2.0959e-04, -2.9964e-02],
        [ 4.6596e-02,  2.1340e-03, -2.7288e-02],
        [ 4.8714e-02,  1.8967e-03, -2.8510e-02],
        [ 4.6787e-02,  5.8904e-03, -2.5347e-02],
        [ 5.0286e-02,  1.0757e-02, -2.6515e-02],
        [ 4.5747e-02,  2.0453e-03, -2.6855e-02],
        [ 4.9505e-02,  5.1520e-03, -2.7278e-02],
        [ 4.9034e-02, -1.2372e-03, -2.9618e-02],
        [ 4.5753e-02,  2.0443e-03, -2.6860e-02],
        [ 4.8997e-02, -1.2391e-03, -2.9609e-02],
        [ 4.7000e-02, -1.1558e-03, -2.8927e-02],
        [ 4.6920e-02,  3.4518e-03, -3.2755e-02],
        [ 4.0

output tensor([[-9.3161e-03,  5.4911e-03,  8.2412e-04],
        [ 2.1278e-03,  5.8045e-03, -5.2808e-03],
        [ 7.9355e-04,  5.5631e-03, -4.3511e-03],
        [ 2.3704e-03,  6.1190e-03, -5.6969e-03],
        [-1.0297e-02,  5.0500e-03,  1.4179e-03],
        [ 2.8137e-04,  5.7725e-03, -4.1118e-03],
        [-7.2706e-03,  5.3795e-03, -7.4625e-05],
        [-1.1359e-03,  5.8684e-03, -2.7997e-03],
        [-1.1878e-03,  5.8784e-03, -2.7955e-03],
        [-3.6204e-03,  5.3773e-03, -2.5091e-03],
        [-1.2519e-02,  6.3905e-03,  1.5749e-03],
        [ 7.4681e-04,  5.8608e-03, -4.5214e-03],
        [ 2.7201e-03,  5.8708e-03, -5.7022e-03],
        [-4.5353e-03,  5.6469e-03, -1.8533e-03],
        [-5.8009e-03,  5.9204e-03, -1.2695e-03],
        [-2.7940e-03,  5.2718e-03, -3.0427e-03],
        [-1.8943e-03,  5.5353e-03, -3.8787e-03],
        [-8.5569e-03,  6.3380e-03, -7.4851e-04],
        [-4.5126e-03,  4.6961e-03, -4.6764e-04],
        [-5.2775e-03,  5.0728e-03, -7.6164e-04],
        [-1.4

output tensor([[-1.1485e-02,  3.8828e-03, -2.4390e-02],
        [-4.1155e-03,  4.4698e-03, -2.5828e-02],
        [-7.0797e-03,  5.9512e-03, -2.4858e-02],
        [-8.9112e-03,  4.0890e-03, -2.4906e-02],
        [-7.0693e-03,  4.8127e-03, -2.5015e-02],
        [-4.6942e-03,  3.7784e-03, -2.6155e-02],
        [-4.4198e-03,  4.2317e-03, -2.5949e-02],
        [-4.4217e-03,  4.7284e-03, -2.5635e-02],
        [-3.2282e-03,  4.3380e-03, -2.6148e-02],
        [-1.8421e-02,  1.6911e-03, -2.3795e-02],
        [-9.0393e-03,  4.7505e-03, -2.4466e-02],
        [-1.5015e-02,  2.3777e-03, -2.4330e-02],
        [ 6.8013e-04,  4.1528e-03, -2.7238e-02],
        [-3.3217e-02, -5.5787e-05, -2.0868e-02],
        [-1.6285e-03,  4.2225e-03, -2.6612e-02],
        [-4.6608e-04,  3.6752e-03, -2.7237e-02],
        [-1.0889e-02,  2.6909e-03, -2.5135e-02],
        [-1.7068e-02,  1.9401e-03, -2.4015e-02],
        [-2.2644e-04,  3.6381e-03, -2.7312e-02],
        [-1.0741e-02,  3.9934e-03, -2.4561e-02],
        [-2.7

output tensor([[ 1.5801e-02, -1.4553e-03, -9.9935e-03],
        [ 2.1249e-02, -1.5621e-03, -1.4031e-02],
        [ 1.3322e-02, -6.9609e-04, -7.6185e-03],
        [ 2.3203e-02, -2.6733e-03, -1.6551e-02],
        [ 1.8536e-02,  1.8833e-04, -1.1446e-02],
        [ 1.2019e-02,  2.4376e-03, -5.2836e-03],
        [ 1.4772e-02,  1.5395e-03, -7.6936e-03],
        [ 1.7956e-02, -2.2295e-03, -1.2089e-02],
        [ 2.2263e-02, -2.6825e-03, -1.5665e-02],
        [ 2.1889e-02, -2.8274e-03, -1.5346e-02],
        [ 5.1558e-03,  4.1716e-03,  4.1712e-04],
        [ 1.4031e-02,  4.5577e-04, -7.6117e-03],
        [ 1.1739e-02,  2.6232e-03, -4.9572e-03],
        [ 1.8849e-02, -2.1781e-03, -1.2708e-02],
        [ 5.7976e-03,  2.9458e-03, -1.0018e-03],
        [ 2.2602e-02, -2.6028e-03, -1.6143e-02],
        [ 1.5520e-02, -3.7193e-04, -9.3978e-03],
        [ 1.7304e-02, -5.1605e-04, -1.0837e-02],
        [ 2.1507e-02, -3.5107e-04, -1.4136e-02],
        [ 2.1362e-02, -2.2526e-03, -1.4632e-02],
        [ 2.3

output tensor([[-0.0041, -0.0067, -0.0067],
        [-0.0022, -0.0075, -0.0080],
        [ 0.0131, -0.0030, -0.0122],
        [ 0.0113, -0.0033, -0.0114],
        [ 0.0109, -0.0034, -0.0111],
        [ 0.0090, -0.0032, -0.0096],
        [ 0.0110, -0.0036, -0.0118],
        [ 0.0115, -0.0038, -0.0123],
        [ 0.0100, -0.0043, -0.0119],
        [ 0.0138, -0.0035, -0.0132],
        [ 0.0131, -0.0036, -0.0130],
        [ 0.0069, -0.0057, -0.0119],
        [-0.0039, -0.0073, -0.0070],
        [-0.0044, -0.0068, -0.0065],
        [ 0.0125, -0.0038, -0.0130],
        [-0.0124, -0.0063, -0.0028],
        [ 0.0103, -0.0027, -0.0101],
        [ 0.0100, -0.0038, -0.0114],
        [ 0.0132, -0.0033, -0.0125],
        [ 0.0117, -0.0049, -0.0138],
        [ 0.0010, -0.0055, -0.0082],
        [ 0.0076, -0.0024, -0.0081],
        [ 0.0091, -0.0034, -0.0101],
        [ 0.0140, -0.0033, -0.0132],
        [ 0.0143, -0.0035, -0.0136],
        [ 0.0060, -0.0027, -0.0073],
        [ 0.0080, -0.0038, -0.0

output tensor([[ 1.5211e-02, -3.8708e-05, -2.7926e-02],
        [ 1.9139e-02,  2.6325e-03, -2.5604e-02],
        [ 1.5465e-02,  1.6356e-04, -2.7233e-02],
        [-1.3738e-03, -8.6939e-05, -2.3165e-02],
        [ 1.7164e-02,  4.5495e-03, -2.3483e-02],
        [-3.3721e-03,  2.6024e-03, -2.1096e-02],
        [ 6.9750e-04,  1.6328e-03, -2.2587e-02],
        [ 1.6424e-02,  1.0477e-04, -2.7668e-02],
        [ 8.0999e-03,  7.9060e-05, -2.4864e-02],
        [ 1.4744e-02, -2.2215e-04, -2.7843e-02],
        [ 1.8741e-02,  3.5282e-03, -2.4698e-02],
        [ 8.7424e-03, -1.8510e-04, -2.5250e-02],
        [ 5.7810e-03,  1.0500e-04, -2.4420e-02],
        [ 4.5549e-03, -1.0146e-03, -2.5086e-02],
        [-6.5453e-03,  1.0185e-03, -2.1368e-02],
        [ 1.2016e-02,  2.5023e-03, -2.3697e-02],
        [ 1.8071e-02,  2.3541e-03, -2.5581e-02],
        [-5.0525e-03,  2.3623e-03, -2.0875e-02],
        [-8.7192e-04, -1.6417e-03, -2.3820e-02],
        [-7.6898e-03,  1.6406e-03, -2.0540e-02],
        [-8.5

output tensor([[ 1.2346e-02,  1.7049e-03, -1.0590e-02],
        [ 1.7649e-02, -8.7118e-04, -1.4143e-02],
        [ 1.8199e-02, -2.3321e-03, -1.5011e-02],
        [ 1.9426e-02,  9.9578e-04, -1.3942e-02],
        [ 2.0005e-02,  7.6092e-05, -1.4598e-02],
        [ 1.6661e-02,  1.6375e-03, -1.2344e-02],
        [ 1.4933e-02,  1.3483e-03, -1.1714e-02],
        [ 1.2878e-02,  2.8940e-03, -9.9403e-03],
        [ 1.7733e-02, -1.8646e-03, -1.4613e-02],
        [ 1.6374e-02,  2.2295e-03, -1.1931e-02],
        [ 1.6002e-02,  1.6664e-03, -1.2160e-02],
        [ 1.7809e-02, -7.0401e-05, -1.3811e-02],
        [ 1.3131e-02,  1.4174e-03, -1.0955e-02],
        [ 2.0435e-02,  7.7476e-04, -1.4500e-02],
        [ 1.2322e-02,  3.1473e-03, -9.5532e-03],
        [ 1.1008e-02,  2.9661e-03, -9.1260e-03],
        [ 1.7346e-02,  1.8647e-03, -1.2559e-02],
        [ 1.7277e-02,  1.2885e-03, -1.2791e-02],
        [ 2.0334e-02,  8.0820e-04, -1.4429e-02],
        [ 1.8016e-02,  8.4378e-04, -1.3362e-02],
        [ 1.7

output tensor([[ 4.4163e-03,  8.3085e-04, -4.5403e-03],
        [ 5.2747e-03,  2.6828e-03, -2.1582e-03],
        [ 3.8661e-03,  3.0465e-03, -1.8689e-03],
        [ 2.8978e-04,  3.3523e-03,  7.4648e-04],
        [ 5.7523e-03,  2.0880e-03, -3.1216e-03],
        [ 3.6118e-03,  3.1938e-03, -1.4181e-03],
        [ 4.3684e-03,  2.9214e-03, -1.5827e-03],
        [ 6.8866e-03,  5.1516e-04, -5.2959e-03],
        [ 5.7974e-03,  2.0819e-03, -3.1434e-03],
        [ 5.7747e-03,  2.1926e-03, -3.1157e-03],
        [ 2.2370e-03,  2.5123e-03, -1.0615e-03],
        [-4.3944e-04,  3.3277e-03,  1.1275e-03],
        [ 7.7799e-04,  2.5714e-03, -2.7451e-04],
        [-7.6905e-03,  2.9477e-03,  4.1460e-03],
        [ 5.1888e-03,  3.0592e-03, -1.9764e-03],
        [ 2.3563e-03,  3.5118e-03,  1.9278e-05],
        [ 6.6143e-03,  1.3651e-03, -4.6026e-03],
        [-1.2672e-02,  5.0938e-04,  5.8037e-03],
        [-8.1256e-03,  1.9866e-03,  3.6820e-03],
        [ 4.6574e-03,  2.6887e-03, -2.3502e-03],
        [-5.0

output tensor([[-0.0144, -0.0034, -0.0067],
        [-0.0018, -0.0050, -0.0173],
        [ 0.0112, -0.0069, -0.0220],
        [ 0.0143, -0.0054, -0.0235],
        [ 0.0153, -0.0053, -0.0235],
        [-0.0073, -0.0033, -0.0092],
        [ 0.0059, -0.0058, -0.0185],
        [ 0.0038, -0.0039, -0.0164],
        [ 0.0028, -0.0038, -0.0153],
        [ 0.0122, -0.0042, -0.0215],
        [ 0.0116, -0.0044, -0.0197],
        [ 0.0095, -0.0057, -0.0209],
        [ 0.0099, -0.0039, -0.0190],
        [ 0.0109, -0.0042, -0.0183],
        [ 0.0132, -0.0053, -0.0226],
        [ 0.0088, -0.0042, -0.0211],
        [ 0.0144, -0.0047, -0.0232],
        [ 0.0108, -0.0067, -0.0217],
        [ 0.0097, -0.0050, -0.0173],
        [-0.0072, -0.0033, -0.0096],
        [ 0.0082, -0.0038, -0.0172],
        [ 0.0055, -0.0036, -0.0151],
        [ 0.0114, -0.0064, -0.0221],
        [ 0.0111, -0.0061, -0.0225],
        [ 0.0131, -0.0061, -0.0227],
        [ 0.0068, -0.0057, -0.0188],
        [ 0.0008, -0.0038, -0.0

output tensor([[-2.7267e-03,  8.8168e-03, -3.6276e-03],
        [ 2.5188e-03,  6.8457e-03, -7.8214e-03],
        [-1.1937e-02,  6.8672e-03, -3.9689e-03],
        [ 4.6532e-03,  5.3220e-03, -1.0507e-02],
        [-3.2105e-03,  8.5865e-03, -3.7525e-03],
        [ 2.6594e-03,  7.0956e-03, -7.5082e-03],
        [ 7.8500e-04,  7.2066e-03, -6.9662e-03],
        [ 3.2708e-04,  6.2966e-03, -6.6068e-03],
        [ 6.8375e-04,  6.9005e-03, -8.3370e-03],
        [ 2.2497e-03,  6.9282e-03, -8.8117e-03],
        [-8.1970e-03,  7.0326e-03, -4.9798e-03],
        [ 5.1579e-03,  6.2981e-03, -9.6919e-03],
        [ 1.1173e-03,  7.1733e-03, -7.9820e-03],
        [ 2.4708e-04,  8.1938e-03, -5.5018e-03],
        [ 1.2370e-03,  7.2755e-03, -7.1089e-03],
        [ 4.0551e-03,  6.5326e-03, -9.9626e-03],
        [-2.5494e-04,  7.6011e-03, -6.3374e-03],
        [-1.9089e-03,  7.1635e-03, -6.6229e-03],
        [-2.1737e-04,  8.2667e-03, -5.6577e-03],
        [ 3.3798e-03,  4.6031e-03, -9.6683e-03],
        [ 4.2

output tensor([[ 1.6868e-03,  1.8151e-06,  8.2559e-03],
        [ 1.2031e-03,  1.8737e-03,  1.1205e-02],
        [ 7.5269e-05,  3.6823e-04,  9.4642e-03],
        [-6.3067e-03,  2.6513e-04,  1.3414e-02],
        [-7.8788e-04,  1.0214e-03,  1.1158e-02],
        [-9.3035e-03,  1.2506e-03,  1.7094e-02],
        [-8.9024e-03,  1.3484e-03,  1.6915e-02],
        [-1.2254e-02,  7.8173e-04,  1.8515e-02],
        [-5.0055e-03,  1.5261e-03,  1.4694e-02],
        [-1.0016e-03,  1.9936e-03,  1.2932e-02],
        [ 4.7579e-04, -1.3185e-03,  6.9997e-03],
        [-1.2887e-04,  3.3339e-03,  1.2924e-02],
        [ 7.3012e-04,  1.6653e-03,  1.1288e-02],
        [-1.2502e-02,  7.5016e-04,  1.8934e-02],
        [ 2.7699e-04, -1.2622e-03,  7.2768e-03],
        [-4.9314e-04,  2.8682e-03,  1.3539e-02],
        [ 1.2969e-03,  1.1088e-04,  8.1931e-03],
        [-2.5932e-03,  2.5648e-03,  1.4631e-02],
        [ 8.0056e-04,  2.1148e-04,  9.1328e-03],
        [-1.7696e-02,  6.1627e-04,  2.3758e-02],
        [ 1.3

output tensor([[ 2.4262e-03,  2.5224e-03,  2.9429e-03],
        [ 2.7775e-03,  2.0914e-03,  2.4646e-03],
        [ 1.4671e-03,  2.8406e-03,  3.7799e-03],
        [-1.5141e-02,  3.9748e-04,  1.0987e-02],
        [-1.1001e-03,  4.0975e-03,  5.7768e-03],
        [ 1.9982e-04,  3.5292e-03,  4.9101e-03],
        [-4.5745e-03,  5.2199e-03,  8.4377e-03],
        [ 2.9566e-04,  2.4918e-03,  4.5980e-03],
        [-5.7011e-03,  5.1897e-03,  9.4923e-03],
        [-3.4819e-03,  1.9295e-03,  6.7914e-03],
        [ 1.9341e-04,  4.1109e-03,  4.9061e-03],
        [ 9.2164e-04,  1.9765e-03,  3.6969e-03],
        [ 3.8697e-04,  3.4288e-03,  4.1889e-03],
        [-3.0592e-03,  6.0161e-03,  7.1153e-03],
        [-2.7712e-03,  3.3960e-03,  7.1103e-03],
        [ 8.5924e-05,  3.5263e-03,  4.9939e-03],
        [ 1.4659e-03,  1.9412e-03,  3.4391e-03],
        [-4.7250e-03,  5.1220e-03,  8.6156e-03],
        [-2.8961e-02, -1.3801e-03,  1.8367e-02],
        [-7.5090e-03,  8.4543e-04,  7.2509e-03],
        [-5.2

output tensor([[-0.0174,  0.0050,  0.0112],
        [-0.0230,  0.0041,  0.0121],
        [-0.0231,  0.0075,  0.0160],
        [-0.0158,  0.0050,  0.0101],
        [-0.0172,  0.0052,  0.0106],
        [-0.0218,  0.0073,  0.0152],
        [-0.0187,  0.0062,  0.0128],
        [-0.0178,  0.0006,  0.0047],
        [-0.0177,  0.0041,  0.0099],
        [-0.0158,  0.0056,  0.0109],
        [-0.0372,  0.0039,  0.0198],
        [-0.0169,  0.0044,  0.0094],
        [-0.0167,  0.0038,  0.0084],
        [-0.0136,  0.0031,  0.0066],
        [-0.0229,  0.0041,  0.0121],
        [-0.0133,  0.0046,  0.0085],
        [-0.0181,  0.0027,  0.0079],
        [-0.0227,  0.0074,  0.0157],
        [-0.0113,  0.0056,  0.0083],
        [-0.0227,  0.0041,  0.0120],
        [-0.0147,  0.0036,  0.0074],
        [-0.0198,  0.0068,  0.0137],
        [-0.0139,  0.0037,  0.0071],
        [-0.0172,  0.0044,  0.0095],
        [-0.0264,  0.0079,  0.0177],
        [-0.0182,  0.0065,  0.0128],
        [-0.0263,  0.0036,  0.0

KeyboardInterrupt: 

In [8]:
def test_model(model_path, test_data_dir):
    # Load the test dataset
    test_dataset = KpVelDataset(test_data_dir)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    # Initialize the model and load the saved state
    model = PosRegModel(12)    
    model.load_state_dict(torch.load(model_path))
    model.to(device)
    model.eval()

    # Criterion for evaluation
    criterion = nn.MSELoss()
    total_loss = 0

    # No gradient needed for evaluation
    with torch.no_grad():
        for start_kp, next_kp, position in test_loader:
            output = model(start_kp.to(device), next_kp.to(device))
            for i in range(start_kp.size(0)):
                individual_start_kp = start_kp[i]
                individual_next_kp = next_kp[i]
                individual_position = position[i]
                predicted_position = output[i]

                print("Start KP:", individual_start_kp)
                print("Next KP:", individual_next_kp)
                print("Actual Position:", individual_position)
                print("Predicted Position:", predicted_position)
                print("-----------------------------------------")
            loss = criterion(output, predicted_position)
            total_loss += loss.item()
            
    # Calculate the average loss
    avg_loss = total_loss / len(test_loader)
    print(f'Average Test Loss: {avg_loss}')

# Usage
model_path = '/home/jc-merlab/Pictures/Data/trained_models/reg_pos_b32_e200_v1.pth'  # Update with your model path
test_data_dir = '/home/jc-merlab/Pictures/panda_data/panda_sim_vel/regression_combined_test_new/split_folder_reg/test/annotations/'  # Update with your test data path
test_model(model_path, test_data_dir)

RuntimeError: Error(s) in loading state_dict for PosRegModel:
	Missing key(s) in state_dict: "fc5.weight", "fc5.bias", "fc6.weight", "fc6.bias", "fc7.weight", "fc7.bias". 
	size mismatch for fc1.weight: copying a param with shape torch.Size([256, 24]) from checkpoint, the shape in current model is torch.Size([1024, 24]).
	size mismatch for fc1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for fc2.weight: copying a param with shape torch.Size([128, 256]) from checkpoint, the shape in current model is torch.Size([512, 1024]).
	size mismatch for fc2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for fc3.weight: copying a param with shape torch.Size([64, 128]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for fc3.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for fc4.weight: copying a param with shape torch.Size([3, 64]) from checkpoint, the shape in current model is torch.Size([128, 256]).
	size mismatch for fc4.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([128]).