In [1]:
import cv2
from PIL import Image
import clip
import torch
import numpy as np
import math
import pandas
import plotly.express as px
import datetime
from IPython.core.display import HTML
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
microwave = './examples/micro.mp4'
kettle = './examples/kettle.mp4'
light = './examples/light_switch.mp4'
slide_cabinet = './examples/slide_cabinet.mp4'
hinge_cabinet = './examples/hinge_cabinet.mp4'
bottom_burner = './examples/bottom_burner.mp4'
top_burner = './examples/top_burner.mp4'

In [3]:
# all_tasks = [microwave, kettle, light, slide_cabinet, hinge_cabinet, bottom_burner, top_burner]
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

In [6]:
def ask_task(task_name, prompt, plot=True, show_frames=False, friend_frames=1):
    frames = []
    video_cv2 = cv2.VideoCapture(task_name)
    frame_cv2 = video_cv2.get(cv2.CAP_PROP_FPS)
    current_frame = 0
    while video_cv2.isOpened():
        ret, frame = video_cv2.read()
        if ret == True:
            frames.append(Image.fromarray(frame[:, :, ::-1]))
        else:
            break
        current_frame += 1
        video_cv2.set(cv2.CAP_PROP_POS_FRAMES, current_frame)
    
    print(f"Frames extracted: {len(frames)}")

    batch_size = 24
    video_batch = math.ceil(len(frames) / batch_size)
    image_features = torch.empty([0, 512], dtype=torch.float16).to(device)

    for i in range(video_batch):
        print(f"Processing batch {i+1}/{video_batch}")
        frame_batch = frames[i*batch_size : (i+1)*batch_size]
        preprocess_batch = torch.stack([preprocess(frame) for frame in frame_batch]).to(device)
        with torch.no_grad():
            batch_features = model.encode_image(preprocess_batch)
            batch_features /= batch_features.norm(dim=-1, keepdim=True)
        image_features = torch.cat((image_features, batch_features))

    print(f"Features: {image_features.shape}")
    
    with torch.no_grad():
        text_features = model.encode_text(clip.tokenize(prompt).to(device))
        text_features /= text_features.norm(dim=-1, keepdim=True)
    similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
    top_frame, index_frame  = similarity.topk(friend_frames, dim=0)
    if plot:
        fig = px.imshow(similarity.cpu().numpy(), aspect='auto', color_continuous_scale='cividis')
        fig.update_layout(coloraxis_showscale=True)
        fig.update_xaxes(showticklabels=True)
        fig.update_yaxes(showticklabels=True)
        # fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
        fig.show()
        plt.plot(similarity.cpu().numpy())
        print()
    
    if show_frames:
        for i in index_frame:
            display(frames[i])
            seconds = round(i.cpu().numpy()[0] * 1 / frame_cv2)
            display(HTML(f"{str(datetime.timedelta(seconds=seconds))} (<a target=\"_blank\" &t={seconds}\">link</a>)"))
    

In [19]:
def ask_task(task_name, prompt, plot=True, show_frames=False, friend_frames=1):
    frames = []
    video_cv2 = cv2.VideoCapture(task_name)
    frame_cv2 = video_cv2.get(cv2.CAP_PROP_FPS)
    current_frame = 0
    while video_cv2.isOpened():
        ret, frame = video_cv2.read()
        if ret == True:
            frames.append(Image.fromarray(frame[:, :, ::-1]))
        else:
            break
        current_frame += 1
        video_cv2.set(cv2.CAP_PROP_POS_FRAMES, current_frame)
    
    print(f"Frames extracted: {len(frames)}")

    batch_size = 24
    video_batch = math.ceil(len(frames) / batch_size)
    image_features = torch.empty([0, 512], dtype=torch.float16).to(device)

    for i in range(video_batch):
        print(f"Processing batch {i+1}/{video_batch}")
        frame_batch = frames[i*batch_size : (i+1)*batch_size]
        preprocess_batch = torch.stack([preprocess(frame) for frame in frame_batch]).to(device)
        with torch.no_grad():
            batch_features = model.encode_image(preprocess_batch)
            batch_features /= batch_features.norm(dim=-1, keepdim=True)
        image_features = torch.cat((image_features, batch_features))

    print(f"Features: {image_features.shape}")
    
    with torch.no_grad():
        text_features = model.encode_text(clip.tokenize(prompt).to(device))
        text_features /= text_features.norm(dim=-1, keepdim=True)
    similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
    top_frame, index_frame  = similarity.topk(friend_frames, dim=0)
    if plot:
        fig = px.imshow(similarity.cpu().numpy(), aspect='auto', color_continuous_scale='cividis')
        fig.update_layout(coloraxis_showscale=True)
        fig.update_xaxes(showticklabels=True)
        fig.update_yaxes(showticklabels=True)
        # fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
        fig.show()
        plt.plot(similarity.cpu().numpy())
        print()
    
    if show_frames:
        for i in index_frame:
            display(frames[i])
            seconds = round(i.cpu().numpy()[0] * 1 / frame_cv2)
            display(HTML(f"{str(datetime.timedelta(seconds=seconds))} (<a target=\"_blank\" &t={seconds}\">link</a>)"))
    

In [8]:
frames = []
video_cv2 = cv2.VideoCapture(kettle)
frame_cv2 = video_cv2.get(cv2.CAP_PROP_FPS)
current_frame = 0
while video_cv2.isOpened():
    ret, frame = video_cv2.read()
    if ret == True:
        frames.append(Image.fromarray(frame[:, :, ::-1]))
    else:
        break
    current_frame += 1
    video_cv2.set(cv2.CAP_PROP_POS_FRAMES, current_frame)

print(f"Frames extracted: {len(frames)}")

Frames extracted: 70


In [16]:
image_input = torch.tensor(np.stack([preprocess(frame) for frame in frames])).to(device)

In [17]:
with torch.no_grad():
    image_features = model.encode_image(image_input).float()

In [18]:
image_features /= image_features.norm(dim=-1, keepdim=True)

In [20]:
ask_task(task_name=kettle, prompt="push the kettle", plot=True, show_frames=True, friend_frames=3)

: 

: 