In [1]:
import torch
from torch import nn
from torchvision import transforms as trns
from torchvision.transforms import ToTensor
from PIL import Image, ImageDraw
import numpy as np
from numpy import random
import os.path as osp
import import_ipynb
from model import resnet18
from config import checkpoint_folder, label_list
from dataset import create_datasets

importing Jupyter notebook from model.ipynb
importing Jupyter notebook from config.ipynb
importing Jupyter notebook from dataset.ipynb


In [2]:
def expand(img, background=(128, 128, 128)):
    topil = trns.ToPILImage()
    totensor = trns.ToTensor()
    # 輸入的img是按NCHW刑事排列的Type(Tensor),需轉換型態
    img = np.array(topil(img)).astype(np.uint8)
    # 隨機生成貼上位置
    height, width, depth = img.shape
    ratio = random.uniform(1, 2)
    # 左邊界位置
    left = random.uniform(0.3 * width, width * ratio - width)
    # 右邊界位置
    top = random.uniform(0.3 * width, width * ratio - width)
    
    while int(left + width) > int(width * ratio) or int(top + height) > int(height * ratio):
        ratio = random.uniform(1, 2)
        left = random.uniform(0.3 * width, width * ratio - width)
        top = random.uniform(0.3 * width, width * ratio - width)
    
    # Create white background
    expand_img = np.zeros(
        (int(height * ratio), int(width * ratio),depth), dtype=img.dtype
    )
    
    # 將背景填充成灰色
    expand_img[:,:,:] = background
    expand_img[
        int(top) : int(top + height), int(left) : int(left + width)
    ] = img
    
    return img

In [9]:
if __name__ == '__main__':
    # Load model
    net = resnet18()
    net.linear = nn.Linear(in_features=512, out_features=4, bias=True)
    net.eval()
    totensor = ToTensor()
    net.load_state_dict(torch.load(osp.join(checkpoint_folder, "reg.pth")))
    
    img_path = "./img/plane.jpg"
    img = Image.open(img_path)
    img = np.array(img)
    expand_img = expand(img)
    height, width = expand_img.shape[:2]
    
    # Coordinate dencryption
    inp = totensor(Image.fromarray((expand_img)).resize((32, 32))).unsqueeze(0)
    out = net(inp)
    xmin, ymin, xmax, ymax = out.view(-1)
    xmin, ymin, xmax, ymax =(
        xmin * width,
        ymin * height,
        xmax * width,
        ymax * height,
    )
    
    # draw predict img
    expand_img = Image.fromarray(expand_img)
    draw = ImageDraw.Draw(expand_img)
    draw.rectangle([(xmin, ymin), (xmax, ymax)], outline=(0, 255, 0), width=10)
    expand_img.show()