In [20]:
import torch
from torch.utils.data import Dataset
import json
import os
import torch.nn as nn
import torch.nn.functional as func
from torch.utils.data import DataLoader
import torch.optim as optim
from os.path import expanduser
import splitfolders
import shutil
import glob
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import pandas as pd

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)




cuda


In [21]:
class KpVelDataset(Dataset):
    def __init__(self, json_folder):
        super(KpVelDataset, self).__init__()
        self.data = []
        for json_file in sorted(os.listdir(json_folder)):
            if json_file.endswith('_combined.json'):
                with open(os.path.join(json_folder, json_file), 'r') as file:
                    data = json.load(file)
                    start_kp = data['start_kp']
                    next_kp = data['next_kp']
                    position = data['position']
                    self.data.append((start_kp, next_kp, position))
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        start_kp, next_kp, position = self.data[idx]
        # Ensure start_kp and next_kp have consistent dimensions
        # if not start_kp or not next_kp:
        #     raise ValueError(f"Empty keypoints found at index {idx}")
        start_kp_flat = torch.tensor([kp for sublist in start_kp for kp in sublist[0][:2]], dtype=torch.float)
        next_kp_flat = torch.tensor([kp for sublist in next_kp for kp in sublist[0][:2]], dtype=torch.float)
        position = torch.tensor(position, dtype=torch.float)
        return start_kp_flat, next_kp_flat, position

def train_test_split(src_dir):
    dst_dir_anno = src_dir + "annotations"
    
    if os.path.exists(dst_dir_anno):
        print("folders exist")
    else:
        os.mkdir(dst_dir_anno)

    for jsonfile in glob.iglob(os.path.join(src_dir, "*_combined.json")):
        shutil.copy(jsonfile, dst_dir_anno)
        
    output = root_dir + "split_folder_reg"
    
    splitfolders.ratio(src_dir, # The location of dataset
                   output=output, # The output location
                   seed=42, # The number of seed
                   ratio=(0.8, 0.1, 0.1), # The ratio of split dataset
                   group_prefix=None, # If your dataset contains more than one file like ".jpg", ".pdf", etc
                   move=False # If you choose to move, turn this into True
                   )
    
#     shutil.rmtree(dst_dir_img)
    shutil.rmtree(dst_dir_anno)
    
    return output  



In [22]:
class PosRegModel(nn.Module):
    def __init__(self, input_size):
        super(PosRegModel, self).__init__()
        self.fc1 = nn.Linear(input_size * 2, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.dropout1 = nn.Dropout(p=0.3)
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.dropout2 = nn.Dropout(p=0.3)
        self.fc3 = nn.Linear(256, 128)
        self.bn3 = nn.BatchNorm1d(128)
        self.fc4 = nn.Linear(128, 3)

    def forward(self, start_kp, next_kp):
        x = torch.cat((start_kp.cuda(), next_kp.cuda()), dim=1)
        x = func.relu(self.bn1(self.fc1(x)))
        x = self.dropout1(x)
        x = func.relu(self.bn2(self.fc2(x)))
        x = self.dropout2(x)
        x = func.relu(self.bn3(self.fc3(x)))
        x = self.fc4(x)
        return x

In [23]:
from torchinfo import summary

model = PosRegModel(input_size=12)
model.eval()

# This is the correct usage for models expecting multiple inputs.
summary(model, input_sizes=[(6, 2), (6, 2)])

Layer (type:depth-idx)                   Param #
PosRegModel                              --
├─Linear: 1-1                            12,800
├─BatchNorm1d: 1-2                       1,024
├─Dropout: 1-3                           --
├─Linear: 1-4                            131,328
├─BatchNorm1d: 1-5                       512
├─Dropout: 1-6                           --
├─Linear: 1-7                            32,896
├─BatchNorm1d: 1-8                       256
├─Linear: 1-9                            387
Total params: 179,203
Trainable params: 179,203
Non-trainable params: 0