In [1]:
import torch
from torch.utils.data import Dataset
import json
import os
import torch.nn as nn
import torch.nn.functional as func
from torch.utils.data import DataLoader
import torch.optim as optim
from os.path import expanduser
import splitfolders
import shutil
import glob
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import pandas as pd
from torchsummary import summary

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [2]:
class KpVelDataset(Dataset):
    def __init__(self, json_folder):
        super(KpVelDataset, self).__init__()
        self.data = []
        for json_file in sorted(os.listdir(json_folder)):
            if json_file.endswith('_combined.json'):
                with open(os.path.join(json_folder, json_file), 'r') as file:
                    data = json.load(file)
                    start_kp = data['start_kp']
                    next_kp = data['next_kp']
                    position = data['position']
                    self.data.append((start_kp, next_kp, position))
    
    def __len__(self):
        return len(self.data)

    
    def __getitem__(self, idx):
        start_kp, next_kp, position = self.data[idx]
        try:
            start_kp_flat = torch.tensor([kp for sublist in start_kp for kp in sublist[0][:2]], dtype=torch.float)
            next_kp_flat = torch.tensor([kp for sublist in next_kp for kp in sublist[0][:2]], dtype=torch.float)
            position = torch.tensor(position, dtype=torch.float)
            if start_kp_flat.nelement() != 18 or next_kp_flat.nelement() != 18:
                raise ValueError(f"Invalid number of elements: start_kp {start_kp_flat.nelement()}, next_kp {next_kp_flat.nelement()} at index {idx}")
        except Exception as e:
            print(f"Error processing index {idx}: {e}")
            print(f"Start KP: {start_kp}, Next KP: {next_kp}")
            raise
        return start_kp_flat, next_kp_flat, position

In [3]:
def train_test_split(src_dir):
    dst_dir_anno = src_dir + "annotations"
    
    if os.path.exists(dst_dir_anno):
        print("folders exist")
    else:
        os.mkdir(dst_dir_anno)

    for jsonfile in glob.iglob(os.path.join(src_dir, "*_combined.json")):
        shutil.copy(jsonfile, dst_dir_anno)
        
    output = root_dir + "split_folder_reg"
    
    splitfolders.ratio(src_dir, # The location of dataset
                   output=output, # The output location
                   seed=42, # The number of seed
                   ratio=(0.8, 0.1, 0.1), # The ratio of split dataset
                   group_prefix=None, # If your dataset contains more than one file like ".jpg", ".pdf", etc
                   move=False # If you choose to move, turn this into True
                   )
    
#     shutil.rmtree(dst_dir_img)
    shutil.rmtree(dst_dir_anno)
    
    return output  

In [4]:
class PosRegModel(nn.Module):
    def __init__(self, input_size):
        super(PosRegModel, self).__init__()
        self.fc1 = nn.Linear(input_size * 2, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.dropout1 = nn.Dropout(p=0.3)
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.dropout2 = nn.Dropout(p=0.3)
        self.fc3 = nn.Linear(256, 128)
        self.bn3 = nn.BatchNorm1d(128)
        self.fc4 = nn.Linear(128, 3)

    def forward(self, start_kp, next_kp):
        x = torch.cat((start_kp, next_kp), dim=1)
        x = func.relu(self.bn1(self.fc1(x)))
        x = self.dropout1(x)
        x = func.relu(self.bn2(self.fc2(x)))
        x = self.dropout2(x)
        x = func.relu(self.bn3(self.fc3(x)))
        x = self.fc4(x)
        return x

In [5]:
model = PosRegModel(18).to(device)

summary(model, [(18,), (18,)], device=device)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                  [-1, 512]          18,944
       BatchNorm1d-2                  [-1, 512]           1,024
           Dropout-3                  [-1, 512]               0
            Linear-4                  [-1, 256]         131,328
       BatchNorm1d-5                  [-1, 256]             512
           Dropout-6                  [-1, 256]               0
            Linear-7                  [-1, 128]          32,896
       BatchNorm1d-8                  [-1, 128]             256
            Linear-9                    [-1, 3]             387
Total params: 185,347
Trainable params: 185,347
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.02
Params size (MB): 0.71
Estimated Total Size (MB): 0.73
-------------------------------------------