# Unsupervised PPE Detection (Colab)

## 1. Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Install dependencies
!pip install -r requirements.txt

import sys
import os
from pathlib import Path

# --- Robust Path Setup ---
# Assumes this notebook is in the project root.
# Get the absolute path of the current working directory.
project_root = Path(os.getcwd())
sys.path.insert(0, str(project_root))

print(f"Project root set to: {project_root}")
print(f"System path updated: {sys.path[0]}")


from config import CONFIG
from data_utils import prepare_dataset
from unsupervised_trainer import UnsupervisedTrainer
from discovery_processor import DiscoveryProcessor
from violation_processor import ViolationProcessor

# --- Override Config Paths at Runtime ---
# The config file might have a hardcoded path. We'll override it here
# to ensure it uses the dynamically found project root.
CONFIG['project_root_path'] = str(project_root)
root = Path(CONFIG['project_root_path'])
out_base_abs = root / CONFIG['output_dir']

CONFIG['data_dir_abs'] = root / CONFIG['data_dir']
CONFIG['output_dir_abs'] = out_base_abs
CONFIG['checkpoint_dir_abs'] = out_base_abs / CONFIG['checkpoint_dir']

print(f"CONFIG 'data_dir_abs' updated to: {CONFIG['data_dir_abs']}")

## 2. Training

In [None]:
# Training overrides
training_overrides = {
    'frozen_layers': 12, # Example of overriding a parameter
    'data_fraction': 0.5 # Example of using a subset of data
}
CONFIG['model']['frozen_layers'] = training_overrides.get('frozen_layers', CONFIG['model']['frozen_layers'])
CONFIG['training']['data_fraction'] = training_overrides.get('data_fraction', CONFIG['training']['data_fraction'])

# Prepare the dataset
image_paths, labels = prepare_dataset(CONFIG['training']['data_fraction'])

# Create a PyTorch dataset and dataloader
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image

def custom_collate_fn(batch):
    images = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    images = torch.stack(images, 0)
    return images, labels

class PpeDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# You would define your transforms here
from torchvision import transforms
transform = transforms.Compose([
    transforms.Resize((518, 518)),
    transforms.ToTensor(),
])

dataset = PpeDataset(image_paths, labels, transform=transform)
data_loader = DataLoader(dataset, batch_size=CONFIG['training']['batch_size'], shuffle=True, collate_fn=custom_collate_fn)

# Initialize and run the trainer
run_training = True
if run_training:
    trainer = UnsupervisedTrainer(CONFIG)
    trainer.train(data_loader)

## 3. Discovery and Mapping

In [None]:
# Load the fine-tuned model (optional)
model_path = 'output/checkpoints/latest_checkpoint.pt'
discovery_processor = DiscoveryProcessor(CONFIG, model_path=model_path)

# Load a sample image from the validation set for discovery
import cv2
import random

valid_image_paths = [p for p in image_paths if 'valid' in p]
if valid_image_paths:
    sample_image_path = random.choice(valid_image_paths)
    print(f"Using sample image for discovery: {sample_image_path}")
    sample_image = cv2.imread(sample_image_path)
    sample_image_rgb = cv2.cvtColor(sample_image, cv2.COLOR_BGR2RGB)
    masks = discovery_processor.generate_object_masks(sample_image)

    # --- Visualize the results ---
    import matplotlib.pyplot as plt

    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.imshow(sample_image_rgb)
    plt.title('Original Image')
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(sample_image_rgb)
    plt.imshow(masks, cmap='jet', alpha=0.5) # Overlay masks
    plt.title('Image with Discovered Masks')
    plt.axis('off')

    plt.show()
else:
    print("No validation images found. Skipping mask generation.")

# Manual class mapping (example) - This would be done after visualizing masks
class_map = {1: 'person', 2: 'helmet', 3: 'vest'} # Example
CONFIG['discovery']['class_map'] = class_map


## 4. Inference and Violation Detection

In [None]:
violation_processor = ViolationProcessor(CONFIG)

# --- Video Processing ---
# To process a video, provide the path to your video file below.
# The code will loop through each frame, detect objects, and check for violations.

video_path = 'path/to/your/video.mp4' # <--- CHANGE THIS PATH

if not os.path.exists(video_path):
    print(f"Video file not found at {video_path}. Skipping inference.")
else:
    cap = cv2.VideoCapture(video_path)
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        
        # 1. Discover objects in the frame (this is a placeholder for the actual discovery output)
        # In a real pipeline, you would run your object discovery model on the 'frame'
        # and get a list of discovered objects.
        discovered_objects = [] # Placeholder
        
        # 2. Apply class map
        labeled_objects = discovery_processor.apply_class_map(discovered_objects)
        
        # 3. Process violations
        violations = violation_processor.process_violations(labeled_objects, frame)
        
        # 4. (Optional) Visualize the output
        # You can draw bounding boxes and violation alerts on the frame here.
        
    cap.release()
    print("Video processing complete.")

## 5. Validation

In [None]:
# --- Validation ---
# This cell evaluates the model's performance on the validation set.

def calculate_iou(boxA, boxB):
    # determine the (x, y)-coordinates of the intersection rectangle
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])
    # compute the area of intersection rectangle
    interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
    # compute the area of both the prediction and ground-truth
    # rectangles
    boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
    boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
    # compute the intersection over union by taking the intersection
    # area and dividing it by the sum of prediction + ground-truth
    # areas - the interesection area
    iou = interArea / float(boxAArea + boxBArea - interArea)
    return iou

def evaluate(model, data_loader, device):
    model.eval() # Set the model to evaluation mode
    all_ious = []

    with torch.no_grad():
        for images, targets in data_loader:
            images = list(img.to(device) for img in images)
            
            # Note: This is a placeholder for your model's prediction logic.
            # The output format will depend on your actual model.
            # For this example, we assume the model returns boxes in the same format as the target.
            predictions = model(images) # This line is illustrative

            for i, target in enumerate(targets):
                pred_boxes = predictions[i]['boxes']
                target_boxes = target['boxes']
                for t_box in target_boxes:
                    best_iou = 0
                    for p_box in pred_boxes:
                        iou = calculate_iou(t_box, p_box)
                        if iou > best_iou:
                            best_iou = iou
                    all_ious.append(best_iou)

    mean_iou = np.mean(all_ious) if all_ious else 0
    print(f"Mean IoU on Validation Set: {mean_iou:.4f}")
    return mean_iou

# Create a new DataLoader for the validation set
valid_dataset = PpeDataset([p for p in image_paths if 'valid' in p], [l for i, l in enumerate(labels) if 'valid' in image_paths[i]], transform=transform)
validation_loader = DataLoader(valid_dataset, batch_size=CONFIG['training']['batch_size'], collate_fn=custom_collate_fn)

# Run evaluation
# Note: The 'evaluate' function needs a model that produces predictions.
# The current UnsupervisedTrainer doesn't have a prediction method, so this is a template.
print("Validation logic is set up. You would call the 'evaluate' function with your trained model.")
# evaluate(trainer.model, validation_loader, device='cuda')