# Scientifically Sound Unsupervised PPE Detection (Colab)

## 1. Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Install dependencies
!pip install -r requirements.txt

import sys
import os
from pathlib import Path
import torch
import cv2
import random
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms
from scipy.optimize import linear_sum_assignment

# --- Path Setup ---
sys.path.insert(0, str(Path.cwd()))

from config import CONFIG
from data_utils import prepare_dataset
from unsupervised_trainer import UnsupervisedTrainer
from discovery_processor import DiscoveryProcessor
from violation_processor import ViolationProcessor

print(f"Project root set to: {CONFIG['project_root_path']}")

## 2. Data Preparation and Training

In [None]:
# Prepare the dataset
image_paths, labels = prepare_dataset(CONFIG['training']['data_fraction'])

# Custom Dataset and DINO Augmentations
class DataAugmentationDINO(object):
    # ... (omitted for brevity - same as before) ...

class PpeDataset(Dataset):
    # ... (omitted for brevity - same as before) ...

transform = DataAugmentationDINO()
dataset = PpeDataset(image_paths, labels, transform=transform)
data_loader = DataLoader(dataset, batch_size=CONFIG['training']['batch_size'], shuffle=True)

# --- Training --- 
# Set to False to skip training and use a pretrained model
run_training = False
if run_training:
    trainer = UnsupervisedTrainer(CONFIG)
    trainer.train(data_loader)
else:
    print("Skipping training. A pretrained DINOv2 model will be used for discovery.")

## 3. Unsupervised Object Discovery

In [None]:
# Load the model for discovery
# If training was run, this would load the fine-tuned model.
# If not, it loads the default pretrained DINOv2.
model_path = CONFIG['checkpoint_dir_abs'] / 'latest_checkpoint.pt'
discovery_processor = DiscoveryProcessor(CONFIG, model_path=model_path if run_training and model_path.exists() else None)

# Select a sample image from the validation set
valid_image_paths = [p for p in image_paths if 'valid' in str(p)]
if valid_image_paths:
    sample_image_path = random.choice(valid_image_paths)
    sample_image = cv2.imread(sample_image_path)
    sample_image_rgb = cv2.cvtColor(sample_image, cv2.COLOR_BGR2RGB)

    # Discover objects using the new method
    discovered_objects, masks = discovery_processor.discover_objects(sample_image_rgb, n_clusters=4)

    # Visualize the results
    plt.figure(figsize=(18, 6))
    plt.subplot(1, 3, 1)
    plt.imshow(sample_image_rgb)
    plt.title('Original Image')
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.imshow(masks, cmap='viridis')
    plt.title('Discovered Segments')
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.imshow(sample_image_rgb)
    for obj in discovered_objects:
        x1, y1, x2, y2 = obj['box']
        plt.gca().add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, edgecolor='red', facecolor='none', lw=2))
        plt.text(x1, y1 - 5, f"Cluster {obj['cluster_id']}", color='white', backgroundcolor='red')
    plt.title('Discovered Bounding Boxes')
    plt.axis('off')
    plt.show()
else:
    print("No validation images found.")

## 4. Validation: Mapping Clusters to Classes and Evaluating

In [None]:
def calculate_iou(boxA, boxB):
    # ... (omitted for brevity - same as before) ...

def map_clusters_to_classes(discovery_processor, image_paths, labels):
    # ... (Implementation in next cell)
    pass

def evaluate_discovery(discovery_processor, cluster_class_map, image_paths, labels):
    # ... (Implementation in next cell)
    pass

# --- Run Validation ---
print("Mapping unsupervised clusters to semantic classes...")
valid_indices = [i for i, p in enumerate(image_paths) if 'valid' in str(p)]
valid_images = [image_paths[i] for i in valid_indices]
valid_labels = [labels[i] for i in valid_indices]

# This is a compute-intensive step, so we'll run on a subset
subset_size = 50
cluster_class_map = map_clusters_to_classes(discovery_processor, valid_images[:subset_size], valid_labels[:subset_size])
print(f"Cluster to Class Map: {cluster_class_map}")

print("\nEvaluating discovery performance...")
mean_iou = evaluate_discovery(discovery_processor, cluster_class_map, valid_images[:subset_size], valid_labels[:subset_size])
print(f"\nMean IoU on Validation Set: {mean_iou:.4f}")

In [None]:
# --- Implementation of Validation Functions ---

def map_clusters_to_classes(discovery_processor, image_paths, labels, n_clusters=4):
    num_classes = len(CONFIG['discovery']['class_map'])
    cost_matrix = np.zeros((n_clusters, num_classes))

    for img_path, label_list in zip(image_paths, labels):
        img = cv2.imread(img_path)
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        discovered_objects, _ = discovery_processor.discover_objects(img_rgb, n_clusters=n_clusters)
        
        for gt_obj in label_list:
            gt_box = gt_obj['box']
            gt_class_id = gt_obj['class_id']
            for pred_obj in discovered_objects:
                pred_box = pred_obj['box']
                iou = calculate_iou(gt_box, pred_box)
                cost_matrix[pred_obj['cluster_id'], gt_class_id] -= iou # Negative because we want to maximize IoU

    row_ind, col_ind = linear_sum_assignment(cost_matrix)
    return {r: c for r, c in zip(row_ind, col_ind)}

def evaluate_discovery(discovery_processor, cluster_class_map, image_paths, labels, n_clusters=4):
    total_iou = 0
    gt_box_count = 0

    for img_path, label_list in zip(image_paths, labels):
        img = cv2.imread(img_path)
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        discovered_objects, _ = discovery_processor.discover_objects(img_rgb, n_clusters=n_clusters)
        
        # Map predicted cluster IDs to class IDs
        for obj in discovered_objects:
            obj['class_id'] = cluster_class_map.get(obj['cluster_id'], -1)

        for gt_obj in label_list:
            gt_box = gt_obj['box']
            gt_class_id = gt_obj['class_id']
            best_iou = 0
            for pred_obj in discovered_objects:
                if pred_obj['class_id'] == gt_class_id:
                    iou = calculate_iou(gt_box, pred_obj['box'])
                    if iou > best_iou:
                        best_iou = iou
            total_iou += best_iou
            gt_box_count += 1

    return total_iou / gt_box_count if gt_box_count > 0 else 0

## 5. End-to-End Inference and Violation Detection

In [None]:
violation_processor = ViolationProcessor(CONFIG)

# --- Video Processing ---
video_path = 'path/to/your/video.mp4' # <--- CHANGE THIS PATH

if not os.path.exists(video_path):
    print(f"Video file not found at {video_path}. Using single image for inference demo.")
    # Use the sample image from before for a single-frame demo
    discovered_objects, _ = discovery_processor.discover_objects(sample_image_rgb, n_clusters=4)
    for obj in discovered_objects:
        obj['class_id'] = cluster_class_map.get(obj['cluster_id'], -1)
    
    violations = violation_processor.process_violations(discovered_objects, sample_image_rgb)
    print(f"Violations found in sample image: {violations}")

else:
    cap = cv2.VideoCapture(video_path)
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        
        # 1. Discover objects
        discovered_objects, _ = discovery_processor.discover_objects(frame_rgb, n_clusters=4)
        
        # 2. Map cluster IDs to class IDs
        for obj in discovered_objects:
            obj['class_id'] = cluster_class_map.get(obj['cluster_id'], -1)
        
        # 3. Process for violations
        violations = violation_processor.process_violations(discovered_objects, frame_rgb)
        
        # 4. (Optional) Visualize the output
        # ... (visualization logic can be added here) ...
        
    cap.release()
    print("Video processing complete.")