In [1]:
import os
os.environ['YOLO_VERBOSE'] = 'False'

In [2]:
import sys
import cv2
from collections import defaultdict

In [3]:
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image

In [4]:
device = torch.device('cpu') # 'cuda:0' if torch.cuda.is_available() else 'cpu')

In [5]:
model = torch.load('models/model.pt', map_location=device)

In [6]:
from IPython.display import display, Image as IMG  # Import IPython display functionality
import ipywidgets as widgets  # Import the ipywidgets library for creating interactive widgets
import threading  # Import the threading library for multithreading

In [7]:
## Create a toggle button as a stop button
stopButton = widgets.ToggleButton(
    value=False,  # The initial state of the button is unselected
    description='Stop',  # Text displayed on the button
    disabled=False,  # The button is initially enabled
    button_style='danger',  # The button style is red
    tooltip='Description',  # Tooltip displayed when hovering over the button
    icon='square'  # Icon displayed on the button
)

In [8]:
class_names = {
    0: 'calling',
    1: 'clapping',
    2: 'cycling',
    3: 'dancing',
    4: 'drinking',
    5: 'eating',
    6: 'fighting',
    7: 'hugging',
    8: 'laughing',
    9: 'listening_to_music',
    10: 'running',
    11: 'sitting',
    12: 'sleeping',
    13: 'texting',
    14: 'using_laptop'
}

In [9]:
from utils.streaming import VStream
from utils.tracking import track

In [10]:
vstream = VStream()
vstream.start()

In [11]:
transform = transforms.Compose([
        transforms.CenterCrop(256),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

Define a function for displaying the video stream

In [12]:
def showStream(func): 
    def inner(*args, **kwargs):
        display_handle = display(None, display_id=True) # Create a display handle for updating the displayed content
        while True: # for _ in range(50):
            try:           
                button, the_frame, state, coords = func(*args, **kwargs)
                for st, coord in zip(state, coords):
                    cv2.putText(the_frame, class_names[st], coord, cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,255), 2, 1)
                _, jpeg = cv2.imencode('.jpg', the_frame) # Encode the frame as JPEG format
                display_handle.update(IMG(data=jpeg.tobytes())) # Update the displayed image
                
                if button.value == True: # Check if the button is pressed
                    vstream.capture.release()
                    cv2.destroyAllWindows()
                    display_handle.update(None) # Clear the displayed content
                
            except Exception as e:
                # print(f"{type(e).__name__}: {str(e)}")
                continue
        return display_handle
    return inner

Update the captured frame

In [13]:
@showStream
def getStream(button, transform, device, model):
    button = button
    coords, state = [], []
    current_frame = vstream.read()
    r = (1, 1)
    boxes = track(current_frame, persist=False);
    a = boxes.cpu().numpy().copy()
    nrows, _ = a.shape
    r = (1, 1)
    for i in range(nrows):
        x1, y1, x2, y2 = ( a[i][:] ).astype(int)
        coord = ( int( x1*r[0] ), int( y1*r[1] ) )       
        coords.append(coord)
        try:
            orgx, orgy = int((x1+x2)/2), int((y1+y2)/2)
            crop_img = current_frame[orgy-128:orgy+128, orgx-128:orgx+128]
            
            # crop_img = resized_frame[y1:y2, x1:x2]
            # crop_img = cv2.resize(crop_img, (256, 256))  # Resize to match the classification model input size
            
            image = Image.fromarray(crop_img)
            img = transform(image)
            img = img.unsqueeze(0)
            img = img.to(device)
            with torch.no_grad():
                outputs = model(img)
                _, st = torch.max(outputs, 1)
            state.append(st[0].item())
        except:
            pass
    return button, current_frame, state, coords

In [None]:
display(stopButton)
thrd = threading.Thread(target=getStream, args=(stopButton, transform, device, model))
thrd.start()

In [15]:
if thrd.is_alive:
    print('Still running!')
else:
    print('Completed')