In [None]:
!pip install transformers
!pip install albumentations
!pip install torchmetrics

### IMPORTING LIBRARIES

In [None]:
## STANDARD LIBRARY AND PIPELINE MODULES IMPORT


##IMPORTING IMAGE LOADING AND PLOTTING LIBS
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
from scipy.spatial import distance as dist

## IMPORTING TORCH FOR ML-PART
import torch
import torch.nn as nn
from torchvision import datasets, models, transforms
import torch.nn.functional as F

## IMPORT ALBUMENTATION AS DATA AUGMENTATION LIBRARY
import albumentations as A
from albumentations.pytorch import ToTensorV2

## IMPORT UNET STRUCTURE FOR SEMI-SUPERVISED SEGMENTATION
import segmentation_models_pytorch as smp

## IMPORTING KRAFT DETECTOR RELATED MODULES FOR MONITOR-TEXT DETECTION
from craft_text_detector import (
    read_image,
    load_craftnet_model,
    load_refinenet_model,
    get_prediction
)

## IMPORT OCR RELATED MODULES
from transformers import TrOCRProcessor, VisionEncoderDecoderModel


### SOME UTILITIES

In [None]:

DEVICE = "cpu"  ##using cpu for inference 

##input image dimension to be used for segmenation
SEG_IN_IMAGE_HEIGHT = 320
SEG_IN_IMAGE_WIDTH = 640

##input image dimension to be used for classification
CLASS_IN_IMAGE_HEIGHT = 360
CLASS_IN_IMAGE_WIDTH = 640


## LOADING TRANSFORMATION TO BE APPLIED TO IMAGE (AT INFERENCE TIME)

test_transform = A.Compose([
    A.Normalize (mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p = 1.0), 
    ToTensorV2()
])

### LOADING INDIVIDUAL PIPLEINE COMPONENTS

#### 1. MONITOR SEGMENTATION

In [None]:
## loading unet model
seg_model = smp.Unet(
    encoder_name="resnet34",        
    encoder_weights="imagenet",     
    in_channels=3,                  
    classes=1,
)

seg_model = seg_model.to(DEVICE) ## adding to device
seg_model.load_state_dict(torch.load("weights/unet.ckpt")) ## loading model checkpoints (from our semi-supervised training)


## function for predicting mask and upsampling it to original image size
def maskPred(img):
    seg_model.eval()
    (orig_H, orig_W, _) = img.shape
    
    #resizing it to desired segmentation input
    img = cv2.resize(img, (SEG_IN_IMAGE_WIDTH, SEG_IN_IMAGE_HEIGHT))
    ## TRANSFORMING IMAGE
    img = test_transform(image=img)["image"]

    ## RUNNING THROUGH MODEL
    mask = seg_model(img.to(DEVICE).unsqueeze(0))
    mask = mask.cpu().squeeze()
    
    ## RESIZING MASK (USAMPLING IT TO ORIGINAL SIZE)
    mask = cv2.resize(np.uint8(mask>0)*255, (orig_W, orig_H))
    return mask




#### 2. PERSPECTIVE TRANSFORMATION

In [None]:
## cell containing functions for improving perception and orientation of extracted monitor segments

def order_points(pts):
	## function for sorting the corner points
    xSorted = pts[np.argsort(pts[:, 0]), :]
        
    leftMost = xSorted[:2, :]
    rightMost = xSorted[2:, :]
        
    leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
    (tl, bl) = leftMost
        
    D = dist.cdist(tl[np.newaxis], rightMost, "euclidean")[0]
    (br, tr) = rightMost[np.argsort(D)[::-1], :]
        
    return np.array([tl, tr, br, bl], dtype="float32")


## function for correcting the prespective (based on the principle of contour detection and convex hull)
def correctPerspective(data, mask):

    kernel = np.ones((20,20), np.uint8)  
    mask = cv2.erode(mask, kernel, iterations=10)  
    mask = cv2.dilate(mask, kernel, iterations=11)  

    _, mask = cv2.threshold(mask, 70, 255, cv2.THRESH_BINARY)

    mask = mask.astype(np.uint8)

    cnts = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    cnts = cnts[0] if len(cnts) == 2 else cnts[1]
    cnts = sorted(cnts, key=cv2.contourArea, reverse=True)

    for count in cnts:
        cnt = cv2.convexHull(count, False)
        epsilon = 0.01 * cv2.arcLength(cnt, True)
        approximations = cv2.approxPolyDP(cnt, epsilon, True)
        for ep in range(1, 5):
            epsilon = ep*0.01 * cv2.arcLength(cnt, True)
            approximations = cv2.approxPolyDP(cnt, epsilon, True)
            if len(approximations) == 4:
                break
        # img = cv2.drawContours(img, [approximations], 0, (0, 255, 0), 2)
        
        h, w = CLASS_IN_IMAGE_HEIGHT, CLASS_IN_IMAGE_WIDTH
        orig_h, orig_w, _ = data.shape
        
        pt1 = np.float32([approximations[0][0], approximations[1][0], approximations[2][0], approximations[3][0]])
        pt1 = order_points(pt1)
        pt2 = np.float32([[0, 0], [w, 0], [w, h], [0, h]])
        orig_pt = np.float32([[0, 0], [orig_w, 0], [orig_w, orig_h], [0, orig_h]])
        matrix = cv2.getPerspectiveTransform(pt1, pt2)
        orig_matrix = cv2.getPerspectiveTransform(pt1, orig_pt)
        orig_op = cv2.warpPerspective(data, orig_matrix, (orig_w, orig_h))
        shrink_op = cv2.warpPerspective(data, matrix, (w, h))

        return (orig_op, shrink_op)

#### 3. MONITOR CLASSIFICATION

In [None]:
## cell containing resnet-18 for classifying 4 different monitor layouts

class Classifier:
    def __init__(self, num_classes):
        self.num_classes = num_classes
        self.model = None
        self.device = DEVICE
    
    def loadResnet18Classifier(self, chkpt_path):
        self.model = models.resnet18()
        num_features = self.model.fc.in_features
        self.model.fc = nn.Linear(num_features, self.num_classes)
        checkpoint = torch.load(chkpt_path, map_location=self.device)
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.model.to(self.device)
        return

    def predict(self, img):
        self.model.eval()
        img = test_transform(image=img)["image"]
        
        ## RUNNING THROUGH MODEL
        img = img.to(self.device).unsqueeze(0)
        outputs = self.model(img)
        _, preds = torch.max(F.softmax(outputs, dim = 1), 1)

        return preds.cpu().numpy()[0]
    
def classification(img):
    co = Classifier(4)
    co.loadResnet18Classifier("weights/resnet18_weights")
    return co.predict(img)

#### 4. MONITOR TEXT DETECTION

In [None]:
## cell for executing function predicting texts bounding boxes
def get_boxes(image):
  # image = read_image(img_path)
  refine_net = load_refinenet_model(cuda= False)
  craft_net = load_craftnet_model(cuda=False)
  prediction_result = get_prediction(
      image=image,
      craft_net=craft_net,
      refine_net=refine_net,
      text_threshold=0.7,
      link_threshold=0.4,
      low_text=0.4,
      cuda=True,
      long_size=1280
  )
  return prediction_result['boxes']

#### 5. OCR

In [None]:
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-printed')
ocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-printed')

ocr_model = model.to(DEVICE)

def trOCR(img, locs):
    x, y, w, h = locs[0], locs[1], locs[2], locs[3]

    x *= img.shape[1]
    w *= img.shape[1]
    y *= img.shape[0]
    h *= img.shape[0]

    crop = img[int(y-h/2):int(y+h/2), int(x-w/2):int(x+w/2)]

    pixel_values = processor(crop, return_tensors="pt").pixel_values
    pixel_values = torch.tensor(pixel_values).to(DEVICE)
    
    generated_ids = ocr_model.generate(pixel_values)
    generated_text2 = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

    return generated_text2

#### 6. RULE BASED CLASSIFCATION OF VITALS 

### FINAL INFERENCE FUNCTION

In [None]:
TEST_IMG_DIR = "test_data"
imlis = os.listdir(TEST_IMG_DIR)
imloc = f"{TEST_IMG_DIR}/{imlis[9]}"

def inference(image_path):

  '''
  Function responsible for inference.
  Args: 
    image_path: str, path to image file. eg. "input/aveksha_micu_mon--209_2023_1_17_12_0_34.jpeg"
  Returns:
    result: dict, final output dictionary. eg. {"HR":"80", "SPO2":"98", "RR":"15", "SBP":"126", "DBP":"86"}
  '''
  result = {}

  ### put your code here
  
  ## EVALUATION ON EACH IMAGE
  image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
  mask = maskPred(image)
  # print(mask)
  (detector_io, class_io) =  correctPerspective(image, mask)
  label = classification(class_io) 

  bounding_boxes = get_boxes(detector_io)

  # print(bounding_boxes)

  # fig, ax = plt.subplots(len(bounding_boxes) + 1)

  # for idx,bounding_box in enumerate(bounding_boxes):
  #     text, crop = OCR.trOCR(detector_io, bounding_box)

  #     print(text)
  #     ax[idx].imshow(crop)
  # ax[-1].imshow(monitor)
    
  return result