In [1]:
"""
Live ASL classifier with MediaPipe hand tracking + SimpleCNN
Requirements:
    pip install torch torchvision opencv-python pillow mediapipe
"""

import cv2
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights
from PIL import Image
import mediapipe as mp
import time

In [2]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(64 * 56 * 56, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)

# -----------------------------
# Load trained model
# -----------------------------
# 1) Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 2) Instantiate the same backbone you trained
#    Don’t pass pretrained weights here—your checkpoint has the weights.
model = resnet18(weights=None)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 27)  # 27 classes: A–Z, del, nothing, space
model = model.to(device)

# 3) Load your checkpoint
ckpt = torch.load(r"best_asl_resnet_checkpoint_5.pth", map_location=device)
model.load_state_dict(ckpt['model_state_dict'])
model.eval()
print("✅ Loaded ResNet checkpoint successfully.")

# -----------------------------
# Class labels
# -----------------------------
idx_to_class = {
    0:  'A',
    1:  'B',
    2:  'Blank',
    3:  'C',
    4:  'D',
    5:  'E',
    6:  'F',
    7:  'G',
    8:  'H',
    9:  'I',
    10: 'J',
    11: 'K',
    12: 'L',
    13: 'M',
    14: 'N',
    15: 'O',
    16: 'P',
    17: 'Q',
    18: 'R',
    19: 'S',
    20: 'T',
    21: 'U',
    22: 'V',
    23: 'W',
    24: 'X',
    25: 'Y',
    26: 'Z'
}


# -----------------------------
# Preprocessing transform
# -----------------------------
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
])

# -----------------------------
# MediaPipe hand detection
# -----------------------------
mp_hands = mp.solutions.hands
hands = mp_hands.Hands(
    static_image_mode=False,
    max_num_hands=1,
    min_detection_confidence=0.5
)
mp_draw = mp.solutions.drawing_utils

# -----------------------------
# Start webcam loop
# -----------------------------
cap = cv2.VideoCapture(0)
if not cap.isOpened():
    raise RuntimeError("Could not open webcam!")

print("✅ Press 'q' to quit.")
fps, frame_count, t_start = 0, 0, time.time()

with torch.no_grad():
    while True:
        ret, frame = cap.read()
        if not ret:
            break

        # Flip if you want mirror effect
        # frame = cv2.flip(frame, 1)

        img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        results = hands.process(img_rgb)
        label = "No hand"

        if results.multi_hand_landmarks:
            h, w, _ = frame.shape
            lm = results.multi_hand_landmarks[0].landmark

            # get bounding box coords
            xs = [p.x for p in lm]
            ys = [p.y for p in lm]
            x_min = max(int(min(xs) * w) - 20, 0)
            x_max = min(int(max(xs) * w) + 20, w)
            y_min = max(int(min(ys) * h) - 20, 0)
            y_max = min(int(max(ys) * h) + 20, h)

            # crop & preprocess
            crop = frame[y_min:y_max, x_min:x_max]
            pil_img = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
            inp = transform(pil_img).unsqueeze(0).to(device)

            # inference
            out = model(inp)
            pred = torch.argmax(out, 1).item()
            label = idx_to_class[pred]

            # optional: draw hand bounding box + landmarks
            cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), (0,255,0), 2)
            mp_draw.draw_landmarks(frame, results.multi_hand_landmarks[0], mp_hands.HAND_CONNECTIONS)

        # FPS calculation
        frame_count += 1
        if frame_count == 10:
            now = time.time()
            fps = 10 / (now - t_start)
            t_start = now
            frame_count = 0

        # overlay label + FPS
        cv2.putText(frame, f"Pred: {label}", (10,40),
                    cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0,255,0), 2)
        cv2.putText(frame, f"FPS: {fps:.1f}", (10,80),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255,255,0), 2)

        cv2.imshow("ASL Live Inference (MediaPipe + CNN)", frame)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

cap.release()
cv2.destroyAllWindows()

✅ Loaded ResNet checkpoint successfully.
✅ Press 'q' to quit.
