In [None]:
import cv2
import torch
import torch.nn as nn
import numpy as np
import os
import io
from IPython.display import display, Image, clear_output
from torchvision import models, transforms
from PIL import Image as PILImage

# ==========================================
# 1. CONFIGURATION
# ==========================================
# ‚ö†Ô∏è UPDATE YOUR IP HERE (Check the app on your phone!)
# Ensure '/video' is at the end of the URL
PHONE_IP = "http://192.168.29.229:8080/video" 

# File to load (ensure this matches your uploaded file name)
WEIGHTS_FILE = "hqcnn_unfrozen_best.pth"

# ==========================================
# 2. MODEL SETUP & ARCHITECTURE
# ==========================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üöÄ Device: {device}")

# Define GTSRB Classes (0-42)
CLASSES = {
    0: 'Speed 20', 1: 'Speed 30', 2: 'Speed 50', 3: 'Speed 60', 4: 'Speed 70', 
    5: 'Speed 80', 6: 'End 80', 7: 'Speed 100', 8: 'Speed 120', 9: 'No Passing', 
    10: 'No Truck Passing', 11: 'Priority Crossroad', 12: 'Priority Road', 
    13: 'Yield', 14: 'STOP', 15: 'No Vehicles', 16: 'No Trucks', 17: 'No Entry', 
    18: 'Caution', 19: 'Curve Left', 20: 'Curve Right', 21: 'Double Curve', 
    22: 'Bumpy Road', 23: 'Slippery', 24: 'Narrow Road', 25: 'Road Work', 
    26: 'Signals', 27: 'Pedestrians', 28: 'Children', 29: 'Bicycles', 
    30: 'Ice/Snow', 31: 'Wild Animals', 32: 'End Speed Limit', 33: 'Turn Right', 
    34: 'Turn Left', 35: 'Ahead Only', 36: 'Straight/Right', 37: 'Straight/Left', 
    38: 'Keep Right', 39: 'Keep Left', 40: 'Roundabout', 41: 'End No Pass', 
    42: 'End No Truck Pass'
}

# Define Architecture (Must match training exactly)
class QuantumLayer(nn.Module):
    def __init__(self, n_qubits=8): 
        super(QuantumLayer, self).__init__()
        self.n_qubits = n_qubits
        self.theta = nn.Parameter(torch.randn(n_qubits) * 0.1) 
    def forward(self, x):
        return torch.cos(x) * torch.sin(self.theta) + torch.sin(x) * torch.cos(self.theta)

class HQCNN(nn.Module):
    def __init__(self, n_classes=43):
        super(HQCNN, self).__init__()
        self.base_model = models.resnet18(weights=None)
        self.base_model.fc = nn.Identity() 
        self.bridge = nn.Linear(512, 8) 
        self.quantum_layer = QuantumLayer(n_qubits=8)
        self.classifier = nn.Linear(8, n_classes)

    def forward(self, x):
        x = self.base_model(x)     
        x = self.bridge(x)        
        feat = self.quantum_layer(x)
        out = self.classifier(feat)     
        return out

# robust path finding logic
if os.path.exists(WEIGHTS_FILE):
    path = WEIGHTS_FILE
elif os.path.exists(os.path.join('notebooks', WEIGHTS_FILE)):
    path = os.path.join('notebooks', WEIGHTS_FILE)
else:
    # Try absolute path based on your WSL info
    path = f'/home/akash_kishore/HQCNN_Project/notebooks/{WEIGHTS_FILE}'

print(f"üìÇ Loading weights from: {path}")

# Load the model
try:
    model = HQCNN(n_classes=43).to(device)
    if torch.cuda.is_available():
        ckpt = torch.load(path, weights_only=False)
    else:
        ckpt = torch.load(path, map_location='cpu', weights_only=False)
    model.load_state_dict(ckpt)
    model.eval()
    print("‚úÖ Model Loaded Successfully!")
except Exception as e:
    print(f"‚ùå Error loading model: {e}")
    # We don't raise here to allow debugging, but logic below might fail if model isn't loaded

# ==========================================
# 3. BROWSER VIDEO LOOP (INLINE DISPLAY)
# ==========================================
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

print(f"üì° Connecting to Phone at {PHONE_IP}...")
cap = cv2.VideoCapture(PHONE_IP)
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1) # Low latency setting

if not cap.isOpened():
    print("‚ùå Connection Failed! Check IP address or restart app.")
else:
    print("‚úÖ Video Stream Started! (Press the Square 'Stop' button in Jupyter toolbar to end)")

try:
    while True:
        ret, frame = cap.read()
        if not ret:
            print("‚ö†Ô∏è Stream Ended / Frame Drop")
            break
            
        # ROI Logic (Green Box)
        h, w, _ = frame.shape
        box = 300
        x1, y1 = (w - box)//2, (h - box)//2
        x2, y2 = x1 + box, y1 + box
        
        roi = frame[y1:y2, x1:x2]
        
        if roi.size > 0:
            # Inference
            roi_pil = PILImage.fromarray(cv2.cvtColor(roi, cv2.COLOR_BGR2RGB))
            input_tensor = preprocess(roi_pil).unsqueeze(0).to(device)
            
            with torch.no_grad():
                out = model(input_tensor)
                probs = torch.nn.functional.softmax(out, dim=1)
                score, idx = torch.max(probs, 1)
                raw_label = CLASSES[idx.item()]
                conf = score.item() * 100
            
            # --- CONFIDENCE THRESHOLD LOGIC ---
            if conf > 90.0:
                label_text = f"{raw_label}: {conf:.0f}%"
                color = (0, 255, 0) # Green
            else:
                label_text = "Scanning..."
                color = (128, 128, 128) # Grey
            
            # Draw UI
            cv2.rectangle(frame, (x1, y1), (x2, y2), color, 3)
            # Label Background
            cv2.rectangle(frame, (x1, y1-40), (x1+300, y1), color, -1)
            # Text
            cv2.putText(frame, label_text, (x1+10, y1-10), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255,255,255), 2)
            
        # Display in Notebook (Inline)
        _, buffer = cv2.imencode('.jpg', frame)
        clear_output(wait=True)
        display(Image(data=buffer.tobytes()))
        
except KeyboardInterrupt:
    print("üõë Stream Stopped by User")
except Exception as e:
    print(f"‚ùå Runtime Error: {e}")
finally:
    cap.release()
    print("üîå Camera Released")