In [1]:
!git clone https://github.com/Karmast/MedShieldAI.git

Cloning into 'MedShieldAI'...
remote: Enumerating objects: 40, done.[K
remote: Counting objects: 100% (40/40), done.[K
remote: Compressing objects: 100% (36/36), done.[K
remote: Total 40 (delta 9), reused 10 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (40/40), 5.27 MiB | 19.69 MiB/s, done.
Resolving deltas: 100% (9/9), done.


In [2]:
!pip install ultralytics
!pip install torch
!pip install torchvision
!pip install torchaudio
!pip install easyocr
!pip install transformers
!pip install sklearn
!pip install gradio

Collecting ultralytics
  Downloading ultralytics-8.3.70-py3-none-any.whl.metadata (35 kB)
Collecting ultralytics-thop>=2.0.0 (from ultralytics)
  Downloading ultralytics_thop-2.0.14-py3-none-any.whl.metadata (9.4 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.8.0->ultralytics)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.8.0->ultralytics)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.8.0->ultralytics)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.8.0->ultralytics)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.8.0->ultralytics)
  Downloading nv

In [3]:
%cd /content/MedShieldAI

/content/MedShieldAI


In [4]:
import cv2
import numpy as np
from ultralytics import YOLO
import easyocr
from transformers import BertTokenizer, BertModel
from sklearn.metrics.pairwise import cosine_similarity
import gradio as gr
import torch
import re

class StampDetector:
    def __init__(self, model_path):
        self.model = YOLO(model_path)
        self.reader = easyocr.Reader(['fa'])  # Persian language
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")
        self.model_bert = BertModel.from_pretrained("bert-base-multilingual-cased")

    def detect_and_process_stamps(self, image):
        """Detect stamps and extract text from the image."""
        results_dict = {}  # Dictionary to store results
        image = np.array(image)

        # Run YOLO inference
        results = self.model(image)
        image_results = []  # List to store results for the current image
        stamps_boxes = []  # List to store bounding boxes of detected stamps

        for i, result in enumerate(results):
            for box in result.boxes.xyxy:  # Extract bounding box coordinates
                x_min, y_min, x_max, y_max = map(int, box)  # Convert to integers

                # Add margin to bounding box for better cropping
                margin = 10
                x_min = max(0, x_min - margin)
                y_min = max(0, y_min - margin)
                x_max = min(image.shape[1], x_max + margin)
                y_max = min(image.shape[0], y_max + margin)

                # Store the bounding boxes for further use
                stamps_boxes.append((x_min, y_min, x_max, y_max))

                # Crop the detected stamp
                cropped = image[y_min:y_max, x_min:x_max]

                # Use EasyOCR to extract text from the cropped stamp
                text_easyocr = self.reader.readtext(cropped, detail=0)
                text_easyocr_combined = " ".join(text_easyocr)  # Combine text without newlines

                # Add the extracted text to the image results
                image_results.append({
                    'stamp_id': i,
                    'text': text_easyocr_combined
                })

        # Extract text excluding stamp regions
        prescription_text = self.extract_prescription_text(image, stamps_boxes)
        image_results.append({
            'prescription_text': prescription_text
        })

        # Compare text from stamp with prescription text
        for result in image_results:
            if 'stamp_id' in result:
                stamp_text = result['text']
                similarity_percentage, confidence = self.compare_texts(stamp_text, prescription_text)
                result['similarity_percentage'] = similarity_percentage
                result['confidence'] = confidence

        return image_results

    def extract_prescription_text(self, image, stamps_boxes, max_length=500):
        """Extract text from the prescription excluding the stamp regions."""
        mask = np.ones(image.shape[:2], dtype=np.uint8) * 255  # White mask

        # Set stamp areas to black (i.e., exclude them from OCR)
        for (x_min, y_min, x_max, y_max) in stamps_boxes:
            mask[y_min:y_max, x_min:x_max] = 0  # Set stamp areas to 0 (black)

        masked_image = cv2.bitwise_and(image, image, mask=mask)

        # Use EasyOCR to extract text from the remaining part of the image
        text_easyocr = self.reader.readtext(masked_image, detail=0)
        text_easyocr_combined = " ".join(text_easyocr)

        # Optional: Truncate the text if it's too long
        if len(text_easyocr_combined) > max_length:
            text_easyocr_combined = text_easyocr_combined[:max_length] + "..."

        cleaned_text = re.sub(r"[^\w\s]", "", text_easyocr_combined)  # Removing special characters

        return cleaned_text

    def compare_texts(self, stamp_text, prescription_text):
        """Compare the text from stamp and prescription and return similarity percentage and confidence."""
        inputs_stamp = self.tokenizer(stamp_text, return_tensors="pt", padding=True, truncation=True)
        inputs_prescription = self.tokenizer(prescription_text, return_tensors="pt", padding=True, truncation=True)

        with torch.no_grad():
            embeddings_stamp = self.model_bert(**inputs_stamp).last_hidden_state.mean(dim=1)
            embeddings_prescription = self.model_bert(**inputs_prescription).last_hidden_state.mean(dim=1)

        similarity = cosine_similarity(embeddings_stamp.cpu().numpy(), embeddings_prescription.cpu().numpy())

        similarity_percentage = f"{similarity[0][0] * 100:.2f}%"

        confidence = "Similar" if similarity[0][0] > 0.4 else "Not Similar"

        return similarity_percentage, confidence


# Gradio interface function to handle user input
def process_image(image):
    # Initialize the detector (adjust model path as needed)
    model_path = "./model/best.pt"  # Path to your YOLO model
    detector = StampDetector(model_path)

    # Run the detector
    results = detector.detect_and_process_stamps(image)

    return results


# Set up Gradio interface
gr.Interface(fn=process_image, inputs=gr.Image(), outputs="json").launch(share=True, debug=True)


Creating new Ultralytics Settings v0.0.6 file ✅ 
View Ultralytics Settings with 'yolo settings' or at '/root/.config/Ultralytics/settings.json'
Update Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings.
Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://2ed117e2f17b4814cf.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




Progress: |██████████████████████████████████████████████████| 100.0% Complete



Progress: |██████████████████████████████████████████████████| 100.0% Complete

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/996k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.96M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/714M [00:00<?, ?B/s]


0: 480x640 1 stamp, 72.1ms
Speed: 37.3ms preprocess, 72.1ms inference, 619.9ms postprocess per image at shape (1, 3, 480, 640)
Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://2ed117e2f17b4814cf.gradio.live


