# Final Project: ASL Recognition
### Professor: Weizhe Li
### Student: Levan Sulimanov

# Getting the model defined:

In [1]:
# Model:
###################################### CREDITS ######################################
# https://learnopencv.com/human-action-recognition-using-detectron2-and-lstm/
#####################################################################################

import numpy as np
import torch.optim as optim
import torchmetrics
import pytorch_lightning as pl

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import time
from numpy import genfromtxt

class PoseDataset(Dataset):
    def __init__(self, data_arr):
        self.data_arr = data_arr

    def __len__(self):
        return len(self.data_arr)

    def __getitem__(self, idx):
        X_path = self.data_arr[idx][0]
        y = self.data_arr[idx][1]
        
        X = torch.from_numpy(genfromtxt(X_path, delimiter=',')).float()
        
        # print("X:", X.shape)
        # print("y:", y)
        
        return X, y


class PoseDataModule(pl.LightningDataModule):
    def __init__(self, data_root, batch_size=32):
        super().__init__()

        self.data_root = data_root
        self.batch_size = batch_size


    def load_X_and_Y(self, data_path, train_mode="train"):
        
        data_x_y_arr = []
        
        read_from_dir = os.path.join(data_path, train_mode)
        
        # go over each class folder, get subsamples per class, and return tuples [sample_path, class_number]
        for class_folder in os.listdir(read_from_dir):
            class_abs_path = os.path.join(read_from_dir, class_folder)
            for sample in os.listdir(class_abs_path):
                sample_path = os.path.join(class_abs_path, sample)
                class_num = class_to_label_num[class_folder]
                data_x_y_arr.append([sample_path, class_num])

        # print(f"Collected full list for {train_mode}:\n{data_x_y_arr}")
        return data_x_y_arr
            

    def setup(self, stage=None):
        train_data_arr = self.load_X_and_Y(self.data_root, train_mode="train")
        eval_data_arr = self.load_X_and_Y(self.data_root, train_mode="val")
        # test_data_arr = self.load_X_and_Y(self.data_root, train_mode="test")
        self.train_dataset = PoseDataset(train_data_arr)
        self.val_dataset = PoseDataset(eval_data_arr)
        # self.test_dataset = PoseDataset(test_data_arr)

    def train_dataloader(self):
        # train loader
        train_loader = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=0
        )
        return train_loader

    def val_dataloader(self):
        # validation loader
        val_loader = torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=0
        )
        return val_loader
    '''
    def test_dataloader(self):
        # validation loader
        test_loader = torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=8
        )
        return test_loader
    '''

# We have 6 output action classes.
TOT_ACTION_CLASSES = 9 #6

#lstm classifier definition
class ActionClassificationLSTM(pl.LightningModule):
    # initialise method
    def __init__(self, input_features, hidden_dim, learning_rate=0.001):
        super().__init__()
        # save hyperparameters
        self.save_hyperparameters()
        # The LSTM takes word embeddings as inputs, and outputs hidden states
        # with dimensionality hidden_dim.
        self.lstm = nn.LSTM(input_features, hidden_dim, batch_first=True)
        # The linear layer that maps from hidden state space to classes
        self.linear = nn.Linear(hidden_dim, TOT_ACTION_CLASSES)

    def forward(self, x):
        # invoke lstm layer
        lstm_out, (ht, ct) = self.lstm(x)
        # invoke linear layer
        return self.linear(ht[-1])

    def training_step(self, batch, batch_idx):
        # get data and labels from batch
        x, y = batch
        # reduce dimension
        y = torch.squeeze(y)
        # convert to long
        y = y.long()
        # get prediction
        y_pred = self(x)
        # calculate loss
        loss = F.cross_entropy(y_pred, y)
        # get probability score using softmax
        prob = F.softmax(y_pred, dim=1)
        # get the index of the max probability
        pred = prob.data.max(dim=1)[1]
        # calculate accuracy
        acc = torchmetrics.functional.accuracy(pred, y, task="multiclass", num_classes=TOT_ACTION_CLASSES)
        dic = {
            'batch_train_loss': loss,
            'batch_train_acc': acc
        }
        # log the metrics for pytorch lightning progress bar or any other operations
        self.log('batch_train_loss', loss, prog_bar=True)
        self.log('batch_train_acc', acc, prog_bar=True)
        #return loss and dict
        return {'loss': loss, 'result': dic}

    def training_epoch_end(self, training_step_outputs):
        # calculate average training loss end of the epoch
        avg_train_loss = torch.tensor([x['result']['batch_train_loss'] for x in training_step_outputs]).mean()
        # calculate average training accuracy end of the epoch
        avg_train_acc = torch.tensor([x['result']['batch_train_acc'] for x in training_step_outputs]).mean()
        # log the metrics for pytorch lightning progress bar and any further processing
        self.log('train_loss', avg_train_loss, prog_bar=True)
        self.log('train_acc', avg_train_acc, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        # get data and labels from batch
        x, y = batch
        # reduce dimension
        y = torch.squeeze(y)
        # convert to long
        y = y.long()
        # get prediction
        y_pred = self(x)
        # calculate loss
        loss = F.cross_entropy(y_pred, y)
        # get probability score using softmax
        prob = F.softmax(y_pred, dim=1)
        # get the index of the max probability
        pred = prob.data.max(dim=1)[1]
        # calculate accuracy
        acc = torchmetrics.functional.accuracy(pred, y, task="multiclass", num_classes=TOT_ACTION_CLASSES)
        dic = {
            'batch_val_loss': loss,
            'batch_val_acc': acc
        }
        # log the metrics for pytorch lightning progress bar and any further processing
        self.log('batch_val_loss', loss, prog_bar=True)
        self.log('batch_val_acc', acc, prog_bar=True)
        #return dict
        return dic

    def validation_epoch_end(self, validation_step_outputs):
        # calculate average validation loss end of the epoch
        avg_val_loss = torch.tensor([x['batch_val_loss']
                                     for x in validation_step_outputs]).mean()
        # calculate average validation accuracy end of the epoch
        avg_val_acc = torch.tensor([x['batch_val_acc']
                                    for x in validation_step_outputs]).mean()
        # log the metrics for pytorch lightning progress bar and any further processing
        self.log('val_loss', avg_val_loss, prog_bar=True)
        self.log('val_acc', avg_val_acc, prog_bar=True)

    def configure_optimizers(self):
        # adam optimiser
        optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        # learning rate reducer scheduler
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, min_lr=1e-15, verbose=True)
        # scheduler reduces learning rate based on the value of val_loss metric
        return {"optimizer": optimizer,
                "lr_scheduler": {"scheduler": scheduler, "interval": "epoch", "frequency": 1, "monitor": "val_loss"}}

# Main Real-Time Core Part of Exegete:

In [2]:
# Library imports:
import os
import cv2
import mediapipe as mp
import time
from PIL import Image, ImageOps
import numpy as np
# from src.lstm import ActionClassificationLSTM
import torch
import torch.nn.functional as F
import winsound
from threading import Thread
from itertools import groupby
import argparse
from datetime import datetime
from pyfiglet import Figlet
import traceback
import imutils


# Define Fixed Video Size
MAIN_WIDTH, MAIN_HEIGHT = 640, 480


# Define output labels (class numbers association from LSTM model)
LABELS = {0: "hello", 1: "my", 2: "world", 3: "me", 4: "every", 5: "moment", 6: "is", 7: "new", 8: "beginning"}

# add directory, if it does not exists
def mkdir_if_none(dir):
    if not os.path.exists(dir):
        os.makedirs(dir)


# Pad the frame to specified size
def padding(img, expected_size):
    desired_size = expected_size
    delta_width = desired_size - img.size[0]
    delta_height = desired_size - img.size[1]
    # get centers
    pad_width = delta_width // 2
    pad_height = delta_height // 2
    # add padding:
    padding = (pad_width, pad_height, delta_width - pad_width, delta_height - pad_height)
    return ImageOps.expand(img, padding)

# resize the frame to specified size
def resize_with_padding(img, expected_size):
    delta_width = expected_size[0] - img.size[0]
    delta_height = expected_size[1] - img.size[1]
    pad_width = delta_width // 2
    pad_height = delta_height // 2
    padding = (pad_width, pad_height, delta_width - pad_width, delta_height - pad_height)
    return ImageOps.expand(img, padding)

# total number of joints detected (hands, elbows, and shoulders):
MAX_NUM_OF_XY_KEYPOINTS_LIST = 46

# if something errored out, return pre-defined array:
BACKUP_ARRAY = np.array([2.0, 2.0] * MAX_NUM_OF_XY_KEYPOINTS_LIST)

    
# just get keypoints from frame:
def get_keypoints_from_frame(frame, results, mp_holistic, switch_turned_on, hand_keypoint_default=21, verbose=True):
    
    try:
    
        # if it came from OpenCV, we need to switch channel order:
        frame = imutils.resize(frame, width=MAIN_WIDTH)
        pil = Image.fromarray(frame)
        frame = cv2.cvtColor(np.array(resize_with_padding(pil, (MAIN_WIDTH, MAIN_HEIGHT))), cv2.COLOR_RGB2BGR)

        # make detections:
        
        frame_coordinates = []
            
        ################################################################################
        # get coordinates:
        # hand points:
        # RIGHT:
        if results.right_hand_landmarks:
            for r_h in results.right_hand_landmarks.landmark:
                frame_coordinates.append(r_h.x)
                frame_coordinates.append(r_h.y)
        else:
            for r_h in range(21):
                frame_coordinates.append(2.0)
                frame_coordinates.append(2.0)

        # LEFT:
        if results.left_hand_landmarks:
            for l_h in results.left_hand_landmarks.landmark:
                frame_coordinates.append(l_h.x)
                frame_coordinates.append(l_h.y)
        else:
            for l_h in range(21):
                frame_coordinates.append(2.0)
                frame_coordinates.append(2.0)

        # SHOULDERS AND ELBOWS:
        if results.pose_landmarks:
            # top torse keypoints:
            r_elbow = results.pose_landmarks.landmark[mp_holistic.PoseLandmark.RIGHT_ELBOW]
            frame_coordinates.append(r_elbow.x)
            frame_coordinates.append(r_elbow.y)

            l_elbow = results.pose_landmarks.landmark[mp_holistic.PoseLandmark.LEFT_ELBOW]
            frame_coordinates.append(l_elbow.x)
            frame_coordinates.append(l_elbow.y)

            r_shoulder = results.pose_landmarks.landmark[mp_holistic.PoseLandmark.RIGHT_SHOULDER]
            frame_coordinates.append(r_shoulder.x)
            frame_coordinates.append(r_shoulder.y)

            l_shoulder = results.pose_landmarks.landmark[mp_holistic.PoseLandmark.LEFT_SHOULDER]
            frame_coordinates.append(l_shoulder.x)
            frame_coordinates.append(l_shoulder.y)
        else:
            frame_coordinates.append(2.0)  # fake x axis for r_elbow
            frame_coordinates.append(2.0)  # fake y axis for r_elbow
            
            frame_coordinates.append(2.0)  # fake x axis for l_elbow
            frame_coordinates.append(2.0)  # fake y axis for l_elbow
            
            frame_coordinates.append(2.0)  # fake x axis for r_shoulder
            frame_coordinates.append(2.0)  # fake y axis for r_shoulder
            
            frame_coordinates.append(2.0)  # fake x axis for l_shoulder
            frame_coordinates.append(2.0)  # fake y axis for l_shoulder

        if verbose:
            for coord_idx in range(0, len(frame_coordinates), 2):
                lm = frame_coordinates[coord_idx], frame_coordinates[coord_idx+1]
                cx, cy = int(lm[0]*frame.shape[1]), int(lm[1]*frame.shape[0])
                if switch_turned_on:
                    cv2.circle(frame, (cx, cy), 4, (0,255,0), cv2.FILLED)
                else:
                    cv2.circle(frame, (cx, cy), 4, (0,0,255), cv2.FILLED)
        ################################################################################
        return np.array(frame_coordinates), frame
        # return frame_coordinates, frame
    except:
        print("<<<ERROR IN GETTING KEYPOINTS>>>")
        print(traceback.format_exc())
        print("---")
        return BACKUP_ARRAY, frame
    

    
def detect_sign(sign_model, buffer):
    
    try:
        # otherwise, if we our buffer filled up, load it into the model
        model_input = torch.Tensor(np.array(buffer, dtype=np.float32))
        model_input = torch.unsqueeze(model_input, dim=0)
        y_pred = sign_model(model_input)
        prob = F.softmax(y_pred, dim=1)
        pred_index = prob.data.max(dim=1)[1].item()
        # pop oldest item out, to let new frame in
        # Retreive predicted label and convert it to its associated string
        label = LABELS[pred_index]
        conf = prob.data[0][pred_index].item()
        return label, conf
    except:
        return None, 0.0

# Main ASL Exagete pipeline
def process_images(model_weights_path, capture_system_path,
                   fps=30, coord_buffer_len=30, thresh=.70, running_string_line_len=30, filtered=False):

    try:
        # setup our LSTM and load its weights
        model_path_dir = model_weights_path

        lstm_classifier = ActionClassificationLSTM.load_from_checkpoint(model_path_dir)
        lstm_classifier.eval()


        # NEWNEWNENWE#####
        # setup MediaPipe:
        mp_holistic = mp.solutions.holistic
        print(f"mp.solutions: {mp.solutions}")
        ##################

        # Setup webcam to be default ID
        device_capture = capture_system_path

        use_CAP_DSHOW = True

        # If Online Stream, convert it to video frames through the pafy libary
        if (not (capture_system_path.isdigit())) and ("http" in capture_system_path):
            print("Connecting to Video Stream...")
            url = capture_system_path
            use_CAP_DSHOW = False
            video = pafy.new(url)
            best = video.getbest(preftype="mp4")
            device_capture = best.url
            print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
            print("Selected Device:", capture_system_path)

        # If it's offline, point device capture to that offline video path
        elif (not (capture_system_path.isdigit())) and (not ("http" in capture_system_path)):
            device_capture = capture_system_path
            use_CAP_DSHOW = False
            print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
            print("Selected Device:", device_capture)

        # Otherwise, select provided Device's ID
        else:
            print("#######################")
            print("Connecting to Device...")
            print("#######################\n")
            device_capture = int(capture_system_path)
            use_CAP_DSHOW = True
            print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
            print("Selected Device:", device_capture)


        print("Selected Model Path:", model_path_dir)
        print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n")

        print("<<<           ASL is called        >>>")
        print("<<<      Starting Recognition!!!   >>>\n")


        # Start getting frames from the chosen device / stream / video
        if use_CAP_DSHOW:
            cap = cv2.VideoCapture(device_capture, cv2.CAP_DSHOW)
        else:
            cap = cv2.VideoCapture(device_capture)
        cap.set(cv2.CAP_PROP_FPS, fps)

        # Add buffer - list of past 32 consecutive pose coordinates (from last 32 frames)
        coord_buffer = []
        coord_buffer_max = coord_buffer_len

        recognized_string = " " * running_string_line_len
        recognized_string_max = running_string_line_len

        success = True
        try:
            _, frame = cap.read()
            # WARNING <<< IF YOU WILL CHANGE SHAPE OF IMAGE (FOR SPEED), do it here too
            height, width = frame.shape[0], frame.shape[1]
        except:
            print("Failed to read first frame")
            success = False
            cap.release()
            cv2.destroyAllWindows()
        
        frame_count = 0
        switch_turned_on = True
        
        if success:
            with mp_holistic.Holistic(min_detection_confidence=0.5,
                                      min_tracking_confidence=0.5,
                                      static_image_mode=False) as holistic:
                while cap.isOpened():
                    
                    if len(recognized_string) >= recognized_string_max:
                        recognized_string = ""
                    
                    ret, frame = cap.read()

                    # recolor feed:
                    # image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

                    # make detections:
                    results = holistic.process(frame)

                    main_coordinates = []     

                    ################################################################################
                    # get coordinates:
                    curr_points_set, frame = get_keypoints_from_frame(frame, results, mp_holistic, switch_turned_on, hand_keypoint_default=21, verbose=True)

                    if len(coord_buffer)<coord_buffer_max:
                        coord_buffer.append(curr_points_set)
                    else:
                        label, conf = detect_sign(lstm_classifier, coord_buffer)
                        if conf >= thresh:
                            recognized_string += f" {label}"
                            recognized_string = recognized_string[-recognized_string_max:]
                        else:
                            print(f"{label} | {round(conf*100, 2)}")
                            label = None
                        coord_buffer = []
                        coord_buffer.append(curr_points_set)


                    # overlay recognized string to frame:
                    cv2.putText(frame, recognized_string.lstrip(),
                                   (10, MAIN_HEIGHT-50), cv2.FONT_HERSHEY_COMPLEX, 0.9, (102, 255, 255), 4)
                    ################################################################################
                    if frame_count % fps == 0:
                        if switch_turned_on:
                            switch_turned_on = False
                        else:
                            switch_turned_on = True
                    # if switch_turned_on:
                    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    cv2.imshow('Raw Webcam Feed', frame)
                    frame_count+=1
                    k = cv2.waitKey(1)
                    if k == 27: # if Esc pressed
                        break

            cap.release()
            cv2.destroyAllWindows()

            print("\n")
            print("<<<      Closing ASL Exagete      >>>\n")
    except:
        print("<<<Something went wrong while running ASL Translation... Refer to the error description below.>>>")
        print(traceback.format_exc())
        print("===========\n")

def main():

    # stylistic printout of ASL Exagete software
    f = Figlet(font='slant')
    print(f.renderText('--------'))
    print(f.renderText('Exegete'))
    print(f.renderText('--------'))
    time.sleep(2.5)

    # let user pass model_Weights name and video device (webcam ID, offline or online stream video)
    parser = argparse.ArgumentParser(description="Initialize ASL Exagete's settings.")
    parser.add_argument('--model_dir', type=str, help='Model Weigths path from ./model dir.', default=os.path.join(os.getcwd(), "models"))
    parser.add_argument('--model_name', type=str, help='Model Weigths path from ./model dir.', default="9_class_model_base_lstm.ckpt")
    parser.add_argument('--device_path', type=str, help='Either pass *device number* for WebCam, or YouTube *Link* for certain video stream (offline or online).', default="0")

    args, unknown = parser.parse_known_args()
    
    # set model weights path
    # filter out extension, to add the correct model weights name
    if args.model_name.endswith(".ckpt"):
        pass
    else:
        args.model_name = f"{args.model_name}.ckpt"
    model_name = os.path.join(args.model_dir, args.model_name)

    # set camera ID or stream path:
    capture_system_path = str(args.device_path)

    # start our ASL Exagete software
    process_images(model_name, capture_system_path, thresh=.85)
    
main()

                                                 
                                                 
 ________________________________________________
/_____/_____/_____/_____/_____/_____/_____/_____/
                                                 
                                                 

    ______                     __     
   / ____/  _____  ____ ____  / /____ 
  / __/ | |/_/ _ \/ __ `/ _ \/ __/ _ \
 / /____>  </  __/ /_/ /  __/ /_/  __/
/_____/_/|_|\___/\__, /\___/\__/\___/ 
                /____/                

                                                 
                                                 
 ________________________________________________
/_____/_____/_____/_____/_____/_____/_____/_____/
                                                 
                                                 

mp.solutions: <module 'mediapipe.python.solutions' from 'C:\\Users\\lrspr\\AppData\\Roaming\\Python\\Python310\\site-packages\\mediapipe\\python\\solutions\\__init