# CELL 1: Environment & GPU Setup

In [None]:
# ============================================================================
# CELL 1: Environment & GPU Setup
# ============================================================================
# Install/reinstall main packages to ensure compatibility
!pip install torch torchvision torchaudio
!pip install numpy==1.26.4 scipy==1.12.0 scikit-learn==1.5.2

# Import libraries
import torch
import numpy as np
import random
import pickle
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import cv2
import mediapipe as mp
import os
import glob
from tqdm import tqdm
import json
import time
import re
import matplotlib.pyplot as plt
import seaborn as sns

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')


# CELL 2: Configuration & Constants

In [None]:
# ============================================================================
# CELL 2: Configuration & Constants
# ============================================================================
# Paths
BASE_INPUT_DIR = "/content/drive/MyDrive/KSL_Project/KSL_Videos"
OUTPUT_SKELETON_DIR = "/content/drive/MyDrive/KSL_Project/KSL_JOINT_STREAM"

# Model parameters
SEQUENCE_LENGTH = 32
NUM_JOINTS = 47
POSE_INDICES = [0, 11, 12, 13, 14]  # Nose, Shoulders, Elbows

# Retry settings
MAX_RETRIES = 5
INITIAL_DELAY_SECONDS = 2


# CELL 3: MediaPipe Holistic Setup & Landmark Extraction

In [None]:
# ============================================================================
# CELL 3: MediaPipe Holistic Setup & Landmark Extraction
# ============================================================================
mp_holistic = mp.solutions.holistic

def extract_landmarks_from_frame(results):
    """Extracts 47 keypoints from hands + pose."""
    frame_coords = np.zeros((NUM_JOINTS, 3), dtype=np.float32)

    # Left hand
    if results.left_hand_landmarks:
        for i, lm in enumerate(results.left_hand_landmarks.landmark):
            frame_coords[i] = [lm.x, lm.y, lm.z]

    # Right hand
    if results.right_hand_landmarks:
        for i, lm in enumerate(results.right_hand_landmarks.landmark):
            frame_coords[i + 21] = [lm.x, lm.y, lm.z]

    # Pose keypoints
    if results.pose_landmarks:
        for i, pose_index in enumerate(POSE_INDICES):
            lm = results.pose_landmarks.landmark[pose_index]
            frame_coords[i + 42] = [lm.x, lm.y, lm.z]

    return frame_coords


# CELL 4: Video Preprocessing Function

In [None]:
# ============================================================================
# CELL 4: Video Preprocessing Function
# ============================================================================
def preprocess_video(video_path):
    """Process video, extract landmarks, and pad/trim to SEQUENCE_LENGTH."""
    cap = None
    delay = INITIAL_DELAY_SECONDS

    for attempt in range(MAX_RETRIES):
        cap = cv2.VideoCapture(video_path)
        if cap.isOpened():
            break
        print(f"[Attempt {attempt+1}/{MAX_RETRIES}] Failed to open {video_path}, retrying in {delay}s...")
        cap.release()
        time.sleep(delay)
        delay *= 2

    if not cap or not cap.isOpened():
        print(f"[FATAL] Skipped video: {video_path}")
        return np.zeros((SEQUENCE_LENGTH, NUM_JOINTS, 3), dtype=np.float32)

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    indices_to_sample = np.linspace(0, total_frames - 1, SEQUENCE_LENGTH, dtype=int) if total_frames > 0 else []

    joint_data = []
    with mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5) as holistic:
        current_frame = 0
        frame_idx = 0
        while cap.isOpened() and frame_idx < SEQUENCE_LENGTH:
            ret, frame = cap.read()
            if not ret:
                break
            if current_frame in indices_to_sample and frame_idx < SEQUENCE_LENGTH:
                image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                image.flags.writeable = False
                results = holistic.process(image)
                coords = extract_landmarks_from_frame(results)
                joint_data.append(coords)
                frame_idx += 1
            current_frame += 1
    cap.release()

    # Pad or trim
    joint_data = np.array(joint_data, dtype=np.float32)
    if joint_data.shape[0] < SEQUENCE_LENGTH:
        padding = np.zeros((SEQUENCE_LENGTH - joint_data.shape[0], NUM_JOINTS, 3), dtype=np.float32)
        joint_data = np.concatenate([joint_data, padding], axis=0)
    elif joint_data.shape[0] > SEQUENCE_LENGTH:
        joint_data = joint_data[:SEQUENCE_LENGTH]

    return joint_data


# CELL 5: Main Video Processing Loop

In [None]:
# ============================================================================
# CELL 5: Main Video Processing Loop
# ============================================================================
def process_all_videos():
    """Extract features from all videos recursively and save as pickle."""
    if not os.path.exists(BASE_INPUT_DIR):
        print(f"Input directory not found: {BASE_INPUT_DIR}")
        return

    os.makedirs(OUTPUT_SKELETON_DIR, exist_ok=True)
    print(f"Output directory ready: {OUTPUT_SKELETON_DIR}")

    video_files = glob.glob(os.path.join(BASE_INPUT_DIR, '**', '*.[mM][pP]4'), recursive=True)
    if not video_files:
        print(f"No videos found in {BASE_INPUT_DIR}")
        return
    print(f"Found {len(video_files)} video files.")

    all_features, all_labels = [], []
    for video_path in tqdm(video_files, desc="Processing Videos"):
        try:
            class_label = int(os.path.basename(os.path.dirname(video_path)))
            joint_data = preprocess_video(video_path)
            if np.all(joint_data == 0):
                continue
            joint_data = joint_data.transpose(2, 0, 1)  # (3, 32, 47)
            all_features.append(joint_data)
            all_labels.append(class_label)
        except Exception as e:
            print(f"Error processing {video_path}: {e}")

    final_features = np.array(all_features, dtype=np.float32)
    final_labels = np.array(all_labels, dtype=np.int64)

    output_pkl_path = os.path.join(OUTPUT_SKELETON_DIR, "KSL77_joint_stream_47pt.pkl")
    with open(output_pkl_path, 'wb') as f:
        pickle.dump((final_features, final_labels), f)

    print(f"\nProcessed {final_features.shape[0]} videos. Data saved to {output_pkl_path}")

# Run the processing
if __name__ == '__main__':
    process_all_videos()


# CELL 6: Load Processed Data & Create Datasets

In [None]:
# ============================================================================
# CELL 6: Load Processed Data & Create Datasets
# ============================================================================
data_path = "/content/drive/MyDrive/KSL_Project/KSL_JOINT_STREAM/KSL77_joint_stream_47pt.pkl"
with open(data_path, "rb") as f:
    features, labels = pickle.load(f)

print("Loaded features:", features.shape)
print("Loaded labels:", labels.shape)

features = torch.tensor(features, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.long)

class PoseDataset(Dataset):
    """PyTorch Dataset for KSL joint features, with optional augmentation."""
    def __init__(self, X, y, augment=False):
        self.X = X
        self.y = y
        self.augment = augment

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x = self.X[idx].clone()
        y = self.y[idx]
        if self.augment:
            x = self.apply_augmentations(x)
        return x, y

    def apply_augmentations(self, x):
        if random.random() < 0.5: x += torch.randn_like(x) * 0.01
        if random.random() < 0.3: x *= 1.0 + (random.random() - 0.5) * 0.1
        if random.random() < 0.3: x = torch.roll(x, shifts=random.randint(-2,2), dims=1)
        if random.random() < 0.3: x[0] = -x[0]
        return x

# Train/Test split
X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.2, stratify=labels, random_state=42)
unique_labels = torch.unique(labels)
label_map = {old.item(): new for new, old in enumerate(unique_labels)}
y_train = torch.tensor([label_map[int(l)] for l in y_train])
y_test = torch.tensor([label_map[int(l)] for l in y_test])
reverse_label_map = {v: k for k, v in label_map.items()}
num_classes = len(unique_labels)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, stratify=y_train, random_state=42)

# DataLoaders
train_loader = DataLoader(PoseDataset(X_train, y_train, augment=True), batch_size=32, shuffle=True)
val_loader = DataLoader(PoseDataset(X_val, y_val), batch_size=32, shuffle=False)
test_loader = DataLoader(PoseDataset(X_test, y_test), batch_size=32, shuffle=False)

print(f"DataLoaders ready: Train {len(train_loader)}, Val {len(val_loader)}, Test {len(test_loader)}")


# CELL 7: Model Definition - PoseCNN_LSTM_Attn

In [None]:
# ============================================================================
# CELL 7: Model Definition - PoseCNN_LSTM_Attn
# ============================================================================
import torch.nn as nn
import torch.nn.functional as F

class PoseCNN_LSTM_Attn(nn.Module):
    """CNN+LSTM+Attention model for KSL joint sequence classification."""
    def __init__(self, num_classes):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=(1,5), padding=(0,2))
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=(1,3), padding=(0,1))
        self.bn2 = nn.BatchNorm2d(128)
        self.pool = nn.MaxPool2d((1,2))
        self.dropout = nn.Dropout(0.3)
        self.temp_conv = nn.Conv2d(128,128,kernel_size=(3,1), padding=(1,0))
        self.bn_temp = nn.BatchNorm2d(128)
        self.lstm = nn.LSTM(input_size=128*(47//2), hidden_size=128, num_layers=1, batch_first=True, bidirectional=True)
        self.attn = nn.Sequential(nn.Linear(128*2,128), nn.Tanh(), nn.Linear(128,1))
        self.fc = nn.Sequential(nn.BatchNorm1d(128*2), nn.Linear(128*2,256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256,num_classes))

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool(x)
        x = self.dropout(x)
        x = F.relu(self.bn_temp(self.temp_conv(x)))
        x = x.permute(0,2,1,3).contiguous()
        x = x.view(x.size(0), x.size(1), -1)
        out,_ = self.lstm(x)
        attn_scores = self.attn(out)
        attn_weights = torch.softmax(attn_scores, dim=1)
        context = torch.sum(attn_weights*out, dim=1)
        return self.fc(context)

model = PoseCNN_LSTM_Attn(num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)


# CELL 8: Training Loop

In [None]:
# ============================================================================
# CELL 8: Training Loop
# ============================================================================
def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
    return running_loss / len(dataloader.dataset)

def evaluate(model, dataloader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return correct / total

# Train model
best_acc, patience, trigger = 0, 10, 0
num_epochs = 100

for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_acc = evaluate(model, val_loader)
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "best_model.pt")
        print(f"Epoch {epoch}: New best accuracy {best_acc:.4f}")
        trigger = 0
    else:
        trigger += 1
        print(f"No improvement. Early stop counter: {trigger}/{patience}")
        if trigger >= patience:
            print(f"Early stopping at epoch {epoch}")
            break


# CELL 9: Test & Evaluation

In [None]:
# ============================================================================
# CELL 9: Test & Evaluation
# ============================================================================
# Test accuracy
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for X_batch, y_batch in test_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        outputs = model(X_batch)
        _, preds = torch.max(outputs, 1)
        correct += (preds == y_batch).sum().item()
        total += y_batch.size(0)

print(f"Test Accuracy: {100*correct/total:.2f}%")

# Classification report
from sklearn.metrics import classification_report
all_preds, all_labels = [], []
with torch.no_grad():
    for X_batch, y_batch in test_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        outputs = model(X_batch)
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(y_batch.cpu().numpy())

report_dict = classification_report(all_labels, all_preds, output_dict=True)
perfect_mapped_classes = [int(cls) for cls, metrics in report_dict.items() if cls.isdigit() and metrics['precision']==1.0 and metrics['recall']==1.0]
perfect_original_classes = [reverse_label_map[c] for c in perfect_mapped_classes]

print("Perfectly predicted classes:", perfect_original_classes)

# Confusion Matrix
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(12,10))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
plt.show()

# Save final model
save_path = "/content/drive/MyDrive/KSL_Project/best_model2.pt"
torch.save(model.state_dict(), save_path)
print(f"Model saved to {save_path}")
