# 05 — Real-Time Screw Detection (Webcam / Video Stream)
---
##  Notebook Overview 

In this notebook, we will use our trained custom model to perform real-time detection on screws.


---
# Setup & Imports

In [36]:

import os
import time
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import torchvision.models as models
import torchvision.transforms as transforms


# Configuration


In [46]:
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 6  # adjust if you change dataset
weights_path = "best_model.pth"   # <-- replace with your trained model path

# Choose video_path =  0 (webcam), video file path, or IP stream URL
# Examples:
#  video_path = 0
# video_path = r"E:\Final project of Ats\videos\screws_belt.mp4"
# video_path = "http://192.168.0.101:4747/video"

video_path = r"C:\Users\A.C\Pictures\Test video for project .mp4"






# Load Trained Model (ResNet50)



In [38]:
model = models.resnet50(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.load_state_dict(torch.load(weights_path, map_location=device))
model = model.to(device)
model.eval()

print(" Model loaded and ready!")




 Model loaded and ready!


# Preprocessing (same as training)



In [39]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Class Labels 

In [40]:
# Class Labels 
class_names = ["good", "manipulated_front", "scratch_head", 
               "scratch_neck", "thread_side", "thread_top"]


# Run Real-Time Stream


In [48]:
# Video Capture

cap = cv2.VideoCapture(video_path)

frame_skip = 3   # process 1 of every 3 frames (speed boost)
frame_id = 0

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    frame_id += 1
    if frame_id % frame_skip != 0:
        continue

    #  Preprocess frame 
    img = cv2.resize(frame, (224, 224))
    img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    img_tensor = transform(img_pil).unsqueeze(0).to(device)

    #  Prediction 
    with torch.no_grad():
        outputs = model(img_tensor)
        _, pred = torch.max(outputs, 1)
        label = class_names[pred.item()]

    # Draw label
    cv2.putText(frame, f"Pred: {label}", (20, 40), 
                cv2.FONT_HERSHEY_SIMPLEX, 1, (0,255,0), 2)

    cv2.imshow("Defect Detection", frame)

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()