In [1]:
import cv2
import numpy as np
import os
import math
from tqdm import trange
import random
import seaborn as sns
import multiprocessing as mp
import csv
import matplotlib.pyplot as plt
from scipy import stats
import shutil
from skimage.filters import gaussian
import skimage

## 处理训练集

#### 1. 从图片中抠出“抖音号：xxxxxxxx”

In [2]:
# 处理过曝光的像素，超过255还原成255
def set_max(array):
    for i in range(len(array)):
        for j in range(len(array[0])):
            if array[i,j] > 255:
                array[i,j] = 255
    return array

In [3]:
# 改成自适应的垂直投影
def getVProjection(image, threshold=255, flag=0):
    if threshold >= 255:
        threshold = 250
    vProjection = np.zeros(image.shape,np.uint8);
    #图像高与宽
    (h,w) = image.shape
    #长度与图像宽度一致的数组
    w_ = [0]*w
    #循环统计每一列白色像素的个数
    for x in range(w):
        for y in range(h):
            if image[y,x] > threshold:
                w_[x]+=1
    #绘制垂直平投影图像
    for x in range(w):
        for y in range(h-w_[x],h):
            vProjection[y,x] = 255
    if flag == 1:
        cv2.imshow('vProjection',vProjection)
    return w_

In [4]:
# 垂直投影右侧边界加两个像素，从右向左遍历找到第一个连续超过2的区间
def get_right_bound(split_w):
    flag = 0
    length = len(split_w)
    bound = length - 1
    for i in range(length):
        # 已经有一个了
        if flag == 1:
            if split_w[length - i - 1] != 0:
                bound = length - i + 3
                break
            else:
                flag = 0
        else:
            if split_w[length - i - 1] != 0:
                flag = 1
    return bound

In [5]:
# 训练集用第200张做模板  (v0200fg10000cb9gnojc77u62n91odlg_2_.jpg)  使用cv2.matchTemplate做模板匹配
path = "D:/dataset_all/bytedance/ocr_data_split/data/train_random_8w/train_set_random"
file_name_list = os.listdir(path)
img = cv2.imread(path + '/' + file_name_list[200], cv2.IMREAD_GRAYSCALE)

template = img[-60:-45,-187:-135]

In [6]:
path = "D:/dataset_all/bytedance/ocr_data_split/data/train_random_8w/train_set_random"
file_name_list = os.listdir(path)

for k in trange(len(file_name_list)):
    
    # 保存文件用，请自行去掉
    if k == 0:
        break
        
    img = cv2.imread(path + '/' + file_name_list[k], cv2.IMREAD_GRAYSCALE)

    # 处理竖屏图片
    if len(img)/len(img[0]) > 1.5:
        img_high = cv2.resize(img, [400, 720], interpolation=cv2.INTER_CUBIC)
        img_high_restoration = skimage.restoration.denoise_tv_chambolle(img_high, weight=0.0001)
        img_bottom = (img_high_restoration[-70:] * 255.0).astype(np.uint8)
        img_top = (img_high_restoration[:70] * 255.0).astype(np.uint8)

    # 处理横屏图片
    elif len(img[0])/len(img) > 1.5:
        img_high = cv2.resize(img, [720, 400], interpolation=cv2.INTER_CUBIC)
        img_high_restoration = skimage.restoration.denoise_tv_chambolle(img_high, weight=0.0001)
        img_bottom = (img_high_restoration[-50:] * 255.0).astype(np.uint8)
        img_top = (img_high_restoration[:50] * 255.0).astype(np.uint8)

    # 处理方形图片
    else:
        img_high = cv2.resize(img, [600, 600], interpolation=cv2.INTER_CUBIC)
        img_high_restoration = skimage.restoration.denoise_tv_chambolle(img_high, weight=0.0001)
        img_bottom = (img_high_restoration[-70:] * 255.0).astype(np.uint8)
        img_top = (img_high_restoration[:70] * 255.0).astype(np.uint8)


    # 增加曝光试试，让所有的像素都变白
    img_top = set_max(img_top * 1.1).astype(np.uint8)
    img_bottom = set_max(img_bottom * 1.1).astype(np.uint8)

    # 模板匹配
    h, w = template.shape[:2]

    # 处理top部分
    ret = cv2.matchTemplate(img_top, template, cv2.TM_CCOEFF_NORMED)
    min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(ret)

    # 处理bottom部分
    ret_b = cv2.matchTemplate(img_bottom, template, cv2.TM_CCOEFF_NORMED)
    min_val_b, max_val_b, min_loc_b, max_loc_b = cv2.minMaxLoc(ret_b)
    
    # 判断模板匹配结果是在top还是在bottom
    if max_val > max_val_b:
        draw_img = img_top.copy()
        ret = cv2.rectangle(draw_img, max_loc, (max_loc[0]+w, max_loc[1]+h), (80, 80, 80), 1)
    else:
        draw_img = img_bottom.copy()
        ret = cv2.rectangle(draw_img, max_loc_b, (max_loc_b[0]+w, max_loc_b[1]+h), (80, 80, 80), 1)

    # 取得倒数第二步结果    
    if max_val > max_val_b:
        img_next = img_top[max_loc[1]:max_loc[1] + h, max_loc[0]:]
    else:
        img_next = img_bottom[max_loc_b[1]:max_loc_b[1] + h, max_loc_b[0]:]

    # 自适应垂直投影确定右侧边界
    split_w = getVProjection(img_next)
    img_final = img_next[:,:get_right_bound(split_w)]
    
    # 保存图片
    cv2.imwrite("D:/dataset_all/bytedance/ocr_data_split/data/train_random_8w/train_set_final/" + file_name_list[k], img_final)

  0%|                                                                                                                                                                             | 0/78699 [00:00<?, ?it/s]


#### 2. 将切得不好的图片从训练集中删除 8w变成6w  并修改图片名称为该图片的标签

In [7]:
# 宽高比大于16或小于8的都删除     8w条删除了2w条
filepath = "D:/dataset_all/bytedance/ocr_data_split/data/dataset_train_random.csv"

path = "D:/dataset_all/bytedance/ocr_data_split/data/train_random_8w/train_set_final"
file_name_list = os.listdir(path)

with open(filepath, 'r') as csvfile:
    reader = csv.reader(csvfile)
    data = [row for row in reader]

for i in trange(len(file_name_list)):
    
    # 保存文件用，请自行去掉
    if i == 0:
        break
    
    img = cv2.imread(path + '/' + file_name_list[i], cv2.IMREAD_GRAYSCALE)
    h = len(img)
    w = len(img[0])
    if 8 < w/h < 16:
        for j in range(len(data)):
            if data[j][0] == file_name_list[i]:
                os.rename(path + '/' + file_name_list[i], path + '/' + data[j][1] + '.jpg')
                break
    else:
        os.unlink(path + '/' + file_name_list[i])

  0%|                                                                                                                                                                             | 0/58028 [00:00<?, ?it/s]


#### 3. 将训练数据分为训练集和测试集

In [8]:
path = './data_total'
file_path_list = os.listdir(path)
file_list_sub = random.sample(file_path_list, int(len(file_path_list)*0.3))
for i in range(len(file_list_sub)):
    shutil.move(path+'/'+file_list_sub[i], './data_val/'+file_list_sub[i])

FileNotFoundError: [WinError 3] 系统找不到指定的路径。: './data_total'

#### 4. 训练集和验证集分开处理
<br>
---训练集包括原始训练集图片 + 缩放1/2再放大成原始尺寸的图片 + 增加高斯模糊的图片（共3倍于原始训练集）
<br>
---验证集只包括缩放1/2再放大成原始尺寸的图片  
<br>
这样做是因为缩放1/2再放大成原始尺寸的图片质量最接近测试集图片，以此类样本为验证集，并在模型在验证集取得最佳acc时保存模型，可以更好的用在测试集上

In [9]:
# 处理训练集
path = "./data_train"
file_name_list = os.listdir(path)

for i in trange(len(file_name_list)):
    img = cv2.imread(path + '/' + file_name_list[i], cv2.IMREAD_GRAYSCALE)
    h = len(img)
    w = len(img[0])

    img_low = cv2.resize(img, [int(0.5*w), int(0.5*h)])
    img_low_high = cv2.resize(img_low, [200, 24])

    img = cv2.resize(img, [200,24])
    img_gaussian = cv2.GaussianBlur(img, [3,3], 0)

    cv2.imwrite('./data_train/' + file_name_list[i].split('.')[0] + '_1.jpg', img_low_high)
    cv2.imwrite('./data_train/' + file_name_list[i].split('.')[0] + '_2.jpg', img_gaussian)

FileNotFoundError: [WinError 3] 系统找不到指定的路径。: './data_train'

In [10]:
# 处理验证集
path = "./data_train"
file_name_list = os.listdir(path)

for i in trange(len(file_name_list)):
    img = cv2.imread(path + '/' + file_name_list[i], cv2.IMREAD_GRAYSCALE)
    h = len(img)
    w = len(img[0])

    img_low = cv2.resize(img, [int(0.5*w), int(0.5*h)])
    img_low_high = cv2.resize(img_low, [200, 24])

    img = cv2.resize(img, [200,24])
    img_gaussian = cv2.GaussianBlur(img, [3,3], 0)

    cv2.imwrite('./data_train/' + file_name_list[i].split('.')[0] + '_1.jpg', img_low_high)
    cv2.imwrite('./data_train/' + file_name_list[i].split('.')[0] + '_2.jpg', img_gaussian)

FileNotFoundError: [WinError 3] 系统找不到指定的路径。: './data_train'

## 处理测试集

#### 1. 从图中抠出“抖音号：xxxxxxxx”

In [11]:
# 测试集用的第203张图片做的模板 (v0200fg10000cb9kkfjc77u9h5cho240_1_.jpg)
path = "D:/dataset_all/bytedance/ocr_data_split/data/test_random_2w/test_set_random"
file_name_list = os.listdir(path)
img = cv2.imread(path + '/' + file_name_list[203], cv2.IMREAD_GRAYSCALE)
img_high = cv2.resize(img, [400, 720], interpolation=cv2.INTER_CUBIC)
img_high_restoration = skimage.restoration.denoise_tv_chambolle(img_high, weight=0.0001)
img_bottom = (img_high_restoration[-70:] * 255.0).astype(np.uint8)

template = img_bottom[-33:-15,8:65]

In [12]:
# 处理测试集的时候忘记是用的原始垂直投影还是自适应垂直投影了，两个中acc差5%左右
def getVProjection_initial(image,flag=0):
    vProjection = np.zeros(image.shape,np.uint8);
    #图像高与宽
    (h,w) = image.shape
    #长度与图像宽度一致的数组
    w_ = [0]*w
    #循环统计每一列白色像素的个数
    for x in range(w):
        for y in range(h):
            if image[y,x] == 255:
                w_[x]+=1
    #绘制垂直平投影图像
    for x in range(w):
        for y in range(h-w_[x],h):
            vProjection[y,x] = 255
    if flag == 1:
        cv2.imshow('vProjection',vProjection)
    return w_

In [13]:
path = "D:/dataset_all/bytedance/ocr_data_split/data/test_random_2w/test_set_random"
file_name_list = os.listdir(path)

for k in trange(len(file_name_list)):
    
        # 保存文件用，请自行去掉
    if k == 0:
        break
    
    img = cv2.imread(path + '/' + file_name_list[k], cv2.IMREAD_GRAYSCALE)

    # 先只处理竖屏图片
    if len(img)/len(img[0]) > 1.5:
        img_high = cv2.resize(img, [400, 720], interpolation=cv2.INTER_CUBIC)
        img_high_restoration = skimage.restoration.denoise_tv_chambolle(img_high, weight=0.0001)
        img_bottom = (img_high_restoration[-70:] * 255.0).astype(np.uint8)
        img_top = (img_high_restoration[:70] * 255.0).astype(np.uint8)

    # 处理横屏图片
    elif len(img[0])/len(img) > 1.5:
        img_high = cv2.resize(img, [720, 400], interpolation=cv2.INTER_CUBIC)
        img_high_restoration = skimage.restoration.denoise_tv_chambolle(img_high, weight=0.0001)
        img_bottom = (img_high_restoration[-50:] * 255.0).astype(np.uint8)
        img_top = (img_high_restoration[:50] * 255.0).astype(np.uint8)

    # 处理方形图片
    else:
        img_high = cv2.resize(img, [600, 600], interpolation=cv2.INTER_CUBIC)
        img_high_restoration = skimage.restoration.denoise_tv_chambolle(img_high, weight=0.0001)
        img_bottom = (img_high_restoration[-70:] * 255.0).astype(np.uint8)
        img_top = (img_high_restoration[:70] * 255.0).astype(np.uint8)


    # 增加曝光试试，让所有的像素都变白
    img_top = set_max(img_top * 1.1).astype(np.uint8)
    img_bottom = set_max(img_bottom * 1.1).astype(np.uint8)

    # 模板匹配
    h, w = template.shape[:2]

    # 处理top部分
    ret = cv2.matchTemplate(img_top, template, cv2.TM_CCOEFF_NORMED)
    min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(ret)

    # 处理bottom部分
    ret_b = cv2.matchTemplate(img_bottom, template, cv2.TM_CCOEFF_NORMED)
    min_val_b, max_val_b, min_loc_b, max_loc_b = cv2.minMaxLoc(ret_b)

    if max_val > max_val_b:
        draw_img = img_top.copy()
        ret = cv2.rectangle(draw_img, max_loc, (max_loc[0]+w, max_loc[1]+h), (80, 80, 80), 1)
    else:
        draw_img = img_bottom.copy()
        ret = cv2.rectangle(draw_img, max_loc_b, (max_loc_b[0]+w, max_loc_b[1]+h), (80, 80, 80), 1)

    # 取得倒数第二步结果    
    if max_val > max_val_b:
        img_next = img_top[max_loc[1]:max_loc[1] + h, max_loc[0]:]
    else:
        img_next = img_bottom[max_loc_b[1]:max_loc_b[1] + h, max_loc_b[0]:]

    # 垂直投影
    split_w = getVProjection(img_next)
    img_final = img_next[:,:get_right_bound(split_w)]
    
    # 保存图片
    cv2.imwrite("D:/dataset_all/bytedance/ocr_data_split/data/test_random_2w/test_set_final/" + file_name_list[k], img_final)

  0%|                                                                                                                                                                             | 0/19675 [00:00<?, ?it/s]


## 最后的测试集预测结果写入result.csv中
### 其中测试集的gt.txt和result_test.txt由文字识别模块生成

In [14]:
# 结果对应
import csv
import numpy as np
from tqdm import trange
path1 = "D:/dataset_all/bytedance/ocr_data_split/data/test_random_2w/result_test.txt"
path2 = "D:/dataset_all/bytedance/ocr_data_split/data/test_random_2w/gt.txt"
with open(path1, 'r') as file:
    data = [row for row in file]
label = []
pred = []
for i in range(len(data)):
    label.append(data[i].split('\t')[0])
    tmp = data[i].split('\t')[1].split('\n')[0]
    if len(tmp) > 4:
        pred.append(tmp[4:])
    else:
        pred.append(tmp)

img = []
label2 = []
with open(path2, 'r') as file:
    data = [row for row in file]
for i in range(len(data)):
    img.append(data[i].split('\t')[0].split('/')[1])
    label2.append(data[i].split('\t')[1].split('\n')[0])
    

file_data = []
file = "D:/dataset_all/bytedance/ocr_data_split/data/submit_sample.csv"
with open(file,'r') as csvfile:
    reader = csv.reader(csvfile)
    for row in reader:
        file_data.append(row)
file_data = np.array(file_data)
file_data = file_data[1:]

for i in trange(len(file_data)):
    for j in range(len(img)):
        if file_data[i,0] == img[j]:
            for k in range(len(label)):
                if label2[j] == label[k]:
                    file_data[i,1] = pred[k]
                    
with open('./result.csv','w',newline='') as csvfile:
    writer = csv.writer(csvfile)
    for row in file_data:
        writer.writerow(row)

  1%|█▉                                                                                                                                                                 | 241/19675 [00:02<03:52, 83.43it/s]


KeyboardInterrupt: 