In [91]:
import json
import os
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
from dotenv import load_dotenv
from ultralytics.yolo.utils.plotting import Annotator


In [92]:
class LeyanDataset:
    def __init__(self, root_dir: Path, train=True, task='quadrant_enumeration'):
        self.root_dir = root_dir

        self.img_paths = list(root_dir.glob('*.jpg'))
        self.img_names = [i.name for i in self.img_paths]

        self.json_paths = list(root_dir.glob('*.json'))

        tooth_annotations = {}
        anomaly_annotations = {}
        for json_path in self.json_paths:
            image_name = json_path.with_suffix('.jpg').name
            with open(json_path) as f:
                json_content = json.load(f)
                shapes = json_content['shapes']

            tooth_annotations[image_name] = {}
            anomaly_annotations[image_name] = []

            for shape in shapes:
                label = shape['label']
                bbox = np.hstack(shape['points'])

                if label.isdigit():
                    tooth_annotations[image_name][label] = bbox
                else:
                    anomaly_annotations[image_name].append((label, bbox))

        self.tooth_annotations = tooth_annotations
        self.anomaly_annotations = anomaly_annotations

    def __len__(self):
        """Returns the number of x-ray images in the dataset."""
        return len(self.img_names)

    def __getitem__(self, idx):
        """Returns the x-ray image and annotations at the specified index."""
        img_name = self.img_names[idx]
        img_path = self.img_paths[idx]

        img = cv2.imread(str(img_path))

        tooth_annotation = self.tooth_annotations[img_name]
        anomaly_annotation = self.anomaly_annotations[img_name]
        return img, tooth_annotation, anomaly_annotation

    def plot(self, image_name):
        im = cv2.imread(str(self.root_dir / image_name))

        annotator = Annotator(im, line_width=3, example=image_name)
        for tooth_number, xyxy in self.tooth_annotations[image_name].items():
            annotator.box_label(xyxy, str(tooth_number), color=(255, 0, 0))

        im1 = annotator.result()
        plt.imshow(im1)
        plt.show()


In [93]:
load_dotenv()
data_dir = Path(os.getenv('DATASET_DIR')) / 'phase-3'

root_dir = data_dir

a = LeyanDataset(root_dir=data_dir)


In [94]:
b = sorted(list(a.tooth_annotations.keys()))

# a.plot(b[30])

for i in range(len(b)):
    c = a.tooth_annotations[b[i]]

    print(b[i], len(c.keys()))



00008199.jpg 27
00008200.jpg 28
00008207.jpg 22
00008210.jpg 28
00008217.jpg 22
00008218.jpg 23
00008223.jpg 25
00008225.jpg 30
00008227.jpg 16
00008228.jpg 28
00008231.jpg 30
00008238.jpg 32
00008239.jpg 1
00008240.jpg 27
00008241.jpg 1
00008243.jpg 1
00008245.jpg 26
00008246.jpg 28
00008250.jpg 31
00008254.jpg 15
00008256.jpg 25
00008257.jpg 27
00008270.jpg 30
00008271.jpg 26
00008272.jpg 32
00008276.jpg 0
00008279.jpg 0
00008280.jpg 25
00008290.jpg 21
00008298.jpg 30
00008306.jpg 26
00008309.jpg 26
00008314.jpg 28
00008317.jpg 29
00008318.jpg 28
00008321.jpg 32
00008322.jpg 32
00008323.jpg 14
00008325.jpg 23
00008328.jpg 31
00008329.jpg 17
00008332.jpg 26
00008333.jpg 32
00008335.jpg 30
00008336.jpg 28
00008338.jpg 29
00008339.jpg 28
00008345.jpg 32
00008351.jpg 31
00008355.jpg 31
00008356.jpg 26
00008358.jpg 27
00008370.jpg 24
00008373.jpg 31
00008376.jpg 30
00008377.jpg 26
00008378.jpg 25
00008380.jpg 25
00008384.jpg 20
00008391.jpg 26
00008393.jpg 30
00008396.jpg 29
00008397.jpg 

In [96]:
a.tooth_annotations


{'00008421.jpg': {'23': array([     1341.5,      622.97,      1449.1,      379.03]),
  '27': array([     1702.1,      597.21,      1812.7,      404.79]),
  '33': array([     1323.3,      650.24,        1390,      903.27]),
  '37': array([     1638.5,      644.18,      1829.4,      871.45]),
  '43': array([     1086.9,      654.79,      1038.5,      891.15]),
  '47': array([     744.49,      637.72,      569.15,      841.38]),
  '13': array([     1024.8,      629.03,         943,      385.09]),
  '17': array([     659.67,      603.27,      579.36,      368.42]),
  '11': array([     1092.2,      402.92,      1199.9,      632.15]),
  '12': array([     1027.6,      390.62,      1115.3,      618.31]),
  '14': array([     844.54,      390.62,      976.85,      622.92]),
  '15': array([     773.77,      387.54,      879.92,      624.46]),
  '16': array([     678.38,      395.23,      799.92,      615.23]),
  '18': array([     459.92,      379.85,      578.38,      609.08]),
  '21': array([   