In [None]:
# Force reinstall numpy and PyTorch/Torchvision to resolve potential conflicts
!pip uninstall numpy torch torchvision torchaudio -y
!pip install --no-cache-dir -I numpy torch torchvision torchaudio

# -*- coding: utf-8 -*-
"""
Fall Detection using Mediapipe and RNN in Google Colab (Revised Paths).ipynb

Automatically generated by Colaboratory.

Adapts to train_data / test_data structure.
"""

# ==============================================================================
# Cell 1: Setup Environment - Mount Drive & Install Libraries
# ==============================================================================
import os
import sys
import subprocess
from google.colab import drive

# Mount Google Drive
print("Mounting Google Drive...")
drive.mount('/content/drive')
print("Google Drive mounted successfully.")

# Install necessary libraries
print("Installing libraries...")
try:
    import mediapipe
    import torch
    import torchvision
    import torchaudio
    print("Libraries already installed or installation not needed in this environment.")
except ImportError:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "mediapipe", "torch", "torchvision", "torchaudio", "opencv-python"])
    print("Libraries installed successfully.")

# Import remaining libraries
import cv2
import numpy as np
import mediapipe as mp
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
# from sklearn.model_selection import train_test_split # No longer needed for this structure
import time
import pickle 
import glob # For finding files

print("All necessary libraries are imported.")

# ==============================================================================
# Cell 2: Configuration - Define Paths and Hyperparameters
# ==============================================================================

# --- User Defined Paths ---
BASE_DRIVE_PATH = '/content/drive/MyDrive/dataset' # Main folder 

TRAIN_DATA_DIR = os.path.join(BASE_DRIVE_PATH, 'train_data')
TEST_DATA_DIR = os.path.join(BASE_DRIVE_PATH, 'test_data')

CLASS_FOLDERS = ["backward_fall", "forward_fall", "side_fall", "non_fall"]

FEATURES_OUTPUT_DIR = os.path.join(BASE_DRIVE_PATH, 'processed_features') # Combined features folder
MODEL_SAVE_DIR = os.path.join(BASE_DRIVE_PATH, 'trained_models')
MODEL_NAME = 'fall_detection_rnn_v2.pth' 

# --- Preprocessing Settings ---
FRAME_SKIP = 3  
SEQUENCE_LENGTH = 30 

# --- Model & Training Hyperparameters ---
NUM_CLASSES = len(CLASS_FOLDERS)
INPUT_SIZE = 33 * 3  # 33 landmarks * 3 coordinates (x, y, visibility)
HIDDEN_SIZE = 128
NUM_LAYERS = 2
BATCH_SIZE = 32
LEARNING_RATE = 0.001
NUM_EPOCHS = 50
# VALIDATION_SPLIT = 0.2 # Removed - Using dedicated test_data folder for validation/testing

# --- Device Configuration ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# --- Create output directories if they don't exist ---
os.makedirs(FEATURES_OUTPUT_DIR, exist_ok=True)
# Create class subdirectories within features directory if they don't exist
for class_name in CLASS_FOLDERS:
    os.makedirs(os.path.join(FEATURES_OUTPUT_DIR, class_name), exist_ok=True)
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)

# --- Define class to index mapping ---
label_map = {name: i for i, name in enumerate(CLASS_FOLDERS)}
index_to_name = {i: name for name, i in label_map.items()}

print("Configuration loaded:")
print(f"  Base Path: {BASE_DRIVE_PATH}")
print(f"  Train Data Path: {TRAIN_DATA_DIR}")
print(f"  Test Data Path: {TEST_DATA_DIR}")
print(f"  Feature Path: {FEATURES_OUTPUT_DIR}")
print(f"  Model Save Path: {MODEL_SAVE_DIR}")
print(f"  Classes: {CLASS_FOLDERS}")
print(f"  Label Map: {label_map}")
print(f"  Frame Skip: {FRAME_SKIP}")
print(f"  Sequence Length: {SEQUENCE_LENGTH}")


# ==============================================================================
# Cell 3: Data Loading & Preprocessing - Frame & Feature Extraction
# ==============================================================================

# Initialize Mediapipe Pose
mp_pose = mp.solutions.pose
pose = mp_pose.Pose(static_image_mode=False, # Process video stream
                    model_complexity=1,      
                    smooth_landmarks=True,
                    enable_segmentation=False,
                    min_detection_confidence=0.5,
                    min_tracking_confidence=0.5)

mp_drawing = mp.solutions.drawing_utils 

def extract_keypoints_from_video(video_path, frame_skip, sequence_length):
    """
    Extracts Mediapipe keypoints from a video file. Handles sequence padding/truncation.
    (Same function as before)
    """
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error: Could not open video {video_path}")
        return None

    frame_count = 0
    keypoints_sequence = []

    while cap.isOpened():
        success, frame = cap.read()
        if not success: break 

        if frame_count % frame_skip == 0:
            try:
                image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                image_rgb.flags.writeable = False
                results = pose.process(image_rgb)
                if results.pose_landmarks:
                    frame_keypoints = []
                    for landmark in results.pose_landmarks.landmark:
                        frame_keypoints.extend([landmark.x, landmark.y, landmark.visibility])
                    keypoints_sequence.append(frame_keypoints)
                else:
                    keypoints_sequence.append(np.zeros(INPUT_SIZE).tolist()) # Pad with zeros if no pose
            except Exception as e:
                print(f"Error processing frame {frame_count} in {video_path}: {e}")
                keypoints_sequence.append(np.zeros(INPUT_SIZE).tolist())

        frame_count += 1
    cap.release()

    num_extracted_frames = len(keypoints_sequence)
    if num_extracted_frames == 0:
        print(f"Warning: No keypoints extracted from {video_path}")
        return None

    final_sequence = np.zeros((sequence_length, INPUT_SIZE), dtype=np.float32)
    if num_extracted_frames >= sequence_length:
        final_sequence = np.array(keypoints_sequence[:sequence_length])
    else:
        sequence_array = np.array(keypoints_sequence)
        final_sequence[:num_extracted_frames] = sequence_array


    return final_sequence


def process_data_folder(data_directory, class_folders, label_map, feature_output_dir):
    """Processes videos in a given directory (train or test) and saves features."""
    all_feature_files = []
    all_labels = []
    print(f"\nProcessing data source: {data_directory}")

    for class_name in class_folders:
        class_label = label_map[class_name]
        video_folder_abs = os.path.join(data_directory, class_name)
        # Features are saved in a common structure under FEATURES_OUTPUT_DIR
        feature_class_dir = os.path.join(feature_output_dir, class_name)

        print(f"\n  Processing class: {class_name} (Label: {class_label})")
        print(f"    Video folder: {video_folder_abs}")

        video_files = glob.glob(os.path.join(video_folder_abs, '*.mp4')) + \
                      glob.glob(os.path.join(video_folder_abs, '*.avi')) # Add other extensions

        if not video_files:
            print(f"    Warning: No video files found in {video_folder_abs}")
            continue

        for video_path in video_files:
            video_filename = os.path.basename(video_path)
            feature_filename = os.path.splitext(video_filename)[0] + '.pkl'
            feature_save_path = os.path.join(feature_class_dir, feature_filename)

            if os.path.exists(feature_save_path):
                print(f"    Skipping {video_filename}, features already exist at {feature_save_path}")
                all_feature_files.append(feature_save_path)
                all_labels.append(class_label)
                continue

            print(f"    Processing {video_filename}...")
            keypoints_data = extract_keypoints_from_video(video_path, FRAME_SKIP, SEQUENCE_LENGTH)

            if keypoints_data is not None:
                try:
                    with open(feature_save_path, 'wb') as f:
                        pickle.dump(keypoints_data, f)
                    all_feature_files.append(feature_save_path)
                    all_labels.append(class_label)
                    print(f"      Saved features to {feature_save_path}")
                except Exception as e:
                    print(f"      Error saving features for {video_filename}: {e}")
            else:
                print(f"      Failed to extract features for {video_filename}.")

    return all_feature_files, all_labels


# --- Main Preprocessing ---
print("\nStarting Feature Extraction...")
start_time = time.time()

# Process Training Data
train_feature_files, train_labels = process_data_folder(
    TRAIN_DATA_DIR, CLASS_FOLDERS, label_map, FEATURES_OUTPUT_DIR
)

# Process Testing Data
test_feature_files, test_labels = process_data_folder(
    TEST_DATA_DIR, CLASS_FOLDERS, label_map, FEATURES_OUTPUT_DIR
)

end_time = time.time()
print(f"\nFeature extraction completed in {end_time - start_time:.2f} seconds.")
print(f"Total training feature files processed/loaded: {len(train_feature_files)}")
print(f"Total testing feature files processed/loaded: {len(test_feature_files)}")

# Close the pose object when done
pose.close()


# ==============================================================================
# Cell 4: Dataset and DataLoader Preparation
# ==============================================================================

# --- Check if data was loaded ---
if not train_feature_files:
     raise ValueError("No training feature files were processed or found. Check train_data paths and preprocessing steps.")
if not test_feature_files:
     print("Warning: No testing feature files were processed or found. Check test_data paths. Proceeding without test data evaluation during training.")

# --- Define PyTorch Dataset ---
class PoseSequenceDataset(Dataset):
    def __init__(self, feature_paths, labels):
        self.feature_paths = feature_paths
        self.labels = labels

    def __len__(self):
        return len(self.feature_paths)

    def __getitem__(self, idx):
        feature_path = self.feature_paths[idx]
        label = self.labels[idx]
        try:
            with open(feature_path, 'rb') as f:
                sequence_data = pickle.load(f)
            sequence_tensor = torch.tensor(sequence_data, dtype=torch.float32)
            label_tensor = torch.tensor(label, dtype=torch.long)
            return sequence_tensor, label_tensor
        except Exception as e:
            print(f"Error loading or processing file {feature_path}: {e}")
            # Return dummy data matching expected shapes/types
            return torch.zeros((SEQUENCE_LENGTH, INPUT_SIZE), dtype=torch.float32), torch.tensor(0, dtype=torch.long)

# --- Create Datasets and DataLoaders ---
# No train_test_split needed as we have separate folders
train_dataset = PoseSequenceDataset(train_feature_files, train_labels)
test_dataset = PoseSequenceDataset(test_feature_files, test_labels) # Using test data for evaluation

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
# Use test_loader for evaluation during training (acting as validation)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print("Datasets and DataLoaders created.")
print(f"  Training samples: {len(train_dataset)}")
print(f"  Testing samples: {len(test_dataset)}")


# ==============================================================================
# Cell 5: Model Building (RNN - LSTM/GRU)
# ==============================================================================

class FallDetectionRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes, rnn_type='LSTM'):
        super(FallDetectionRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn_type = rnn_type

        rnn_dropout = 0.2 if num_layers > 1 else 0 
        if rnn_type == 'LSTM':
            self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=rnn_dropout)
        elif rnn_type == 'GRU':
            self.rnn = nn.GRU(input_size, hidden_size, num_layers, batch_first=True, dropout=rnn_dropout)
        else:
            raise ValueError("Unsupported RNN type. Choose 'LSTM' or 'GRU'.")


        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        if self.rnn_type == 'LSTM':
            c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
            hidden = (h0, c0)
        else: # GRU
            hidden = h0

        out, _ = self.rnn(x, hidden)
        out = self.fc(out[:, -1, :]) # Use output of the last time step
        return out

# --- Instantiate the model ---
model = FallDetectionRNN(INPUT_SIZE, HIDDEN_SIZE, NUM_LAYERS, NUM_CLASSES, rnn_type='LSTM') # Or 'GRU'
model.to(DEVICE)

print("Model Architecture:")
print(model)
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {total_params}")


# ==============================================================================
# Cell 6: Training Setup - Loss Function and Optimizer
# ==============================================================================

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

print("Loss function and optimizer defined.")
print(f"  Criterion: {criterion}")
print(f"  Optimizer: Adam (LR={LEARNING_RATE})")

# ==============================================================================
# Cell 7: Model Training Loop
# ==============================================================================

print("\nStarting Training...")
best_test_accuracy = 0.0 # Saving based on test set performance
model_save_path = os.path.join(MODEL_SAVE_DIR, MODEL_NAME)

for epoch in range(NUM_EPOCHS):
    epoch_start_time = time.time()

    # --- Training Phase ---
    model.train()
    running_loss = 0.0
    correct_train = 0
    total_train = 0
    for i, (sequences, labels) in enumerate(train_loader):
        sequences, labels = sequences.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(sequences)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()

    train_loss = running_loss / len(train_loader)
    train_accuracy = 100 * correct_train / total_train

    # --- Testing Phase (using the test_data folder) ---
    model.eval()
    test_loss = 0.0
    correct_test = 0
    total_test = 0
    # Check if test_loader is available (in case test data was empty)
    if test_loader:
        with torch.no_grad():
            for sequences, labels in test_loader:
                sequences, labels = sequences.to(DEVICE), labels.to(DEVICE)
                outputs = model(sequences)
                loss = criterion(outputs, labels)
                test_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total_test += labels.size(0)
                correct_test += (predicted == labels).sum().item()

        test_loss /= len(test_loader)
        test_accuracy = 100 * correct_test / total_test
    else: # Handle case where there's no test data
        test_loss = 0.0
        test_accuracy = 0.0


    epoch_end_time = time.time()
    epoch_duration = epoch_end_time - epoch_start_time

    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] ({epoch_duration:.2f}s) | "
          f"Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}% | "
          f"Test Loss: {test_loss:.4f}, Test Acc: {test_accuracy:.2f}%") # Changed Val to Test

    # --- Save the best model based on test accuracy ---

    if test_loader and test_accuracy > best_test_accuracy:
        best_test_accuracy = test_accuracy
        torch.save(model.state_dict(), model_save_path)
        print(f"  -> New best model saved to {model_save_path} (Test Acc: {best_test_accuracy:.2f}%)") 

print("\nTraining Finished.")
if test_loader:
    print(f"Best Test Accuracy achieved during training: {best_test_accuracy:.2f}%")
    print(f"Model saved at: {model_save_path}")
else:
    print("Training finished, but no test data was available for evaluation.")
    torch.save(model.state_dict(), model_save_path) # Save the last model if no test data
    print(f"Last model state saved to {model_save_path}")


# ==============================================================================
# Cell 8: Testing Function
# ==============================================================================

def predict_video(video_path, trained_model, device, frame_skip, sequence_length, class_index_to_name):
    """
    Processes a single video, extracts keypoints, and predicts the class using the trained model.
    (Same function as before)
    """
    print(f"\nTesting video: {video_path}")
    mp_pose_pred = mp.solutions.pose
    pose_pred = mp_pose_pred.Pose(static_image_mode=False, model_complexity=1, smooth_landmarks=True,
                                  enable_segmentation=False, min_detection_confidence=0.5, min_tracking_confidence=0.5)

    keypoints_data = extract_keypoints_from_video(video_path, frame_skip, sequence_length)
    pose_pred.close()

    if keypoints_data is None:
        print("  -> Failed to extract features from the video.")
        return None

    sequence_tensor = torch.tensor(keypoints_data, dtype=torch.float32).unsqueeze(0).to(device)

    trained_model.eval()
    with torch.no_grad():
        outputs = trained_model(sequence_tensor)
        probabilities = torch.softmax(outputs, dim=1)
        confidence, predicted_idx = torch.max(probabilities, 1)
        predicted_class_index = predicted_idx.item()
        prediction_confidence = confidence.item()

    predicted_class_name = class_index_to_name.get(predicted_class_index, "Unknown")
    print(f"  -> Prediction: {predicted_class_name} (Confidence: {prediction_confidence:.4f})")
    return predicted_class_name


# ==============================================================================
# Cell 9: Example Usage of Testing Function
# ==============================================================================

# --- Load the saved model ---
print("\nLoading the trained model for testing...")
model_load_path = os.path.join(MODEL_SAVE_DIR, MODEL_NAME)

test_model = FallDetectionRNN(INPUT_SIZE, HIDDEN_SIZE, NUM_LAYERS, NUM_CLASSES, rnn_type='LSTM') # Ensure params match

try:
    # Load state dict considering the device
    map_location = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    test_model.load_state_dict(torch.load(model_load_path, map_location=map_location))
    test_model.to(DEVICE) # Move model to the correct device
    test_model.eval()
    print("Model loaded successfully.")
except FileNotFoundError:
    print(f"Error: Model file not found at {model_load_path}. Cannot perform testing.")
    test_model = None
except Exception as e:
    print(f"Error loading model state_dict: {e}")
    test_model = None

# --- Specify a test video ---
TEST_VIDEO_PATH = '/content/drive/MyDrive/dataset/test_data/side_fall/example_side_fall_test_video.mp4'


# --- Run prediction (if model loaded successfully) ---
if test_model is not None:
    if os.path.exists(TEST_VIDEO_PATH):
        predicted_class = predict_video(
            video_path=TEST_VIDEO_PATH,
            trained_model=test_model,
            device=DEVICE,
            frame_skip=FRAME_SKIP, 
            sequence_length=SEQUENCE_LENGTH,
            class_index_to_name=index_to_name
        )
        if predicted_class:
            print(f"\nFinal Prediction for {os.path.basename(TEST_VIDEO_PATH)}: {predicted_class}")
        else:
            print(f"\nCould not get prediction for {os.path.basename(TEST_VIDEO_PATH)}.")
    else:
         print(f"\nError: Test video file not found at {TEST_VIDEO_PATH}. Please verify the path.")
elif test_model is None:
     print("\nSkipping testing because the model could not be loaded.")

print("\n--- End of Script ---")

Found existing installation: numpy 2.2.4
Uninstalling numpy-2.2.4:
  Successfully uninstalled numpy-2.2.4
Found existing installation: torch 2.6.0
Uninstalling torch-2.6.0:
  Successfully uninstalled torch-2.6.0
Found existing installation: torchvision 0.21.0
Uninstalling torchvision-0.21.0:
  Successfully uninstalled torchvision-0.21.0
Found existing installation: torchaudio 2.6.0
Uninstalling torchaudio-2.6.0:
  Successfully uninstalled torchaudio-2.6.0
Collecting numpy
  Downloading numpy-2.2.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (62 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.0/62.0 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch
  Downloading torch-2.6.0-cp311-cp311-manylinux1_x86_64.whl.metadata (28 kB)
Collecting torchvision
  Downloading torchvision-0.21.0-cp311-cp311-manylinux1_x86_64.whl.metadata (6.1 kB)
Collecting torchaudio
  Downloading torchaudio-2.6.0-cp311-cp311-manylinux1_x86_64.whl.metada

Mounting Google Drive...
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Google Drive mounted successfully.
Installing libraries...
Libraries already installed or installation not needed in this environment.
All necessary libraries are imported.
Using device: cpu
Configuration loaded:
  Base Path: /content/drive/MyDrive/dataset
  Train Data Path: /content/drive/MyDrive/dataset/train_data
  Test Data Path: /content/drive/MyDrive/dataset/test_data
  Feature Path: /content/drive/MyDrive/dataset/processed_features
  Model Save Path: /content/drive/MyDrive/dataset/trained_models
  Classes: ['backward_fall', 'forward_fall', 'side_fall', 'non_fall']
  Label Map: {'backward_fall': 0, 'forward_fall': 1, 'side_fall': 2, 'non_fall': 3}
  Frame Skip: 3
  Sequence Length: 30

Starting Feature Extraction...

Processing data source: /content/drive/MyDrive/dataset/train_data

  Processing class: backward_fall (Label: 0)
 