In [1]:
import cv2
import torch
import torch.nn as nn
from torchvision import transforms
from model import SimpleCNN
from PIL import Image
from utils.preprocess import Preprocess
from collections import Counter
import numpy as np

# 設定設備
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 加載模型和權重
model = SimpleCNN(num_classes=11).to(device)
model.load_state_dict(torch.load('simple_cnn_1220_epoch30.pth'))
model.eval()

# 定義轉換
transform = transforms.Compose([
    transforms.Resize((240, 320)),
    transforms.ToTensor(),
])

# 類別名稱列表
class_names = ['8', '5', '4', '9', 'ok', '1', '7', '6', '3', '2', '0']

# 初始化攝影機
cap = cv2.VideoCapture(0)

# 檢查攝影機是否打開
if not cap.isOpened():
    print("Error: Could not open video stream.")
    exit()

# 創建 Preprocess 對象
processor = Preprocess()

# 儲存預測結果的列表
predictions = []
equation = ''
count_limit = -1

# 儲存每個 block 記錄的穩定幀數量
block_stability = {symbol: 0 for symbol in ['+', '-', '*', '/', 'c']}

# 儲存每個 block 的當前判斷結果
last_predicted_symbol = {symbol: None for symbol in ['+', '-', '*', '/', 'c']}

while True:
    # 捕獲視頻幀
    ret, frame = cap.read()
    
    if not ret:
        break

    num_enable = 1 #當檢查到block有非黑色pixel則設為0

    # 反轉幀 (1 表示水平反轉，0 表示垂直反轉，-1 表示水平和垂直反轉)
    frame = cv2.flip(frame, 1)
    # 定義每個正方形的大小和位置
    square_size = 50

    # 定義符號和位置
    symbols = ['+', '-', '*', '/', 'c']
    left_top = [(330, 10), (390, 10), (450, 10), (510, 10), (570, 10)]
    right_bottom = [(380, 60), (440, 60), (500, 60), (560, 60), (620, 60)]

    try:
        # 圖像預處理
        preprocessed_img = processor.hsv_segmentation(frame)
        preprocessed_img = processor.largest_connected_component(preprocessed_img)
        preprocessed_img = processor.gray_level(preprocessed_img)
    except:
        continue #如果沒有物件，跑CCL會出錯

    # 繪製正方形和添加符號
    for i, (lt, rb) in enumerate(zip(left_top, right_bottom)):
        cv2.rectangle(frame, lt, rb, (255, 255, 255), 2)  # 繪製白色的正方形
        cv2.putText(frame, symbols[i], (lt[0] + 10, lt[1] + 35), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 2, cv2.LINE_AA)  # 在正方形內添加符號
        # 獲取當前 block 的範圍
        x1, y1 = lt
        x2, y2 = rb

        # 從 preprocessed_img 中提取當前 block 的區域
        block = preprocessed_img[y1:y2, x1:x2]

        # 檢查 block 中是否有不為黑的像素（灰階圖中值不為0）
        current_symbol = None
        if np.any(block != 0):
            current_symbol = symbols[i]
            num_enable = 0

        # 如果 current_symbol 不為 None，則處理其穩定幀計數
        if current_symbol:
            # 如果該 block 中的符號與前一次相同，增加穩定幀數
            if current_symbol == last_predicted_symbol[symbols[i]]:
                block_stability[current_symbol] += 1
            else:
                block_stability[current_symbol] = 0

            # 如果穩定了 60 幀，將該符號加入 equation
            if block_stability[current_symbol] == 60:
                equation += current_symbol

                if symbols[i] == 'c': equation = ''

                # 清空穩定幀計數器
                block_stability = {symbol: 0 for symbol in symbols}

            # 更新最後一次的符號判斷
            last_predicted_symbol[symbols[i]] = current_symbol

    if num_enable and count_limit < 0:
        # 將幀轉換為PIL圖像，然後應用轉換
        image = cv2.cvtColor(preprocessed_img, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(image)
        image = transform(image).unsqueeze(0).to(device)

        # 進行預測
        with torch.no_grad():
            outputs = model(image)
            _, predicted = torch.max(outputs, 1)
            predicted_class = class_names[predicted.item()]

        # 在幀上顯示結果
        cv2.putText(frame, predicted_class, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2, cv2.LINE_AA)
        
        # 如果 predictions 不是空的，檢查是否有不同的預測值
        if predictions and predicted_class != predictions[-1]:
            predictions = []  # 發現不同預測時清空列表

        # 將預測結果加入列表
        predictions.append(predicted_class)
        
        # 每60幀顯示最頻繁的類別
        if len(predictions) == 60:
            # 計算最頻繁的類別
            counter = Counter(predictions)
            most_common_class, count = counter.most_common(1)[0]
            if most_common_class == 'ok': 
                most_common_class = '='
                result = eval(equation)
                most_common_class += str(result)
                count_limit = 120

            equation += most_common_class
            
            # 清空預測列表
            predictions = []
    
    # 顯示結果
    cv2.putText(frame, f'Equation: {equation}', (10, 420), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 2, cv2.LINE_AA)

    # 顯示視頻幀
    cv2.imshow('Video', frame)
    count_limit -= 1

    if count_limit == 0: equation = ''

    # 按下'q'鍵退出
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

# 釋放攝影機和關閉所有窗口
cap.release()
cv2.destroyAllWindows()


  model.load_state_dict(torch.load('simple_cnn_1220_epoch30.pth'))
