In [None]:
%matplotlib inline
from glob import glob
from matplotlib import pyplot as plt
import cv2
from PIL import Image
import random
import numpy as np
from numpy.linalg import inv

In [None]:
def get_pic_list(dir):
    loc_list = glob(dir + "/*.jpg")
    return loc_list

In [None]:
def get_random_pic(img_list):
    size = len(img_list)
    index = random.randint(1, size)
    return index-1

In [None]:
# H是一个3*3矩阵
def compute_Perspective_transform(four_points, perturbed_four_points):
    H = cv2.getPerspectiveTransform(np.float32(four_points), np.float32(perturbed_four_points))
    return H

In [None]:
# 包含两张灰度图，以及四个点加四个点，可用来计算单应
# 如果原始四个点固定，那只需要存储改变后的四个点也可以
# i 表示图片是第几个方便读
def get_datum(img_dir, img_list, index, i):
    
    dir = "./DataSet/"+ img_dir
    points_file = "./DataSet/" + img_dir +  "Points.csv"

    
    img = Image.open(img_list[index-1])
    # 变灰度图，重建大小
    # 128 * 3 = 384
    img = img.convert('L').resize((400, 400))
    img = np.asarray(img)
    #选取需要旋转的点，先固定，再想办法随机
    rho          = 50
    patch_size   = 128

    begin = 130
    
    # (130, 130)
    # (258, 130)
    # (258，258)
    # (130, 258)
    top_point    = (begin,begin)
    left_point   = (patch_size+begin, begin)
    bottom_point = (patch_size+begin, patch_size+begin)
    right_point  = (begin, patch_size+begin)
    
    # 选取了四个固定点
    four_points = [top_point, left_point, bottom_point, right_point]
    
    # 随机选取四个旋转点
    perturbed_four_points = []
    for point in four_points:
        perturbed_four_points.append((point[0] + random.randint(-rho,rho), point[1]+random.randint(-rho,rho)))
    
    
    
    h = compute_Perspective_transform(four_points, perturbed_four_points)
    h_inverse = inv(h)
    warped_image = cv2.warpPerspective(img, h_inverse, (400,400))
    
    image1 = img[top_point[1]:bottom_point[1],top_point[0]:bottom_point[0]]
    image2 = warped_image[top_point[1]:bottom_point[1],top_point[0]:bottom_point[0]]

    

    cv2.imwrite(dir+str(i)+"_1.jpg", image1)
    cv2.imwrite(dir+str(i)+"_2.jpg", image2)
    four_points_str = ''
    for e in perturbed_four_points:
        for xe in e:
            four_points_str += str(xe) + ", "

    four_points_str = "\n" + str(i) + ", " + four_points_str + img_list[index-1]
    with open(points_file, "a") as f:
        f.write(four_points_str)
    

In [None]:
# trainSet 
print("train set")
dir = "./train2017"
img_list = get_pic_list(dir)
train = "train/"

percent = 0
for i in range(500000):
    index = get_random_pic(img_list)

    get_datum(train, img_list, index, i)
    if i % 25000 == 0:
        print("train " + str(percent) + "%")
        percent += 5

# valSet
print("val set")
dir = "./val2017"
img_list = get_pic_list(dir)
val = "val/"

percent = 0
for i in range(20000):
    index = get_random_pic(img_list)

    get_datum(val, img_list, index, i)
    if i % 1000 == 0:
        print("val " + str(percent) + "%")
        percent += 5

# testSet
print("test set")
dir = "./test2017"
img_list = get_pic_list(dir)
test = "test/"

percent = 0
for i in range(250000):
    index = get_random_pic(img_list)

    get_datum(test, img_list, index, i)
    if (i+1) % 12500 == 0:
        print("test " + str(percent) + "%")
        percent += 5