In [None]:
!pip install albumentations
!pip install torchmetrics
!pip install paddleocr>=2.0.1 paddlepaddle
!pip install opencv-python-headless==4.5.3.56
!pip install segmentation_models_pytorch

### IMPORTING LIBRARIES

In [36]:
## STANDARD LIBRARY AND PIPELINE MODULES IMPORT


##IMPORTING IMAGE LOADING AND PLOTTING LIBS
import os
import re
import numpy as np
import cv2
import matplotlib.pyplot as plt
from scipy.spatial import distance as dist

## IMPORTING TORCH FOR ML-PART
import torch
import torch.nn as nn
from torchvision import datasets, models, transforms
import torch.nn.functional as F

## IMPORT ALBUMENTATION AS DATA AUGMENTATION LIBRARY
import albumentations as A
from albumentations.pytorch import ToTensorV2

## IMPORT UNET STRUCTURE FOR SEMI-SUPERVISED SEGMENTATION
import segmentation_models_pytorch as smp

## IMPORTING PADDLEOCR RELATED MODULES FOR MONITOR-TEXT DETECTION AND OCR
from paddleocr import PaddleOCR

## stop logging
import logging
logger = logging.getLogger('paddle')
logger.disabled = True


### SOME UTILITIES

In [2]:

DEVICE = "cpu"  ##using cpu for inference 

##input image dimension to be used for segmenation
SEG_IN_IMAGE_HEIGHT = 320
SEG_IN_IMAGE_WIDTH = 640

##input image dimension to be used for classification
CLASS_IN_IMAGE_HEIGHT = 360
CLASS_IN_IMAGE_WIDTH = 640

##input image dimension to be used for most ocr
OCR_IN_IMAGE_HEIGHT = 180
OCR_IN_IMAGE_WIDTH = 320


## LOADING TRANSFORMATION TO BE APPLIED TO IMAGE (AT INFERENCE TIME)

test_transform = A.Compose([
    A.Normalize (mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p = 1.0), 
    ToTensorV2()
])

### LOADING INDIVIDUAL PIPLEINE COMPONENTS

#### 1. MONITOR SEGMENTATION

In [4]:
## loading unet model
seg_model = smp.Unet(
    encoder_name="resnet34",        
    encoder_weights="imagenet",     
    in_channels=3,                  
    classes=1,
)

seg_model = seg_model.to(DEVICE) ## adding to device
seg_model.load_state_dict(torch.load("weights/unet.ckpt", map_location=DEVICE)) ## loading model checkpoints (from our semi-supervised training)


## function for predicting mask and upsampling it to original image size
def maskPred(img):
    seg_model.eval()
    (orig_H, orig_W, _) = img.shape
    
    #resizing it to desired segmentation input
    img = cv2.resize(img, (SEG_IN_IMAGE_WIDTH, SEG_IN_IMAGE_HEIGHT))
    ## TRANSFORMING IMAGE
    img = test_transform(image=img)["image"]

    ## RUNNING THROUGH MODEL
    mask = seg_model(img.to(DEVICE).unsqueeze(0))
    mask = mask.cpu().squeeze()
    
    ## RESIZING MASK (USAMPLING IT TO ORIGINAL SIZE)
    mask = cv2.resize(np.uint8(mask>0)*255, (orig_W, orig_H))
    return mask




#### 2. PERSPECTIVE TRANSFORMATION

In [9]:
def order_points(pts):
	
    xSorted = pts[np.argsort(pts[:, 0]), :]
        
    leftMost = xSorted[:2, :]
    rightMost = xSorted[2:, :]
        
    leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
    (tl, bl) = leftMost
        
    D = dist.cdist(tl[np.newaxis], rightMost, "euclidean")[0]
    (br, tr) = rightMost[np.argsort(D)[::-1], :]
        
    return np.array([tl, tr, br, bl], dtype="float32")


def correctPerspective(data, mask):

    kernel = np.ones((20,20), np.uint8)  
    mask = cv2.erode(mask, kernel, iterations=10)  
    mask = cv2.dilate(mask, kernel, iterations=11)  

    _, mask = cv2.threshold(mask, 70, 255, cv2.THRESH_BINARY)

    mask = mask.astype(np.uint8)

    cnts = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    cnts = cnts[0] if len(cnts) == 2 else cnts[1]
    cnts = sorted(cnts, key=cv2.contourArea, reverse=True)

    for count in cnts:
        cnt = cv2.convexHull(count, False)
        epsilon = 0.01 * cv2.arcLength(cnt, True)
        approximations = cv2.approxPolyDP(cnt, epsilon, True)
        for ep in range(1, 5):
            epsilon = ep*0.01 * cv2.arcLength(cnt, True)
            approximations = cv2.approxPolyDP(cnt, epsilon, True)
            if len(approximations) == 4:
                break
        # img = cv2.drawContours(img, [approximations], 0, (0, 255, 0), 2)
        
        h, w = CLASS_IN_IMAGE_HEIGHT, CLASS_IN_IMAGE_WIDTH
        orig_h, orig_w = 720,1280
        h_, w_ = OCR_IN_IMAGE_HEIGHT, OCR_IN_IMAGE_WIDTH
        pt1 = np.float32([approximations[0][0], approximations[1][0], approximations[2][0], approximations[3][0]])
        pt1 = order_points(pt1)
        pt2 = np.float32([[0, 0], [w, 0], [w, h], [0, h]])
        orig_pt = np.float32([[0, 0], [orig_w, 0], [orig_w, orig_h], [0, orig_h]])
        es_pt = np.float32([[0, 0], [w_, 0], [w_, h_], [0, h_]])
        matrix = cv2.getPerspectiveTransform(pt1, pt2)
        orig_matrix = cv2.getPerspectiveTransform(pt1, orig_pt)
        es_matrix = cv2.getPerspectiveTransform(pt1, es_pt)
        orig_op = cv2.warpPerspective(data, orig_matrix, (orig_w, orig_h))
        shrink_op = cv2.warpPerspective(data, matrix, (w, h))
        es_op = cv2.warpPerspective(data, es_matrix, (w_, h_))

        return (es_op, shrink_op, orig_op)


#### 3. MONITOR CLASSIFICATION

In [41]:
class Classifier:
    def __init__(self, num_classes):
        self.num_classes = num_classes
        self.model = None
    
    def loadResnet18Classifier(self, chkpt_path):
        self.model = models.resnet18()
        num_features = self.model.fc.in_features
        self.model.fc = nn.Linear(num_features, self.num_classes)
        checkpoint = torch.load(chkpt_path, map_location=DEVICE)
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.model.to(DEVICE)
        self.model.eval()
        return

    def predict(self, img):
        img = test_transform(image=img)["image"]
        
        ## RUNNING THROUGH MODEL
        img = img.to(DEVICE).unsqueeze(0)
        outputs = self.model(img)
        _, preds = torch.max(F.softmax(outputs, dim = 1), 1)

        return preds.cpu().numpy()[0]

co = Classifier(4)
co.loadResnet18Classifier("weights/resnet18_weights")    

def classification(img):
    return co.predict(img)
    

#### 4. MONITOR TEXT DETECTION AND OCR

In [33]:
## cell for executing function predicting texts bounding boxes using paddleocr
ocr_m = PaddleOCR(lang='en', use_gpu=(DEVICE=="gpu"), det_db_box_thresh=0.6, drop_score = 0.4) # need to run only once to download and load model into memory

def get_text(crop_img, det=False):
    return ocr_m.ocr(crop_img, det=det)


[2023/02/06 00:48:59] ppocr DEBUG: Namespace(alpha=1.0, benchmark=False, beta=1.0, cls_batch_num=6, cls_image_shape='3, 48, 192', cls_model_dir='/home/drobot/.paddleocr/whl/cls/ch_ppocr_mobile_v2.0_cls_infer', cls_thresh=0.9, cpu_threads=10, crop_res_save_dir='./output', det=True, det_algorithm='DB', det_box_type='quad', det_db_box_thresh=0.6, det_db_score_mode='fast', det_db_thresh=0.3, det_db_unclip_ratio=1.5, det_east_cover_thresh=0.1, det_east_nms_thresh=0.2, det_east_score_thresh=0.8, det_limit_side_len=960, det_limit_type='max', det_model_dir='/home/drobot/.paddleocr/whl/det/en/en_PP-OCRv3_det_infer', det_pse_box_thresh=0.85, det_pse_min_area=16, det_pse_scale=1, det_pse_thresh=0, det_sast_nms_thresh=0.2, det_sast_score_thresh=0.5, draw_img_save_dir='./inference_results', drop_score=0.4, e2e_algorithm='PGNet', e2e_char_dict_path='./ppocr/utils/ic15_dict.txt', e2e_limit_side_len=768, e2e_limit_type='max', e2e_model_dir=None, e2e_pgnet_mode='fast', e2e_pgnet_score_thresh=0.5, e2e_p

#### 6. RULE BASED CLASSIFCATION OF VITALS 

In [8]:
class Rules:   

    def __init__(self):
        pass

    def _txt_len(self, text):
        ##avoiding string containing long texts as it will be generally something else
        if len(text)>7:
            return False
        return True

    def _is_alpha(self, text):
        ##avoiding string containing alphabets
        if re.search('\d+', text)==None:
            return False
        return True
    def _pre_check(self, text):
        ## pre checking for valid characters
        valid = True
        valid = (self._txt_len(text) and valid)
        valid = (self._is_alpha(text) and valid)
        return valid

    def BPRule(self, text):
        sbp = 0
        dbp = 0
        pos = text.find("/")
        # print("POS - ", pos)
        if pos!=-1:
            sbp = text[max(pos-3, 0):pos]
            dbp = text[pos+1:pos+4]
            return True, [sbp, dbp]
        else:
            return False, None

    def check_green(self, cropped):
        normalizedImg = np.zeros((cropped.shape[0], cropped.shape[1]))
        cropped = cv2.normalize(cropped,  normalizedImg, 0, 255, cv2.NORM_MINMAX)
        hsv = cv2.cvtColor(cropped, cv2.COLOR_RGB2HSV)
        mask = cv2.inRange(hsv, (40, 40, 50), (80, 255,255))
        imask = mask>0
        area = cropped.shape[0]*cropped.shape[1]
        thres = 0.16
#         print(imask.sum()/area)
        if (imask.sum()/area) > thres and area > self.min_area:    
            return True
        return False
    
    def check_map(self, text):
        if re.search('\(.*\d+\)', text)==None:
            return False
        return True
    
    def check_yellow(self, cropped):
        normalizedImg = np.zeros((cropped.shape[0], cropped.shape[1]))
        cropped = cv2.normalize(cropped,  normalizedImg, 0, 255, cv2.NORM_MINMAX)
        hsv = cv2.cvtColor(cropped, cv2.COLOR_RGB2HSV)
        mask = cv2.inRange(hsv, (20, 50, 50), (40, 255, 255))
        imask = mask>0
        area = cropped.shape[0]*cropped.shape[1]
#         print(imask.sum()/area)
        thres = 0.15
        if (imask.sum()/area) > thres and area > self.min_area:    
            return True
        return False
    def check_cyan(self, cropped):
        normalizedImg = np.zeros((cropped.shape[0], cropped.shape[1]))
        cropped = cv2.normalize(cropped,  normalizedImg, 0, 255, cv2.NORM_MINMAX)
        hsv = cv2.cvtColor(cropped, cv2.COLOR_RGB2HSV)
        mask = cv2.inRange(hsv, (80, 50, 60), (100, 255,255))
        imask = mask>0
        area = cropped.shape[0]*cropped.shape[1]
        thres = 0.15        
        if (imask.sum()/area) > thres and area > self.min_area:    
            return True
        return False
    
    def class_pred(self, text, cropped, label):
        self.min_area = 300
        if label == 1:
            self.min_area = 700
        valid = self._pre_check(text)
        
        if valid:
            hr = self.check_green(cropped)
            map_check = self.check_map(text)
            yellow = self.check_yellow(cropped)
            cyan = self.check_cyan(cropped)
            bp = self.BPRule(text)
            if bp[0]:
                return "BP", [("SBP", bp[1][0]), ("DBP", bp[1][1])]
            elif hr:
                return "HR", [text]
            # elif map_check:
            #     return "MAP", int(text[1:len(text)-1])
            elif yellow:
                return "yellow", (text)
            elif cyan:
                return 'cyan', (text)
            else:
                return ""     
        else:
            return ""

def find_nearby(slash_box, boxes):
    X, Y, W, H, _ = slash_box
    ans_index = -1
    min_del_y = 1e5
    for i in range(len(boxes)):
        x,y,w,h, _ = boxes[i]
        if x < X:
            if min_del_y > abs(y-Y):
                min_del_y = abs(y-Y)
                ans_index = i
    return ans_index

ru = Rules()

### FINAL INFERENCE FUNCTION

In [34]:
def inference(image_path):

    '''
    Function responsible for inference.
    Args: 
      image_path: str, path to image file. eg. "input/aveksha_micu_mon--209_2023_1_17_12_0_34.jpeg"
    Returns:
      result: dict, final output dictionary. eg. {"HR":"80", "SPO2":"98", "RR":"15", "SBP":"126", "DBP":"86"}
    '''
    result = {}

    ### put your code here
    
    ## EVALUATION ON EACH IMAGE

    image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
    ocr_io = cv2.resize(image, (320, 180))
    class_io = cv2.resize(image, (640, 360))


    # mask = maskPred(image)
    # # print(mask)
    # (ocr_io, class_io, orig_io) =  correctPerspective(image, mask)
    label = classification(class_io) 

    #label = 2
    if label==1:
        candidates =  get_text(class_io, det=True)
        ocr_io = class_io
    elif label == 3:
        kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]])
        ocr_io = cv2.filter2D(ocr_io, -1, kernel)
        
        candidates =  get_text(ocr_io[:, :115], det=True)
    else:
        # print("here")
        candidates =  get_text(ocr_io, det=True)
    can = candidates[0]
    # plt.imshow(ocr_io[:, :115])
    # plt.show()
    hr_boxes = []

    boxes = []
    yellow_boxes = []
    slash_box = []
    for line in can:
        bbox = line[0]
        txt = line[1][0]
        confidence = line[1][1]
        
        x1, y1, x2, y2 = int(bbox[0][0]),int(bbox[0][1]) , int(bbox[2][0]),int(bbox[2][1])
        cropped = ocr_io[y1: y2, x1:x2]
        # print(x1, y1,x2,y2)
        # print(txt)
        # plt.imshow(cropped)
        # plt.show()
        pred = ru.class_pred(txt, cropped, label)
        if len(re.findall("^/\d+", txt.strip()))>0:
            slash_box = [(x1+x2)/2, (y1+y2)/2, x2-x1,y2-y1, txt]
        else:
            boxes.append([(x1+x2)/2, (y1+y2)/2, x2-x1,y2-y1, txt ])
        if len(pred)>0:
            clas = pred[0]
            if(clas=="BP"):
                result[pred[1][0][0]] = pred[1][0][1]
                result[pred[1][1][0]] = pred[1][1][1]
            elif(clas == 'HR'):
                hr_boxes.append([cropped.shape[0]*cropped.shape[1], [x1,y1,x2,y2], pred])
            elif(clas == 'cyan'):
                if label == 1:
                    result['RR'] = pred[1]
                elif label==0 or label == 2: 
                    result['SPO2'] = pred[1]
            elif(clas=='yellow'):
                if label == 1:
                    yellow_boxes.append([(y1+y2)/2, cropped.shape[0]*cropped.shape[1], [x1,y1,x2,y2], txt])
                elif label == 0 or label == 2:
                    result['RR'] = txt
                elif label == 3:
                    result['SPO2'] = txt
                    
    #         print(txt ,"-", len(pred))
    yellow_boxes = sorted(yellow_boxes, key=lambda x: x[0], reverse=True)
    hr_boxes = sorted(hr_boxes, key=lambda x: x[0], reverse=True)

    if label == 1 and len(yellow_boxes) > 0:
        result['SPO2'] = yellow_boxes[0][3]

    if len(hr_boxes)>0:
        x1,y1,x2,y2 = hr_boxes[0][1]
        hr = (hr_boxes[0][2][1][0])
        a = re.findall('^\d+', hr)
        hr = int(a[0])
        if int(a[0][:3]) > 220:
            hr = int(a[0][:2])
        elif int(a[0]) > 300:
            hr = int(a[0][:3])
        result['HR'] = hr
    #         plt.imshow(ocr_io[y1: y2, x1:x2])
    #         plt.show()

    if len(slash_box) > 0:
        index = find_nearby(slash_box, boxes)
        x, y, w, h, pred = boxes[index]
        
        result['SBP'] = pred.strip()
        result['DBP'] = slash_box[-1][1:]

    boxes = sorted(boxes, key= lambda x: x[1], reverse=True)
    if label == 3:
        for (x,y,w,h,pred) in boxes:
            if w*h > 200:
                if pred[1] == 'T':
                    pred = pred[0] + '7'
                result['RR'] = pred
                break



    # print(result)
        
    return result

In [38]:
%%time
img_classes = [ 'BPL-EliteView-EV10-B_Meditec-England-A', 'BPL-EliteView-EV100-C' , 'BPL-Ultima-PrimeD-A' , 'Nihon-Kohden-lifescope-A']
TEST_IMG_DIR = f"val/{img_classes[0]}"
imlis = os.listdir(TEST_IMG_DIR)
imloc = f"{TEST_IMG_DIR}/{imlis[5]}"

inference(imloc)

[2023/02/06 00:58:55] ppocr DEBUG: dt_boxes num : 12, elapse : 0.04934191703796387
[2023/02/06 00:58:56] ppocr DEBUG: rec_res num  : 12, elapse : 0.5578830242156982
CPU times: user 1.26 s, sys: 44.7 ms, total: 1.31 s
Wall time: 699 ms


{'SBP': '115', 'DBP': '70', 'SPO2': '95', 'RR': '17', 'HR': 87}