In [6]:
import torch
import matplotlib.pyplot as plt
import timm

In [1]:
import cv2


In [13]:

print('load_midas')
PLOT = True

class Midas:
    def __init__(self):
        midas, transform, device = self.load_model_midas()
        self.midas = midas
        self.transform = transform
        self.device = device

    def load_model_midas(self):
        print('load_midas')
        # model_t
        model_type = "DPT_BEiT_L_512" # MiDaS v3.1 - Large (For highest quality - 3.2023)
        # model_type = "DPT_Large"     # MiDaS v3 - Large     (highest accuracy, slowest inference speed)
        # model_type = "DPT_Hybrid"   # MiDaS v3 - Hybrid    (medium accuracy, medium inference speed)
        # model_type = "MiDaS_small"  # MiDaS v2.1 - Small   (lowest accuracy, highest inference speed)

        midas = torch.hub.load("intel-isl/MiDaS", model_type)
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        midas.to(device)
        midas.eval()
        midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")

        if model_type == "DPT_Large" or model_type == "DPT_Hybrid":
            transform = midas_transforms.dpt_transform
        elif model_type == "DPT_BEiT_L_512":
            transform = midas_transforms.beit512_transform
        else:
            transform = midas_transforms.small_transform
        return midas, transform, device

    def predict(self, img):
        global IMG_ITERATOR
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        # img = img[:,:,:]
        #img = img[:, :, ::-1]
        input_batch = self.transform(img).to(self.device)
        with torch.no_grad():
            prediction = self.midas(input_batch)

            prediction = torch.nn.functional.interpolate(
                prediction.unsqueeze(1),
                size=img.shape[:2],
                mode="bicubic",
                align_corners=False,
            ).squeeze()
        if PLOT:
            plt.figure()
            plt.imshow(img)
            plt.axis('off')
            plt.show()

            plt.figure()
            plt.imshow(prediction.cpu().numpy())
            plt.axis('off')
            plt.axhline(10,0,1,color='black')
            plt.axhline(230,0,1,color='black')
            plt.axvline(22, 0, 1, color='black')
            plt.axvline(112, 0, 1, color='black')
            plt.axvline(272, 0, 1, color='black')
            plt.axvline(362, 0, 1, color='black')
            plt.show()
        return prediction.cpu().numpy()
midas = Midas()

load_midas
load_midas


Using cache found in /home/olek/.cache/torch/hub/intel-isl_MiDaS_master
Using cache found in /home/olek/.cache/torch/hub/intel-isl_MiDaS_master


In [14]:
import numpy as np
class MidasInterpreter2:
    MIN_SAFE_DISTANCE_MIN = 7250   #20 #7000
    MIN_SAFE_DISTANCE =     7500  #25 #7500
    MIN_SAFE_DISTANCE_MEAN = 7000 # 21  #6000
    RESOLUTION = (384, 384)
    GROUP_SIZE = 10
    MEAN_PIXEL_COUNT_RATIO = 0.1
    MEAN_PIXEL_COUNT = int(RESOLUTION[0] * 0.35 * RESOLUTION[1] * 0.2 * MEAN_PIXEL_COUNT_RATIO)
    Y_BOX_POSITION = (10, 230)#330) # split into 10 - 320 - 54
    X_BOX_POSITION = (22, 112, 272, 362) # split into 22 - 90 - 160 - 90 - 22

    def __init__(self):
        self.free_boxes = np.array([False, False, False])

    def find_obstacles(self,depth_image):
        self.free_boxes = np.array([False, False, False])
        left_part = depth_image[self.Y_BOX_POSITION[0]:self.Y_BOX_POSITION[1], self.X_BOX_POSITION[0]:self.X_BOX_POSITION[1]]
        mid_part = depth_image[self.Y_BOX_POSITION[0]:self.Y_BOX_POSITION[1], self.X_BOX_POSITION[1]:self.X_BOX_POSITION[2]]
        right_part = depth_image[self.Y_BOX_POSITION[0]:self.Y_BOX_POSITION[1], self.X_BOX_POSITION[2]:self.X_BOX_POSITION[3]]

        left_depth, left_count = self.look_for_grouping(left_part)
        mid_depth, mid_count = self.look_for_grouping(mid_part)
        right_depth, right_count = self.look_for_grouping(right_part)

        left_average =  self.mean_biggest_values(left_part)
        mid_average =  self.mean_biggest_values(mid_part)
        right_average =  self.mean_biggest_values(right_part)

        left_free = left_depth < self.MIN_SAFE_DISTANCE and left_count < 10 and left_average < self.MIN_SAFE_DISTANCE_MEAN
        mid_free = mid_depth < self.MIN_SAFE_DISTANCE and mid_count < 10 and mid_average < self.MIN_SAFE_DISTANCE_MEAN
        right_free = right_depth < self.MIN_SAFE_DISTANCE and right_count < 10 and right_average < self.MIN_SAFE_DISTANCE_MEAN

        print(f"Depth prediction:\n"
              f"Mean Distances: left={left_average} middle={mid_average} right{right_average}\n"
              f"Max Distances: left={left_depth} middle={mid_depth} right{right_depth}\n"
              f"Distances Count: left={left_count} middle={mid_count} right{right_count}\n"
              f"is_free: left={left_free} middle={mid_free} right{right_free}")

        self.free_boxes = np.array([left_free, mid_free, right_free])
        return self.free_boxes

    @staticmethod
    def look_for_grouping(array):
        best_mean = 0.0
        count = 0
        for x in range(0,array.shape[0],MidasInterpreter2.GROUP_SIZE):
            for y in range(0,array.shape[1],MidasInterpreter2.GROUP_SIZE):
                grid = array[x:x+MidasInterpreter2.GROUP_SIZE, y:y+MidasInterpreter2.GROUP_SIZE]
                mean = grid.mean()
                if mean > MidasInterpreter2.MIN_SAFE_DISTANCE_MIN:
                    count += 1
                if mean > best_mean:
                    best_mean = mean
        return best_mean, count

    @staticmethod
    def mean_biggest_values(array):
        pixel_count = int(array.shape[0]* array.shape[1] * 0.1)
        array = array.flatten()
        ind = np.argpartition(array, - pixel_count)[pixel_count:]
        return np.average(array[ind])
midas_interpreter = MidasInterpreter2()

In [16]:
import os
print('get images')
imgs_path = '../../Integration/img/frame/'

for img in os.listdir(imgs_path)[6:]:
    frame = cv2.imread(str(imgs_path+img))
    depth_frame = midas.predict(frame)
    midas_interpreter.find_obstacles()



get images


KeyboardInterrupt: 