In [1]:
!pip install --upgrade torch torchvision opencv-python matplotlib
!apt-get install ffmpeg -y > /dev/null 2>&1


Collecting torch
  Downloading torch-2.6.0-cp310-cp310-manylinux1_x86_64.whl.metadata (28 kB)
Collecting torchvision
  Downloading torchvision-0.21.0-cp310-cp310-manylinux1_x86_64.whl.metadata (6.1 kB)
Collecting opencv-python
  Downloading opencv_python-4.11.0.86-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Collecting matplotlib
  Downloading matplotlib-3.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Download

In [2]:
# Model file
model_path = "/content/faster_rcnn_checkpoint_epoch_86.pth"

# Input paths
image_input_path = "/content/test_images"
mp4_video_path = "/content/video1.mp4"
#webm_video_path = "/content/video2.webm"

# Confidence threshold
threshold = 0.05


In [13]:
# 🚀 Install dependencies (Ensure latest torchvision version)
!pip install --upgrade torch torchvision opencv-python matplotlib
!apt-get install ffmpeg -y > /dev/null 2>&1

import torch
import torchvision.transforms as T
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor  # Fixed import issue
import cv2
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches

# ✅ Define class names (Modify if needed)
CLASS_NAMES = {
    0: "Booken piece", 1: "Conglomerate", 2:"XL", 3:"Empty", 4:"Fissile shale",
    5: "LA", 6:"Non-core", 7: "Rip-up", 8:"SSD", 9: "Wavy bedding",
    10:"BMM", 11: "BSM", 12: "Concretion", 13:"Current ripple",
    14: "Intraclast", 15: "MM", 16:"MS", 17: "MD", 18: "PL"
}  # Update if needed

# ✅ Load Faster R-CNN model
def load_model(model_path, device, num_classes=19):
    model = fasterrcnn_resnet50_fpn(weights=None, num_classes=num_classes)
    state_dict = torch.load(model_path, map_location=device)

    if "model_state_dict" in state_dict:
        print("⚠️ Warning: Model file contains extra keys! Loading only 'model_state_dict'.")
        state_dict = state_dict["model_state_dict"]

    model.load_state_dict(state_dict, strict=False)  # Allow extra keys
    model.to(device)
    model.eval()
    return model

# ✅ Preprocess image for model inference
def preprocess_image(image):
    transform = T.Compose([T.ToTensor()])
    return transform(image).unsqueeze(0)

# ✅ Inference function (Detect objects)
def inference(model, image_tensor, device, threshold=0.2):
    image_tensor = image_tensor.to(device)
    with torch.no_grad():
        predictions = model(image_tensor)

    pred_boxes = predictions[0]['boxes'].cpu().numpy()
    pred_scores = predictions[0]['scores'].cpu().numpy()
    pred_labels = predictions[0]['labels'].cpu().numpy()

    # Filter predictions based on confidence threshold
    filtered_boxes = [box for i, box in enumerate(pred_boxes) if pred_scores[i] > threshold]
    filtered_scores = [score for score in pred_scores if score > threshold]
    filtered_labels = [label for i, label in enumerate(pred_labels) if pred_scores[i] > threshold]

    return filtered_boxes, filtered_scores, filtered_labels

# ✅ Visualize results (For images)
def visualize_results(image, boxes, scores, labels, save_path=None):
    fig, ax = plt.subplots(1, figsize=(12, 8))
    ax.imshow(image)

    for box, score, label in zip(boxes, scores, labels):
        x_min, y_min, x_max, y_max = box
        rect = patches.Rectangle((x_min, y_min), x_max - x_min, y_max - y_min,
                                 linewidth=2, edgecolor='r', facecolor='none')
        ax.add_patch(rect)
        class_name = CLASS_NAMES.get(label, "Unknown")
        ax.text(x_min, y_min - 5, f"{class_name}, Score: {score:.2f}",
                color='red', fontsize=10, bbox=dict(facecolor='white', alpha=0.5))

    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, bbox_inches='tight', dpi=300)
        print(f"✅ Image saved: {save_path}")

    plt.close(fig)

# ✅ Process images
def process_images(folder_path, model, device, threshold):
    output_folder = "output_images"
    os.makedirs(output_folder, exist_ok=True)

    for image_name in os.listdir(folder_path):
        if image_name.lower().endswith((".jpg", ".png", ".jpeg")):
            image_path = os.path.join(folder_path, image_name)
            image = Image.open(image_path).convert("RGB")
            image_tensor = preprocess_image(image)
            boxes, scores, labels = inference(model, image_tensor, device, threshold)

            output_image_path = os.path.join(output_folder, f"detected_{image_name}")
            visualize_results(image, boxes, scores, labels, save_path=output_image_path)

# ✅ Process video
def process_video(video_path, model, device, threshold):
    cap = cv2.VideoCapture(video_path, cv2.CAP_FFMPEG)  # Force FFmpeg for better compatibility
    if not cap.isOpened():
        print("❌ Error: Cannot open video. Try converting it to .mp4 using ffmpeg.")
        return

    output_folder = "output_videos"
    os.makedirs(output_folder, exist_ok=True)

    output_video_path = os.path.join(output_folder, os.path.basename(video_path).replace(".webm", "_detected.webm").replace(".mp4", "_detected.mp4"))
    fourcc = cv2.VideoWriter_fourcc(*"VP80") if video_path.lower().endswith(".webm") else cv2.VideoWriter_fourcc(*"mp4v")
    width, height = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))

    out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))

    frame_count = 0
    total_detections = 0

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        image_tensor = preprocess_image(image)
        boxes, scores, labels = inference(model, image_tensor, device, threshold)

        frame_detections = len(boxes)
        total_detections += frame_detections
        print(f"🔎 Frame {frame_count}: {frame_detections} objects detected")

        for box, score, label in zip(boxes, scores, labels):
            x_min, y_min, x_max, y_max = [int(i) for i in box]
            cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
            class_name = CLASS_NAMES.get(label, "Unknown")
            cv2.putText(frame, f"{class_name}, {score:.2f}", (x_min, y_min - 5),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

        out.write(frame)
        frame_count += 1

    cap.release()
    out.release()

    if total_detections > 0:
        print(f"✅ Processed video saved: {output_video_path} with {total_detections} detections.")
    else:
        print("⚠️ No objects detected. Try lowering the confidence threshold.")

# ✅ Main function
def main(input_path, model_path, threshold=0.2):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"🚀 Using device: {device}")

    model = load_model(model_path, device)

    if input_path.lower().endswith((".mp4", ".avi", ".mov", ".webm")):
        process_video(input_path, model, device, threshold)
    else:
        process_images(input_path, model, device, threshold)




In [14]:
main(image_input_path, model_path, threshold)

🚀 Using device: cuda
✅ Image saved: output_images/detected_image7.jpg
✅ Image saved: output_images/detected_image8.jpg
✅ Image saved: output_images/detected_image10.jpg
✅ Image saved: output_images/detected_image6.jpg
✅ Image saved: output_images/detected_image9.jpg
✅ Image saved: output_images/detected_image2.jpg
✅ Image saved: output_images/detected_image5.jpg
✅ Image saved: output_images/detected_image3.jpg
✅ Image saved: output_images/detected_image4.jpg
✅ Image saved: output_images/detected_image1.jpg


In [15]:
main(mp4_video_path, model_path, threshold)

🚀 Using device: cuda
🔎 Frame 0: 80 objects detected
🔎 Frame 1: 78 objects detected
🔎 Frame 2: 78 objects detected
🔎 Frame 3: 78 objects detected
🔎 Frame 4: 82 objects detected
🔎 Frame 5: 82 objects detected
🔎 Frame 6: 80 objects detected
🔎 Frame 7: 83 objects detected
🔎 Frame 8: 82 objects detected
🔎 Frame 9: 80 objects detected
🔎 Frame 10: 80 objects detected
🔎 Frame 11: 79 objects detected
🔎 Frame 12: 81 objects detected
🔎 Frame 13: 81 objects detected
🔎 Frame 14: 80 objects detected
🔎 Frame 15: 81 objects detected
🔎 Frame 16: 80 objects detected
🔎 Frame 17: 81 objects detected
🔎 Frame 18: 80 objects detected
🔎 Frame 19: 80 objects detected
🔎 Frame 20: 80 objects detected
🔎 Frame 21: 80 objects detected
🔎 Frame 22: 80 objects detected
🔎 Frame 23: 80 objects detected
🔎 Frame 24: 80 objects detected
🔎 Frame 25: 80 objects detected
🔎 Frame 26: 80 objects detected
🔎 Frame 27: 80 objects detected
🔎 Frame 28: 80 objects detected
🔎 Frame 29: 79 objects detected
🔎 Frame 30: 80 objects detect