<a href="https://colab.research.google.com/github/ShinAsakawa/ShinAsakawa.github.io/blob/master/2022notebooks/2022_1210bit_line_bisection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# BIT 線分二等分線の試作機

# 1. 準備作業

In [None]:
# このセルは 2 回実行しないといけないかも知れません
%config InlineBackend.figure_format = 'retina'
try:
    import bit
except ImportError:
    !pip install ipynbname --upgrade > /dev/null
    !git clone https://github.com/ShinAsakawa/bit.git > /dev/null
    import bit

isColab = bit.isColab
HOME = bit.HOME

if isColab:
    # 2022_0916 現在 PIL のバージョンが古く truetype フォント
    # の表示に不具合が出るためバージョン 9.2.0 以上に更新する
    !pip install --upgrade Pillow

fonts_jp = bit.get_notojp_fonts()
fonts_en = bit.get_notoen_fonts()

import torch
import PIL
print(f'PIL.__version__:{PIL.__version__}')
import os
import sys
import numpy as np
import matplotlib.pyplot as plt

## 1.1 準備作業 続き

In [None]:
try:
    import japanize_matplotlib
except ImportError:
    !pip install japanize_matplotlib
    import japanize_matplotlib

try:
    import PIL_util # if not os.path.exists('PIL_util.py'):
except ImportError:
    !wget https://ShinAsakawa.github.io/2022notebooks/PIL_util.py -O PIL_util.py
    import PIL_util # if not os.path.exists('PIL_util.py'):

try:
    import bit_utils
except ImportError:
    #if not os.path.exists('bit_utils.py'):
    !wget https://ShinAsakawa.github.io/2022notebooks/bit_utils.py -O bit_utils.py
    import bit_utils

In [None]:
import os
from glob import glob
from PIL import ImageFont

noto_font_dir = 'fonts'
notofonts_fnames = glob(os.path.join(noto_font_dir,'*otf'))
notofonts = {fname.split('/')[-1].split('.')[0]:{'fname':fname} for fname in notofonts_fnames}
for fontname in notofonts.keys():
    notofonts[fontname]['data'] = ImageFont.truetype(notofonts[fontname]['fname'])
#notofonts;
symbols = bit.BIT(fontdata=notofonts).symbols     # 文字の登録
#print(symbols);

## 1.2 準備作業 続き 2

In [None]:
import os
import sys
import shutil
import typing
import cv2
import glob
from tqdm.notebook import tqdm

if isColab:
    from PIL import ImageFont
    from glob import glob

    !pip install pycocotools --quiet
    !git clone https://github.com/pytorch/vision.git
    !git checkout v0.3.0

    # Download TorchVision repo to use some files from references/detection
    # os.symlink(src,dst) にした方が良いかも
    !cp vision/references/detection/utils.py ./
    !cp vision/references/detection/transforms.py ./
    !cp vision/references/detection/coco_eval.py ./
    !cp vision/references/detection/engine.py ./
    !cp vision/references/detection/coco_utils.py ./

    !pip install japanize_matplotlib

## 1.3 Torch ライブラリなどの準備作業

In [None]:
# ライブラリのインポート
# torchvision ライブラリ
import torch
import torchvision
from torchvision import transforms as torchtrans
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
#from torchvision.models.detection import fasterrcnn_resnet50_fpn

import os
import random
import numpy as np

import matplotlib.pyplot as plt
import matplotlib.patches as patches
try:
    import japanize_matplotlib
except ImportError:
    !pip install japanize_matplotlib
    import japanize_matplotlib


# ヘルパライブラリをインポート
from engine import train_one_epoch, evaluate
import utils
import transforms as T

## 1.4 作業モデルの変更作業

In [None]:
def get_object_detection_model(num_classes):
    """see https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html"""
    # MS-COCO で事前に学習させたモデルを読み込み
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

    # 分類器の入力特徴数の取得
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    print(f'変換前 model.roi_heads:{model.roi_heads}')

    # 事前学習済頭部を新しいものに置き換え
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    print(f'変換後 model.roi_heads:{model.roi_heads}')

    return model

num_classes = len(symbols)
#num_classes = len(bit.symbols)
bit_model = get_object_detection_model(num_classes)
print(f'num_classes:{num_classes}, bit.symbols:{symbols}')
#print(f'num_classes:{num_classes}, bit.symbols:{bit.symbols}')
#bit_model.roi_heads

# 2. 訓練済パラメータの読み込み

Google アカウントでの認証作業が必要となる



In [None]:
!pip install -U -q PyDrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

# 1. Authenticate and create the PyDrive client.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

# データの ID を入れて，データを入手
download = drive.CreateFile({'id': '1pn9VafOaSL4OCxE-oFg_t4LH9OU2xj5F'})
download = drive.CreateFile({'id': '1csAtAOsrRv0YUTp1LIMhmckX8AG45Cu6'})
#download = drive.CreateFile({'id': '1KhP4iAP_tc28EV5fyo95pKuQrNOAX-bT'})

download.GetContentFile('2022_0620fine_tuned_bit_line_bisection.cpt')

In [None]:
# 上で認証した訓練済パラメータの読み込み
pretrained_fname = '2022_0620fine_tuned_bit_line_bisection.cpt'
bit_model.load_state_dict(torch.load(pretrained_fname)['model'])

In [None]:
from bit_utils import torch_to_pil
from bit_utils import plot_img_bbox
from bit_utils import apply_nms

import matplotlib.pyplot as plt
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
_bit = bit.BIT(fontdata=notofonts)
images, bboxes = _bit.make_line_bisection_task_images(N=5, n_lines=3)
img = images[4]

In [None]:
# plt.figure(figsize=(8,8))
# plt.axis('off')
# plt.imshow(img)

In [None]:
import PIL

def draw_center_mark(img_pt:torch.Tensor=None,
                     prediction:dict=None,
                     check_mark_offset:int=6,
                     check_mark_width:int=4,
                     check_mark_color:tuple=(0,255,0),
                     title=None,
                     img:PIL.Image=None,
                    ):

    if img == None:
        img = torch_to_pil(img_pt)
    _draw = PIL.ImageDraw.Draw(img)
    #_draw = ImageDraw.Draw(img)

    boxes = prediction['boxes']
    for box in boxes:
        left, top, right, bottom = box.clone().numpy()
        #print(left,top,right,bottom)
        h_center = int((right - left)/2 + left)
        v_center = int((bottom - top)/2 + top)

        x0 = h_center - check_mark_offset
        y0 = v_center - check_mark_offset
        x1 = h_center + check_mark_offset
        y1 = v_center + check_mark_offset
        _draw.line(xy=[(x0,y0),(x1,y1)], fill=check_mark_color, width=check_mark_width, joint=None)

    return img, _draw

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

N=25
images, bboxes = _bit.make_line_bisection_task_images(N=N, n_lines=3)

num = np.random.choice(N)
img = images[num]
print(bboxes)

img_rgb = cv2.cvtColor(np.array(img), cv2.COLOR_BGR2RGB).astype(np.float32)
img_res = cv2.resize(img_rgb, (224, 224), cv2.INTER_AREA)
img_res /= 255.0
img_pt = torch.Tensor(img_res).permute(2,0,1)

bit_model.eval()
with torch.no_grad():
    pred = bit_model([img_pt.to(device)])[0]
pred

In [None]:
nms_prediction = apply_nms(pred, iou_thresh=0.01)
print(nms_prediction)

In [None]:
#nms_prediction
plot_img_bbox(img_pt.numpy().transpose(1,2,0),
              nms_prediction)

In [None]:
import copy

def deviate_prediction(prediction:dict=None,
                       factor:float=0.2):
    """素朴な偏位モデルの実装"""
    _prediction = copy.deepcopy(prediction)
    _boxes = _prediction['boxes'].clone()
    for i, _box in enumerate(_boxes):
        left, top, right, bottom = _box
        _len = right - left
        _len *= factor
        _left = left + _len
        _box[0] = _left
        _boxes[i][0] = _left

    _prediction['boxes'] = _boxes
    return _prediction

def deviate_prediction_left(prediction:dict=None,
                            factor:float=0.2):
    """素朴な偏位モデルの実装"""
    _prediction = copy.deepcopy(prediction)
    _boxes = _prediction['boxes'].clone()
    for i, _box in enumerate(_boxes):
        left, top, right, bottom = _box
        _len = right - left
        _len *= factor
        _right = left - _len
        _box[2] = _right
        _boxes[i][2] = _right

    _prediction['boxes'] = _boxes
    return _prediction

factor = 0.4
_prediction = deviate_prediction(pred, factor=factor)
print(f'prediction:{pred["boxes"]}')
print(f'_prediction:{_prediction["boxes"]}')

_img, _draw = draw_center_mark(img_pt=img_pt, prediction=nms_prediction, check_mark_color='green') #, img=img)
#_img, _draw = draw_center_mark(img_pt=img_pt, prediction=_prediction, check_mark_color='green', img=img)
plt.figure(figsize=(5,5))
plt.title(f'素朴 偏位モデル: 偏位因子:{factor}')
plt.imshow(_img)
plt.show()

In [None]:
#img, img_pt, pred = make_a_prediction(n_lines=2)

#factor = 0.4
#img, _ = draw_center_mark(img_pt=img_pt, img=img, prediction=pred, check_mark_color='blue', check_mark_width=4)
#img, _draw = draw_center_mark(img_pt=img, prediction=pred, check_mark_color='green', img=img)
#_pred = deviate_prediction(pred, factor=factor)
#plt.imshow(_img)
#pred, __img, draw = make_a_stim_then_predict(n_lines=2, isDraw=False, verbose=False)
#plt.imshow(__img)
#print(f'pred:{pred}')


In [None]:
#images, bboxes = _bit.make_line_bisection_task_images(N=10, n_lines=3)
#prediction, _img, draw = make_a_stim_then_predict(isDraw=False, verbose=False)

factor = 0.4
_img, _draw = draw_center_mark(img_pt=img_pt, prediction=nms_prediction, check_mark_color='green') #, img=images[-1])
_prediction = deviate_prediction(nms_prediction, factor=factor)
#print(f'prediction:{prediction["boxes"]}')
#print(f'_prediction:{_prediction["boxes"]}')

_img, _draw = draw_center_mark(img_pt=img_pt, prediction=_prediction, check_mark_color='green') #, img=img)
plt.figure(figsize=(5,5))
plt.title(f'素朴 偏位モデル: 偏位因子:{factor}')
plt.imshow(_img)
plt.show()

In [None]:
#images, bboxes = _bit.make_line_bisection_task_images(N=10, n_lines=3)
#prediction, _img, draw = make_a_stim_then_predict(isDraw=False, verbose=False)

factor = 0.4
_img, _draw = draw_center_mark(img_pt=img_pt, prediction=nms_prediction, check_mark_color='green') #, img=images[-1])
_prediction = deviate_prediction(nms_prediction, factor=factor)
#print(f'prediction:{prediction["boxes"]}')
#print(f'_prediction:{_prediction["boxes"]}')

_img, _draw = draw_center_mark(img_pt=img_pt, prediction=_prediction, check_mark_color='green') #, img=img)
plt.figure(figsize=(5,5))
plt.title(f'素朴 偏位モデル: 偏位因子:{factor}')
plt.imshow(_img)
plt.show()

In [None]:
import copy

def deviate_prediction(prediction:dict=None,
                       factor:float=0.2):
    """素朴な偏位モデルの実装"""
    _prediction = copy.deepcopy(prediction)
    _boxes = _prediction['boxes'].clone()
    for i, _box in enumerate(_boxes):
        left, top, right, bottom = _box
        _len = right - left
        _len *= factor
        _left = left - _len
        _box[0] = _left
        _boxes[i][0] = _left

    _prediction['boxes'] = _boxes
    return _prediction

def deviate_prediction_left(prediction:dict=None,
                            factor:float=0.2):
    """素朴な偏位モデルの実装"""
    _prediction = copy.deepcopy(prediction)
    _boxes = _prediction['boxes'].clone()
    for i, _box in enumerate(_boxes):
        left, top, right, bottom = _box
        _len = right - left
        _len *= factor
        _right = left - _len
        _box[2] = _right
        _boxes[i][2] = _right

    _prediction['boxes'] = _boxes
    return _prediction

factor = 0.4
_prediction = deviate_prediction(pred, factor=factor)
print(f'prediction:{pred["boxes"]}')
print(f'_prediction:{_prediction["boxes"]}')

_img, _draw = draw_center_mark(img_pt=img_pt, prediction=_prediction, check_mark_color='green', img=img)
#_img, _draw = draw_center_mark(img_pt=img_pt, prediction=_prediction, check_mark_color='green', img=img)
plt.figure(figsize=(5,5))
plt.title(f'素朴 偏位モデル: 偏位因子:{factor}')
plt.imshow(_img)
plt.show()

In [None]:
images, bboxes = _bit.make_line_bisection_task_images(N=5, n_lines=3)
img = images[4]
plt.imshow(img)

In [None]:
factor = 0.4
_prediction = deviate_prediction(pred, factor=factor)
print(f'prediction:{pred["boxes"]}')
#print(f'_prediction:{_pred["boxes"]}')

_img, _draw = draw_center_mark(img_pt=img, prediction=pred, check_mark_color='green', img=img)
#_img, _draw = draw_center_mark(img_pt=img_pt, prediction=_pred, check_mark_color='green', img=img)
plt.figure(figsize=(7,7))
plt.title(f'素朴 偏位モデル: 偏位因子:{factor}')
plt.imshow(_img)
plt.show()

_img, _draw = draw_center_mark(img_pt=img, prediction=_prediction, check_mark_color='green', img=img)
#_img, _draw = draw_center_mark(img_pt=img_pt, prediction=_prediction, check_mark_color='green', img=img)
plt.figure(figsize=(7,7))
plt.title(f'素朴 偏位モデル: 偏位因子:{factor}')
plt.imshow(_img)
plt.show()