In [1]:
import traitlets
import ipywidgets.widgets as widgets
from ipywidgets import HTML
from IPython.display import display
import threading

# from matplotlib import pyplot as plt
import numpy as np

import cv2
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image


from learn import SimpleCNN,SimpleCNN_1,CustomResNet,CustomCircleDataset
# Camera and Motor Interface for JetBot
from jetbot import Robot, Camera, bgr8_to_jpeg

In [2]:
camera = Camera()

widget_width = 224
widget_height = 224

camera_widget = widgets.Image(format='jpg', width=widget_width, height=widget_height)
target_widget = widgets.Image(format='jpg', width=widget_width, height=widget_height)
image_layout = widgets.HBox([camera_widget, target_widget])

In [3]:
low_h_slider = widgets.IntSlider(description='low h', min=0, max=179, value=90,step=1)
high_h_slider = widgets.IntSlider(description=' high h', min=0, max=179, value=100,step=1)
low_s_slider = widgets.IntSlider(description=' low s', min=0, max=255, value=140,step=1)
high_s_slider = widgets.IntSlider(description=' high s', min=0, max=255, value=255,step=1)
low_v_slider = widgets.IntSlider(description=' low v', min=0, max=255, value=0,step=1)
high_v_slider = widgets.IntSlider(description=' high v', min=0, max=255, value=255,step=1)

h_slider = widgets.HBox([low_h_slider, high_h_slider])
s_slider = widgets.HBox([low_s_slider, high_s_slider])
v_slider = widgets.HBox([low_v_slider, high_v_slider])
slider = widgets.VBox([h_slider,s_slider,v_slider])

In [4]:
layout = widgets.Layout(width='128px', height='64px')
snap_button = widgets.Button(description='snapshot', button_style='success', layout=layout)
percent = widgets.IntText(layout = layout,value = 0)
gui_layout =  widgets.HBox([image_layout,percent])

In [6]:
m# デバイスの設定（CUDAが利用可能な場合はGPUを使用）
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# モデルのインスタンスを作成
model = SimpleCNN()

# モデルの重みをロードし、推論モードに設定
model.load_state_dict(torch.load('train/model_best.pth', map_location=device))
model.to(device)
model.eval()

SimpleCNN(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (fc1): Linear(in_features=100352, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=3, bias=True)
)

In [7]:
i = 0
dir_path = ''

def save_snapshot(change):
    global i
    image_path = dir_path + 'sample' + str(i) + '.jpg'
    i += 1
    with open(image_path, 'wb') as f:
        f.write(target_widget.value)

snap_button.on_click(save_snapshot)
        
def apply_hsv_threshold(image):
#     low_h, high_h = 91, 103  # Example range for yellow hue
#     low_s, high_s = 200, 255 # Example range for saturation
#     low_v, high_v = 140, 255 # Example range for value
    low_h, high_h = low_h_slider.value, high_h_slider.value
    low_s, high_s = low_s_slider.value, high_s_slider.value
    low_v, high_v = low_v_slider.value, high_v_slider.value

    # Convert the image from RGB to HSV
    hsv_image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
    
    # Define the lower and upper bounds of the HSV threshold
    lower_bound = np.array([low_h, low_s, low_v])
    upper_bound = np.array([high_h, high_s, high_v])
    
    # Create a mask where pixels within the threshold are white, and others are black
    mask = cv2.inRange(hsv_image, lower_bound, upper_bound)
    
    # Create an all black image
    black_image = np.zeros_like(image)
    
    # Copy the pixels from the original image where the mask is white
    result_image = np.where(mask[:, :, None] == 255, image, black_image)
    
    return result_image

def calibration(image):
    # カメラ行列と歪み係数を設定（キャリブレーションプロセスから取得）
    camera_matrix = np.array([[108.16614089 ,  0 ,        111.55104361],
                            [  0,         143.70743912 ,121.87531702],
                            [  0,           0,           1,        ]])
    distortion_coefficients = np.array([-0.33140309 , 0.12007599, -0.00293275,  0.00091844 ,-0.02038985])

    # 補正したい画像を読み込む
    undistorted_image = cv2.undistort(image, camera_matrix, distortion_coefficients)
    
    return undistorted_image

def calculate_yellow_percentage(image):
    
    # マスク画像内の白いピクセル（黄色いピクセルに相当）の数を数える
    yellow_pixels = np.count_nonzero(image)

    # 画像内の全ピクセルの数
    total_pixels = image.shape[0] * image.shape[1]

    # 黄色いピクセルの割合を計算
    yellow_percentage = (yellow_pixels / total_pixels) * 100

    return yellow_percentage

def display_xy(camera_image):
    image = np.copy(camera_image)
    image = calibration(image)
    image = apply_hsv_threshold(image)
    
    pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    image = transform(pil_image).unsqueeze(0).to(device)

    # モデルを使用して予測
    with torch.no_grad():
        outputs = model(image)
        predicted_coords = outputs.cpu().numpy()[0]
        pos_x, pos_y, pos_r = predicted_coords

    # 座標と半径をスケーリング
    pos_x = int(pos_x * camera_image.shape[1])
    pos_y = int(pos_y * camera_image.shape[0])
    pos_r = int(pos_r * (camera_image.shape[1] * 1.414))
    if pos_r < 0:
        pos_r = 0

    img = np.copy(camera_image)
    img = apply_hsv_threshold(img)
    # percent.value = calculate_yellow_percentage(img)  # 必要に応じて実装

    img = cv2.circle(img, (pos_x, pos_y), pos_r, (0, 255, 0), 2)
    jpeg_image = bgr8_to_jpeg(img)

    return jpeg_image

traitlets.dlink((camera, 'value'), (camera_widget, 'value'), transform=bgr8_to_jpeg)
traitlets.dlink((camera, 'value'), (target_widget, 'value'), transform=display_xy)
snap_button.on_click(save_snapshot)


In [8]:
display(image_layout)
display(slider)
display(snap_button)

HBox(children=(Image(value=b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xdb\x00C…

VBox(children=(HBox(children=(IntSlider(value=90, description='low h', max=179), IntSlider(value=100, descript…

Button(button_style='success', description='snapshot', layout=Layout(height='64px', width='128px'), style=Butt…