PyTorchの初期化

In [1]:
import torch
import torchvision

model = torchvision.models.alexnet(pretrained=False)
model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, 3)

モデルの読み込み

In [2]:
model.load_state_dict(torch.load('best_model.pth'))

カメラのフォーマットを学習済みモデルのフォーマットに合せ変換

In [3]:
import cv2
import numpy as np

mean = 255.0 * np.array([0.485, 0.456, 0.406])
stdev = 255.0 * np.array([0.229, 0.224, 0.225])

normalize = torchvision.transforms.Normalize(mean, stdev)

def preprocess(camera_value):
    global device, normalize
    x = camera_value
    x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
    x = x.transpose((2, 0, 1))
    x = torch.from_numpy(x).float()
    x = normalize(x)
    x = x.to(device)
    x = x[None, ...]
    return x

カメラ画像の取得

In [4]:
from jetcam.csi_camera import CSICamera
import ipywidgets
from IPython.display import display
from jetcam.utils import bgr8_to_jpeg

camera = CSICamera(width=224, height=224)
image = camera.read()
image_widget = ipywidgets.Image(format='jpeg')
image_widget.value = bgr8_to_jpeg(image)
display(image_widget)

Image(value=b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xdb\x00C\x00\x02\x01\x0…

カメラ画像の更新

In [5]:
camera.running = True

def update_image(change):
    image = change['new']
    image_widget.value = bgr8_to_jpeg(image)
    
camera.observe(update_image, names='value')

GPU

In [11]:
device = torch.device('cuda')
model = model.to(device)

推論

In [12]:
import torch.nn.functional as F
import time
import sys
def update(change):
    global blocked_slider, robot
    x = change['new'] 
    x = preprocess(x)
    y = model(x)
    
    # we apply the `softmax` function to normalize the output vector so it sums to 1 (which makes it a probability distribution)
    y = F.softmax(y, dim=1)
    
    one_blocked = float(y.flatten()[0])
    two_blocked = float(y.flatten()[1])
    three_blocked = float(y.flatten()[2])
    sys.stdout.write("\rone=%f,two=%f, three=%f" % (one_blocked, two_blocked, three_blocked))
    sys.stdout.flush()
   
    
    time.sleep(0.001)
        
update({'new': image})  # we call the function once to intialize

one=0.031782,two=0.754780, three=0.213438

In [13]:
camera.observe(update, names='value')  # this attaches the 'update' function to the 'value' traitlet of our camera

one=0.036304,two=0.846246, three=0.117450