# Dental Disease Detection & Segmentation Training
## Optimized for Google Colab

This notebook provides a complete workflow to train **two separate models** for the D-Chart ML Service:
1. **Anatomy Model**: Segments each tooth and assigns it an ISO number (11-48).
2. **Pathology Model**: Detects diseases and restorations (15 Classes).

### Post-Training Steps
1. Download `anatomy_model.pth` and `pathology_model.pth` generated by this notebook.
2. Place them in your `ml_service/app/models/` directory.
3. Update the service code to load these weights.

In [None]:
# 1. Setup Environment
!pip install -U torch torchvision pycocotools opencv-python-headless matplotlib

import os
import json
import random
import numpy as np
import cv2
import torch
import torch.utils.data
from PIL import Image
import torchvision
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

# Ensure GPU availability
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Using device: {device}")

In [None]:
# 2. Dataset Definition
class DentalDataset(torch.utils.data.Dataset):
    def __init__(self, root, mode='pathology', transforms=None):
        """
        mode: 'pathology' (Diseases) or 'anatomy' (Teeth)
        """
        self.root = root
        self.transforms = transforms
        self.mode = mode
        self.imgs = list(sorted(os.listdir(os.path.join(root, "images"))))

    def __getitem__(self, idx):
        img_path = os.path.join(self.root, "images", self.imgs[idx])
        img = Image.open(img_path).convert("RGB")
        
        # --- CRITICAL: DATASET PARSING LOGIC ---
        # You must implement the logic to load YOUR annotations (COCO JSON or Masks).
        
        # FOR ANATOMY MODE:
        # Expected Labels: 1 to 32 (corresponding to ISO 11-48 as per anatomy.py mapping).
        # Example: Label 1 = Tooth 18, Label 16 = Tooth 28.
        # The model LEARNS to identify "Tooth 18" by seeing examples of Class 1.
        
        # FOR PATHOLOGY MODE:
        # Expected Labels: 1 to 15 (Caries, Filling, Missing, etc.).
        # Example: Label 13 = Missing. The model LEARNS to find "Gaps" labeled as Class 13.
        
        target = {}
        # REPLACE THESE DUMMY VALUES with your actual loaded annotations:
        target["boxes"] = torch.as_tensor([[0,0,10,10]], dtype=torch.float32) 
        target["labels"] = torch.as_tensor([1], dtype=torch.int64) 
        target["masks"] = torch.zeros((1, 100, 100), dtype=torch.uint8)
        
        target["image_id"] = torch.tensor([idx])
        target["area"] = torch.as_tensor([100], dtype=torch.float32)
        target["iscrowd"] = torch.zeros((1,), dtype=torch.int64)

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.imgs)

In [None]:
# 3. Model Helper
def get_model(num_classes):
    model = maskrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
    return model

## 4. Train Anatomy Model (Teeth Detection)
This model learns to identify individual teeth (e.g. Tooth 18, Tooth 46).
**Classes**: Background (0) + 32 Teeth = **33 Classes**.

In [None]:
# --- ANATOMY TRAINING ---
print("Training Anatomy Model...")
num_classes_anatomy = 33 # 0=BG, 1..32=Teeth
model_anatomy = get_model(num_classes_anatomy)
model_anatomy.to(device)

# Setup Optimizer, DataLoader
# dataset_anatomy = DentalDataset('path/to/teeth_data', mode='anatomy')
# ... training loop specific to Anatomy ...

# Save
torch.save(model_anatomy.state_dict(), "anatomy_model.pth")
print("Saved anatomy_model.pth")

## 5. Train Pathology Model (Disease Detection)
This model learns to identify diseases (Caries, Pulpitis, etc.).
**Classes**: Background (0) + 14 Pathologies = **15 Classes**.
Classes: Caries, Apical, Pulpitis, Filling, Crown, Implant, RootFrag, Fracture, Tartar, BoneLoss, Supernumerary, Impacted, Missing, EnamelDefect

In [None]:
# --- PATHOLOGY TRAINING ---
print("Training Pathology Model...")
num_classes_pathology = 15 # 0=BG, 1..14=Diseases
model_pathology = get_model(num_classes_pathology)
model_pathology.to(device)

# Setup Optimizer, DataLoader
# dataset_pathology = DentalDataset('path/to/disease_data', mode='pathology')
# ... training loop specific to Pathology ...

# Save
torch.save(model_pathology.state_dict(), "pathology_model.pth")
print("Saved pathology_model.pth")

## 6. Deployment Instructions
1. **Download**: Use `files.download('anatomy_model.pth')` in Colab.
2. **Deploy**: Put files in `ml_service/app/models/`.
3. **Update Code**: In `anatomy.py` and `pathology.py` update `_load_model` to:
   ```python
   model.load_state_dict(torch.load("app/models/anatomy_model.pth"))
   ```