<a href="https://colab.research.google.com/github/RyosukeHanaoka/TechTeacher_New/blob/main/vit_eval_mode_new.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [7]:
!pip install pyheif rembg torch timm

Collecting pyheif
  Downloading pyheif-0.7.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (9.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.8/9.8 MB[0m [31m29.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting rembg
  Downloading rembg-2.0.57-py3-none-any.whl (33 kB)
Collecting onnxruntime (from rembg)
  Downloading onnxruntime-1.18.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (6.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m82.2 MB/s[0m eta [36m0:00:00[0m
Collecting pymatting (from rembg)
  Downloading PyMatting-1.1.12-py3-none-any.whl (52 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.0/53.0 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
Collecting coloredlogs (from onnxruntime->rembg)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00

In [16]:
import os
import cv2
import torch
import numpy as np
import pyheif
from PIL import Image
from rembg import remove
from torchvision import transforms
import timm
import torch.nn.functional as F

class RheumatoidArthritisDetector:
    def __init__(self, model_checkpoint, device=None):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.load_model(model_checkpoint)
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

    def load_model(self, model_checkpoint):
        model = timm.create_model('vit_base_patch16_224.augreg_in21k', pretrained=False, num_classes=2)
        checkpoint = torch.load(model_checkpoint, map_location=self.device)
        model.load_state_dict(checkpoint)
        model.to(self.device)
        model.eval()
        return model

    def convert_heic_to_image(self, heic_path):
        heif_file = pyheif.read(heic_path)
        return Image.frombytes(
            heif_file.mode,
            heif_file.size,
            heif_file.data,
            "raw",
            heif_file.mode,
            heif_file.stride,
        )

    def convert_to_jpg(self, input_directory, output_directory):
        os.makedirs(output_directory, exist_ok=True)
        for filename in os.listdir(input_directory):
            file_path = os.path.join(input_directory, filename)
            output_file_path = os.path.join(output_directory, os.path.splitext(filename)[0] + ".jpg")

            if filename.lower().endswith(".heic"):
                img = self.convert_heic_to_image(file_path)
            elif filename.lower().endswith(".pdf"):
                images = Image.open(file_path).convert("RGB")
                images.save(output_file_path, "JPEG")
                continue
            else:
                img = Image.open(file_path)

            img = img.convert("RGB")  # Convert RGBA to RGB
            img.save(output_file_path, "JPEG")
            img.close()

    def remove_background(self, input_directory):
        for filename in os.listdir(input_directory):
            if filename.lower().endswith(".jpg"):
                file_path = os.path.join(input_directory, filename)
                output_file_path = file_path  # Overwrite the same file

                input_image = Image.open(file_path)
                output_image = remove(input_image)
                output_image = output_image.convert("RGB")  # Ensure image is in RGB mode
                output_image.save(output_file_path)
                input_image.close()

    def flip_images(self, input_directory):
        for filename in os.listdir(input_directory):
            if filename.lower().endswith(".jpg"):
                file_path = os.path.join(input_directory, filename)
                img = Image.open(file_path)
                flipped_img = img.transpose(Image.FLIP_LEFT_RIGHT)
                flipped_img.save(file_path)
                img.close()

    def preprocess_image(self, image_path):
        image = Image.open(image_path)
        return self.transform(image).unsqueeze(0).to(self.device)

    def predict(self, image_tensor):
        with torch.no_grad():
            output = self.model(image_tensor)
            probabilities = F.softmax(output, dim=1)
        return probabilities[0][1].item()  # Rheumatoid arthritis class probability

    def detect_rheumatoid_arthritis(self, right_hand_dir, left_hand_dir):
        self.convert_to_jpg(right_hand_dir, right_hand_dir)
        self.convert_to_jpg(left_hand_dir, left_hand_dir)

        self.remove_background(right_hand_dir)
        self.remove_background(left_hand_dir)

        self.flip_images(left_hand_dir)

        right_hand_images = [img for img in os.listdir(right_hand_dir) if img.lower().endswith(".jpg")]
        left_hand_images = [img for img in os.listdir(left_hand_dir) if img.lower().endswith(".jpg")]

        if not right_hand_images:
            raise ValueError("右手の画像が存在しません。")
        if not left_hand_images:
            raise ValueError("左手の画像が存在しません。")

        right_hand_results = [self.predict(self.preprocess_image(os.path.join(right_hand_dir, img))) for img in right_hand_images]
        left_hand_results = [self.predict(self.preprocess_image(os.path.join(left_hand_dir, img))) for img in left_hand_images]

        right_hand_avg_prob = sum(right_hand_results) / len(right_hand_results)
        left_hand_avg_prob = sum(left_hand_results) / len(left_hand_results)

        return {"right_hand": right_hand_avg_prob, "left_hand": left_hand_avg_prob}

# 使用例
detector = RheumatoidArthritisDetector(model_checkpoint="/content/drive/MyDrive/OptPhotoFiles/model.pth")
result = detector.detect_rheumatoid_arthritis(
    right_hand_dir="/content/drive/MyDrive/image_righthand",
    left_hand_dir="/content/drive/MyDrive/image_lefthand"
)
print(result)


{'right_hand': 0.6743829846382141, 'left_hand': 0.643915057182312}
