In [None]:
import os
from os import listdir
import pandas as pd
import numpy as np
import glob
import cv2
import json
from os.path import expanduser
import splitfolders
import shutil
from define_path import Def_Path

from tqdm import tqdm

import torch 
import torchvision
from torchvision import models
from torchvision.models.detection.rpn import AnchorGenerator
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn 
import torch.nn.functional as func
import torchvision.transforms as T
from torchvision.transforms import functional as F
from torchsummary import summary
from torch.cuda.amp import GradScaler, autocast
from sklearn.model_selection import train_test_split
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data

import albumentations as A # Library for augmentations

import matplotlib.pyplot as plt 
from PIL import Image

import transforms, utils, engine, train
from utils import collate_fn
from engine import train_one_epoch, evaluate


t = torch.cuda.get_device_properties(0).total_memory
print(t)
torch.cuda.empty_cache()

r = torch.cuda.memory_reserved(0)
print(r)
a = torch.cuda.memory_allocated(0)
print(a)
# f = r-a  # free inside reserved

In [None]:
# to generalize home directory. User can change their parent path without entering their home directory
path = Def_Path()

parent_path =  path.home + "/Pictures/" + "Data/"

root_dir = parent_path + path.year + "-" + path.month + "-" + path.day + "/"

print(root_dir)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# torch.cuda.set_per_process_memory_fraction(0.9, 0)
print(device)

In [None]:
# this fucntion tranforms an input image for diverseifying data for training
def train_transform():
    return A.Compose([
        A.Sequential([
            A.RandomRotate90(p=1), 
            A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.2, brightness_by_max=True, always_apply=False, p=1), 
        ], p=1),
        A.Resize(640, 480),  # Resize every image to 640x480 after all other transformations
    ],
    keypoint_params=A.KeypointParams(format='xy'),
    bbox_params=A.BboxParams(format='pascal_voc', label_fields=['bboxes_labels'])
    )

In [None]:
# this function is to split the dataset into train, test and validation folder.
def train_test_split(src_dir):
    dst_dir_img = src_dir + "images"
    dst_dir_anno = src_dir + "annotations"
    
    if os.path.exists(dst_dir_img) and os.path.exists(dst_dir_anno):
        print("folders exist")
    else:
        os.mkdir(dst_dir_img)
        os.mkdir(dst_dir_anno)
        
    for jpgfile in glob.iglob(os.path.join(src_dir, "*.jpg")):
        shutil.copy(jpgfile, dst_dir_img)

    for jsonfile in glob.iglob(os.path.join(src_dir, "*.json")):
        shutil.copy(jsonfile, dst_dir_anno)
        
    output = parent_path + "split_folder_output" + "-" + path.year + "-" + path.month + "-" + path.day 
    
    print(output)
    
    splitfolders.ratio(src_dir, # The location of dataset
                   output=output, # The output location
                   seed=42, # The number of seed
                   ratio=(.7, .2, .1), # The ratio of split dataset
                   group_prefix=None, # If your dataset contains more than one file like ".jpg", ".pdf", etc
                   move=False # If you choose to move, turn this into True
                   )
    
    shutil.rmtree(dst_dir_img)
    shutil.rmtree(dst_dir_anno)
    
    return output

In [None]:
class ClassDataset(Dataset):
    def __init__(self, root, transform=None, demo=False):                
        self.root = root
        self.transform = transform
        self.demo = demo 
        self.imgs_files = sorted(os.listdir(os.path.join(root, "images")))
        self.annotations_files = sorted(os.listdir(os.path.join(root, "annotations")))
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.root, "images", self.imgs_files[idx])
        annotations_path = os.path.join(self.root, "annotations", self.annotations_files[idx])
        img_original = cv2.imread(img_path)
        img_original = cv2.cvtColor(img_original, cv2.COLOR_BGR2RGB)        
        
        with open(annotations_path) as f:
            data = json.load(f)
            bboxes_original = data['bboxes']
            keypoints_original = data['keypoints']
            
            # All objects are keypoints on the arm
            bboxes_labels_original = [] 
            bboxes_labels_original.append('base_kp')
            bboxes_labels_original.append('joint1')
            bboxes_labels_original.append('joint2')
            bboxes_labels_original.append('joint3')
            bboxes_labels_original.append('joint4')
            bboxes_labels_original.append('joint5')

        if self.transform:
            keypoints_original_flattened = [el[0:2] for kp in keypoints_original for el in kp]
            transformed = self.transform(image=img_original, bboxes=bboxes_original, bboxes_labels=bboxes_labels_original, keypoints=keypoints_original_flattened)
            img = transformed['image']
            bboxes = transformed['bboxes']
            keypoints_transformed_unflattened = np.reshape(np.array(transformed['keypoints']), (-1,1,2)).tolist()
            
            keypoints = []
            for o_idx, obj in enumerate(keypoints_transformed_unflattened):
                obj_keypoints = []
                for k_idx, kp in enumerate(obj):
                    obj_keypoints.append(kp + [keypoints_original[o_idx][k_idx][2]])
                keypoints.append(obj_keypoints)
        else:
            img, bboxes, keypoints = img_original, bboxes_original, keypoints_original  

            # Convert everything into a torch tensor        
        bboxes = torch.as_tensor(bboxes, dtype=torch.float32)       
        target = {}
        labels = [1, 2, 3, 4, 5, 6]   
#         labels = [1, 2, 3, 4]
        target["boxes"] = bboxes
        target["labels"] = torch.as_tensor(labels, dtype=torch.int64) # all objects are joint positions
        target["image_id"] = torch.tensor([idx])
        target["area"] = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
        target["iscrowd"] = torch.zeros(len(bboxes), dtype=torch.int64)
        target["keypoints"] = torch.as_tensor(keypoints, dtype=torch.float32)
        img = F.to_tensor(img)        
        bboxes_original = torch.as_tensor(bboxes_original, dtype=torch.float32)
        target_original = {}
        target_original["boxes"] = bboxes_original
        target_original["labels"] = torch.as_tensor(labels, dtype=torch.int64) # all objects are glue tubes
        target_original["image_id"] = torch.tensor([idx])
        target_original["area"] = (bboxes_original[:, 3] - bboxes_original[:, 1]) * (bboxes_original[:, 2] - bboxes_original[:, 0])
        target_original["iscrowd"] = torch.zeros(len(bboxes_original), dtype=torch.int64)
        target_original["keypoints"] = torch.as_tensor(keypoints_original, dtype=torch.float32)        
        img_original = F.to_tensor(img_original)


        if self.demo:
            return img, target, img_original, target_original
        else:
            return img, target
    
    def __len__(self):
        return len(self.imgs_files)                     
    

In [None]:
def construct_graph_for_training(gt_keypoints):
    N = gt_keypoints.shape[0]
    edge_index = [(i, (i + 1) % N) for i in range(N)]
    
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    
    x = torch.tensor(gt_keypoints[:, :2], dtype=torch.float)  # assuming keypoints are (x, y, visibility)
    return Data(x=x, edge_index=edge_index)

# def construct_graph_for_training(gt_keypoints):
#     N = gt_keypoints.shape[0]
#     edge_index = []
#     for i in range(N):
#         for j in range(N):
#             edge_index.append((i, j))

#     edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    
#     x = torch.tensor(gt_keypoints[:, :2], dtype=torch.float)
#     return Data(x=x, edge_index=edge_index)


In [None]:
def construct_graph_for_prediction(keypoints, total_keypoints=6):
    if len(keypoints) < total_keypoints:
        dummy_keypoints = np.zeros((total_keypoints - len(keypoints), 2))
        keypoints = np.concatenate([keypoints, dummy_keypoints], axis=0)

    N = keypoints.shape[0]
    edge_index = [(i, (i + 1) % N) for i in range(N)]

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    
    x = torch.tensor(keypoints[:, :2], dtype=torch.float)
    return Data(x=x, edge_index=edge_index)

# def construct_graph_for_prediction(keypoints, total_keypoints=6):
#     # If there are missing keypoints, add dummy nodes.
#     if len(keypoints) < total_keypoints:
#         dummy_keypoints = np.zeros((total_keypoints - len(keypoints), 3))
#         keypoints = np.concatenate([keypoints, dummy_keypoints], axis=0)

#     N = keypoints.shape[0]
#     edge_index = []
#     for i in range(N):
#         for j in range(N):
#             edge_index.append((i, j))

#     edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    
#     x = torch.tensor(keypoints[:, :2], dtype=torch.float)
#     return Data(x=x, edge_index=edge_index)

In [None]:
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class SimpleGNNLayer(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(SimpleGNNLayer, self).__init__(aggr='add')  # 'add' aggregation
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # Add self loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Transform node feature matrix.
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=self.lin(x))

    def message(self, x_j, edge_index, size):
        # Compute normalization.
        row, col = edge_index
        deg = degree(row, size[0], dtype=x_j.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        return norm.view(-1, 1) * x_j

class SimpleGNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(SimpleGNN, self).__init__()
        self.layer1 = SimpleGNNLayer(in_channels, hidden_channels)
        self.layer2 = SimpleGNNLayer(hidden_channels, hidden_channels)
#         self.layer3 = SimpleGNNLayer(hidden_channels, hidden_channels)
        self.fc = nn.Linear(hidden_channels, out_channels)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = torch.relu(self.layer1(x, edge_index))
        x = torch.relu(self.layer2(x, edge_index))
#         x = torch.relu(self.layer3(x, edge_index))
        x = self.fc(x)
        return x

In [None]:
import torch.nn as nn
import torchvision.models as models

class HybridModel(nn.Module):
    def __init__(self):
        super(HybridModel, self).__init__()
        
        # Pre-trained ResNet for feature extraction
        self.resnet = models.resnet50(pretrained=True)
        self.resnet.fc = nn.Identity()  # Remove the final FC layer
        
        # Your Simple GNN model (assuming 2D keypoints)
        self.gnn = SimpleGNN(in_channels=2 + 2048, hidden_channels=128, out_channels=2)
        
    def forward(self, x, data):
        # Get feature map from CNN
        features = self.resnet(x)

        # Reshape and repeat features to match keypoints
        n = data.x.size(0)
        repeated_features = features.unsqueeze(2).repeat(1, 1, n).transpose(1, 2).reshape(-1, 2048)

        # Concatenate features with keypoints
        combined_data = torch.cat((data.x, repeated_features), dim=1)
        data.x = combined_data

        # Pass through GNN
        x = self.gnn(data)
        return x




In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = HybridModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.0005)
criterion = torch.nn.MSELoss()

KEYPOINTS_FOLDER_TRAIN = train_test_split(root_dir) +"/train" #train_test_split(root_dir) +"/train"
KEYPOINTS_FOLDER_VAL = train_test_split(root_dir) +"/val"
KEYPOINTS_FOLDER_TEST = train_test_split(root_dir) +"/test"

num_epochs = 100
batch_size = 4

dataset_train = ClassDataset(KEYPOINTS_FOLDER_TRAIN, transform=train_transform(), demo=False)
dataset_val = ClassDataset(KEYPOINTS_FOLDER_VAL, transform=None, demo=False)
dataset_test = ClassDataset(KEYPOINTS_FOLDER_TEST, transform=None, demo=False)

data_loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, pin_memory=True)
data_loader_val = DataLoader(dataset_val, batch_size=1, shuffle=False, collate_fn=collate_fn)
data_loader_test = DataLoader(dataset_test, batch_size=1, shuffle=False, collate_fn=collate_fn)


# Training loop
for epoch in range(num_epochs):
    for batch_idx, batch in enumerate(data_loader_train):
        images, targets = batch
        images = torch.stack(images).to(device)  
        for i in range(len(images)):
            gt_keypoints = targets[i]['keypoints'].to(device).squeeze()
            data = construct_graph_for_training(gt_keypoints).to(device)
            print("Gt Keypoints:", gt_keypoints[:,:2])
            optimizer.zero_grad()
            out = model(images[i].unsqueeze(0), data)
            print("Predicted Keypoints:", out)
            loss = criterion(out, gt_keypoints[:, :2].to(device))
            loss.backward()
            optimizer.step()
    print(f'Epoch:{epoch} and Loss:{loss.item()}')

In [None]:
gnn_model = HybridModel().to(device)
weights_path = '/home/jc-merlab/Pictures/Data/trained_models/keypointsrcnn_weights_sim_b1_e25_v0.pth'
cnn_model = torch.load(weights_path).to(device)
image = '/home/jc-merlab/Pictures/Data/2023-08-14-Occluded/002654.rgb.jpg'
image = Image.open(image).convert("RGB")

def predict_keypoints(cnn_model, gnn_model, image):
    gnn_model.eval()
    cnn_model.eval()
    image = F.to_tensor(image).to(device)
#     image = list(image)    
    with torch.no_grad():
        output = cnn_model([image])  
        scores = output[0]['scores'].detach().cpu().numpy()
        high_scores_idxs = np.where(scores > 0.7)[0].tolist() # Indexes of boxes with scores > 0.7
        post_nms_idxs = torchvision.ops.nms(output[0]['boxes'][high_scores_idxs], output[0]['scores'][high_scores_idxs], 0.3).cpu().numpy() 
        keypoints = []
        labels = []
        for kps in output[0]['keypoints'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy():
            keypoints.append(list(map(int, kps[0,0:2])))        
        for label in output[0]['labels'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy():
            labels.append(label)
        initial_keypoints = [x for _,x in sorted(zip(labels,keypoints))]
        print(initial_keypoints)
        data = construct_graph_for_prediction(initial_keypoints)
        data = data.to(device)
        predicted_keypoints = gnn_model(image.unsqueeze(0), data).cpu().numpy()
    print(predicted_keypoints)
    return predicted_keypoints, initial_keypoints

In [None]:
predicted_keypoints, gt_keypoints = predict_keypoints(cnn_model, gnn_model, image)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

def visualize_keypoints(image_path, keypoints, gt_keypoints):
    """
    Visualize the keypoints on an image.
    
    Args:
    - image_path (str): Path to the image.
    - keypoints (np.array): Array of keypoints, assumed to be in (x, y) format.
    """
    
    # Load the image
#     img = Image.open(image_path).convert("RGB")
    
    # Create a figure and axis
    fig, ax = plt.subplots(1)
    
    # Display the image
    ax.imshow(image_path)
    print(type(keypoints))
    # Extract the x and y coordinates
    x_coords = keypoints[:, 0]
    y_coords = keypoints[:, 1]
    
    print(type(gt_keypoints))
    gt_keypoints = np.array(gt_keypoints)
    
    x_gt = gt_keypoints[:, 0]
    y_gt = gt_keypoints[:, 1]
    
    # Plot the keypoints
    ax.scatter(x_coords, y_coords, c='r', s=40, label="Keypoints")
    ax.scatter(x_gt, y_gt, c='b', s=40, label="gt_keypoints")
    
    # Show the image with keypoints
    plt.legend()
    plt.show()


In [None]:
visualize_keypoints(image, predicted_keypoints, gt_keypoints)