# アノテーション

アノテーションでは、11_record_cameraで撮影した走行データにアノテーションを実施し、転移学習をおこないます。

Desktop板は、OSXの学習結果はJetRacerにうまく反映できない場合があるようです(GPUの違いで各種Float等の扱いの差分等が原因)

In [15]:
import os

# WORKSPACEは、環境に合わせて書き直す!964
WORKSPACE = None # <- 環境に合わせて修正してください。Noneを指定すると、Currentフォルダを参照

if WORKSPACE == None:
    current_path = os.getcwd()
else:
    current_path = WORKSPACE

走行データはcameraフォルダに録画されています。今度は、cameraフォルダのデータにアノテーションをおこないます。

In [16]:
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.patches as patches
import re
import ipywidgets
from utils import preprocess
from ipywidgets import Button, Layout, Textarea, HBox, VBox, Label
#from jetcam.utils import bgr8_to_jpeg
import cv2
import torchvision.transforms as transforms
from xy_dataset import XYDataset
import time
import threading
import torch
import torchvision
import ipywidgets
from ipyevents import Event
from PIL import Image, ImageDraw, ImageFont
import io
import inspect

In [17]:
import cv2

def bgr8_to_jpeg(value, quality=75):
    return bytes(cv2.imencode('.jpg', value)[1])

In [18]:
IMG_WIDTH = 224
IMG_HEIGHT = 224

SLEEP = [50,100,200,300,400,500]
SKIP = [1,2,3,4,5]

LOAD_CATEGORIES = ['xy','speed']
SAVE_CATEGORIES = ['xy','speed']

LOAD_DATASETS = ['X','Y','Z']
SAVE_DATASETS = []

LOAD_TASK = ['camera','dataset','interactive']
SAVE_TASK = ['dataset']

check_flag = False
running = False
sleep_time = 50

def get_dirs(path):
    # ディレクトリ内のすべてのエントリを取得
    files = os.listdir(path)
    
    # ディレクトリのみをフィルタリング
    dirs = [f for f in files if os.path.isdir(os.path.join(path, f))]
    
    # 除外するディレクトリ名のリスト
    exclude_dirs = {".ipynb_checkpoints", ".DS_Store"}
    
    # 除外処理を実行
    dirs = [
        d for d in dirs
        if d not in exclude_dirs and not d.endswith('.zip')
    ]
    
    # ソートして返す
    dirs = sorted(dirs)
    return dirs

try:  
    path = os.path.join(current_path,"dataset")
    dirs = get_dirs(path)
    if not dirs:
        SAVE_DATASETS = ['dataset1']
    else:
        SAVE_DATASETS = dirs
        
except:
    SAVE_DATASETS = ['dataset1']

TRANSFORMS = transforms.Compose([
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

datasets = {}
for name in SAVE_DATASETS:
    for task in SAVE_TASK:
        datasets[name] = XYDataset(current_path + '/' + task + '/' + name, SAVE_CATEGORIES, TRANSFORMS, random_hflip=True)
        
dataset = datasets[SAVE_DATASETS[0]]

In [19]:
l = Layout(flex='0 1 auto', height='100px', min_height='100px', width='auto')
process_widget = ipywidgets.Textarea(description='ログ', value='', layout=l)

process_no = 0
def write_log(msg):
    global process_widget, process_no
    process_no = process_no + 1
    process_widget.value = str(process_no) + ": " + msg + "\n" + process_widget.value

In [20]:
sleep_dropdown = ipywidgets.Dropdown(options=SLEEP, description='sleep(ms)', index=0)
skip_dropdown = ipywidgets.Dropdown(options=SKIP, description='skip(枚)', index=1)
skip_movie_dropdown = ipywidgets.Dropdown(options=SKIP, description='skip(枚)', index=1)

#picture_widget = ClickableImageWidget(width=224, height=224)
# 画像のバイトデータを取得
image = Image.new('RGB', (224, 224), 'white')
image_byte_arr = io.BytesIO()
image.save(image_byte_arr, format='JPEG')
# 画像を表示するウィジェットを作成
picture_widget = ipywidgets.Image(
    value=image_byte_arr.getvalue(),
    format='jpg',
    width=224,
    height=224
)
no_widget = ipywidgets.IntText(description='no')
x_widget = ipywidgets.IntText(description='data x')
y_widget = ipywidgets.IntText(description='data y')
speed_widget = ipywidgets.IntText(description='data speed')
ai_x_widget = ipywidgets.IntText(description='AI　x')
ai_y_widget = ipywidgets.IntText(description='AI　y')
ai_speed_widget = ipywidgets.IntText(description='AI speed')
model_widget = ipywidgets.Text(description='model')
model_widget.value = "model.pth"
load_model_button = ipywidgets.Button(description='load model')
speed_slider = ipywidgets.IntSlider(description='speed', min=0, max=224, step=1, value=0, orientation='vertical')
add_speed_button = ipywidgets.Button(description='速度追加')

In [21]:
from packaging import version

torchvision_version = version.parse(torchvision.__version__)

# デバイスの選択
if torch.backends.mps.is_available():
    device = torch.device('mps')
elif torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    
output_dim = 2*len(LOAD_CATEGORIES)  # LOAD_CATEGORIESは事前に定義されている必要があります

if torchvision_version >= version.parse("0.13"):
    from torchvision.models.resnet import ResNet18_Weights, resnet18

    default_weights = torchvision.models.ResNet18_Weights.DEFAULT
    model = torchvision.models.resnet18(weights=default_weights)
    model.fc = torch.nn.Linear(model.fc.in_features, output_dim)
else:
    model = torchvision.models.resnet18(pretrained=True)
    model.fc = torch.nn.Linear(512, output_dim)

model = model.to(device)

In [22]:
def load_model(c):
    global torchvision_version, load_model_widget, model, device, output_dim
    
    model_name = load_model_widget.value
    # デバイスの選択
    if torch.backends.mps.is_available():
        device = torch.device('mps')
    elif torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    
    if torchvision_version >= version.parse("0.13"):
        # torchvision 0.13以降の場合
        from torchvision.models.resnet import ResNet18_Weights, resnet18
        
        if model_name == "[new]":
            # 新しい重みを使ってモデルをロード
            default_weights = torchvision.models.ResNet18_Weights.DEFAULT
            model = torchvision.models.resnet18(weights=default_weights)
            model.fc = torch.nn.Linear(model.fc.in_features, output_dim)
            model = model.to(device)
            write_log("[new]が選択されたのでresnet18の最新の重みから始めます(torchvision 0.13以降)。")
        else:
            model = torchvision.models.resnet18(weights=None)  # pretrained=Falseの代わり
            model.fc = torch.nn.Linear(model.fc.in_features, output_dim)
            model = model.to(device)
            model.load_state_dict(torch.load(model_name))
            write_log(model_name + "のモデルを読込ました(torchvision 0.13以降)。")

    else:
        # torchvision 0.13より前の場合
        if model_name == "[new]":
            model = torchvision.models.resnet18(pretrained=True)
            model.fc = torch.nn.Linear(512, output_dim)
            model = model.to(device)
            write_log("[new]が選択されたのでresnet18のpretrainedから始めます。")
        else:
            model = torchvision.models.resnet18(pretrained=False)
            model.fc = torch.nn.Linear(512, output_dim)
            model = model.to(device)
            model.load_state_dict(torch.load(model_name))
            write_log(model_name + "のモデルを読込ました。")
    
    get_jetson_nano_memory_usage()

def save_model(c):
    global save_model_name_widget, model, device
    path = "./model/"
    if not os.path.exists(path):
        subprocess.call(['mkdir', '-p', path])
    torch.save(model.state_dict(), path + save_model_name_widget.value)
    write_log(path + save_model_name_widget.value + "に保存しました。")

In [23]:
import time

BATCH_SIZE = 8

epochs_widget = ipywidgets.IntText(description='epochs', value=1)
eval_button = ipywidgets.Button(description='evaluate')
train_button = ipywidgets.Button(description='train')
loss_widget = ipywidgets.FloatText(description='loss')
progress_widget = ipywidgets.FloatProgress(min=0.0, max=1.0, description='progress')

best_loss = float('inf')  

def train_eval(is_training):
    global BATCH_SIZE, model, dataset, optimizer, best_loss, current_path
    
    optimizer = torch.optim.Adam(model.parameters())
    
    xy_path = os.path.join(current_path,save_task_widget.value,save_datasets_widget.value,"xy")
    speed_path = os.path.join(current_path,save_task_widget.value,save_datasets_widget.value,"speed")

    xy_is_dir = os.path.isdir(xy_path)
    speed_is_dir = os.path.isdir(speed_path)
    
    xy_file_count = 0
    speed_file_count = 0
    
    if xy_is_dir:
        xy_file_count = sum(os.path.isfile(os.path.join(xy_path,name)) for name in os.listdir(xy_path))
    if speed_is_dir:
        speed_file_count = sum(os.path.isfile(os.path.join(speed_path,name)) for name in os.listdir(speed_path))
    write_log("-------------------------")
    write_log("学習を開始します。")
    write_log("データセット: " + save_task_widget.value + '/' + save_datasets_widget.value)
    write_log("XYデータ数: " + str(xy_file_count) + " Speedデータ数: " + str(speed_file_count))
    write_log("-------------------------")    
    dataset = XYDataset(current_path + '/' + save_task_widget.value + '/' + save_datasets_widget.value, SAVE_CATEGORIES, TRANSFORMS, random_hflip=True)
    train_button.disabled = True
    eval_button.disabled = True
        
    # 有効なデータのインデックスを確認
    valid_indices = []
    for i in range(len(dataset)):
        try:
            item = dataset[i]
            if item is not None and item[0] is not None:
                valid_indices.append(i)
        except Exception as e:
            pass
        
    # 有効なインデックスを持つサブセットを作成
    valid_dataset = torch.utils.data.Subset(dataset, valid_indices)
    
    
    try:
        train_loader = torch.utils.data.DataLoader(
            valid_dataset,
            batch_size=BATCH_SIZE,
            shuffle=True
        )
        time.sleep(1)

        if is_training:
            model = model.train()
        else:
            model = model.eval()
        epoch_count = 0
        
        while epochs_widget.value > 0:
            epoch_start_time = time.time()  # エポック開始時間を記録
            epoch_count += 1
            i = 0
            sum_loss = 0.0
            error_count = 0.0
            for images, category_idx, xy in iter(train_loader):
                if images is None or xy is None:
                    print("Warning: None type data found at index", i)
                    continue
                images = images.to(device)
                xy = xy.to(device)

                if is_training:
                    # zero gradients of parameters
                    optimizer.zero_grad()

                # execute model to get outputs
                outputs = model(images)

                # compute MSE loss over x, y coordinates for associated categories
                loss = 0.0
                for batch_idx, cat_idx in enumerate(list(category_idx.flatten())):
                    loss += torch.mean((outputs[batch_idx][2 * cat_idx:2 * cat_idx+2] - xy[batch_idx])**2)
                loss /= len(category_idx)

                if is_training:
                    # run backpropogation to accumulate gradients
                    loss.backward()

                    # step optimizer to adjust parameters
                    optimizer.step()

                # increment progress
                count = len(category_idx.flatten())
                i += count
                sum_loss += float(loss)
                progress_widget.value = i / len(dataset)
                loss_widget.value = sum_loss / i
            
            # エポック終了時に時間を記録してログに出力
            epoch_end_time = time.time()
            epoch_duration = epoch_end_time - epoch_start_time
            write_log(f"{epoch_count} Epoch目: {epoch_duration:.2f}秒")
            #get_jetson_nano_memory_usage()
            
            # 最小損失をチェックし、必要に応じてモデルを保存
            if is_training and loss_widget.value < best_loss:
                best_loss = loss_widget.value
                model_dir = './model'
                if not os.path.exists(model_dir):
                    os.makedirs(model_dir)
                torch.save(model.state_dict(), model_dir + "/" + 'best_model.pth')
                write_log(f"新しいベストモデルが保存されました。Epoch loss: {best_loss:.4f}")

            if is_training:
                epochs_widget.value = epochs_widget.value - 1
            else:
                break
    except Exception as e:
        train_button.disabled = False
        eval_button.disabled = False
        pass
    
    model = model.eval()

    train_button.disabled = False
    eval_button.disabled = False

train_button.on_click(lambda c: train_eval(is_training=True))
eval_button.on_click(lambda c: train_eval(is_training=False))
    
train_eval_widget = ipywidgets.VBox([
    epochs_widget,
    progress_widget,
    loss_widget,
    ipywidgets.HBox([train_button, eval_button])
])

display(train_eval_widget)

VBox(children=(IntText(value=1, description='epochs'), FloatProgress(value=0.0, description='progress', max=1.…

01_finw_pwmを実行して、pwmの値を設定してください。

In [29]:
from os.path import join
import subprocess
import datetime
import glob

def extract_numbers(filename):
    matches = re.findall(r'(\d+)', filename)
    if matches and len(matches) >= 3: 
        return int(matches[-1])  
    else:
        return float('inf') 

def get_file_names(path):
    file_names = os.listdir(path)
    file_names = [os.path.join(path, file_name) for file_name in file_names]
    image_names = []

    image_names = sorted(file_names, key=lambda f: extract_numbers(os.path.basename(f)))
    image_names = [f for f in image_names if os.path.splitext(f)[1].lower() == ".jpg"]
    
    return image_names
    
def load_img(no):
    global running, img, load_flag, xy_filenames, play_num, current_path
    global image_byte_arr

    load_task_value = load_task_widget.value
    category_value = load_category_widget.value
    datasets_value = load_datasets_widget.value
    
    xy_path = os.path.join(current_path,load_task_value,datasets_value,"xy")
    speed_path = os.path.join(current_path,load_task_value,datasets_value,"speed")
        
    xy_imagenames = get_file_names(xy_path)
            
    if no >= len(xy_imagenames):
        no_widget.value = no - 1
        write_log("ファイルが存在しません。" + str(len(xy_imagenames)-1) + "以内の値を設定してください。")
        running = False
        return
    
    xy_name = xy_imagenames[no]
    basename = os.path.basename(xy_name)
    pattern = r'(\d+)_(\d+)_.*'
    xy_result = re.match(pattern, basename)
    if xy_result:
        x = xy_result.group(1)
        y = xy_result.group(2)
        x_widget.value = x
        y_widget.value = y
    else:
        write_log("正規表現にマッチしませんでした: " + basename)

    speed = 0
    try:
        speed_imagenames = get_file_names(speed_path)
        speed_name = speed_imagenames[no]
        speed_result = re.match(pattern, speed_name)
        if speed_result:
            speed = int(speed_result.group(2))
            speed_widget.value = speed
            speed_slider.value = speed
    except:
        speed_widget.value = 0
    
    img = cv2.imread(xy_name)
    if img is None:
        write_log("Image could not be loaded: " + xy_name)
        return

    marked_img = img.copy()
    
    black_color = (0, 0, 0)
    blue_color = (255, 0, 0)
    green_color = (0, 255, 0)
    
    if int(x) != 0 or int(y) != 0:
        marked_img = cv2.circle(marked_img, (int(x), int(y)), 8, green_color, 3)
    
    try:
        preprocessed = preprocess(img)
        output = model(preprocessed).detach().cpu().numpy().flatten()
        result_x = output[0]
        result_y = output[1]
        result_speed = output[3]
        result_x = int(IMG_WIDTH * (result_x / 2.0 + 0.5))
        result_y = int(IMG_HEIGHT * (result_y / 2.0 + 0.5))
        result_speed = int(IMG_HEIGHT * (result_speed / 2.0 + 0.5))
        
        ai_x_widget.value = result_x
        ai_y_widget.value = result_y
        ai_speed_widget.value = result_speed
        
        marked_img = cv2.circle(marked_img, (int(result_x), int(result_y)), 8, blue_color, 3)
        marked_img = cv2.line(marked_img,(112,0),(112,224),(255,255,0),1)
        marked_img = cv2.line(marked_img,(82,0),(82,224),(100,100,255),1)
        marked_img = cv2.line(marked_img,(52,0),(52,224),(100,100,255),1)
        marked_img = cv2.line(marked_img,(22,0),(22,224),(100,100,255),1)
        marked_img = cv2.line(marked_img,(142,0),(142,224),(100,100,255),1)
        marked_img = cv2.line(marked_img,(172,0),(172,224),(100,100,255),1)
        marked_img = cv2.line(marked_img,(202,0),(202,224),(100,100,255),1)

        marked_img = cv2.line(marked_img,(0,112),(224,112),(255,255,0),1)
        marked_img = cv2.line(marked_img,(0,82),(224,82),(100,100,255),1)
        marked_img = cv2.line(marked_img,(0,52),(224,52),(100,100,255),1)
        marked_img = cv2.line(marked_img,(0,22),(224,22),(100,100,255),1)
        marked_img = cv2.line(marked_img,(0,142),(224,142),(100,100,255),1)
        marked_img = cv2.line(marked_img,(0,172),(224,172),(100,100,255),1)
        marked_img = cv2.line(marked_img,(0,202),(224,202),(100,100,255),1)
        
        # Speed
        if result_speed> 224:
            result_speed = 224
        elif result_speed < 0:
            result_speed = 0
            
        marked_img = cv2.line(marked_img,(218,0),(218,224),black_color,5)
        marked_img = cv2.line(marked_img,(219,224-result_speed),(219,224),blue_color,3)
        marked_img = cv2.putText(marked_img,"speed:"+str(result_speed),(160,215),cv2.FONT_HERSHEY_SIMPLEX,0.3,(255,255,255))
    
        if int(speed) != 0:
            marked_img = cv2.line(marked_img,(1,0),(1,224),black_color,5)
            marked_img = cv2.line(marked_img,(2,224-int(speed)),(2,224),green_color,3)        
    except Exception as e:
        write_log(f"エラーが発生しました: {e}")
    
    picture_widget.value = bgr8_to_jpeg(marked_img)

    if running == True:
        play_num += 1
        if play_num % 10 == 0:
            write_log(f"{play_num} 回目の再生")
    else:
        next_image_button.disabled = False
        prev_image_button.disabled = False
        write_log(str(no) + "枚目の" + xy_name + "を読込ました。") 

def del_pic(c):
    global no, load_task_widget,load_datasets_widget, xy_filenames
    no = no_widget.value
    name = xy_filenames[no]
    os.remove(name)
    write_log(name + "を削除しました。")
    
def load_dataset(c):
    global img,load_flag,current_path
    dataset_path = os.path.join(current_path,load_task_widget.value,load_datasets_widget.value)
    write_log("データセット: " + dataset_path + "を読込みます。")
    load_flag = True
    no = 0
    no_widget.value = 0
    write_log("初回の読込みには時間がかかります。(30秒〜1分)")
    load_img(no)
    get_jetson_nano_memory_usage()

def next_pic(c):
    global no,x,y,load_flag,skip,check_flag
    load_flag = True
    check_flag = False
    no = no_widget.value
    no = int(no) + skip_dropdown.value
    no_widget.value = no
    next_image_button.disabled = True
    prev_image_button.disabled = True
    load_img(no)
    
def before_pic(c):
    global no,x,y,load_flag,skip,check_flag
    load_flag = True
    check_flag = False
    no = no_widget.value
    no = int(no) - skip_dropdown.value
    if no < 0:
        no = 0
    no_widget.value = no
    next_image_button.disabled = True
    prev_image_button.disabled = True
    load_img(no)

def file_count():
    global current_path
    
    try:
        xy_path = os.path.join(current_path,save_task_widget.value,save_datasets_widget.value,"xy")
        speed_path = os.path.join(current_path,save_task_widget.value,save_datasets_widget.value,"speed")
        
        xy_is_dir = os.path.isdir(xy_path)
        
        if xy_is_dir:
            xy_file_count = sum(os.path.isfile(os.path.join(xy_path,name)) for name in os.listdir(xy_path))
            datasets_xy_count_widget.value = xy_file_count
        else:
            datasets_xy_count_widget.value = 0
        
        speed_is_dir = os.path.isdir(speed_path)
        
        if speed_is_dir:
            speed_file_count = sum(os.path.isfile(os.path.join(speed_path,name)) for name in os.listdir(speed_path))
            datasets_speed_count_widget.value = speed_file_count
        else:
            datasets_speed_count_widget.value = 0
            
    except Exception as e:
        #print("An error occurred:", e)
        datasets_xy_count_widget.value = 0
        datasets_speed_count_widget.value = 0
    
def save_snapshot(_, content, msg):
    global img,x,y,load_flag,save_datasets_widget,save_task_widget,save_category_widget
    if content['event'] == 'click' and load_flag == True:
        load_flag = False
        data = content['eventData']
        x = data['offsetX']
        y = data['offsetY']

        remarked_img = img.copy()
        remarked_img = cv2.circle(remarked_img, (int(x), int(y)), 8, (0, 255, 0), 3)
        picture_widget.value = bgr8_to_jpeg(remarked_img)
        name = save_datasets_widget.value
        if save_task_widget.value == "":
            write_out("データセット名を指定してください")
        else:
            write_log("["+save_task_widget.value + "/" + name + "]の" + SAVE_CATEGORIES[0] + "カテゴリにデータを追加しました。")
            dataset = datasets[name]
            dataset.save_entry("xy", img, x, y)
            #write_log("新しい座標で保存しました。")
            file_count()

def image_handle_event(event):
    global img,x,y,load_flag,save_datasets_widget,save_task_widget,save_category_widget

    image_coordinates = (event['offsetX'], event['offsetY'])
    write_log(f'image coordinates: {image_coordinates}')
    x = event['offsetX']
    y = event['offsetY']
    
    remarked_img = img.copy()
    remarked_img = cv2.circle(remarked_img, (int(x), int(y)), 8, (0, 255, 0), 3)
    
    picture_widget.value = bgr8_to_jpeg(remarked_img)
    name = save_datasets_widget.value
    if save_task_widget.value == "":
        write_out("データセット名を指定してください")
    else:
        write_log("["+save_task_widget.value + "/" + name + "]の" + SAVE_CATEGORIES[0] + "カテゴリにデータを追加しました。")
        dataset = datasets[name]
        dataset.save_entry("xy", img, x, y)
        #write_log("新しい座標で保存しました。")
        file_count()
    
    """
    # 元の画像データを読み込み
    image_byte_arr.seek(0)
    image_copy = Image.open(image_byte_arr).copy()
    # throttle座標はrecord_infoから取得する
    no = no_widget.value
    record_info = dataset.annotations[no - 1]
    write_log(f"{record_info}")
    x = 0
    y = record_info["user/throttle"]
    (x, y) = joystick_xy_to_image_coords(image_copy.size, x, y)  # レコードデータのJoystick座標をカメラ画像widgetのピクセル座標に変換
    image_coordinates = (event['offsetX'], y)  # x軸はマウスクリックのピクセル座標をつかい、y軸はレコードデータを変換した座標を使う

    # カメラ画像の座標をステアリング画像の座標に変更する
    x = remap(image_coordinates[0], 0, image_copy.size[0]-1, 0, steering_image_copy.size[0]-1)
    y = remap(image_coordinates[1], 0, image_copy.size[1]-1, 0, steering_image_copy.size[1]-1)
    coordinates = (int(x), int(y))

    # 直線を描く
    updated_image = draw_st_cross_line(steering_image_copy, coordinates, color=((0,255,0),(223,255,223)))

    # 画像のバイトデータを再度取得
    updated_image_byte_arr = io.BytesIO()
    updated_image.save(updated_image_byte_arr, format='JPEG')

    # 画像を表示するウィジェットを更新
    steering_image_widget.value = updated_image_byte_arr.getvalue()
    
    # データを更新する
    no = no_widget.value
    record_info = dataset.annotations[no - 1]
    index = record_info["_index"]
    x = remap(coordinates[0], 0, updated_image.size[0]-1, -1.0, 1.0)
    x = max(min(x, 1.0), -1.0)
    record_info["user/angle"] = x
    dataset.save_record(index, record_info)
    load_img(no)
    """
def save_speed(c):
    global img,speed_slider,save_datasets_widget,save_task_widget
    speed = speed_slider.value
    remarked_img = img.copy()
    name = save_datasets_widget.value
    dataset = datasets[name]
    dataset.save_entry("speed", img, 0, speed)
    write_log("["+save_task_widget.value + "/" + name + "]の" + SAVE_CATEGORIES[1] + "カテゴリにデータを追加しました。")
    
    file_count()
        
def live():
    global no,running, skip, sleep_time, play_num
    load_flag = True
    play_num = 0
    no = no_widget.value
    while running:
        no += skip
        no_widget.value = no
        try:
            load_img(no)
        except:
            write_log("no: " + no + "のファイルの読込に失敗")
        time.sleep(sleep_time/1000)  
    
def play(c):
    global running, execute_thread, skip, sleep_time, check_flag
    skip = skip_dropdown.value
    sleep_time = sleep_dropdown.value
    running = True
    check_flag = False
    execute_thread = threading.Thread(target=live)
    execute_thread.start()
    
def stop(c):
    global running, execute_thread, load_flag, check_flag
    running = False
    load_flag = True
    check_flag = False
    try:
        execute_thread.join()
        write_log("STOP")
    except:
        write_log("現在再生されていません。")

def create_dataset(c):
    global datasets_name_widget, save_datasets_widget
    new_dataset_name = datasets_name_widget.value
    
    if new_dataset_name not in SAVE_DATASETS:
        SAVE_DATASETS.append(new_dataset_name)

        save_datasets_widget.options = SAVE_DATASETS
        save_datasets_widget.value = new_dataset_name
    
    datasets = {}
    for name in SAVE_DATASETS:
        for task in SAVE_TASK:
            datasets[name] = XYDataset(current_path + '/' + task + '/' + name, SAVE_CATEGORIES, TRANSFORMS, random_hflip=True)
            
    dataset = datasets[new_dataset_name]  # 新しいデータセットを指定
    write_log("Datasetを作成しました：" + new_dataset_name)


# Eventインスタンスを作成
image_event_handler = Event(source=picture_widget, watched_events=['click'])
# イベントハンドラを設定
image_event_handler.on_dom_event(image_handle_event)


# 画像の操作
play_button = ipywidgets.Button(description='▶')
stop_button = ipywidgets.Button(description='⏹')
next_image_button = ipywidgets.Button(description='>')
prev_image_button = ipywidgets.Button(description='<')
load_image_button = ipywidgets.Button(description='読込')
update_image_button = ipywidgets.Button(description='更新')
delete_image_button = ipywidgets.Button(description='削除')

save_model_button = ipywidgets.Button(description='save model')

load_model_widget = ipywidgets.Dropdown(options=[],description='読込モデル')
load_model_time_widget = ipywidgets.Text(description='作成日時')
save_model_name_widget = ipywidgets.Text(description='保存モデル名',value="model.pth")

dataset_create_button = ipywidgets.Button(description='Create dataset')
datasets_name_widget = ipywidgets.Text(description='Name')

dataset_create_button.on_click(create_dataset)

play_button.on_click(play)
stop_button.on_click(stop)

add_speed_button.on_click(save_speed)

load_model_button.on_click(load_model)
save_model_button.on_click(save_model)
load_image_button.on_click(load_dataset)
next_image_button.on_click(next_pic)
prev_image_button.on_click(before_pic)
delete_image_button.on_click(del_pic)

load_datasets_widget = ipywidgets.Dropdown(options=LOAD_DATASETS, description='dataset', index=0)
save_datasets_widget = ipywidgets.Dropdown(options=SAVE_DATASETS, description='dataset')
datasets_xy_count_widget = ipywidgets.IntText(description='XYデータ数')
datasets_speed_count_widget = ipywidgets.IntText(description='速度データ数')

def set_dataset(change):
    global dataset
    datasets[change['new']] = XYDataset(current_path + '/' + save_task_widget.value + '/' + change['new'], SAVE_CATEGORIES, TRANSFORMS, random_hflip=True)
    #dataset = datasets[change['new']]
    #write_log(change['new'])
save_datasets_widget.observe(set_dataset, names='value')

load_task_widget = ipywidgets.Dropdown(options=LOAD_TASK, description='task')
save_task_widget = ipywidgets.Dropdown(options=SAVE_TASK,  value=SAVE_TASK[0], description='task')

def change_load_task(change):
    global dataset, current_path, load_task_widget, load_datasets_widget
    try:
        path = os.path.join(current_path,load_task_widget.value)
        dirs = get_dirs(path)
        load_datasets_widget.options = dirs
    except:
        write_log(path + "が存在していません。")
        load_datasets_widget.options = []
load_task_widget.observe(change_load_task, names='value')
change_load_task(LOAD_TASK[0])

def change_save_task(change):
    global dataset, current_path
    try:
        path = os.path.join(current_path,save_task_widget.value)
        if not os.path.exists(path):
            subprocess.call(['mkdir', '-p', path])
        dirs = get_dirs(path)
        save_datasets_widget.options = dirs
    except:
        write_log(path + "が存在していません。")
        save_datasets_widget.options = ['']
save_task_widget.observe(change_save_task, names='value')
change_save_task(SAVE_TASK[0])

def change_save_dataset(change):
    global dataset, current_path
    file_count()
save_datasets_widget.observe(change_save_dataset, names='value')
change_save_dataset(SAVE_DATASETS[0])

def change_sleep(change):
    global sleep_time, sleep_dropdown
    sleep_time = sleep_dropdown.value
sleep_dropdown.observe(change_sleep, names='value')

def change_skip(change):
    global sleep, sleep_dropdown
    skip = skip_dropdown.value
skip_dropdown.observe(change_skip, names='value')

def model_list(change):
    global load_model_widget
    try:
        files = glob.glob('./model/*.pth', recursive=True)
        files.insert(0,"[new]")
        load_model_widget.options = files
        load_model_time_widget.value = ""
    except:
        load_model_widget.options = []
model_list("list")

def change_file(change):
    global load_model_widget
    try:
        file = load_model_widget.value
        ts = os.path.getctime(file)
        d = datetime.datetime.fromtimestamp(ts)
        s = d.strftime('%Y-%m-%d %H:%M:%S')
        load_model_time_widget.value = s
    except:
        load_model_time_widget.value = ""
load_model_widget.observe(change_file, names='value')

def update_image(change):
    global load_flag, no
    load_flag = True
    no = no_widget.value
    load_img(no)       
update_image_button.on_click(update_image)

load_category_widget = ipywidgets.Dropdown(options=LOAD_CATEGORIES, description='category')
save_category_widget = ipywidgets.Dropdown(options=SAVE_CATEGORIES, description='category')

In [30]:
import numpy as np
from functools import partial

WIDTH = 80
HEIGHT = 80
SIZE = 8

check_image_button = ipywidgets.Button(description=f'{SIZE}個単位チェック')
check_next_images_button = ipywidgets.Button(description=f'[{SIZE}]>')
check_prev_images_button = ipywidgets.Button(description=f'<[-{SIZE}]')
check_start_index_widget = ipywidgets.IntText(description='開始位置')
check_end_index_widget = ipywidgets.IntText(description='終了位置')
check_image_count_widget = ipywidgets.IntText(description='最終画像位置')
check_update_button = ipywidgets.Button(description='更新')

# 画像を表示するウィジェット
snapshot_widgets = []
snapshot_button_widgets = []


def edit_image(index, b):
    global load_flag,no,check_flag
    no = check_no + index
    load_flag = True
    check_flag = False
    load_img(no)
    no_widget.value = no
    
for i in range(SIZE):
    image = ipywidgets.Image(width=WIDTH, height=HEIGHT)
    edit_button = ipywidgets.Button(description="編集", layout=ipywidgets.Layout(width=f'{WIDTH}px', height=f'30px'))
    edit_button.on_click(partial(edit_image, i))
    black_image = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)
    image.value = bgr8_to_jpeg(black_image)
    snapshot_widgets.append(image)
    snapshot_button_widgets.append(VBox([edit_button,image]))

def get_load_dataset_length():
    global current_path
    xy_path = os.path.join(current_path, load_task_widget.value, load_datasets_widget.value, "xy")
    xy_filenames = get_file_names(xy_path)
    check_image_count_widget.value = len(xy_filenames)
    last_no = len(xy_filenames)
    return last_no

def next_images(c):
    global check_no,last_no,check_flag
    check_next_images_button.disabled = True
    check_prev_images_button.disabled = True
    write_log(f"check_flag:"+str(check_flag))
    if check_flag == False:
        check_no = no_widget.value
        check_flag = True
    else:
        check_no += SIZE
        write_log(f"check_no: {check_no}")
    try:
        last_no = get_load_dataset_length()
        if check_no < last_no:
            load_images(check_no)
        else:
            check_next_images_button.disabled = False
            check_prev_images_button.disabled = False
    except:
        check_next_images_button.disabled = False
        check_prev_images_button.disabled = False
    
def prev_images(c):
    global check_no,check_flag
    check_next_images_button.disabled = True
    check_prev_images_button.disabled = True
    write_log(f"check_flag:"+str(check_flag))
    if check_flag == False:
        check_no = no_widget.value
        check_no -= SIZE
        check_flag = True
    else:
        check_no -= SIZE
        write_log(f"check_no: {check_no}")
    try:
        last_no = get_load_dataset_length()
        if check_no < 0:
            check_no = 0
        load_images(check_no)
    except:
        check_next_images_button.disabled = False
        check_prev_images_button.disabled = False
    
def load_images(c):
    global snapshot_widgets,check_no,last_no, current_path
    write_log("画像を" + str(SIZE) + "枚読込み、推論結果を付与します。")
    try:
        xy_path = os.path.join(current_path,load_task_widget.value,load_datasets_widget.value,"xy")
        xy_filenames = get_file_names(xy_path)
        last_no = len(xy_filenames)
        check_image_count_widget.value = last_no
        now_no = 0
        write_log(f"{xy_path}のデータセットを呼び込みます。データ数(xy): {last_no}")
        for i in range(SIZE):
            now_no = check_no + i
            if now_no < last_no:
                try:
                    xy_name = xy_filenames[now_no]
                    img = cv2.imread(xy_name)
                    preprocessed = preprocess(img)
                    output = model(preprocessed).detach().cpu().numpy().flatten()
                    result_x = output[0]
                    result_y = output[1]
                    result_speed = output[3]
                    result_x = int(IMG_WIDTH * (result_x / 2.0 + 0.5))
                    result_y = int(IMG_HEIGHT * (result_y / 2.0 + 0.5))
                    result_speed = int(IMG_HEIGHT * (result_speed / 2.0 + 0.5))
                    marked_img = cv2.circle(img, (int(result_x), int(result_y)), 8, (255, 0, 0), 3)
                    marked_img = cv2.line(marked_img,(219,224-result_speed),(219,224),(0,140,255),3)
                    marked_img = cv2.putText(marked_img,"speed:"+str(result_speed),(160,215),cv2.FONT_HERSHEY_SIMPLEX,0.3,(255,255,255))

                    snapshot_widgets[i].value = bgr8_to_jpeg(marked_img)
                    
                    time.sleep(10/1000)
                except Exception as e:
                    write_log(f"{e}")
                    black_image = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)
                    snapshot_widgets[i].value = bgr8_to_jpeg(black_image)
            else:
                black_image = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)
                snapshot_widgets[i].value = bgr8_to_jpeg(black_image)
        check_next_images_button.disabled = False
        check_prev_images_button.disabled = False
    except Exception as e:
        write_log(f"{e}")
    #get_jetson_nano_memory_usage()

def update_images(c):
    global check_no
    check_no = check_start_index_widget.value
    load_images(check_no)
    
check_no = 0
check_image_button.on_click(load_images)
check_prev_images_button.on_click(prev_images)
check_next_images_button.on_click(next_images)
check_update_button.on_click(update_images)

In [31]:
movie_button = ipywidgets.Button(description='動画の作成')
movie_name_widget = ipywidgets.Text(description='動画名',value="run_video")

def make_movie(change):
    global model,current_path
    
    if not movie_name_widget.value.strip():
        write_log("ファイル名を指定してください。")
        return 
    write_log("動画を作成します。")
    path = os.path.join(current_path, "video/")
    if not os.path.exists(path):
        if os.name == 'nt':
            # Windowsの場合
            os.makedirs(path, exist_ok=True)
        else:
            # Mac/Linuxの場合
            subprocess.call(['mkdir', '-p', path])
    output = path + movie_name_widget.value + ".mp4"
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    fps = int(30 / skip_movie_dropdown.value)
    outfh = cv2.VideoWriter(output, fourcc, fps, (224, 224))
    file_list = sorted(
        glob.glob(load_task_widget.value + '/' + load_datasets_widget.value + '/xy/*.jpg'),
        key=os.path.getmtime
    )
    
    xy_path = os.path.join(current_path, load_task_widget.value, load_datasets_widget.value, "xy")     
    file_list = os.listdir(xy_path)
    file_list = [os.path.join(xy_path, file_name) for file_name in file_list if file_name.endswith('.jpg')]    
    file_list = sorted(file_list, key=lambda f: extract_numbers(os.path.basename(f)))
    
    
    try:
        res_num = len(file_list)
        
        count = 0
        skip_movie = skip_movie_dropdown.value
        terminal_time = 1/(30/skip_movie)
        current_time = 0
        process_time = 0
        total_process_time = 0
        for i, file_name in enumerate(file_list):
            
            if i % skip_movie == 0:
                current_time += terminal_time
                img = cv2.imread(file_name)
                
                process_time = time.time()
                preprocessed = preprocess(img)
                output = model(preprocessed).detach().cpu().numpy().flatten()
                result_x = float(output[0])
                result_y = float(output[1])
                result_x = int(IMG_WIDTH * (result_x / 2.0 + 0.5))
                result_y = int(IMG_HEIGHT * (result_y / 2.0 + 0.5))    
                img = cv2.circle(img, (int(result_x), int(result_y)), 8, (255, 0, 0), 3)

                # Speed
                result_speed = output[3]
                result_speed = int(IMG_WIDTH * (result_speed / 2.0 + 0.5))
                if result_speed > 224:
                    result_speed = 244
                elif result_speed < 0:
                    result_speed = 0
                img = cv2.line(img,(218,0),(218,224),(0,0,0),5)
                img = cv2.line(img,(219,224-result_speed),(219,224),(0,140,255),3)
                img = cv2.putText(img,"speed:"+str(result_speed),(160,215),cv2.FONT_HERSHEY_SIMPLEX,0.3,(255,255,255))
                total_process_time += time.time() - process_time 
                
                if i % (skip_movie*10) == 0:
                    write_log(f"{current_time:.1f}秒まで完了, 推論処理平均: {total_process_time/10*1000:.1f}ms, {int(i/skip_movie)}枚目/{int(res_num/skip_movie)}枚中を処理中")
                    total_process_time = 0
                outfh.write(img)
                del img
    finally:
        # エラーが発生しても確実にリソースを解放する
        outfh.release()
        write_log("動画の出力が完了しました。")
        get_jetson_nano_memory_usage()

movie_button.on_click(make_movie)

In [32]:
import subprocess
import re

used_memory_widget = ipywidgets.IntText(description='Useメモリ', value=1)
total_memory_widget = ipywidgets.IntText(description='全メモリ', value=1)
memory_button = ipywidgets.Button(description='使用メモリ量の取得')

def get_jetson_nano_memory_usage(event=None):
    command = 'tegrastats'
    try:
        process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
        
        mem_usage_pattern = re.compile(r'RAM (\d+)/(\d+)MB')
        
        max_lines_to_read = 10
        for _ in range(max_lines_to_read):
            line = process.stdout.readline()
            if not line:
                break 
            matches = mem_usage_pattern.search(line)
            if matches:
                used_memory_widget.value = int(matches.group(1))
                total_memory_widget.value = int(matches.group(2))
                write_log("使用メモリ： " + str(used_memory_widget.value) + "/" + str(total_memory_widget.value))
                process.kill()
                return
        
        process.kill()  
        return

    except subprocess.CalledProcessError as e:
        return

get_jetson_nano_memory_usage()
memory_button.on_click(get_jetson_nano_memory_usage)

In [33]:
separator = ipywidgets.HTML('<hr style="border-color:gray;margin:10px 0"/>')
title1 = ipywidgets.HTML('<b>【1.使用する推論モデル】</b> [New]は新規モデル。')
title2 = ipywidgets.HTML('<b>【2.読込元データセット】</b> アノテーションを実施するデータセットを選択。')
title3 = ipywidgets.HTML('<b>【3.保存先データセット】</b> データセットの保存先を選択。')
title4 = ipywidgets.HTML('<b>【4.アノテーションの実施】</b> 緑◯がアノテーション, 青◯がAIでの推論。車両の走らせたい場所で、画面をクリックすると保存先データセットのxyにデータが登録されます。Speedは[速度追加]で追加します。')
title5 = ipywidgets.HTML('<b>【5.学習】</b> EPOCH指定で学習できます。')
title6 = ipywidgets.HTML('<b>【6.評価動画の作成】</b> 動画を作成します。')

data_collection_widget = ipywidgets.VBox([
    separator,
    title1,
    ipywidgets.HBox([load_model_widget,load_model_time_widget,load_model_button]),
    ipywidgets.HBox([used_memory_widget,total_memory_widget,memory_button]),
    process_widget,
    separator,
    title2,
    ipywidgets.HBox([load_datasets_widget,load_task_widget,load_image_button]),
    ipywidgets.HBox([used_memory_widget,total_memory_widget,memory_button]),
    process_widget,
    separator,
    title3,
    ipywidgets.HBox([save_datasets_widget,save_task_widget]),
    ipywidgets.HBox([datasets_xy_count_widget,datasets_speed_count_widget]),
    ipywidgets.HBox([Label('datasetの新規作成'),datasets_name_widget,dataset_create_button]),
    ipywidgets.HBox([used_memory_widget,total_memory_widget,memory_button]),
    process_widget,
    separator,
    title4,
    ipywidgets.HBox([no_widget,update_image_button,delete_image_button]), 
    skip_dropdown,
    save_datasets_widget,
    ipywidgets.HBox([picture_widget,ipywidgets.VBox([speed_slider,add_speed_button]),ipywidgets.VBox([play_button,prev_image_button,Label(f'{SIZE}個単位での処理'),check_prev_images_button]),ipywidgets.VBox([stop_button,next_image_button,Label(f''),check_next_images_button])]),
    ipywidgets.HBox(snapshot_button_widgets),
    ipywidgets.HBox([save_datasets_widget,save_task_widget]),
    ipywidgets.HBox([datasets_xy_count_widget,datasets_speed_count_widget]),
    ipywidgets.HBox([used_memory_widget,total_memory_widget,memory_button]),
    process_widget,
    separator,
    title5,
    ipywidgets.HBox([save_datasets_widget,save_task_widget]),
    ipywidgets.HBox([datasets_xy_count_widget,datasets_speed_count_widget]),
    ipywidgets.HBox([epochs_widget,train_button,eval_button]),
    ipywidgets.HBox([progress_widget,loss_widget]),
    ipywidgets.HBox([save_model_name_widget,save_model_button]),
    ipywidgets.HBox([used_memory_widget,total_memory_widget,memory_button]),
    process_widget,
    separator,
    title6,
    ipywidgets.HBox([load_datasets_widget,load_task_widget]),
    ipywidgets.HBox([movie_name_widget,skip_movie_dropdown,movie_button]),
    ipywidgets.HBox([used_memory_widget,total_memory_widget,memory_button]),
    process_widget,
])
display(data_collection_widget)

VBox(children=(HTML(value='<hr style="border-color:gray;margin:10px 0"/>'), HTML(value='<b>【1.使用する推論モデル】</b> […