# 简易的从者数据库构建

## 环境

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from skimage.transform import resize

### 0. 必要函数定义

In [2]:
# 定义图像之间差别
def image_diff(im_1, im_2, mask=None):
    """
    Variables
    ---------
    
    im_1, im_2: np.array
        Two images that need to be compared
    
    mask: np.array
        Mask on two images
        `Zero' or `False' value in mask means that we don't compare `im_1' and `im_2' at these pixels
    
    Note
    ----
    
    Shape of `im_1', `im_2', `mask' should be the same
    """
    if im_1.shape != im_2.shape: im_1 - im_2  # raise error here
    if mask is None:
        mask = np.zeros(im_1.shape) + 1
    else:
        mask = (mask>0)
    return (np.abs((im_1 - im_2) * mask)).sum() / mask.sum()

In [3]:
# 寻找一条线上，低于阈值的部分的局域极小值在列表中的位置
def find_lmin(line, thresh=None):
    """
    Variables
    ---------
    
    line: np.array
        the line we want to find the local minimun
    
    thresh: float
        we just account the area that the line under than thresh
    
    Note
    ----
    
    All values in `line' should be larger than zero
    as well as `line' should be float and be acceptable for np.nan
    """
    if thresh is None: thresh = np.max(line)
    lmin = []
    line_mask = np.array(line<thresh, dtype=float)
    for i in range(len(line_mask)):
        if line_mask[i] == 0.: line_mask[i] = np.nan
    line_under_thresh = line * line_mask
    for i in range(1,len(line_under_thresh)-1):
        if (line_under_thresh[i] <= line_under_thresh[i-1]) \
            and (line_under_thresh[i] <= line_under_thresh[i+1]):
            lmin.append(i)
    return lmin

## 1. 读取数据

这里不使用从右上角的编号读取数据的做法

从者数据都是通过截图的时候让第一排尽量靠上，用从者坐标的像素数与图片的编号确定从者编号

从者所在的横向的像素：67, 255, 442, 630, 817, 1005

从者所在的纵向像素分割线：500

截图区域：从者右上方 1 像素位置开始，截取右上方 (156\*130) 大小

In [4]:
for img_index in range(14):
    img_index_with_0 = "{:02d}".format(img_index+1)
    # 1. read box picture
    img_box_path = "database/my_box_20180715/" + img_index_with_0 + ".png"
    img_box = plt.imread(img_box_path)[:,:,:3]
    if img_box.shape[0] / img_box.shape[1] != 0.5625:
        raise ValueError("Please import figures correctly!")
    img_box = resize(img_box, (720,1280), mode="reflect")
    # 2. read servant indicator
    im_servant_indicator = []
    im_servant_indicator.append(plt.imread("database/servant_indicator/gold.png")[:,:,:3])
    im_servant_indicator.append(plt.imread("database/servant_indicator/silver.png")[:,:,:3])
    im_servant_indicator.append(plt.imread("database/servant_indicator/bronze.png")[:,:,:3])
    im_servant_indicator_mask = np.zeros(im_servant_indicator[0].shape)
    im_servant_indicator_mask[:,:45,:] += 1
    im_servant_indicator_mask[:,111:,:] += 1
    # 3. locate servant
    servant_location = []
    col_estimate = np.array([67, 255, 442, 630, 817, 1005])
    row_index = range(330,720-19)
    for i in col_estimate:
        for rarity in range(3):
            col_diff = np.array([image_diff(img_box[j:j+19, i:i+158, :], im_servant_indicator[rarity], \
                                            im_servant_indicator_mask) for j in row_index ])
            col_lmin = find_lmin(col_diff, 0.07)
            for j in col_lmin:
                servant_location.append([i, j+330])
    # 4. dump servant figure
    for loc in servant_location:
        servant_index = img_index*12 + col_estimate.searchsorted(loc[0]) + 1
        if loc[1] > 500: servant_index += 6
        plt.imsave("database/servant_database/" + "{:03d}".format(servant_index) + ".png", \
                   img_box[loc[1]-131:loc[1]-1, loc[0]+1:loc[0]+157, :])