<a href="https://colab.research.google.com/github/Syed-MuhammadTaha/dental-assistant/blob/main/dental_analysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Dental Radiograph Analysis Pipeline
# End-to-End Implementation

import os
import json
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.models import efficientnet_v2_l
from PIL import Image
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from tqdm import tqdm

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

In [None]:
#############################################
# 1. DATASET CLASSES
#############################################

class SegmentationDataset(Dataset):
    def __init__(self, img_dir, json_dir, transform=None):
        self.img_dir = img_dir
        self.json_dir = json_dir
        self.transform = transform

        # Get list of all images with json annotations
        self.img_files = [f for f in os.listdir(img_dir) if f.endswith('.png') or f.endswith('.jpg')]
        self.img_files = [f for f in self.img_files if os.path.exists(os.path.join(json_dir, f.split('.')[0] + '.json'))]

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.img_dir, self.img_files[idx])
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Initialize mask and instance masks
        mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.int32)
        tooth_info = []

        # Load annotations
        json_path = os.path.join(self.json_dir, self.img_files[idx].split('.')[0] + '.json')
        with open(json_path, 'r') as f:
            ann_data = json.load(f)

        # Process each tooth annotation (FDI notation)
        for shape in ann_data['shapes']:
            label = shape['label']  # FDI tooth number (e.g., "21")

            # Skip if it's not a polygon or has no points
            if shape['shape_type'] != 'polygon' or len(shape['points']) < 3:
                continue

            # Convert points to numpy array for polygon drawing
            points = np.array(shape['points'], dtype=np.int32)

            try:
                # Parse FDI tooth number
                tooth_id = int(label)

                # Create binary mask for this tooth
                tooth_mask = np.zeros_like(mask)
                cv2.fillPoly(tooth_mask, [points], 1)

                # Assign a unique ID for each tooth based on FDI number
                mask[tooth_mask == 1] = tooth_id

                # Store tooth info
                tooth_info.append({
                    'tooth_id': tooth_id,
                    'points': points
                })
            except ValueError:
                # Skip if label is not a valid number
                print(f"Skipping invalid tooth label: {label}")
                continue

        # Apply transforms
        if self.transform:
            # For image transforms only (mask should stay as integers)
            transformed_image = self.transform(image)
            sample = {
                'image': transformed_image,
                'mask': torch.from_numpy(mask).long(),
                'tooth_info': tooth_info,
                'img_path': img_path
            }
        else:
            sample = {
                'image': torch.from_numpy(image.transpose((2, 0, 1))).float() / 255.0,
                'mask': torch.from_numpy(mask).long(),
                'tooth_info': tooth_info,
                'img_path': img_path
            }

        return sample

In [None]:
class ClassificationDataset(Dataset):
    def __init__(self, img_dir, json_dir, transform=None):
        self.img_dir = img_dir
        self.json_dir = json_dir
        self.transform = transform

        # Get all tooth subimages with condition annotations
        self.img_files = [f for f in os.listdir(img_dir) if f.endswith('.png') or f.endswith('.jpg')]
        self.img_files = [f for f in self.img_files if os.path.exists(os.path.join(json_dir, f.split('.')[0] + '.json'))]

        # Define tooth condition classes based on the dataset description
        self.condition_classes = {
            0: 'Tooth without anomalies',
            1: 'Tooth with fillings',
            2: 'Tooth with RCT',
            3: 'Tooth with crown',
            4: 'Tooth with caries',
            5: 'Residual root',
            6: 'Tooth with RCT and crown'
        }

        # For condition name to ID mapping
        self.condition_name_to_id = {v: k for k, v in self.condition_classes.items()}

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.img_dir, self.img_files[idx])
        image = Image.open(img_path).convert('RGB')

        # Load annotations
        json_path = os.path.join(self.json_dir, self.img_files[idx].split('.')[0] + '.json')
        with open(json_path, 'r') as f:
            ann_data = json.load(f)

        # Get tooth condition label - should be in filename or annotation
        # Try to extract from filename (if stored like "21_3.png" for tooth 21 with crown)
        condition_label = 0  # Default: Tooth without anomalies

        # Check for the condition in the group_id field (as requested)
        if 'group_id' in ann_data and ann_data['group_id'] is not None:
            try:
                condition_label = int(ann_data['group_id'])
            except (ValueError, TypeError):
                pass  # If group_id exists but is not a valid integer
        elif 'group_id' in ann_data:
            # If group_id is null, it means no condition (default to 0)
            condition_label = 0
        # Check other places if group_id doesn't contain the condition
        elif 'condition' in ann_data:
            condition_label = int(ann_data['condition'])
        elif 'attributes' in ann_data and 'condition' in ann_data['attributes']:
            condition_label = int(ann_data['attributes']['condition'])
        elif len(ann_data['shapes']) > 0 and 'attributes' in ann_data['shapes'][0]:
            # Sometimes condition is stored in attributes of the shape
            if 'condition' in ann_data['shapes'][0]['attributes']:
                condition_label = int(ann_data['shapes'][0]['attributes']['condition'])
        else:
            # Try to extract from filename (format: "toothID_conditionID.jpg")
            parts = os.path.splitext(self.img_files[idx])[0].split('_')
            if len(parts) > 1 and parts[-1].isdigit():
                condition_label = int(parts[-1])

        # Make sure condition is in valid range (0-6)
        condition_label = max(0, min(condition_label, 6))

        # Apply transforms
        if self.transform:
            image = self.transform(image)

        return {
            'image': image,
            'condition': condition_label,
            'condition_name': self.condition_classes[condition_label],
            'img_path': img_path
        }

In [None]:
#############################################
# 2. MODEL DEFINITIONS
#############################################

class YOLOv9SegmentationModel(nn.Module):
    def __init__(self, num_classes=49):  # 48 teeth in FDI + supernumerary (91) + background (0)
        super(YOLOv9SegmentationModel, self).__init__()
        # For simplicity, we're implementing a U-Net style architecture
        # In real implementation, load YOLOv9-e weights

        # Encoder
        self.enc1 = self._make_layer(3, 64)
        self.enc2 = self._make_layer(64, 128)
        self.enc3 = self._make_layer(128, 256)
        self.enc4 = self._make_layer(256, 512)

        # Decoder
        self.dec4 = self._make_layer(512, 256)
        self.dec3 = self._make_layer(512, 128)
        self.dec2 = self._make_layer(256, 64)
        self.dec1 = self._make_layer(128, 32)

        # Output
        self.final = nn.Conv2d(32, num_classes, kernel_size=1)

    def _make_layer(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(nn.MaxPool2d(2)(e1))
        e3 = self.enc3(nn.MaxPool2d(2)(e2))
        e4 = self.enc4(nn.MaxPool2d(2)(e3))

        # Decoder with skip connections
        d4 = self.dec4(e4)
        d3 = self.dec3(torch.cat([nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)(d4), e3], dim=1))
        d2 = self.dec2(torch.cat([nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)(d2), e1], dim=1))

        out = self.final(d1)
        return out

In [None]:
class EfficientNetClassificationModel(nn.Module):
    def __init__(self, num_classes=7):  # 7 dental conditions (0-6)
        super(EfficientNetClassificationModel, self).__init__()

        # Load pre-trained EfficientNetV2-L
        self.efficientnet = efficientnet_v2_l(pretrained=True)

        # Replace classifier
        in_features = self.efficientnet.classifier[1].in_features
        self.efficientnet.classifier = nn.Sequential(
            nn.Dropout(p=0.4, inplace=True),
            nn.Linear(in_features=in_features, out_features=num_classes)
        )

    def forward(self, x):
        return self.efficientnet(x)

In [None]:
#############################################
# 3. TRAINING FUNCTIONS
#############################################

def train_segmentation_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=10, scheduler=None):
    model.train()
    history = {'train_loss': [], 'val_loss': [], 'val_iou': []}
    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0

        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]"):
            images = batch['image'].to(device)
            masks = batch['mask'].to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, masks)

            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        epoch_train_loss = running_loss / len(train_loader)
        history['train_loss'].append(epoch_train_loss)

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_iou = 0.0

        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]"):
                images = batch['image'].to(device)
                masks = batch['mask'].to(device)

                outputs = model(images)
                loss = criterion(outputs, masks)
                val_loss += loss.item()

                # Calculate IoU
                preds = torch.argmax(outputs, dim=1)
                iou = calculate_iou(preds, masks)
                val_iou += iou

        epoch_val_loss = val_loss / len(val_loader)
        epoch_val_iou = val_iou / len(val_loader)

        history['val_loss'].append(epoch_val_loss)
        history['val_iou'].append(epoch_val_iou)

        # Step scheduler if provided
        if scheduler:
            scheduler.step(epoch_val_loss)

        # Save best model
        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            torch.save(model.state_dict(), 'best_segmentation_model.pth')

        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_train_loss:.4f}, "
              f"Val Loss: {epoch_val_loss:.4f}, Val IoU: {epoch_val_iou:.4f}")

    return history

In [None]:
def calculate_iou(pred, target):
    """Calculate Intersection over Union for segmentation masks"""
    # Convert to binary mask for each tooth
    unique_classes = torch.unique(target)
    total_iou = 0.0

    # Skip background (0)
    for cls in unique_classes:
        if cls == 0:
            continue

        pred_mask = (pred == cls).float()
        target_mask = (target == cls).float()

        intersection = (pred_mask * target_mask).sum()
        union = pred_mask.sum() + target_mask.sum() - intersection

        if union > 0:
            total_iou += (intersection / union).item()

    # Return mean IoU across all teeth
    if len(unique_classes) > 1:  # At least one tooth
        return total_iou / (len(unique_classes) - 1)
    return 0.0


def train_classification_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=10, scheduler=None):
    model.train()
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    best_val_acc = 0.0

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]"):
            images = batch['image'].to(device)
            labels = batch['condition'].to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_train_loss = running_loss / len(train_loader)
        epoch_train_acc = 100 * correct / total

        history['train_loss'].append(epoch_train_loss)
        history['train_acc'].append(epoch_train_acc)

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]"):
                images = batch['image'].to(device)
                labels = batch['condition'].to(device)

                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                # Calculate accuracy
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        epoch_val_loss = val_loss / len(val_loader)
        epoch_val_acc = 100 * val_correct / val_total

        history['val_loss'].append(epoch_val_loss)
        history['val_acc'].append(epoch_val_acc)

        # Step scheduler if provided
        if scheduler:
            scheduler.step(epoch_val_loss)

        # Save best model
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            torch.save(model.state_dict(), 'best_classification_model.pth')

        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_train_loss:.4f}, "
              f"Train Acc: {epoch_train_acc:.2f}%, Val Loss: {epoch_val_loss:.4f}, "
              f"Val Acc: {epoch_val_acc:.2f}%")

    return history

In [None]:
#############################################
# 4. INFERENCE/PREDICTION FUNCTIONS
#############################################

def process_tooth_region(tooth_image, classification_model, device, transform):
    """Process a single tooth region with the classification model"""
    # Convert to PIL for transforms
    tooth_pil = Image.fromarray(tooth_image)

    # Apply transforms
    tooth_tensor = transform(tooth_pil).unsqueeze(0).to(device)

    # Get prediction
    classification_model.eval()
    with torch.no_grad():
        outputs = classification_model(tooth_tensor)
        probs = torch.nn.functional.softmax(outputs, dim=1)
        conf, pred_class = torch.max(probs, 1)

    return pred_class.item(), conf.item()


def extract_tooth_regions(image, mask):
    """Extract individual tooth regions from the segmentation mask using FDI numbering"""
    unique_ids = np.unique(mask)
    tooth_regions = []

    # Skip background (0)
    for tooth_id in unique_ids:
        if tooth_id == 0:
            continue

        # Create binary mask for this tooth
        tooth_mask = (mask == tooth_id).astype(np.uint8)

        # Find bounding box
        contours, _ = cv2.findContours(tooth_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        if contours:
            x, y, w, h = cv2.boundingRect(contours[0])

            # Add margin (ROI*1.5 as shown in diagram)
            margin = int(max(w, h) * 0.25)  # 1.5x = original + 0.5x margin
            x_min = max(0, x - margin)
            y_min = max(0, y - margin)
            x_max = min(image.shape[1], x + w + margin)
            y_max = min(image.shape[0], y + h + margin)

            # Extract region
            tooth_region = image[y_min:y_max, x_min:x_max]

            # For FDI numbering, we directly use the tooth_id
            # In FDI notation, 11-48 are standard teeth, 91 is supernumerary
            tooth_regions.append({
                'id': int(tooth_id),  # FDI tooth number
                'region': tooth_region,
                'bbox': (x_min, y_min, x_max, y_max)
            })

    return tooth_regions


def generate_results_visualization(image, mask, tooth_predictions):
    """Create visualization with colored segmentation and tooth numbering"""
    # Create copy of the image
    vis_image = image.copy()

    # Define colors for visualization (colorful as in the image)
    colors = [
        (0, 255, 0),    # Green
        (255, 0, 0),    # Red
        (0, 0, 255),    # Blue
        (255, 255, 0),  # Yellow
        (255, 0, 255),  # Magenta
        (0, 255, 255),  # Cyan
        (128, 0, 0),    # Maroon
        (0, 128, 0),    # Dark Green
        (0, 0, 128),    # Navy
        (128, 128, 0),  # Olive
        (128, 0, 128),  # Purple
    ]

    # Create overlay mask
    overlay = np.zeros_like(vis_image)
    unique_ids = np.unique(mask)

    # Create a mapping of tooth ID to consistent color
    # This ensures each tooth type (incisors, canines, etc.) gets same color
    tooth_color_map = {}
    for tooth_id in unique_ids:
        if tooth_id == 0:
            continue  # Skip background

        # FDI notation: first digit is quadrant, second is tooth position
        # Use position for color to ensure symmetry
        if tooth_id == 91:  # Supernumerary
            tooth_position = 9
        else:
            tooth_position = tooth_id % 10

        if tooth_position not in tooth_color_map:
            tooth_color_map[tooth_position] = colors[tooth_position % len(colors)]

    # Color each tooth
    for tooth_id in unique_ids:
        if tooth_id == 0:
            continue  # Skip background

        tooth_position = tooth_id % 10 if tooth_id != 91 else 9
        color = tooth_color_map[tooth_position]
        overlay[mask == tooth_id] = color

    # Blend with original image
    alpha = 0.5
    vis_image = cv2.addWeighted(vis_image, 1-alpha, overlay, alpha, 0)

    # Add tooth numbers and conditions
    for pred in tooth_predictions:
        tooth_id = pred['tooth_id']
        condition = pred['condition_name']
        bbox = pred['bbox']

        # Get center point of tooth for labeling
        center_x = (bbox[0] + bbox[2]) // 2
        center_y = (bbox[1] + bbox[3]) // 2

        # Draw tooth number
        cv2.putText(vis_image, str(tooth_id), (center_x, center_y),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)

        # Draw condition indicator (small circle with color based on condition)
        condition_colors = {
            'Tooth without anomalies': (0, 255, 0),      # Green
            'Tooth with fillings': (0, 255, 255),        # Cyan
            'Tooth with RCT': (0, 0, 255),               # Blue
            'Tooth with crown': (255, 0, 255),           # Magenta
            'Tooth with caries': (255, 0, 0),            # Red
            'Residual root': (128, 0, 0),                # Maroon
            'Tooth with RCT and crown': (255, 255, 0)    # Yellow
        }

        condition_color = condition_colors.get(condition, (255, 255, 255))
        cv2.circle(vis_image, (center_x, center_y + 15), 5, condition_color, -1)

    return vis_image


def generate_dental_report(tooth_predictions):
    """Generate the dental report as shown in the image"""
    # Define condition names according to the dataset
    condition_names = {
        0: 'Tooth without anomalies',
        1: 'Tooth with fillings',
        2: 'Tooth with RCT',
        3: 'Tooth with crown',
        4: 'Tooth with caries',
        5: 'Residual root',
        6: 'Tooth with RCT and crown'
    }

    # Count each condition
    condition_counts = {name: 0 for name in condition_names.values()}

    # Missing teeth detection (compare with expected FDI numbers)
    # Standard full dentition has teeth 11-18, 21-28, 31-38, 41-48
    all_teeth_fdi = set(list(range(11, 19)) + list(range(21, 29)) +
                         list(range(31, 39)) + list(range(41, 49)))

    # Get the teeth we've found
    found_teeth = set(pred['tooth_id'] for pred in tooth_predictions)

    # Missing teeth are those in all_teeth_fdi but not in found_teeth
    missing_teeth = all_teeth_fdi - found_teeth

    # Count conditions for found teeth
    for pred in tooth_predictions:
        condition_counts[pred['condition_name']] += 1

    # Add missing teeth count
    condition_counts['Missing tooth'] = len(missing_teeth)

    # Format the report similar to the image
    report = {
        'Missing tooth': condition_counts['Missing tooth'],
        'Tooth with fillings': condition_counts['Tooth with fillings'],
        'Tooth with RCT': condition_counts['Tooth with RCT'],
        'Tooth with crown': condition_counts['Tooth with crown'],
        'Tooth with caries': condition_counts['Tooth with caries'],
        'Residual root': condition_counts['Residual root'],
        'Tooth with RCT + crown': condition_counts['Tooth with RCT and crown']
    }

    return report


def predict_panoramic_radiograph(image_path, seg_model, cls_model, device, seg_transform, cls_transform):
    """End-to-end prediction pipeline for a panoramic dental radiograph"""
    # Load image
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Prepare for segmentation model
    input_tensor = seg_transform(Image.fromarray(image)).unsqueeze(0).to(device)

    # Segment teeth
    seg_model.eval()
    with torch.no_grad():
        seg_output = seg_model(input_tensor)
        seg_mask = torch.argmax(seg_output, dim=1).cpu().numpy()[0]

    # Extract individual tooth regions
    tooth_regions = extract_tooth_regions(image, seg_mask)

    # Initialize condition class mapping
    condition_mapping = {
        0: 'Tooth without anomalies',
        1: 'Tooth with fillings',
        2: 'Tooth with RCT',
        3: 'Tooth with crown',
        4: 'Tooth with caries',
        5: 'Residual root',
        6: 'Tooth with RCT and crown'
    }

    # Process each tooth with classification model
    tooth_predictions = []
    for tooth in tooth_regions:
        condition_id, confidence = process_tooth_region(
            tooth['region'], cls_model, device, cls_transform
        )

        tooth_predictions.append({
            'tooth_id': tooth['id'],
            'condition': condition_id,
            'condition_name': condition_mapping[condition_id],
            'confidence': confidence,
            'bbox': tooth['bbox']
        })

    # Generate visualization
    result_image = generate_results_visualization(image, seg_mask, tooth_predictions)

    # Generate report
    report = generate_dental_report(tooth_predictions)

    return {
        'original_image': image,
        'segmentation_mask': seg_mask,
        'result_visualization': result_image,
        'tooth_predictions': tooth_predictions,
        'report': report
    }

In [None]:
def extract_teeth_from_panoramics(dataset_info, seg_transform=None, device=None):
    """Extract individual teeth from panoramic radiographs for classification training"""
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    img_dir = dataset_info['img_dir']
    json_dir = dataset_info['json_dir']
    teeth_output_dir = dataset_info['teeth_output_dir']
    teeth_ann_dir = dataset_info['teeth_ann_dir']
    paired_files = dataset_info['paired_files']

    # Load segmentation model for extracting teeth
    seg_model = YOLOv9SegmentationModel(num_classes=49)  # 48 teeth + background

    # Try to load pre-trained weights if available
    if os.path.exists('best_segmentation_model.pth'):
        seg_model.load_state_dict(torch.load('best_segmentation_model.pth', map_location=device))
        print("Loaded pre-trained segmentation model.")

    seg_model = seg_model.to(device)
    seg_model.eval()

    # Default transform if none provided
    if seg_transform is None:
        seg_transform = transforms.Compose([
            transforms.Resize((512, 512)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    # Process each panoramic image
    extracted_count = 0

    for img_file in tqdm(paired_files, desc="Extracting teeth"):
        # Load image
        img_path = os.path.join(img_dir, img_file)
        image = cv2.imread(img_path)
        if image is None:
            print(f"Warning: Could not read image {img_path}")
            continue
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Load annotation
        json_path = os.path.join(json_dir, os.path.splitext(img_file)[0] + '.json')
        with open(json_path, 'r') as f:
            ann_data = json.load(f)

        # Process shapes to get tooth annotations
        teeth_annotations = {}

        for shape in ann_data['shapes']:
            if shape['shape_type'] != 'polygon' or len(shape['points']) < 3:
                continue

            # Get tooth ID from label
            try:
                tooth_id = int(shape['label'])
                points = np.array(shape['points'], dtype=np.int32)

                # Get tooth condition from attributes or group_id
                condition = 0  # Default: No anomalies

                # Check for condition in group_id first
                if 'group_id' in ann_data and ann_data['group_id'] is not None:
                    try:
                        condition = int(ann_data['group_id'])
                    except (ValueError, TypeError):
                        pass

                # Next, check for condition in shape attributes
                if ('group_id' not in ann_data or ann_data['group_id'] is None) and 'attributes' in shape:
                    if 'condition' in shape['attributes']:
                        try:
                            condition = int(shape['attributes']['condition'])
                        except (ValueError, TypeError):
                            pass

                # Check if condition is valid (0-6)
                condition = max(0, min(condition, 6))

                # Create binary mask for this tooth
                tooth_mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
                cv2.fillPoly(tooth_mask, [points], 1)

                # Find bounding box
                contours, _ = cv2.findContours(tooth_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
                if contours:
                    x, y, w, h = cv2.boundingRect(contours[0])

                    # Add margin (1.5x original size)
                    margin = int(max(w, h) * 0.25)
                    x_min = max(0, x - margin)
                    y_min = max(0, y - margin)
                    x_max = min(image.shape[1], x + w + margin)
                    y_max = min(image.shape[0], y + h + margin)

                    # Extract region
                    tooth_region = image[y_min:y_max, x_min:x_max]

                    # Save tooth image with FDI_condition format
                    tooth_img_name = f"{tooth_id}_{condition}.png"
                    tooth_img_path = os.path.join(teeth_output_dir, tooth_img_name)
                    cv2.imwrite(tooth_img_path, cv2.cvtColor(tooth_region, cv2.COLOR_RGB2BGR))

                    # Create and save annotation for this tooth
                    tooth_ann = {
                        "version": "1.0",
                        "flags": {},
                        "shapes": [],
                        "imagePath": tooth_img_name,
                        "imageHeight": tooth_region.shape[0],
                        "imageWidth": tooth_region.shape[1],
                        "group_id": condition  # Store condition in group_id as requested
                    }

                    # Save annotation
                    tooth_ann_path = os.path.join(teeth_ann_dir, os.path.splitext(tooth_img_name)[0] + '.json')
                    with open(tooth_ann_path, 'w') as f:
                        json.dump(tooth_ann, f, indent=2)

                    extracted_count += 1

            except ValueError:
                # Skip if label is not a valid number
                print(f"Skipping invalid tooth label: {shape['label']}")
                continue

    print(f"Extracted {extracted_count} individual teeth for classification training.")
    return extracted_count


def dental_radiograph_analysis_pipeline(
    data_dir,
    output_dir,
    batch_size=8,
    num_epochs_seg=30,
    num_epochs_cls=30,
    learning_rate=0.001,
    device=None
):
    """Complete pipeline for dental radiograph analysis"""
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print(f"Using device: {device}")

    # Prepare dataset
    dataset_info = prepare_dataset(data_dir, output_dir)

    # Define transforms
    seg_transform = transforms.Compose([
        transforms.Resize((512, 512)),  # Resize for consistency
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    cls_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Create datasets and extract individual teeth if needed
    # Check if we already have individual teeth extracted
    teeth_files = [f for f in os.listdir(dataset_info['teeth_output_dir'])
                   if f.endswith('.png') or f.endswith('.jpg')]

    if len(teeth_files) < 100:  # Arbitrary threshold - extract teeth if we don't have enough
        print("Extracting individual teeth from panoramic radiographs...")
        extract_teeth_from_panoramics(dataset_info, seg_transform, device)

    # 1. SEGMENTATION MODEL TRAINING

    # Create segmentation dataset
    seg_dataset = SegmentationDataset(
        img_dir=dataset_info['img_dir'],
        json_dir=dataset_info['json_dir'],
        transform=seg_transform
    )

    # Split dataset
    train_indices, val_indices = train_test_split(
        range(len(seg_dataset)), test_size=0.2, random_state=42
    )

    # Create data loaders
    train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
    val_sampler = torch.utils.data.SubsetRandomSampler(val_indices)

    train_loader = DataLoader(
        seg_dataset, batch_size=batch_size, sampler=train_sampler,
        num_workers=4, pin_memory=True
    )

    val_loader = DataLoader(
        seg_dataset, batch_size=batch_size, sampler=val_sampler,
        num_workers=4, pin_memory=True
    )

    # Initialize segmentation model
    seg_model = YOLOv9SegmentationModel(num_classes=49).to(device)  # 48 teeth + background

    # Check if we should load pretrained weights
    if os.path.exists('best_segmentation_model.pth'):
        seg_model.load_state_dict(torch.load('best_segmentation_model.pth', map_location=device))
        print("Loaded pretrained segmentation model weights.")
    else:
        # Define loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(seg_model.parameters(), lr=learning_rate)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=5, verbose=True
        )

        # Train the model
        print("Training segmentation model...")
        seg_history = train_segmentation_model(
            seg_model, train_loader, val_loader, criterion, optimizer,
            device, num_epochs=num_epochs_seg, scheduler=scheduler
        )

        # Plot training history
        plt.figure(figsize=(12, 4))
        plt.subplot(1, 2, 1)
        plt.plot(seg_history['train_loss'], label='Train Loss')
        plt.plot(seg_history['val_loss'], label='Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()

        plt.subplot(1, 2, 2)
        plt.plot(seg_history['val_iou'], label='Validation IoU')
        plt.xlabel('Epoch')
        plt.ylabel('IoU')
        plt.legend()
        plt.savefig(os.path.join(output_dir, 'seg_training_history.png'))

    # 2. CLASSIFICATION MODEL TRAINING

    # Create classification dataset
    cls_dataset = ClassificationDataset(
        img_dir=dataset_info['teeth_output_dir'],
        json_dir=dataset_info['teeth_ann_dir'],
        transform=cls_transform
    )

    # Split dataset
    train_indices, val_indices = train_test_split(
        range(len(cls_dataset)), test_size=0.2, random_state=42
    )

    # Create data loaders
    train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
    val_sampler = torch.utils.data.SubsetRandomSampler(val_indices)

    train_loader = DataLoader(
        cls_dataset, batch_size=batch_size, sampler=train_sampler,
        num_workers=4, pin_memory=True
    )

    val_loader = DataLoader(
        cls_dataset, batch_size=batch_size, sampler=val_sampler,
        num_workers=4, pin_memory=True
    )

    # Initialize classification model
    cls_model = EfficientNetClassificationModel(num_classes=7).to(device)  # 7 dental conditions

    # Check if we should load pretrained weights
    if os.path.exists('best_classification_model.pth'):
        cls_model.load_state_dict(torch.load('best_classification_model.pth', map_location=device))
        print("Loaded pretrained classification model weights.")
    else:
        # Define loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(cls_model.parameters(), lr=learning_rate)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=5, verbose=True
        )

        # Train the model
        print("Training classification model...")
        cls_history = train_classification_model(
            cls_model, train_loader, val_loader, criterion, optimizer,
            device, num_epochs=num_epochs_cls, scheduler=scheduler
        )

        # Plot training history
        plt.figure(figsize=(12, 4))
        plt.subplot(1, 2, 1)
        plt.plot(cls_history['train_loss'], label='Train Loss')
        plt.plot(cls_history['val_loss'], label='Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()

        plt.subplot(1, 2, 2)
        plt.plot(cls_history['train_acc'], label='Train Accuracy')
        plt.plot(cls_history['val_acc'], label='Validation Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy (%)')
        plt.legend()
        plt.savefig(os.path.join(output_dir, 'cls_training_history.png'))

    # 3. TEST ON A SAMPLE IMAGE

    # Get a sample test image from validation set
    val_imgs = [os.path.join(dataset_info['img_dir'], dataset_info['paired_files'][i])
                for i in val_indices[:5]]  # Get 5 sample validation images

    # Process each test image
    for test_img_path in val_imgs:
        print(f"Processing test image: {test_img_path}")

        # Run prediction pipeline
        results = predict_panoramic_radiograph(
            test_img_path, seg_model, cls_model, device, seg_transform, cls_transform
        )

        # Save visualization
        vis_path = os.path.join(output_dir, f"result_{os.path.basename(test_img_path)}")
        cv2.imwrite(vis_path, cv2.cvtColor(results['result_visualization'], cv2.COLOR_RGB2BGR))

        # Print report
        print("\nDental Report:")
        for condition, count in results['report'].items():
            print(f"  {condition}: {count}")

    print(f"\nAll results saved to {output_dir}")
    return seg_model, cls_model


def create_dental_report_dashboard(results, output_path):
    """Create an HTML dashboard for dental reports"""
    # Create DataFrame for pie chart
    conditions = list(results['report'].keys())
    counts = list(results['report'].values())
    df = pd.DataFrame({'Condition': conditions, 'Count': counts})

    # Save results visualization
    vis_path = os.path.join(os.path.dirname(output_path), 'visualization.png')
    cv2.imwrite(vis_path, cv2.cvtColor(results['result_visualization'], cv2.COLOR_RGB2BGR))

    # Create HTML content
    html_content = f"""
    <!DOCTYPE html>
    <html>
    <head>
        <title>Dental Radiograph Analysis Report</title>
        <style>
            body {{ font-family: Arial, sans-serif; margin: 20px; }}
            .container {{ display: flex; flex-wrap: wrap; }}
            .image-container {{ flex: 1; min-width: 600px; margin-right: 20px; }}
            .report-container {{ flex: 1; min-width: 400px; }}
            table {{ border-collapse: collapse; width: 100%; margin-top: 20px; }}
            th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
            th {{ background-color: #f2f2f2; }}
            tr:nth-child(even) {{ background-color: #f9f9f9; }}
            h2 {{ color: #333; }}
            .tooth-list {{ columns: 2; }}
        </style>
    </head>
    <body>
        <h1>Dental Radiograph Analysis Report</h1>
        <div class="container">
            <div class="image-container">
                <h2>Visualization</h2>
                <img src="visualization.png" alt="Dental Radiograph Analysis" style="max-width: 100%;">
            </div>
            <div class="report-container">
                <h2>Summary Report</h2>
                <table>
                    <tr>
                        <th>Condition</th>
                        <th>Count</th>
                    </tr>
    """

    # Add rows for each condition
    for condition, count in results['report'].items():
        html_content += f"""
                    <tr>
                        <td>{condition}</td>
                        <td>{count}</td>
                    </tr>
        """

    html_content += """
                </table>

                <h2>Detailed Tooth Analysis</h2>
                <div class="tooth-list">
                    <ul>
    """

    # Add details for each tooth
    for tooth in results['tooth_predictions']:
        html_content += f"""
                        <li>Tooth {tooth['tooth_id']}: {tooth['condition_name']} (Confidence: {tooth['confidence']:.2f})</li>
        """

    html_content += """
                    </ul>
                </div>
            </div>
        </div>
    </body>
    </html>
    """

    # Write HTML to file
    with open(output_path, 'w') as f:
        f.write(html_content)

    print(f"Dashboard created at {output_path}")

In [None]:
#############################################
# ENTRY POINT
#############################################

if __name__ == "__main__":
    # Setup command line arguments
    import argparse

    parser = argparse.ArgumentParser(description='Dental Radiograph Analysis Pipeline')
    parser.add_argument('--data_dir', type=str, required=True, help='Path to dataset directory')
    parser.add_argument('--output_dir', type=str, default='output', help='Path to output directory')
    parser.add_argument('--batch_size', type=int, default=8, help='Batch size for training')
    parser.add_argument('--epochs_seg', type=int, default=30, help='Number of epochs for segmentation training')
    parser.add_argument('--epochs_cls', type=int, default=30, help='Number of epochs for classification training')
    parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
    parser.add_argument('--test_img', type=str, default=None, help='Path to a test image (optional)')

    args = parser.parse_args()

    # Run the pipeline
    seg_model, cls_model = dental_radiograph_analysis_pipeline(
        data_dir=args.data_dir,
        output_dir=args.output_dir,
        batch_size=args.batch_size,
        num_epochs_seg=args.epochs_seg,
        num_epochs_cls=args.epochs_cls,
        learning_rate=args.lr
    )

    # If a test image is provided, run the prediction
    if args.test_img:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Define transforms
        seg_transform = transforms.Compose([
            transforms.Resize((512, 512)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        cls_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        # Run prediction
        results = predict_panoramic_radiograph(
            args.test_img, seg_model, cls_model, device, seg_transform, cls_transform
        )

        # Save visualization
        vis_path = os.path.join(args.output_dir, f"result_{os.path.basename(args.test_img)}")
        cv2.imwrite(vis_path, cv2.cvtColor(results['result_visualization'], cv2.COLOR_RGB2BGR))

        # Create dashboard
        dashboard_path = os.path.join(args.output_dir, f"report_{os.path.basename(args.test_img).split('.')[0]}.html")
        create_dental_report_dashboard(results, dashboard_path)