In [None]:
import sys
sys.path.append('../train')

import ipywidgets.widgets as widgets
from IPython.display import display
import traitlets
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
import cv2
import threading

from learn import SimpleCNN,SimpleCNN_1,CustomResNet,CustomCircleDataset
from jetbot import Robot, Camera, bgr8_to_jpeg

## GUI

In [None]:
widget_width = 224
widget_height = 224

camera_widget = widgets.Image(format='jpg', width=widget_width, height=widget_height)
target_widget = widgets.Image(format='jpg', width=widget_width, height=widget_height)
image_layout = widgets.HBox([camera_widget, target_widget])

#mask slider
low_h_slider = widgets.IntSlider(description='low h', min=0, max=179, value=90,step=1)
high_h_slider = widgets.IntSlider(description=' high h', min=0, max=179, value=100,step=1)
low_s_slider = widgets.IntSlider(description=' low s', min=0, max=255, value=140,step=1)
high_s_slider = widgets.IntSlider(description=' high s', min=0, max=255, value=255,step=1)
low_v_slider = widgets.IntSlider(description=' low v', min=0, max=255, value=0,step=1)
high_v_slider = widgets.IntSlider(description=' high v', min=0, max=255, value=255,step=1)

h_slider = widgets.HBox([low_h_slider, high_h_slider])
s_slider = widgets.HBox([low_s_slider, high_s_slider])
v_slider = widgets.HBox([low_v_slider, high_v_slider])
slider = widgets.VBox([h_slider,s_slider,v_slider])

#robot param slider
far_slider = widgets.IntSlider(description='far rad', min=0, max=100, value=20,step=1)
near_slider = widgets.IntSlider(description=' near rad', min=0, max=100, value=30,step=1)
speed_slider = widgets.FloatSlider(description='speed', min=0, max=1.0, value=0,step=0.1)
interval_slider = widgets.FloatSlider(description=' interval', min=0, max=5.0, value=0.5,step=0.1)
robot_slider = widgets.VBox([far_slider, near_slider, speed_slider, interval_slider])

layout = widgets.Layout(width='128px', height='64px')
snap_button = widgets.Button(description='snapshot', button_style='success', layout=layout)
percent = widgets.IntText(layout = layout,value = 0)
gui_layout =  widgets.HBox([image_layout,percent])

## モデル

In [None]:
# デバイスの設定（CUDAが利用可能な場合はGPUを使用）
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# モデルのインスタンスを作成
model = SimpleCNN()
# モデルの重みをロードし、推論モードに設定
model.load_state_dict(torch.load('weight/model_2.pth'))
model.to(device)  # モデルをGPUに移動
model.eval()

## 関数定義

In [None]:
def model_input(image,model,device):
    pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    img_tonsor = transform(pil_image).unsqueeze(0).to(device)  # テンソルをGPUに移動
    
    with torch.no_grad():
            outputs = model(img_tonsor)
            predicted_coords = outputs.cpu().numpy()[0]  # 結果をCPUに戻す
            pos_x, pos_y, pos_r = predicted_coords
            return pos_x, pos_y, pos_r
        

#マスク処理する関数
def apply_hsv_threshold(image):
#     low_h, high_h = 91, 103  # Example range for yellow hue
#     low_s, high_s = 200, 255 # Example range for saturation
#     low_v, high_v = 140, 255 # Example range for value
    low_h, high_h = low_h_slider.value, high_h_slider.value
    low_s, high_s = low_s_slider.value, high_s_slider.value
    low_v, high_v = low_v_slider.value, high_v_slider.value

    # Convert the image from RGB to HSV
    hsv_image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
    
    # Define the lower and upper bounds of the HSV threshold
    lower_bound = np.array([low_h, low_s, low_v])
    upper_bound = np.array([high_h, high_s, high_v])
    
    # Create a mask where pixels within the threshold are white, and others are black
    mask = cv2.inRange(hsv_image, lower_bound, upper_bound)
    
    # Create an all black image
    black_image = np.zeros_like(image)
    
    # Copy the pixels from the original image where the mask is white
    result_image = np.where(mask[:, :, None] == 255, image, black_image)
    
    return result_image



## カメラ

In [None]:
camera = Camera.instance()
#画像処理後のデータ
img_with_circle = traitlets.Any()



In [None]:
def camera_processing_thread():
    while True:
        frame = camera.value
        # 画像処理ロジック
        processed_frame = apply_hsv_threshold(frame)
        processed_image = processed_frame  # 処理後の映像を更新
        # モデルを使用して予測
        global model,device
        x, y, r = model_input(processed_image,model,device)
        x = int(x * processed_image.shape[1])
        y = int(y * processed_image.shape[0])
        r = int(r * (processed_image.shape[1] * 1.414))
        
        if pos_r < 0:
            pos_r = 0
            
        global img_with_circle
        img_with_circle = cv2.circle(processed_image, (x, y), r, (0, 255, 0), 2)
        
        
traitlets.dlink((img_with_circle, 'value'), (camera_widget, 'value'), transform=bgr8_to_jpeg)

