In [9]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import cv2
from PIL import Image
import torch.nn.functional as F

# -----------------------------
# 1. Define CNN architecture
# -----------------------------
class TrafficSignCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(3,32,3,padding=1), nn.ReLU(), nn.MaxPool2d(2,2),
            nn.Conv2d(32,64,3,padding=1), nn.ReLU(), nn.MaxPool2d(2,2),
            nn.Conv2d(64,128,3,padding=1), nn.ReLU(), nn.MaxPool2d(2,2)
        )
        self.fc = nn.Sequential(
            nn.Linear(128*4*4,256), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(256,43)
        )
        
    def forward(self,x):
        x = self.network(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

# -----------------------------
# 2. Load trained model
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TrafficSignCNN().to(device)
model.load_state_dict(torch.load("traffic_sign_model.pt", map_location=device))
model.eval()

# -----------------------------
# 3. Preprocessing
# -----------------------------
transform = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# -----------------------------
# 4. GTSRB Class Labels
# -----------------------------
class_labels = {
    0: "Speed limit 20 km/h",
    1: "Speed limit 30 km/h",
    2: "Speed limit 50 km/h",
    3: "Speed limit 60 km/h",
    4: "Speed limit 70 km/h",
    5: "Speed limit 80 km/h",
    6: "End of speed limit 80 km/h",
    7: "Speed limit 100 km/h",
    8: "Speed limit 120 km/h",
    9: "No passing",
    10: "No passing for vehicles over 3.5 metric tons",
    11: "Right-of-way at the next intersection",
    12: "Priority road",
    13: "Yield",
    14: "Stop",
    15: "No vehicles",
    16: "Vehicles over 3.5 metric tons prohibited",
    17: "No entry",
    18: "General caution",
    19: "Dangerous curve left",
    20: "Dangerous curve right",
    21: "Double curve",
    22: "Bumpy road",
    23: "Slippery road",
    24: "Road narrows on the right",
    25: "Road work",
    26: "Traffic signals",
    27: "Pedestrians",
    28: "Children crossing",
    29: "Bicycles crossing",
    30: "Beware of ice/snow",
    31: "Wild animals crossing",
    32: "End of all speed and passing limits",
    33: "Turn right ahead",
    34: "Turn left ahead",
    35: "Ahead only",
    36: "Go straight or right",
    37: "Go straight or left",
    38: "Keep right",
    39: "Keep left",
    40: "Roundabout mandatory",
    41: "End of no passing",
    42: "End of no passing by vehicles over 3.5 metric tons"
}

# -----------------------------
# 5. Real-time webcam prediction
# -----------------------------
cap = cv2.VideoCapture(0)
confidence_threshold = 0.6  # Confidence threshold for detecting a sign

while True:
    ret, frame = cap.read()
    if not ret:
        break

    # Preprocess frame
    img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    pil_img = Image.fromarray(img_rgb)
    img_tensor = transform(pil_img).unsqueeze(0).to(device)

    # Predict
    with torch.no_grad():
        output = model(img_tensor)
        probs = F.softmax(output, dim=1)
        max_prob, pred = torch.max(probs, 1)

        if max_prob.item() < confidence_threshold:
            label = "No Sign Detected"
        else:
            label = class_labels[pred.item()]

    # Display prediction
    cv2.putText(frame, f"{label} ({max_prob.item():.2f})", (10,30),
            cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0,0,255), 2)

    
    cv2.imshow("Traffic Sign Recognition", frame)

    # Quit on 'q'
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()
