# Dental X-Ray Quadrant Detection

This notebook demonstrates quadrant detection in dental X-rays using Detectron2.

In [None]:
from detectron2.data import DatasetCatalog, MetadataCatalog
import cv2
import numpy as np
import matplotlib.pyplot as plt
from pycocotools.coco import COCO

from src.models.config import get_quadrant_config
from src.training.training import DentalTrainer, train_quadrant_phase
from src.utils.visualization import DentalVisualizer, analyze_results
from src.utils.config import load_paths
from src.data.dataset import DentalDataset

## 1. Load Configuration

In [None]:
# Load paths from config
paths = load_paths()

# Get dataset paths
train_img_dir = paths["data"]["images"]["train"]
val_img_dir = paths["data"]["images"]["val"]
train_json = paths["data"]["annotations"]["train"]
val_json = paths["data"]["annotations"]["val"]

## 2. Dataset Registration

In [None]:
# Create and register training dataset
train_dataset = DentalDataset(data_dir=train_img_dir, json_file=train_json)
DatasetCatalog.register("quadrant_train", train_dataset.get_dataset_dicts)
MetadataCatalog.get("quadrant_train").set(thing_classes=train_dataset.get_class_names())

# Create and register validation dataset
val_dataset = DentalDataset(data_dir=val_img_dir, json_file=val_json)
DatasetCatalog.register("quadrant_val", val_dataset.get_dataset_dicts)
MetadataCatalog.get("quadrant_val").set(thing_classes=val_dataset.get_class_names())

# Get metadata for visualization
metadata = MetadataCatalog.get("quadrant_train")
metadata

## 3. Data Visualization

In [None]:
# Create visualizer
visualizer = DentalVisualizer()

# Get training data
train_data = train_dataset.get_dataset_dicts()

# Visualize a few samples
for i, d in enumerate(train_data[:3]):
    img = cv2.imread(d["file_name"])
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    visualizer.visualize_data_dict(img, d, metadata)

## 4. Model Configuration

In [None]:
# Get model configuration
cfg = get_quadrant_config()

# Display configuration
print("Training Configuration:")
print(f"Base LR: {cfg.SOLVER.BASE_LR}")
print(f"Max Iterations: {cfg.SOLVER.MAX_ITER}")
print(f"Batch Size: {cfg.SOLVER.IMS_PER_BATCH}")
print(f"\nModel Architecture:")
print(f"Backbone: {cfg.MODEL.BACKBONE.NAME}")
print(f"Number of Classes: {cfg.MODEL.ROI_HEADS.NUM_CLASSES}")

## 5. Training

In [None]:
# Train the model
weights_path = train_quadrant_phase(
    data_dir=train_img_dir,
    json_file=train_json,
    output_dir=paths["models"]["quadrant"]["output"]
)

## 6. Results Analysis

In [None]:
# Load trained model for inference
cfg.MODEL.WEIGHTS = weights_path
predictor = DefaultPredictor(cfg)

# Get validation data
val_data = val_dataset.get_dataset_dicts()

# Run inference on a few validation samples
for d in val_data[:3]:
    # Read image
    img = cv2.imread(d["file_name"])
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    # Run inference
    outputs = predictor(img)
    
    # Visualize results
    v = Visualizer(img, metadata=metadata)
    v = v.draw_instance_predictions(outputs["instances"].to("cpu"))
    plt.figure(figsize=(10, 10))
    plt.imshow(v.get_image())
    plt.axis("off")
    plt.show()
    
    # Analyze and print results
    analyze_results(outputs, metadata)