In [None]:
import cv2
import torch
import torch.nn as nn
from torchvision import transforms
from model import SimpleCNN
from PIL import Image
from utils.preprocess import Preprocess
from collections import Counter
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = SimpleCNN(num_classes=11).to(device)
model.load_state_dict(torch.load('simple_cnn_1220_epoch30.pth'))
model.eval()

transform = transforms.Compose([
    transforms.Resize((240, 320)),
    transforms.ToTensor(),
])

class_names = ['8', '5', '4', '9', 'ok', '1', '7', '6', '3', '2', '0']

cap = cv2.VideoCapture(0)
processor = Preprocess()

if not cap.isOpened():
    print("Error: Could not open video stream.")
    exit()

predictions = []
equation = ''
count_limit = -1

# record stable frame of each block
block_stability = {symbol: 0 for symbol in ['+', '-', '*', '/', 'c']}
last_predicted_symbol = {symbol: None for symbol in ['+', '-', '*', '/', 'c']}

while True:

    ret, frame = cap.read()
    
    if not ret:
        break

    num_enable = 1 # if detect non-black pixel => 0

    frame = cv2.flip(frame, 1)

    square_size = 50

    # operator position block
    symbols = ['+', '-', '*', '/', 'c']
    left_top = [(330, 10), (390, 10), (450, 10), (510, 10), (570, 10)]
    right_bottom = [(380, 60), (440, 60), (500, 60), (560, 60), (620, 60)]

    try:
        preprocessed_img = processor.hsv_segmentation(frame)
        preprocessed_img = processor.largest_connected_component(preprocessed_img)
        preprocessed_img = processor.gray_level(preprocessed_img)
    except:
        continue # if no object ccl might error

    for i, (lt, rb) in enumerate(zip(left_top, right_bottom)):
        cv2.rectangle(frame, lt, rb, (255, 255, 255), 2)  
        cv2.putText(frame, symbols[i], (lt[0] + 10, lt[1] + 35), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 2, cv2.LINE_AA)  
        
        x1, y1 = lt
        x2, y2 = rb

        # detect block area in preprocessed_img
        block = preprocessed_img[y1:y2, x1:x2]

        current_symbol = None
        if np.any(block != 0):
            current_symbol = symbols[i]
            num_enable = 0

        if current_symbol:
            if current_symbol == last_predicted_symbol[symbols[i]]:
                block_stability[current_symbol] += 1
            else:
                block_stability[current_symbol] = 0

            if block_stability[current_symbol] == 60:
                equation += current_symbol

                if symbols[i] == 'c': equation = ''

                block_stability = {symbol: 0 for symbol in symbols}

            last_predicted_symbol[symbols[i]] = current_symbol

    if num_enable and count_limit < 0:
        
        image = cv2.cvtColor(preprocessed_img, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(image)
        image = transform(image).unsqueeze(0).to(device)

        with torch.no_grad():
            outputs = model(image)
            _, predicted = torch.max(outputs, 1)
            predicted_class = class_names[predicted.item()]

        cv2.putText(frame, predicted_class, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2, cv2.LINE_AA)
        
        if predictions and predicted_class != predictions[-1]:
            predictions = []  

        predictions.append(predicted_class)
        
        # stable prediction 60 frame
        if len(predictions) == 60:
            counter = Counter(predictions)
            most_common_class, count = counter.most_common(1)[0]
            if most_common_class == 'ok': 
                most_common_class = '='
                result = eval(equation)
                most_common_class += str(result)
                count_limit = 120

            equation += most_common_class
            
            predictions = []
    
    cv2.putText(frame, f'Equation: {equation}', (10, 420), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 2, cv2.LINE_AA)
    cv2.imshow('Video', frame)
    count_limit -= 1

    if count_limit == 0: equation = ''

    # click q to exit
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()

