In [1]:
import torch
import torchvision.transforms as transforms
import copy
from utils.utils import channel12ToImgSize, bgr8_to_jpeg, imgXYToChannel12
from utils.ImageList import ImageList

In [2]:
project = "0422school"
dataset = "A"
category = "choicest"

width = 224
height = 224

TRANSFORMS = transforms.Compose([
        transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),  # 改变图片的对比度 亮度 饱和度 色调
        transforms.Resize((224, 224)),  # 调整图片尺寸到[224, 224]
        transforms.ToTensor(),  # 转换到Tensor
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        # 用平均值和标准差对浮点张量图像进行标准化 list内三个值对图片的三个通道使用不同值
    ])

In [3]:
image_list = ImageList(project, category, dataset, transform=TRANSFORMS, random_hflip=True)

2000 files in the project folder


In [4]:
import ipywidgets
from IPython.display import display
import threading
import time

In [5]:
# create image preview
camera_widget = ipywidgets.Image(width=width, height=height)

# create widgets
dataset_widget = ipywidgets.Text(value=dataset, description='dataset', disabled=True, style={'description_width': 'initial'})
category_widget = ipywidgets.Text(value=category, description='category', disabled=True, style={'description_width': 'initial'})
count_widget = ipywidgets.IntText(description='count', disabled=True, value=len(image_list), style={'description_width': 'initial'})

# ----------------------------------------------------------------------------------------------------------------------------------------------
CH1_widget = ipywidgets.BoundedFloatText(description='CH1', value=0, min=-1, max=1, disabled=True, style={'description_width': 'initial'})
CH2_widget = ipywidgets.BoundedFloatText(description='CH2', value=0, min=-1, max=1, disabled=True, style={'description_width': 'initial'})
view_X_widget = ipywidgets.BoundedIntText(description='X', value=0, min=0, max=height, step=1, disabled=True, style={'description_width': 'initial'})
view_Y_widget = ipywidgets.BoundedIntText(description='Y', value=0, min=0, max=height, step=1, disabled=True, style={'description_width': 'initial'})
#-------------------------------------------------------------------------------------------------------------------------------
def set_camera_widgets(image_dict):
    img_x, img_y = image_dict["img_XY"]["X"], image_dict["img_XY"]["Y"]
    image = image_list.draw_circle(image_value=image_dict["img_value"], img_x=img_x, img_y=img_y)
    if view_ruler_widget.value:
        camera_widget.value = bgr8_to_jpeg(image_list.draw_ruler(image))
    else:
        camera_widget.value = bgr8_to_jpeg(image)
    view_X_widget.value = img_x
    view_Y_widget.value = img_y
    CH1_widget.value = image_dict["standard_channel"]["1"]
    CH2_widget.value = image_dict["standard_channel"]["2"]

# 定义播放按钮
play_button = ipywidgets.ToggleButton(description='play', value=False, disabled=False)
change_bool = play_button.value
def play():
    global image_list, change_bool
    while True:
        # if change_bool is True
        if change_bool:
            image_dict = image_list.get_next()
            img_x, img_y = image_dict["img_XY"]["X"], image_dict["img_XY"]["Y"]
            image = image_list.draw_circle(image_value=image_dict["img_value"], img_x=img_x, img_y=img_y)
            if view_ruler_widget.value:
                camera_widget.value = bgr8_to_jpeg(image_list.draw_ruler(image))
            else:
                camera_widget.value = bgr8_to_jpeg(image)
            view_X_widget.value = img_x
            view_Y_widget.value = img_y
            index_widget.value = image_list.get_index()
            filename_list_widget.value = image_dict["img_name"]
            set_trained_widget(image_dict)
        time.sleep(0.2)
execute_thread = threading.Thread(target=play)
execute_thread.start()

def playClick(change):
    global change_bool
    change_bool = change.new
    if play_button.description == "play":
        play_button.description = "stop"
    else:
        play_button.description = "play"
play_button.observe(playClick, names='value')

forward_button = ipywidgets.Button(description='next')
def forward(c):
    image_dict = image_list.get_next(step_widget.value)
    set_camera_widgets(image_dict)
    index_widget.value = image_list.get_index()
    filename_list_widget.value = image_dict["img_name"]
    set_trained_widget(image_dict)
forward_button.on_click(forward)

back_button = ipywidgets.Button(description='back')
def back(c):
    image_dict = image_list.get_next(-step_widget.value)
    set_camera_widgets(image_dict)
    index_widget.value = image_list.get_index()
    filename_list_widget.value = image_dict["img_name"]
    set_trained_widget(image_dict)
back_button.on_click(back)

step_widget = ipywidgets.BoundedIntText(value=1, min=0, max=count_widget.value, step=1)


filename_list = image_list.get_filename_list()
filename_list_widget = ipywidgets.Dropdown(options=filename_list, value=filename_list[0] , description='img path', style={'description_width': 'initial'})
def filename_list_update(change):
    new_filename = change.new
    image_list.index = image_list.get_filename_list().index(new_filename)
    image_dict = image_list.get_value()
    index_widget.value = image_list.get_index()
    set_camera_widgets(image_dict)
filename_list_widget.observe(filename_list_update, names='value')

index_widget = ipywidgets.IntText(description='index', value=0,style={'description_width': 'initial'})
def index_change(change):
    new_index = change.new
    if new_index > len(image_list):
        new_index -= len(image_list)
    image_list.index = new_index
    image_dict = image_list.get_value()
    set_camera_widgets(image_dict)
    filename_list_widget.value = image_dict["img_name"]
index_widget.observe(index_change, names='value')


view_ruler_widget = ipywidgets.Checkbox(value=True, description="draw ruler", indent=True, disabled=False, style={'description_width': 'initial'})
def view_draw_ruler(change):
    global image_list
    image_dict = image_list.get_value()
    image = image_list.draw_circle(image_value=image_dict["img_value"], img_x=view_X_widget.value, img_y=view_Y_widget.value)
    if change.new:
        camera_widget.value = bgr8_to_jpeg(image_list.draw_ruler(image))
    else:
        camera_widget.value = (bgr8_to_jpeg(image))        
view_ruler_widget.observe(view_draw_ruler, names='value')

In [6]:
# 初始值
image_dict = image_list.get_value()
img_x, img_y = image_dict["img_XY"]["X"], image_dict["img_XY"]["Y"]
image = image_list.draw_circle(image_value=image_dict["img_value"], img_x=img_x, img_y=img_y)
if view_ruler_widget.value:
    camera_widget.value = bgr8_to_jpeg(image_list.draw_ruler(image))
else:
    camera_widget.value = bgr8_to_jpeg(image)
view_X_widget.value = img_x
view_Y_widget.value = img_y
index_widget.value = image_list.get_index()
filename_list_widget.value = image_dict["img_name"]

In [7]:
from utils.train import set_model, save_model, load_model, preprocess

In [8]:
# 训练段初始化
# 初始化数据集

# 设定GPU版torch
device = torch.device('cuda')
# 设置模型   RESNET 18
model = set_model(model_name="RESNET_18", output_dim=2)
# 将模型传到GPU
model_state = model.to(device)
# 模型训练相关
BATCH_SIZE = 64

# 优化器
# https://zhuanlan.zhihu.com/p/32338983
optimizer = torch.optim.Adam(model_state.parameters())
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)

# 训练轮数
# epochs_widget = ipywidgets.IntText(description='epochs', value=10)

# 训练/验证
train_eval = "train"
# train_eval = "eval"

In [9]:
train_view_widget = ipywidgets.Image(width=width, height=height)
loss_widget = ipywidgets.FloatText(description='loss', style={'description_width': 'initial'}, disabled=True)
epochs_widget = ipywidgets.BoundedIntText(description='epochs', value=1, step=1, min=0, max=200, style={'description_width': 'initial'})
progress_widget = ipywidgets.FloatProgress(min=0.0, max=1.0, description='progress', disabled=True, style={'description_width': 'initial'})
model_path_widget = ipywidgets.Text(description='model path', value='road_following_model.pth', style={'description_width': 'initial'})
trained_CH1_widget = ipywidgets.BoundedFloatText(description='CH1', value=0, min=-1, max=1, disabled=True, style={'description_width': 'initial'})
trained_CH2_widget = ipywidgets.BoundedFloatText(description='CH2', value=0, min=-1, max=1, disabled=True, style={'description_width': 'initial'})
trained_X_widget = ipywidgets.IntText(description='X', value=0, disabled=True, style={'description_width': 'initial'})
trained_Y_widget = ipywidgets.IntText(description='Y', value=0, disabled=True, style={'description_width': 'initial'})
trained_ruler_widget = ipywidgets.Checkbox(value=True, description="draw ruler", indent=True, disabled=False, style={'description_width': 'initial'})

def set_trained_widget(image_dict):
    global evaluated
    if evaluated is True:
        try:
            CH1_value, CH2_value = image_dict["trained_channel"]["CH1"], image_dict["trained_channel"]["CH2"]
            img_x, img_y = image_dict["trained_XY"]["X"], image_dict["trained_XY"]["Y"]
            image = image_list.draw_circle(image_value=image_dict["img_value"], img_x=img_x, img_y=img_y)
            if trained_ruler_widget.value:
                train_view_widget.value = bgr8_to_jpeg(image_list.draw_ruler(image))
            else:
                train_view_widget.value = bgr8_to_jpeg(image)
            trained_X_widget.value = img_x
            trained_Y_widget.value = img_y
            trained_CH1_widget.value = CH1_value
            trained_CH2_widget.value = CH2_value
        except:
            print("error in set_trained_widget")
            exit()


def trained_ruler(change):
    global image_list
    image_dict = image_list.get_value()
    img_x, img_y = image_dict["trained_XY"]["X"], image_dict["trained_XY"]["Y"]
    image = image_list.draw_circle(image_value=image_dict["img_value"], img_x=img_x, img_y=img_x)
    if change.new:
        trained_ruler_widget.value = bgr8_to_jpeg(image_list.draw_ruler(image))
    else:
        trained_ruler_widget.value = (bgr8_to_jpeg(image))        
trained_ruler_widget.observe(trained_ruler, names='value')

In [10]:
def train_disabled(state):
    model_path_widget.disabled = state
    epochs_widget.disabled = state
    train_button.disabled = state
    eval_button.disabled = state
    index_widget.disabled = state
    filename_list_widget.disabled = state
    play_button.disabled = state
    forward_button.disabled = state
    back_button.disabled = state
    model_load_button.disabled = state
    model_save_button.disabled = state
    step_widget.disabled = state
    if play_button.value is True:
        play_button.value = False
        play_button.description = "play"

In [11]:
model_load_button = ipywidgets.Button(description='load model')
def load_model_widget(c):
    global model_state
    load_model(model_state, model_path_widget.value)
    model_save_button.disabled = False
    eval_button.disabled = False
    traineval(False)
model_load_button.on_click(load_model_widget)

model_save_button = ipywidgets.Button(description='save model', disabled=True) 
def save_model_widget(c):
    global model_state
    save_model(model_state, model_path_widget.value)
model_save_button.on_click(save_model_widget)


eval_button = ipywidgets.Button(description='evaluate', disabled=True)
train_button = ipywidgets.Button(description='train')
def traineval(is_training):
    global BATCH_SIZE, LEARNING_RATE, MOMENTUM, model_state, image_list, optimizer, eval_button, train_button, accuracy_widget, loss_widget, progress_widget
    
    try:
        train_loader = torch.utils.data.DataLoader(
            image_list,
            batch_size=BATCH_SIZE,
            shuffle=True
        )
        
        # 停止播放循环
        play_button.value = False
        play_button.description = "play"
        # 停止所有可更改项
        train_disabled(True)
        
        time.sleep(1)

        if is_training:
            model_state = model_state.train()
        else:
            model_state = model_state.eval()

        while epochs_widget.value > 0:
            i = 0
            sum_loss = 0.0
            error_count = 0.0
            for images, category_idx, xy in iter(train_loader):
                # send data to device
                images = images.to(device)
                xy = xy.to(device)

                if is_training:
                    # zero gradients of parameters
                    optimizer.zero_grad()

                # execute model to get outputs
                outputs = model_state(images)
                print(outputs)

                # compute MSE loss over x, y coordinates for associated categories
                loss = 0.0
                for batch_idx, cat_idx in enumerate(list(category_idx.flatten())):
                    loss += torch.mean((outputs[batch_idx][2 * cat_idx:2 * cat_idx+2] - xy[batch_idx])**2)
                loss /= len(category_idx)

                if is_training:
                    # run backpropogation to accumulate gradients
                    loss.backward()

                    # step optimizer to adjust parameters
                    optimizer.step()

                # increment progress
                count = len(category_idx.flatten())
                i += count
                sum_loss += float(loss)
                progress_widget.value = i / len(image_list)
                loss_widget.value = sum_loss / i
                
            if is_training:
                epochs_widget.value -= 1
            else:
                break
    except:
        pass
    model_state = model_state.eval()

    train_disabled(False)


    
evaluated = False
def evla_img_list(change):
    global image_list, model_state, evaluated
    train_disabled(True)
    progress_widget.value = 0
    for single_dict in image_list.iter_img_list():
        image = single_dict["img_value"]
        preprocessed = preprocess(image)
        model_state = model_state.eval()
        local_output = model_state(preprocessed).detach().cpu().numpy().flatten()
        CH1 = local_output[0]
        CH2 = local_output[1]
        print(local_output)
        img_x, img_y = channel12ToImgSize(CH1, CH2, width, height)
        image_list.json_dict[single_dict["count"]].update({
            "trained_channel":{"CH1":CH1, "CH2":CH2},
            "trained_XY":{"X":img_x, "Y":img_y}})
        progress_widget.value += 1 / len(image_list)
    train_disabled(False)
    evaluated = True
    progress_widget.value = 0
eval_button.on_click(evla_img_list)
train_button.on_click(lambda c: traineval(is_training=True))

In [15]:
play_button.layout.width='{}px'.format((width/4) - 3)
forward_button.layout = play_button.layout
step_widget.layout = play_button.layout
back_button.layout = play_button.layout

view_X_widget.layout.width = '{}px'.format((width/2) - 3)
view_Y_widget.layout = view_X_widget.layout

CH1_widget.layout = view_X_widget.layout
CH2_widget.layout = view_X_widget.layout
trained_CH1_widget.layout = view_X_widget.layout
trained_CH2_widget.layout = view_X_widget.layout

# ------------------------------------------------------------------------
dataset_widget.layout.width = '{}px'.format(width)
category_widget.layout = dataset_widget.layout
filename_list_widget.layout = dataset_widget.layout
count_widget.layout = dataset_widget.layout
index_widget.layout = dataset_widget.layout
# ------------------------------------------------------------------------
model_path_widget.layout.width = '{}px'.format(width)
progress_widget.layout.width = '{}px'.format(width)

loss_widget.layout.width = '{}px'.format(width)
epochs_widget.layout = loss_widget.layout
view_ruler_widget.layout = loss_widget.layout

eval_button.layout.width = '{}px'.format((width/2) - 2)
train_button.layout = eval_button.layout
model_save_button.layout = eval_button.layout
model_load_button.layout = eval_button.layout
trained_X_widget.layout = eval_button.layout
trained_Y_widget.layout = eval_button.layout

In [16]:
view_XY_text_widget = ipywidgets.HBox([view_X_widget, view_Y_widget])
trained_XY_text_widget = ipywidgets.HBox([trained_X_widget, trained_Y_widget])

channel_text_widget = ipywidgets.HBox([CH1_widget, CH2_widget])
trained_channel_text_widget = ipywidgets.HBox([trained_CH1_widget, trained_CH2_widget])

play_botton_widget = ipywidgets.HBox([play_button, forward_button, step_widget, back_button])


middle_widget = ipywidgets.VBox([
    ipywidgets.Label("",layout=ipywidgets.Layout(width='20px', height='80px'))
])

view_widget_single = ipywidgets.HBox([
    ipywidgets.VBox([camera_widget, play_botton_widget, view_XY_text_widget, channel_text_widget]), 
    middle_widget,
    ipywidgets.VBox([
        filename_list_widget,
        index_widget,
        count_widget,
        dataset_widget,
        category_widget,
        view_ruler_widget])
])

train_mode_left = ipywidgets.VBox([
    camera_widget,
    view_XY_text_widget,
    channel_text_widget,
    play_botton_widget,
    filename_list_widget,
    index_widget,
    count_widget,
    dataset_widget,
    category_widget,
    view_ruler_widget,])

train_mode_right = ipywidgets.VBox([
    train_view_widget,
    trained_XY_text_widget,
    trained_channel_text_widget,
    loss_widget,
    epochs_widget,
    progress_widget,
    ipywidgets.HBox([train_button, eval_button]),
    model_path_widget,
    ipywidgets.HBox([model_load_button, model_save_button])])


children = [view_widget_single, 
            ipywidgets.HBox([train_mode_left, middle_widget, train_mode_right])]
tab = ipywidgets.Tab()
tab.children = children
# titles = ['View Mode', 'Change Mode', 'Train Mode']
titles = ['View Mode', 'Train Mode']
for i in range(len(children)):
    tab.set_title(i, titles[i])
display(tab)
    
# tab.selected_index = 0
# time.sleep(3)
# tab.selected_index = 1
# time.sleep(3)
# tab.selected_index = 0
# display(tab)

Tab(children=(HBox(children=(VBox(children=(Image(value=b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01…

注：change mode当前的双XY/CH值分割显示还有问题 待更改

In [None]:
# print(image_list.json_dict[107])