In [1]:
import os
from pathlib import Path

import numpy as np
import timm
import shutil
import torch
from dotenv import load_dotenv
from timm.data import ImageDataset, create_loader
from torchvision import transforms
from ultralytics import YOLO

load_dotenv()

temp_dir = Path(os.getenv('TEMP_DIR')) / 'crop_tooth_image'
model_dir = Path(os.getenv('ViT_MODEL_DIR'))
data_dir = Path(os.getenv('DATASET_DIR')) / 'phase-2'
yolo_model_dir = Path(os.getenv('YOLO_MODEL_DIR'))
yolo_dir = yolo_model_dir / '..'


  warn(f"Failed to load image Python extension: {e}")


In [2]:
src = ['00006145.jpg', '00008026.jpg', '00008075.jpg']
src = [data_dir / i for i in src]

src


[PosixPath('/Users/lucyxu/PycharmProjects/datasets/phase-2/00006145.jpg'),
 PosixPath('/Users/lucyxu/PycharmProjects/datasets/phase-2/00008026.jpg'),
 PosixPath('/Users/lucyxu/PycharmProjects/datasets/phase-2/00008075.jpg')]

In [3]:
model = YOLO(yolo_model_dir / 'enumerate.pt')

results = model(src)



0: 320x640 1 11, 1 12, 2 13s, 2 14s, 2 15s, 2 16s, 1 17, 1 21, 1 22, 1 23, 1 26, 1 28, 1 31, 1 32, 1 33, 1 34, 1 35, 1 37, 1 41, 1 42, 1 43, 2 44s, 2 45s, 1 47, 2 48s, 1: 320x640 1 11, 1 12, 1 13, 1 14, 1 15, 1 16, 1 18, 1 21, 1 22, 1 23, 1 24, 1 25, 1 26, 1 27, 1 28, 1 31, 1 32, 1 33, 1 34, 1 35, 1 36, 1 37, 1 38, 1 41, 1 42, 1 43, 1 44, 1 45, 1 46, 1 47, 1 48, 2: 320x640 1 11, 1 12, 1 13, 1 14, 2 15s, 1 16, 1 17, 1 21, 1 22, 1 23, 1 24, 1 25, 1 26, 1 27, 1 28, 1 31, 1 32, 1 33, 1 34, 1 35, 1 36, 1 37, 1 41, 1 42, 1 43, 1 44, 1 45, 1 46, 2 47s, 140.1ms
Speed: 2.3ms preprocess, 46.7ms inference, 1.9ms postprocess per image at shape (1, 3, 640, 640)


In [4]:
for file in temp_dir.glob('*.jpg'):
    os.remove(file)

for result in results:
    filename = Path(result.path).stem
    result.save_crop(temp_dir, filename)

all_dir = list(temp_dir.glob('*'))

for im_path in list(temp_dir.glob('**/*.jpg')):
    filename = im_path.stem
    tooth_number = str(im_path).split('/')[-2]

    src_path = im_path
    dst_path = temp_dir / f'{filename}-{tooth_number}.jpg'

    shutil.move(src_path, dst_path)

for my_dir in all_dir:
    if os.path.isdir(my_dir):
        os.rmdir(my_dir)


In [16]:
len(list(temp_dir.glob("*.jpg")))


94

In [5]:
# Vit model loading
model_path = model_dir / 'yolov8-base.pt'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

vit_model = timm.create_model('swin_base_patch4_window7_224_in22k', num_classes=6)
vit_model.load_state_dict(torch.load(model_path, map_location=device))
vit_model.to(device)
vit_model.eval()


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [6]:
# Preprocess
transform = transforms.Compose([
    transforms.ToTensor(),
    # (lambda image: padding_to_size(image, 224)),
    transforms.Resize(size=(224, 224)),
    transforms.Normalize(mean=0.5, std=0.5),
])
target_transform = transforms.Compose([
    (lambda y: torch.Tensor(y)),
])
dataset = ImageDataset(temp_dir, transform=transform)

if torch.cuda.is_available():
    dataloader = create_loader(dataset, (3, 224, 224), 4)
else:
    dataloader = create_loader(dataset, (3, 224, 224), 4, use_prefetcher=False)

size = len(dataloader.dataset)


In [7]:
threshold = torch.Tensor([0.5, 0.85, 0.5, 0.5, 0.5, 0.5]).to(device)
pred_encodes = []
# target_labels = ['caries', 'endo', 'post', 'crown']
target_labels = ['R.R', 'caries', 'crown', 'endo', 'filling', 'post']
# target_labels = ['caries', 'crown', 'endo', 'filling', 'post']
with torch.no_grad():
    for batch, (X, _) in enumerate(dataloader):
        X = X.to(device)

        # Compute prediction error
        pred = vit_model(X)
        pred_encode = pred > threshold
        pred_encodes.append(pred_encode.cpu().numpy())

pred_encodes = np.vstack(pred_encodes)
pred_encodes = pred_encodes[:, 1:]
detected_list = [()] * len(pred_encodes)
for im_path, pred_encode in enumerate(pred_encodes):
    detected_list[im_path] = tuple((target_labels[j] for j, checker in enumerate(pred_encode) if checker))


In [12]:
pred_encodes


array([[False, False, False, False, False],
       [ True, False, False, False, False],
       [ True, False, False, False, False],
       [ True, False, False, False, False],
       [ True, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False,  True, False],
       [False, False,  True, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [ True, False, False, False, False],
       [ True, False, False, False, False],
       [ True, False, False, False, False],
       [False, False,  True, False, False],
       [ True, False, False, False, False],
       [False, False, False,  True, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [ True, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, Fal