In [None]:
import numpy as np
import cv2
import torch
from enum import Enum
from PIL import Image
import torchvision.models as models
import torch.nn as nn
import matplotlib.pyplot as plt

def draw_grid(image):
    HEIGHT, WIDTH, CHANNELS = image.shape
    one_third_height = HEIGHT // 3
    one_third_width = WIDTH // 3
    new_image = np.copy(image)
    cv2.line(new_image, (one_third_width, 0), (one_third_width, HEIGHT), (0, 0, 0), 5)
    cv2.line(new_image, (2 * one_third_width, 0), (2 * one_third_width, HEIGHT), (0, 0, 0), 5)
    cv2.line(new_image, (0, one_third_height), (WIDTH, one_third_height), (0, 0, 0), 5)
    cv2.line(new_image, (0, 2 * one_third_height), (WIDTH, 2 * one_third_height), (0, 0, 0), 5)
    return new_image

def vibration_matrix(mask):
    HEIGHT, WIDTH = mask.shape
    one_third_height = HEIGHT // 3
    one_third_width = WIDTH // 3

    COUNT_THRESH = (one_third_height * one_third_width) / 5
    freq_matrix = np.zeros((3, 3))

    output = np.zeros((HEIGHT, WIDTH, 3))

    for i in range(3):
        for j in range(3):
            startx = one_third_width * i
            endx = startx + one_third_width
            starty = one_third_height * j
            endy = starty + one_third_height

            box = mask[starty:endy, startx:endx]
            freq_matrix[i][j] = np.sum(box == 1)

            clr = (0, 255, 0) if freq_matrix[i][j] < COUNT_THRESH else (0, 0, 255)
            cv2.rectangle(output, (startx, starty), (endx, endy), clr, thickness=cv2.FILLED)

    output = draw_grid(output)
    for row in freq_matrix:
        print(row)

    return output


class ModelType(Enum):
    DPT_LARGE = "DPT_Large"
    DPT_HYBRID = "DPT_Hybrid"
    MIDAS_SMALL = "MiDaS_small"
    
class Midas():
    def __init__(self, modelType: ModelType = ModelType.DPT_LARGE):
        self.midas = torch.hub.load("isl-org/MiDaS", modelType.value)
        self.modelType = modelType
        self.THRESH = 0

    def useCUDA(self):
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")
        self.midas.to(self.device)
        self.midas.eval()

    def transform(self):
        midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
        if self.modelType.value == "DPT_Large":
            self.transform = midas_transforms.dpt_transform
            self.THRESH = 26
        elif self.modelType.value == "DPT_Hybrid":
            self.transform = midas_transforms.dpt_transform
            self.THRESH = 1900
        else:
            self.transform = midas_transforms.small_transform
            self.THRESH = 850

    def predict(self, frame):
        img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        input_batch = self.transform(img).to(self.device)
        with torch.no_grad():
            prediction = self.midas(input_batch)
            prediction = torch.nn.functional.interpolate(
                prediction.unsqueeze(1),
                size=img.shape[:2],
                mode="bicubic",
                align_corners=False,
            ).squeeze()
        depthMap = prediction.cpu().numpy()
        return depthMap

def process_frame(frame, midasObj):
    # Resize the frame to 256x256
    frame = cv2.resize(frame, (256, 256))
    
    # Convert the frame to tensor, permute to match the model's expected input shape
    input_image = torch.tensor(frame).float().permute(2, 0, 1).unsqueeze(0).to(device)
    
    # Segmentation prediction
    predicted_mask = model(input_image).squeeze(0).cpu().detach().numpy()
    predicted_mask = np.uint8((predicted_mask > 0.9) * 255)
    
    # Ensure mask is 2D by taking the first channel if it's 3D
    if len(predicted_mask.shape) == 3:
        mask = predicted_mask[0]  # Taking the first channel of the mask
    else:
        mask = predicted_mask  # If already 2D

    # Depth Prediction using MiDaS model
    depthMap = midasObj.predict(frame)

    # Create a masked depth map only where the segmentation mask is present
    masked_depth_map = np.zeros_like(depthMap)
    
    # Apply the 2D mask to the depth map
    masked_depth_map[mask > 0] = depthMap[mask > 0]
    
    # Threshold based on the depth map's values
    mask = masked_depth_map > midasObj.THRESH

    # Highlight regions based on the depth threshold (red overlay on frame)
    frame[mask] = [255, 0, 0]
    
    # Generate vibration matrix
    output = vibration_matrix(mask.astype(np.uint8))

    return frame, output


def run_on_webcam(modelType: ModelType):
    midasObj = Midas(modelType)
    midasObj.useCUDA()
    midasObj.transform()

    cap = cv2.VideoCapture(0)

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        # Process the frame for segmentation and depth prediction
        processed_frame, vibration_output = process_frame(frame, midasObj)

        # Display the processed frame and the vibration matrix output
        cv2.imshow('Processed Frame', processed_frame)
        cv2.imshow('Vibration Matrix', vibration_output)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()

if __name__ == '__main__':
    run_on_webcam(ModelType.DPT_HYBRID)