In [None]:
import cv2
import torch
import torch.nn as nn
import torchvision.transforms as T

# ---------- CONFIG ----------
MODEL_PATH = "hand_model_mudra_green_10_me.pth"  # path to your saved model
INPUT_SIZE = 416                           # resize camera frames to this
USE_GPU = torch.mps.is_available()
# -----------------------------




class MudraCNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        def block(inp, out):
            return nn.Sequential(
                nn.Conv2d(inp, out, 3, padding=1),
                nn.ReLU(),
                nn.BatchNorm2d(out),
                nn.Conv2d(out, out, 3, padding=1),
                nn.ReLU(),
                nn.BatchNorm2d(out)
            )
        
        self.net = nn.Sequential(
            block(3, 32),
            nn.MaxPool2d(2),

            block(32, 64),
            nn.MaxPool2d(2),

            block(64, 128),
            nn.MaxPool2d(3),

            block(128, 256),
            nn.MaxPool2d(3),

            block(256, 256),
            nn.MaxPool2d(3),

            # nn.AdaptiveAvgPool2d(3)
        )
        self.fc = nn.Linear(256*3*3, num_classes)

    def forward(self, x):
        x = self.net(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)



def load_model():
    """
    Load model and class names from checkpoint.
    """
    checkpoint = torch.load(MODEL_PATH, map_location="cpu")
    classes = checkpoint["classes"]

    model = MudraCNN(num_classes=len(classes))
    model.load_state_dict(checkpoint["state_dict"])

    model.eval()
    if USE_GPU:
        model.to("mps")
    return model, classes


# Preprocessing transform for each frame (matches training)
transform = T.Compose([
    T.ToPILImage(),
    T.Resize((INPUT_SIZE, INPUT_SIZE)),
    T.ToTensor(),
])


@torch.no_grad()
def predict_frame(model, classes, frame_bgr):
    """
    Takes a BGR frame (from OpenCV), returns (predicted_label, confidence)
    """
    # Convert BGR (OpenCV) → RGB (PyTorch convention)
    frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)

    # Apply transforms
    img = transform(frame_rgb)  # shape: [C, H, W]
    img = img.unsqueeze(0)      # shape: [1, C, H, W]

    if USE_GPU:
        img = img.to("mps")

    # Forward pass
    logits = model(img)  # [1, num_classes]
    probs = torch.softmax(logits, dim=1)[0]  # [num_classes]

    conf, pred_idx = torch.max(probs, dim=0)
    label = classes[pred_idx.item()]
    return label, conf.item()


def main():
    print("Loading model...")
    model, classes = load_model()
    print("Model loaded. Opening camera...")

    cap = cv2.VideoCapture(0)  # 0 = default camera
    if not cap.isOpened():
        print("Error: Could not open webcam.")
        return

    print("Press 'q' to quit.")

    while True:
        ret, frame = cap.read()
        if not ret:
            print("Failed to grab frame")
            break

        # Predict on current frame
        label, conf = predict_frame(model, classes, frame)

        # Draw prediction on the frame
        text = f"{label} ({conf*100:.1f}%)"
        cv2.putText(
            frame, text,
            (10, 30),                  # position
            cv2.FONT_HERSHEY_SIMPLEX,  # font
            1.0,                       # font scale
            (0, 255, 0),               # color (B, G, R)
            2,                         # thickness
            cv2.LINE_AA
        )

        # Show the frame
        cv2.imshow("Live Prediction", frame)

        # If user presses 'q' → quit
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()


if __name__ == "__main__":
    main()
