In [1]:
from PIL import Image
import torch
import os 
import matplotlib.pyplot as plt
import numpy as np
from transformers import OwlViTProcessor, OwlViTForObjectDetection
from PIL import ImageDraw
from PIL import ImageFont

processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")



In [2]:
stand_texts = [["a hill", "a mountain", "a peak", "a pile of gravel", "a pyramid", "a triangle of dirt", "a mound of dirt", "a cone", "a mound of trash", "a pile of rocks"]]

t_small = [ text.split(" ")[0] + " small " + " ".join(text.split(" ")) for text in stand_texts[0]]
t_large = [ text.split(" ")[0] + " large " + " ".join(text.split(" ")) for text in stand_texts[0]]

texts = [t_small]#[stand_texts[0] + t_small + t_large]


def cutout(image):
    image = np.array(image)
    cropped = image[0:1000, 600:1400]
    return cropped

def analyse_frame(frame, last_height = (False, 0)):
    global texts
    image = Image.fromarray(cutout(frame))
    inputs = processor(text=texts, images=image, return_tensors="pt")

    outputs = model(**inputs)

    target_sizes = torch.Tensor([image.size[::-1]])

    results = processor.post_process(outputs=outputs, target_sizes=target_sizes)

    i = 0  
    text = texts[i]
    boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]

    # sort the boxes and associated scores in descending order
    scores, idxs = scores.sort(descending=True)
    scores  = scores.detach().numpy()
    boxes   = boxes[idxs].detach().numpy()
    labels  = labels[idxs].detach().numpy()
    draw    = ImageDraw.Draw(image)

    # scores  = scores[:5]
    # boxes   = boxes[:5]

    # # average over all boxes
    # print(scores)
    # box = np.average(boxes, axis=0)#, weights = scores)



    box = boxes[0]

    height = box[1]
    if last_height[0]:
        height = (height + last_height[1]) / 2
        box[1] = height
    # height = box[1]
    # for box, score, label in zip(boxes, scores, labels):
    if height < 50:
        txt = "Overfull!!"
        color = "red"
    elif height < 200:
        txt = "Filling up!"
        color = "orange"
    else:
        txt = "All good :)"
        color = "green"

    #     print(texts[i][label.numpy()])
    draw.rectangle(box, outline= color, width=10)

    # draw text on the image
    draw.text((300, 700), text= txt , fill=color, font=ImageFont.truetype("arial", 75))

    return image, height

# for img in os.listdir("test_imgs"):

#     image = Image.open("test_imgs/" + img)
#     print(image.size)
#     plt.imshow(image)
#     # break

#     # texts = [["a pile", "pile"]]
#     image = analyse_frame(image)
    
    
#     # draw bounding box on image

#     plt.imshow(image)
#     # break
#     plt.show()
#     break

In [3]:
import cv2

def show():
    cap = cv2.VideoCapture('videos/trashpile_combined.mp4')

    # get the frame width and height
    width = int(800)
    height = int(1000)

    # create VideoWriter object
    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    out = cv2.VideoWriter('output_oneshot.avi', fourcc, 3, (width, height))

    # get the frames per second
    fps = cap.get(cv2.CAP_PROP_FPS)

    # get the frame width and height
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    predictions = []
    i = 0
    total_i = 0
    last_height = (False, 0)
    while cap.isOpened():
        ret, frame = cap.read()
        if ret:
            i += 1
            if i % 20 != 0:
                continue
            # start a thread to analyze the frame
            # frame = cutout(frame)
            frame, height = analyse_frame(frame, last_height)
            last_height = (True, height)
            # convert from PIL image to OpenCV image
            frame = np.array(frame)
            frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)

            # show the frame


            # cv2.imshow('frame', frame)

            print("actual", frame.shape)
            # write to video
            out.write(frame)
            # press q to quit
            if cv2.waitKey(1) & 0xFF == ord('q'):
                print("broke")
                break
        else:
            print("no frame")
            break

    # release the VideoCapture object
    print("done")
    cap.release()
    cv2.destroyAllWindows()


show()



actual (1000, 800, 3)
actual (1000, 800, 3)
actual (1000, 800, 3)
actual (1000, 800, 3)
actual (1000, 800, 3)
actual (1000, 800, 3)
actual (1000, 800, 3)
actual (1000, 800, 3)
actual (1000, 800, 3)
actual (1000, 800, 3)
actual (1000, 800, 3)
actual (1000, 800, 3)
actual (1000, 800, 3)
actual (1000, 800, 3)
actual (1000, 800, 3)
actual (1000, 800, 3)
actual (1000, 800, 3)
actual (1000, 800, 3)
actual (1000, 800, 3)
actual (1000, 800, 3)
actual (1000, 800, 3)
actual (1000, 800, 3)
no frame
done
