# データセットの収集と転移学習


In [None]:
import Jetson.GPIO as GPIO

BOARD_NAME=GPIO.gpio_pin_data.get_data()[0]
if BOARD_NAME == "JETSON_NX":
    print("Jetson Xavier NXを認識")
    I2C_BUSNUM = 8
    MODE = 3
elif BOARD_NAME == "JETSON_XAVIER":
    print("Jetson AGX Xavierを認識")
    I2C_BUSNUM = 8
    MODE = 2
elif BOARD_NAME == "JETSON_NANO":
    print("Jetson Nanoを認識")
    I2C_BUSNUM = 1
    MODE = 0
elif BOARD_NAME == "JETSON_ORIN":
    print("Jetson AGX Orinを認識")
    I2C_BUSNUM = 7
    MODE = 0
elif BOARD_NAME == "JETSON_ORIN_NANO":
    print("Jetson Orin Nanoを認識")
    I2C_BUSNUM = 7
    MODE = 0

In [None]:
!echo "jetson" | sudo -S nvpmodel -m $MODE

In [None]:
!echo "jetson" | sudo -S nvpmodel -q

In [None]:
!echo "jetson" | sudo -S jetson_clocks

### ログ用のWidgewt

In [None]:
import ipywidgets
from ipywidgets import Button, Layout, Textarea, HBox, VBox, Label

l = Layout(flex='0 1 auto', height='100px', min_height='100px', width='auto')
process_widget = ipywidgets.Textarea(description='ログ', value='', layout=l)

process_no = 0
def write_log(msg):
    global process_widget, process_no
    process_no = process_no + 1
    process_widget.value = str(process_no) + ": " + msg + "\n" + process_widget.value

### カメラの初期化

In [None]:
from jetcam.csi_camera import CSICamera
# from jetcam.usb_camera import USBCamera

camera = CSICamera(width=224, height=224)
# camera = USBCamera(width=224, height=224)

camera.running = True

### データセットのタスクの定義

In [None]:
import torchvision.transforms as transforms
from xy_dataset import XYDataset
import os
import subprocess

TASKS = ['interactive', 'train']

CATEGORIES = ['xy', 'speed']

DATASETS = ['A', 'B', 'C', 'D']
DATASETS_TRAIN = ['RE_A', 'RE_B', 'RE_C', 'RE_D']

for name in DATASETS:
    current_path = os.getcwd()
    path = current_path + "/" + TASKS[0] + "/" + name + "/"
    if not os.path.exists(path):
        subprocess.call(['mkdir', '-p', path])            

TRANSFORMS = transforms.Compose([
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

datasets = {}
for name in DATASETS:
    datasets[name] = XYDataset(TASKS[0] + '/' + name, CATEGORIES, TRANSFORMS, random_hflip=True)
    print(TASKS[0] + '/' + name)
for name in DATASETS_TRAIN:
    datasets[name] = XYDataset(TASKS[1] + '/' + name, CATEGORIES, TRANSFORMS, random_hflip=True)
    print(TASKS[1] + '/' + name)

## エディタ作成

In [None]:
import re
# 汎用レイアウトを定義
description_style = {'description_width': 'initial'}
widget_width = ipywidgets.Layout(width=str(camera.width)+'px')
widget_width_half = ipywidgets.Layout(width=str(camera.width/2)+'px')
widget_width_third = ipywidgets.Layout(width=str(int(camera.width/3))+'px')

def get_xy(path):
    """
    path:
        interactive/A/xy/103_90_e7227980-4ca6-11ee-8ab3-d03745aec1a2.jpg
    """
    pattern = r'^(.*/)(\d+)_(\d+)_(\w+-\w+-\w+-\w+-\w+)\.jpg$'
    result = re.match(pattern, path)
    
    if result:
        dirpath = result.group(1)
        x = int(result.group(2))
        y = int(result.group(3))
        uuid = result.group(4)
    else:
        dirpath = ''
        x = 0
        y = 0
        uuid = ''
        
    return dirpath, x, y, uuid

### データ収集

In [None]:
import cv2
import traitlets
from IPython.display import display
from jetcam.utils import bgr8_to_jpeg
from jupyter_clickable_image_widget import ClickableImageWidget
import os

# initialize active dataset
dataset = datasets[DATASETS[0]]
#print(DATASETS[0])

# 保存済み画像がある場合、最初からedit画面を有効にします
if len(dataset) > 0:
    is_snapshot_active = True
else:
    is_snapshot_active = False

# unobserve all callbacks from camera in case we are running this cell for second time
camera.unobserve_all()

# create image preview
camera_widget = ClickableImageWidget(width=camera.width, height=camera.height)
snapshot_widget = ClickableImageWidget(width=camera.width, height=camera.height)
traitlets.dlink((camera, 'value'), (camera_widget, 'value'), transform=bgr8_to_jpeg)
no_widget = ipywidgets.IntText(description='no', style=description_style, layout=widget_width)

# create widgets
task_widget = ipywidgets.Dropdown(options=TASKS, description='task', layout=widget_width)
dataset_widget = ipywidgets.Dropdown(options=DATASETS, description='dataset', layout=widget_width)
category_widget = ipywidgets.Dropdown(options=dataset.categories, description='category', layout=widget_width)
count_xy_widget = ipywidgets.IntText(description='xy count', layout=widget_width)
count_speed_widget = ipywidgets.IntText(description='speed count', layout=widget_width)

# manually update counts at initialization
count_xy_widget.value = dataset.get_count(CATEGORIES[0])
count_speed_widget.value = dataset.get_count(CATEGORIES[1])

# sets the active dataset
def set_dataset(change):
    global dataset, dataset_widget, task_widget
    try:
        dataset = XYDataset(task_widget.value + '/' + dataset_widget.value, CATEGORIES, TRANSFORMS, random_hflip=True)
        #dataset = datasets[change['new']]
        count_xy_widget.value = dataset.get_count(CATEGORIES[0])
        count_speed_widget.value = dataset.get_count(CATEGORIES[1])
    except:
        write_log("データセットが存在していません。")
dataset_widget.observe(set_dataset, names='value')

# update counts when we select a new category
def update_counts(change):
    count_xy_widget.value = dataset.get_count(CATEGORIES[0])
    count_speed_widget.value = dataset.get_count(CATEGORIES[1])
category_widget.observe(update_counts, names='value')


def change_task(change):
    global dataset_widget, task_widget
    try:
        current_path = os.getcwd()
        path = current_path + "/" + task_widget.value + "/"
        files = os.listdir(path)
        dirs = [f for f in files if os.path.isdir(os.path.join(path, f))]
        dirs = sorted(dirs)
        dataset_widget.options = dirs
    except:
        write_log(path + "が存在していません。")
        dataset_widget.options = []
task_widget.observe(change_task, names='value')
change_task(TASKS[0])

def save_snapshot(_, content, msg):
    global process_widget, is_snapshot_active
    if content['event'] == 'click':
        data = content['eventData']
        # クリックしたx,y座標を取得します（ピクセル座標）
        x = data['offsetX']
        y = data['offsetY']
        
        # save to disk
        dataset.save_entry(category_widget.value, camera.value, x, y)
        
        # display saved snapshot
        snapshot = camera.value.copy()
        snapshot = cv2.circle(snapshot, (x, y), 8, (0, 255, 0), 3)
        snapshot_widget.value = bgr8_to_jpeg(snapshot)
        count_xy_widget.value = dataset.get_count("xy")
        write_log("count:"+str(count_xy_widget.value))
        count_speed_widget.value = dataset.get_count(CATEGORIES[1])
        write_log("xy (" + str(x) + "," + str(y) + ")のポイントでデータを登録しました。")
        # スナップショットウィジェットが有効になったことを記録します
        is_snapshot_active = True
        no_widget.value = len(dataset)

camera_widget.on_msg(save_snapshot)

data_collection_widget = ipywidgets.VBox([
    ipywidgets.HBox([camera_widget, snapshot_widget]),
    ipywidgets.HBox([ipywidgets.VBox([dataset_widget,count_xy_widget,count_speed_widget]),task_widget]),
])

display(data_collection_widget)

## スナップショットウィジェットで座標を編集できるようにします。

In [None]:
def load_img(no):
    """
    noは1からn番までの値で、ファイル番号を表します。
    """
    if len(dataset) == 0:
        no_widget.value = 0
        write_log("データファイルが存在しません。")
        return
    if no > len(dataset):
        no = 1
    if no < 1:
        no = len(dataset)

    no_widget.value = no
    name = dataset.annotations[no -1]['image_path']
    write_log(str(no) + "枚目の" + name + "を読込みます。")

    dirpath, x, y, uuid = get_xy(name)
    write_log("x,y,name: {},{}, {}".format(x,y,name))
    img = cv2.imread(name)
    marked_img = img.copy()
    marked_img = cv2.circle(marked_img, (int(x), int(y)), 8, (0, 255, 0), 3)
    snapshot_widget.value = bgr8_to_jpeg(marked_img)
    write_log(str(no) + "枚目の" + name + f"を読込みました。({x}, {y})")


def delete_load_img(no):
    """
    load_img()との違いは、最後のファイルを削除したときに、先頭のファイルを表示するのではなく、最後尾のファイルを表示すること。
        if no > len(dataset):
        no += -1
    noは1からn番までの値で、ファイル番号を表します。
    """
    count_xy_widget.value = len(dataset)

    if len(dataset) == 0:
        no_widget.value = 0
        write_log("データファイルが存在しません。")
        return
    if no >= len(dataset):
        no = len(dataset)
    if no < 1:
        no = 1

    no_widget.value = no
    name = dataset.annotations[no -1]['image_path']
    write_log(str(no) + "枚目の" + name + "を読込みます。")

    dirpath, x, y, uuid = get_xy(name)
    write_log("x,y,name: {},{}, {}".format(x,y,name))
    img = cv2.imread(name)
    marked_img = img.copy()
    marked_img = cv2.circle(marked_img, (int(x), int(y)), 8, (0, 255, 0), 3)
    snapshot_widget.value = bgr8_to_jpeg(marked_img)
    write_log(str(no) + "枚目の" + name + f"を読込みました。({x}, {y})")


def delete_img(no):
    """
    noは1からn番までの値で、ファイル番号を表します。
    """
    if len(dataset) == 0:
        no_widget.value = 0
        write_log("データファイルが存在しません。")
        return
    if no > len(dataset):
        no = len(dataset)
    if no < 1:
        no = 1

    no_widget.value = no
    name = dataset.annotations[no -1]['image_path']
    write_log(str(no) + "枚目の" + name + "を削除します。")
    
    # ファイルパスを決定します
    image_path = name
    # 画像を削除します
    if os.path.exists(image_path):
        os.remove(image_path)
        write_log(str(no) + "枚目の" + name + "を削除しました。")
    else:
        write_log(str(no) + "枚目の" + name + "は存在しません。")
    dataset.refresh()

def delete_pic(c):
    global is_snapshot_active
    is_snapshot_active = True
    no = no_widget.value
    delete_img(no)
    delete_load_img(no)

def prev_pic(c):
    global is_snapshot_active
    is_snapshot_active = True
    no = no_widget.value
    no = int(no) - 1
    load_img(no)

def next_pic(c):
    global is_snapshot_active
    is_snapshot_active = True
    no = no_widget.value
    no = int(no) + 1
    load_img(no)

def save_edit(_, content, msg):
    if content['event'] == 'click' and is_snapshot_active == True:
        try:
            #load_flag = False
            data = content['eventData']
            x = data['offsetX']
            y = data['offsetY']

            old_file_name = dataset.annotations[no_widget.value -1]['image_path']
            dirpath, old_x, old_y, uuid = get_xy(old_file_name)
            old_file_path = old_file_name
            write_log("old_file_path: {}".format(old_file_path))
            new_file_name = '%s%03d_%03d_%s' % (dirpath, x, y, uuid) + '.jpg'
            new_file_path = new_file_name
            write_log("new_file_path: {}".format(new_file_path))
            os.rename(old_file_path, new_file_path)
        
            # display saved remarked_img
            remarked_img = cv2.imread(new_file_path)
            remarked_img = cv2.circle(remarked_img, (int(x), int(y)), 8, (0, 255, 0), 3)
            snapshot_widget.value = bgr8_to_jpeg(remarked_img)
            dataset.refresh()
            write_log(f"新しい座標で保存しました。({x}, {y})")
        except Exception as e:
            write_log(f"{e}")

snapshot_widget.on_msg(save_edit)
prev_pic_button = ipywidgets.Button(description='prev', layout=widget_width_third)
next_pic_button = ipywidgets.Button(description='next', layout=widget_width_third)
delete_pic_button = ipywidgets.Button(description='delete', layout=widget_width_third)

prev_pic_button.on_click(prev_pic)
next_pic_button.on_click(next_pic)
delete_pic_button.on_click(delete_pic)


# 保存済み画像がある場合、snapshot_widgetに保存済みの最初の画像を表示します
if is_snapshot_active:
    load_img(1)


In [None]:
# スナップショットウィジェットと説明と表示中のファイル番号とファイル操作ボタンを垂直に配置します
vb_snapshot_widget = ipywidgets.VBox([
    snapshot_widget,
    ipywidgets.Label('edit data'),
    no_widget,
    ipywidgets.HBox([prev_pic_button, delete_pic_button, next_pic_button])],
    layout=ipywidgets.Layout(align_items='center')
)

### 転移学習用の学習済みモデルを読み込み

In [None]:
import subprocess
import datetime
import os
import glob

model_load_widget = ipywidgets.Dropdown(options=[],description='読込モデル')
model_load_time_widget = ipywidgets.Label(description='作成日時：')
model_save_name_widget = ipywidgets.Text(description='保存モデル名',value="model.pth")

def model_list(change):
    global model_load_widget
    try:
        files = glob.glob('./model/*.pth', recursive=True)
        files.insert(0,"[new]")
        model_load_widget.options = files
        model_load_time_widget.value = '作成日時：'
    except:
        model_load_widget.options = []
model_list("list")

def change_file(change):
    global model_load_widget
    try:
        file = model_load_widget.value
        ts = os.path.getctime(file)
        d = datetime.datetime.fromtimestamp(ts)
        s = d.strftime('%Y-%m-%d %H:%M:%S')
        model_load_time_widget.value = f'作成日時：{s}'
    except:
        model_load_time_widget.value = '作成日時：'
model_load_widget.observe(change_file, names='value')

In [None]:
import torch
import torchvision
import os
import subprocess

device = torch.device('cuda')
output_dim = 2 * len(dataset.categories)  # x, y coordinate for each category

def pretrained_model():
    # ALEXNET
    # model = torchvision.models.alexnet(pretrained=True)
    # model.classifier[-1] = torch.nn.Linear(4096, output_dim)

    # SQUEEZENET 
    # model = torchvision.models.squeezenet1_1(pretrained=True)
    # model.classifier[1] = torch.nn.Conv2d(512, output_dim, kernel_size=1)
    # model.num_classes = len(dataset.categories)

    # RESNET 18
    model = torchvision.models.resnet18(pretrained=True)
    model.fc = torch.nn.Linear(512, output_dim)

    # RESNET 34
    # model = torchvision.models.resnet34(pretrained=True)
    # model.fc = torch.nn.Linear(512, output_dim)

    # DENSENET 121
    # model = torchvision.models.densenet121(pretrained=True)
    # model.classifier = torch.nn.Linear(model.classifier.in_features, output_dim)

    return model

def weights_model():
    # ALEXNET
    # model = torchvision.models.alexnet(weights=torchvision.models.AlexNet_Weights.DEFAULT)
    # model.classifier[-1] = torch.nn.Linear(4096, output_dim)

    # SQUEEZENET 
    # model = torchvision.models.squeezenet1_1(weights=torchvision.models.SqueezeNet1_1_Weights.DEFAULT)
    # model.classifier[1] = torch.nn.Conv2d(512, output_dim, kernel_size=1)
    # model.num_classes = len(dataset.categories)

    # RESNET 18
    model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
    model.fc = torch.nn.Linear(512, output_dim)

    # RESNET 34
    # model = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.DEFAULT)
    # model.fc = torch.nn.Linear(512, output_dim)

    # DENSENET 121
    # model = torchvision.models.densenet121(weights=torchvision.models.DenseNet121_Weights.DEFAULT)
    # model.classifier = torch.nn.Linear(model.classifier.in_features, output_dim)

    return model

# torchvisionのバージョン文字列を取得
version_str = torchvision.__version__

# 正規表現でメジャー、マイナー、パッチのバージョンを抜き出す
match = re.match(r'(\d+)\.(\d+)\.(\d+)', version_str)
if match:
    major, minor, _ = map(int, match.groups())

    # 0.13以上の場合
    if major > 0 or minor >= 13:
        # pretrainedが非推奨となったため、最新の学習済みwightsを使う
        # https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/
        model = weights_model()
    else:
        # pretrainedを使う
        model = pretrained_model()
else:
    write_log("Unable to parse torchvision version")

model = model.to(device)

model_save_button = ipywidgets.Button(description='save model', layout=widget_width_half)
model_load_button = ipywidgets.Button(description='load model', layout=widget_width_half)

def load_model(c):
    model.load_state_dict(torch.load(model_load_widget.value))
    write_log(model_load_widget.value + "を読み込みました。evaluateボタンを1回実行してください。")
model_load_button.on_click(load_model)
    
def save_model(c):
    path = "./model/"
    if not os.path.exists(path):
        subprocess.call(['mkdir', '-p', path])
    torch.save(model.state_dict(), path + model_save_name_widget.value)
    write_log("学習結果を" + path + model_save_name_widget.value + "に保存しました。")
model_save_button.on_click(save_model)

### ライブにデータセット作成と学習を実行

In [None]:
import threading
from utils import preprocess
import torch.nn.functional as F

state_widget = ipywidgets.ToggleButtons(options=['stop', 'live'], description='state', value='stop', style=description_style)
state_widget.style.button_width='50px'
prediction_widget = ipywidgets.Image(format='jpeg', width=camera.width, height=camera.height)

def live(state_widget, model, camera, prediction_widget):
    global dataset
    while state_widget.value == 'live':
        image = camera.value
        preprocessed = preprocess(image)
        output = model(preprocessed).detach().cpu().numpy().flatten()
        category_index = dataset.categories.index(category_widget.value)
        x = output[2 * category_index]
        y = output[2 * category_index + 1]
        
        x = int(camera.width * (x / 2.0 + 0.5))
        y = int(camera.height * (y / 2.0 + 0.5))
        
        prediction = image.copy()
        prediction = cv2.circle(prediction, (x, y), 8, (255, 0, 0), 3)
        prediction_widget.value = bgr8_to_jpeg(prediction)
            
def start_live(change):
    if change['new'] == 'live':
        execute_thread = threading.Thread(target=live, args=(state_widget, model, camera, prediction_widget))
        execute_thread.start()

state_widget.observe(start_live, names='value')

live_execution_widget = ipywidgets.VBox([
    prediction_widget,
    state_widget
])

display(live_execution_widget)

In [None]:
import time

BATCH_SIZE = 8

optimizer = torch.optim.Adam(model.parameters())
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)

epochs_widget = ipywidgets.IntText(description='epochs', value=1, layout=widget_width)
eval_button = ipywidgets.Button(description='evaluate', layout=widget_width_half)
train_button = ipywidgets.Button(description='train', layout=widget_width_half)
loss_widget = ipywidgets.FloatText(description='loss', layout=widget_width)
progress_widget = ipywidgets.FloatProgress(min=0.0, max=1.0, description='progress', layout=widget_width)

def train_eval(is_training):
    global BATCH_SIZE, LEARNING_RATE, MOMENTUM, model, dataset, optimizer, eval_button, train_button, accuracy_widget, loss_widget, progress_widget, state_widget
    
    try:
        if is_training:
            write_log("学習を開始します(50枚で10秒/1epoch, 100枚で20秒/1epoch,150枚で30秒/1epoch)")
        else:
            write_log("評価を開始します。")
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=BATCH_SIZE,
            shuffle=True
        )

        state_widget.value = 'stop'
        train_button.disabled = True
        eval_button.disabled = True
        time.sleep(1)
        if is_training:
            model = model.train()
        else:
            model = model.eval()
        total_start_time = time.time()
        epoch_count = 1
        while epochs_widget.value > 0:
            start_time = time.time()
            i = 0
            sum_loss = 0.0
            error_count = 0.0
            for images, category_idx, xy in iter(train_loader):
                # send data to device
                images = images.to(device)
                xy = xy.to(device)
                
                if is_training:
                    # zero gradients of parameters
                    optimizer.zero_grad()

                # execute model to get outputs
                outputs = model(images)

                # compute MSE loss over x, y coordinates for associated categories
                loss = 0.0
                for batch_idx, cat_idx in enumerate(list(category_idx.flatten())):
                    loss += torch.mean((outputs[batch_idx][2 * cat_idx:2 * cat_idx+2] - xy[batch_idx])**2)
                loss /= len(category_idx)
                
                if is_training:
                    # run backpropogation to accumulate gradients
                    loss.backward()

                    # step optimizer to adjust parameters
                    optimizer.step()
                # increment progress
                count = len(category_idx.flatten())
                i += count
                sum_loss += float(loss)
                progress_widget.value = i / len(dataset)
                loss_widget.value = sum_loss / i
            if is_training:
                end_time = time.time() - start_time
                write_log(str(epoch_count)+"epoch目が完了(処理時間:" + str(round(end_time,2)) + "秒)")
                epochs_widget.value = epochs_widget.value - 1
                epoch_count += 1
            else:
                break
        total_end_time = time.time() - total_start_time
        if is_training:
            write_log("学習を完了しました:(処理時間合計:" + str(round(total_end_time,2)) + "秒)")
        else:
            write_log("評価を完了しました。")
    except:
        write_log("Error")
        pass
    model = model.eval()

    train_button.disabled = False
    eval_button.disabled = False
    state_widget.value = 'live'
    
train_button.on_click(lambda c: train_eval(is_training=True))
eval_button.on_click(lambda c: train_eval(is_training=False))
    
train_eval_widget = ipywidgets.VBox([
    epochs_widget,
    progress_widget,
    loss_widget,
    ipywidgets.HBox([train_button, eval_button])
])

display(train_eval_widget)

### ファイルの読込保存

### 全部まとめて表示

以下のウィジェットは，複数クラスのx, yデータセットのラベル付けに使用できます． 1つの画像に対して，各クラスのインスタンスを1つだけラベリングすることをサポートしていますが（例：犬は1つだけ），1つの画像に対して複数のクラス（例：犬，猫，馬）をラベリングすることも可能です．

左上の画像をクリックすると、``category``の画像が、クリックした場所の``dataset``に保存されます。

| Widget | Description |
|--------|-------------|
| dataset | datasetを選択する |
| category | categoryを選択する |
| epochs | 転移学習のためのエポック数の数値を設定する |
| train | エポック数で指定された階数だけ選択されたデータセットを学習する  |
| evaluate | 選択したデータセットの1エポックあたりの精度を評価する |
| model path | モデルのパス名を指定 |
| load | model pathで指定した、学習済みモデルを読込 |
| save | model pathで指定した、学習済みモデルを保存 |
| stop | ライブデモを停止する |
| live | ライブデモを起動する |

In [None]:
# カメラウィジェットと説明とファイル総数を垂直に配置します
vb_data_collection_widget = ipywidgets.VBox([
        ipywidgets.HBox([camera_widget]),
        ipywidgets.Label('click to collect data'),
        ipywidgets.VBox([task_widget,dataset_widget,count_xy_widget,count_speed_widget]),
    ], layout=ipywidgets.Layout(align_items='center'))


all_widget = ipywidgets.VBox([
    # ウィジェットを水平に配置します

    ipywidgets.HBox([vb_data_collection_widget, 
                     vb_snapshot_widget,
                     live_execution_widget]), 
    train_eval_widget,
    ipywidgets.HBox([model_load_widget,model_load_time_widget,model_load_button]),
    ipywidgets.HBox([model_save_name_widget,model_save_button]),
    process_widget,
])

display(all_widget)

### カメラの終了処理(必須)

次のNotebookに進む際に必ずカメラの終了処理を実行してください。<br>
エラーがでますが、エラーがでた場合でも、処理が正しく実行されています。

In [None]:
camera.running = False
camera.cap.release()