In [6]:
import cv2
import mediapipe as mp
import time
import numpy as np
import joblib
from safetensors.torch import load_file
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GraphConv, global_mean_pool
from torch_geometric.utils import to_undirected
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

In [7]:
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu'

In [8]:
mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles
mp_hands = mp.solutions.hands

In [9]:
class GCN(nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(input_channels, hidden_channels)
        self.conv2 = GraphConv(hidden_channels, hidden_channels)
        self.lin = nn.Linear(hidden_channels, output_channels)
        
    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)

        x = global_mean_pool(x, batch)

        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)

        return x

model = GCN(2, 32, 6).to(device)
print(model)

GCN(
  (conv1): GraphConv(2, 32)
  (conv2): GraphConv(32, 32)
  (lin): Linear(in_features=32, out_features=6, bias=True)
)


In [None]:
state_dict = load_file("gnn.safetensors")
model.load_state_dict(state_dict)
le = joblib.load("label_encoder.pkl")

In [11]:
# make edge index
edge_index = torch.tensor([
    [0, 1, 2, 4, 4, 5, 6, 8, 8, 9, 10, 12, 12, 13, 14, 16, 17, 18],
    [1, 2, 3, 5, 8, 6, 7, 9, 12, 10, 11, 16, 13, 14, 15, 17, 18, 19]
    ], dtype=torch.long).to(device)

# make it undirected
edge_index = to_undirected(edge_index)

In [12]:
def graphify(c):
    c = np.array(c[1:])
    c = c.reshape(20, 2)
    
    data = Data(
        x=torch.tensor(c, dtype=torch.float).to(device), 
        edge_index=edge_index,
    )

    data = DataLoader([data], batch_size=1)

    return data

In [13]:
cap = cv2.VideoCapture(0)

model.eval()
with mp_hands.Hands(
    model_complexity = 0,
    min_detection_confidence = 0.5,
    min_tracking_confidence = 0.5
) as hands:
    
    while cap.isOpened():
        label = ""

        time_start = time.time()

        retval, frame = cap.read()

        if not retval:
            print("Camera Error; Exiting")
            break

        frame.flags.writeable = False
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        results = hands.process(frame)
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)

        if results.multi_hand_landmarks:
            for hand_landmarks in results.multi_hand_landmarks:
                
                data = graphify([[lmk.x, lmk.y] for lmk in hand_landmarks.landmark])
                data = next(iter(data))

                label = model(data.x, data.edge_index, data.batch).argmax().item()
                label = le.inverse_transform([label])[0]

                mp_drawing.draw_landmarks(
                    frame,
                    hand_landmarks,
                    mp_hands.HAND_CONNECTIONS,
                    mp_drawing_styles.get_default_hand_landmarks_style(),
                    mp_drawing_styles.get_default_hand_connections_style()
                )

        frame = cv2.copyMakeBorder(frame, 0, 50, 0, 0, cv2.BORDER_CONSTANT)

        fps = 1 / (time.time() - time_start)

        frame = cv2.putText(frame, f"FPS: {round(fps)}", (10,30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 3)
        frame = cv2.putText(frame, f"{label}", (10,frame.shape[1]-130), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 3)

        cv2.imshow("Hand Landmarks", frame)

        if cv2.waitKey(1) == ord('q'):
            break

cap.release()
cv2.destroyAllWindows()

In [14]:
cap.release()
cv2.destroyAllWindows()