### import modules

In [3]:
import cv2
import torch
import mediapipe as mp

### define model

In [4]:
import torch.nn as nn

class FallDetectionLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(FallDetectionLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):
        # Initialize hidden and cell states
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)

        # Forward propagate LSTM
        out, _ = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size)

        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])

        out  = torch.sigmoid(out)
        return out


### model parameters and pose detection initialization

In [None]:
input_size=132
hidden_size=132
num_layers=3
mp_pose = mp.solutions.pose
pose = mp_pose.Pose(static_image_mode=False, 
                    model_complexity=1, 
                    enable_segmentation=False, 
                    min_detection_confidence=0.5)

### create an instance of the model 

In [None]:
model = FallDetectionLSTM(input_size, hidden_size, num_layers)
model.load_state_dict(torch.load('lstm_model.pth'))
model.eval()

### predict on a recorded video

In [None]:
cap = cv2.VideoCapture("/Users/varunshankarhoskere/Downloads/WhatsApp Video 2023-11-22 at 12.45.34.mp4")

while True:
    ret, frame = cap.read()

    # Process the frame and detect the pose using MediaPipe
    results = pose.process(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

    # Check if pose landmarks are detected
    if results.pose_landmarks:
        # Extract pose landmarks and convert to tensor
        pose_landmarks = torch.tensor([[lm.x, lm.y, lm.z, lm.visibility] for lm in results.pose_landmarks.landmark]).flatten()
        pose_landmarks = pose_landmarks.unsqueeze(0).unsqueeze(0)  # Add batch and sequence dimensions

        # Make prediction using the LSTM model
        with torch.no_grad():
            output = model(pose_landmarks)
            predicted_label = (output > 0.5).float().item()

        

        # Display the result
        if predicted_label == 1:
            
            cv2.putText(frame, "Fall Detected!", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
        else:
            cv2.putText(frame, "No Fall Detected", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)

    # Display the frame
    cv2.imshow('Fall Detection', frame)

    # Break the loop when 'q' key is pressed
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

# Release the camera and close all windows
cap.release()
cv2.destroyAllWindows()

: 