In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
!pip install pyunpack
!pip install patool

from pyunpack import Archive
Archive('./final_data.rar').extractall('./')

In [0]:
import pandas as pd

data = pd.read_csv('./final_data/train_data.csv')
answers = pd.read_csv('./final_data/train_answers.csv')
merged = data.merge(answers, on='itemId')
merged.to_csv('./final_data/merged.csv')

# Обучение модели
Задание гиперпараметров и названия модели

In [0]:
run_name ='MRCTest3'
data_file = './final_data/merged.csv'
save_to = './drive/My Drive/MailRuContest/'
val_rate = 0.2
mask_size = 256
num_channels = 32
num_epochs = 5
batch_size = 4
delta = 1.1

In [0]:
import tensorflow as tf
import keras
import numpy as np
import model as md
import generator_v2 as gn
import os
import random
from matplotlib import pyplot as plt
import matplotlib.patches as patches
from skimage import io
from skimage import measure
from skimage.transform import resize
from keras.models import load_model

Объявление функций для парсинга данных, лосса, метрики и уменьшения скорости обучения с каждой эпохой

In [0]:
def get_data_from_merged_file(data_file):
    gtruth = {}
    usr_coords = {}
    itemIds = []
    with open(data_file, mode='r') as f:
        print('Parsing annotation files')
        first_line = f.readline()
        
        for line in f:
            line_split = line.strip().split(',')
            (id_line,userId,itemId,Xmin,Ymin,Xmax,Ymax,Xmin_true,Ymin_true,Xmax_true,Ymax_true) = line_split
            if itemId not in usr_coords:
                usr_coords[itemId] = [{'Xmin': int(Xmin),'Ymin': int(Ymin),'Xmax': int(Xmax),'Ymax': int(Ymax)}]
                gtruth[itemId] = {'Xmin': int(Xmin_true),'Ymin': int(Ymin_true),'Xmax': int(Xmax_true),'Ymax': int(Ymax_true)}
                itemIds.append(itemId)
            else:
                usr_coords[itemId].append({'Xmin': int(Xmin),'Ymin': int(Ymin),'Xmax': int(Xmax),'Ymax': int(Ymax)})
    return gtruth, usr_coords, itemIds


def iou_loss(y_true, y_pred):
    y_true = tf.reshape(y_true, [-1])
    y_pred = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true * y_pred)
    score = (intersection + 1.) / (tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) - intersection + 1.)
    return 1 - score


def iou_bce_loss(y_true, y_pred):
    return 0.8 * keras.losses.binary_crossentropy(y_true, y_pred) + 0.2 * iou_loss(y_true, y_pred)
#       return iou_loss(y_true, y_pred)


def mean_iou(y_true, y_pred):
    y_pred = tf.round(y_pred)
    intersect = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
    union = tf.reduce_sum(y_true, axis=[1, 2, 3]) + tf.reduce_sum(y_pred, axis=[1, 2, 3])
    smooth = tf.ones(tf.shape(intersect))
    return tf.reduce_mean((intersect + smooth) / (union - intersect + smooth))


def cosine_annealing(x):
    lr = 0.001
    epochs = num_epochs
    return lr*(np.cos(np.pi*x/epochs)+1.)/2

def area(box):
    return (box[2] - box[0]) * (box[3] - box[1])

def intersection_over_union(boxes):
    assert(len(boxes) == 8)
    boxA = boxes[:4]
    boxB = boxes[4:]
    
    boxAArea = area(boxA)
    boxBArea = area(boxB)
    
    if (boxAArea == 0 or boxBArea == 0):
        return 0
        
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])

    interArea = max(0, xB - xA) * max(0, yB - yA)

    
    iou = interArea / float(boxAArea + boxBArea - interArea)
    return iou

Парсинг данных и подготовка к обучению

In [0]:
gtruth, usr_coords, itemIds = get_data_from_merged_file(data_file)

random.shuffle(itemIds)
n_valid_samples = int(len(itemIds)*val_rate)
train_itemIds = itemIds[n_valid_samples:]
valid_itemIds = itemIds[:n_valid_samples]

with open(os.path.join(save_to, run_name + '.txt'), mode ='w') as f:
    for itemId in valid_itemIds:
        f.write(itemId + '\n')

print('n train samples', len(train_itemIds))
print('n valid samples', len(valid_itemIds))

model = md.create_network(input_size=mask_size, channels=num_channels, n_blocks=16, depth=3)
keras.backend.get_session().run(tf.global_variables_initializer())
model.compile(optimizer='adam',
              loss=iou_bce_loss,
              metrics=['accuracy', mean_iou])

learning_rate = tf.keras.callbacks.LearningRateScheduler(cosine_annealing)
stop_training = tf.keras.callbacks.EarlyStopping(monitor='val_mean_iou', mode='max', patience=2, restore_best_weights=True, verbose=1)

train_gen = gn.Generator(train_itemIds, usr_coords, gtruth, delta, batch_size=batch_size, mask_size=mask_size,
                         shuffle=True, augment=True, predict=False)
valid_gen = gn.Generator(valid_itemIds, usr_coords, gtruth, delta, batch_size=batch_size, mask_size=mask_size,
                         shuffle=False, predict=False)


# model.summary()

Визуализация масок юзерских bbox'ов и true-box'ов перед обучением

In [0]:
for usr_msks, msks in valid_gen:
    f, axarr = plt.subplots(1, 4, figsize=(20,15))
    axarr = axarr.ravel()
    axidx = 0
    for usr_msk, msk in zip(usr_msks, msks):
        axarr[axidx].imshow(usr_msk[:, :, 0].T, cmap='gray')
        comp = msk[:, :, 0] > 0.5
        comp = measure.label(comp)
        predictionString = ''
        for region in measure.regionprops(comp):
            x, y, x2, y2 = region.bbox
            height = y2 - y
            width = x2 - x
            axarr[axidx].add_patch(patches.Rectangle((x,y),width,height,linewidth=2,edgecolor='b',facecolor='none'))
        axidx += 1
    plt.show()
    break

Обучение модели и сохранение результата

In [0]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    history = model.fit_generator(train_gen, validation_data=valid_gen, callbacks=[learning_rate, stop_training],
                              epochs=num_epochs)
    model.save(save_to+run_name+'.h5')
    model.save_weights(save_to+run_name+'weights_only.h5')

# Результат обучения
Подгрузка модели

In [0]:
run_name ='MRCTest3'
data_file = './final_data/merged.csv'
save_to = './drive/My Drive/MailRuContest/'
mask_size = 400
batch_size = 4
delta = 1.1

In [0]:
from keras.models import load_model

model = load_model(save_to+run_name+'.h5', custom_objects={'iou_bce_loss': iou_bce_loss, 'mean_iou': mean_iou})

gtruth, usr_coords, itemIds = get_data_from_merged_file(data_file)

valid_itemIds =[]

with open(os.path.join(save_to, run_name + '.txt'), mode ='r') as f:
    for line in f:
        line = line.split('\n')
        valid_itemIds.append(line[0])
    
print(len(valid_itemIds))
    
valid_gen = gn.Generator(valid_itemIds, usr_coords, gtruth, delta, batch_size=batch_size, mask_size=mask_size,
                         shuffle=False, predict=False)

Parsing annotation files
94


Визуализация батча масок юзерских bbox'ов с true-box'ами и предсказаниями

In [0]:
for usr_msks, msks in valid_gen:
    preds = model.predict(usr_msks)
    f, axarr = plt.subplots(2, 4, figsize=(20,15))
    axarr = axarr.ravel()
    axidx = 0
    for usr_msk, msk, pred in zip(usr_msks, msks, preds):
        axarr[axidx].imshow(usr_msk[:, :, 0].T, cmap='gray')
        comp = msk[:, :, 0] > 0.5
        comp = measure.label(comp)
        for region in measure.regionprops(comp):
            x, y, x2, y2 = region.bbox
            height = y2 - y
            width = x2 - x
            axarr[axidx].add_patch(patches.Rectangle((x,y),width,height,linewidth=2,edgecolor='b',facecolor='none'))
        axarr[axidx+1].imshow(pred[:, :, 0].T, cmap='gray')
        comp = pred[:, :, 0] > 0.8
        comp = measure.label(comp)
        for region in measure.regionprops(comp):
            x, y, x2, y2 = region.bbox
            height = y2 - y
            width = x2 - x
            axarr[axidx+1].add_patch(patches.Rectangle((x,y),width,height,linewidth=2,edgecolor='r',facecolor='none'))
            axarr[axidx].add_patch(patches.Rectangle((x,y),width,height,linewidth=2,edgecolor='r',facecolor='none'))
        axidx += 2
    plt.show()
    break

Точный подбор трешхолда с лучшим mean_iou у валид-сета



In [0]:
bottom = 0.5
top = 0.7
step = 0.1

trh = bottom
max_mean_iou = 0
best_trh = 0
while trh< top:
    ious = 0
    i=0
    for usr_msks, msks in valid_gen:
        preds = model.predict(usr_msks)        
        for msk, pred in zip(msks, preds):
            pred = pred[:, :, 0] > trh
            pred = measure.label(pred)
            
            max_sq = 0
            max_pred_box = [0,0,0,0]
            for region in measure.regionprops(pred):
                x, y, x2, y2 = region.bbox
                box = [x, y, x2, y2]
                if area(box) > max_sq:
                    max_sq = area(box)
                    max_pred_box = box
                    
            msk = msk[:,:,0] > 0.5
            msk = measure.label(msk)
            max_sq = 0
            max_msk_box = [0,0,0,0]
            for region in measure.regionprops(msk):
                x, y, x2, y2 = region.bbox
                box = [x, y, x2, y2]
                if area(box) > max_sq:
                    max_sq = area(box)
                    max_msk_box = box
            iou = intersection_over_union(max_pred_box + max_msk_box)
            ious+=iou
            i+=1
    mean_iou = ious/i
    
    if mean_iou>max_mean_iou:
        max_mean_iou = mean_iou
        best_trh = trh
        print(mean_iou)
        print(trh)
    else:
        print('pass --> '+ str(trh))
    trh += step

print(best_trh)

# Формирование тестового сабмита
Подгрузка модели и формирование предсказаний

In [0]:
run_name ='MRCTest3'
data_file = './final_data/test_data.csv'
save_to = './drive/My Drive/MailRuContest/'
result_file ='./sample_data.csv'
mask_size = 256
batch_size = 4
delta = 1.1
trashhold = 0.82

In [0]:
def get_data_from_test_file(data_file):
    usr_coords = {}
    itemIds = []
    with open(data_file, mode='r') as f:
        print('Parsing annotation files')
        first_line = f.readline()
        
        for line in f:
            line_split = line.strip().split(',')
            (userId,itemId,Xmin,Ymin,Xmax,Ymax) = line_split
            if itemId not in usr_coords:
                usr_coords[itemId] = [{'Xmin': int(Xmin),'Ymin': int(Ymin),'Xmax': int(Xmax),'Ymax': int(Ymax)}]
                itemIds.append(itemId)
            else:
                usr_coords[itemId].append({'Xmin': int(Xmin),'Ymin': int(Ymin),'Xmax': int(Xmax),'Ymax': int(Ymax)})
    return usr_coords, itemIds

from keras.models import load_model
import preprocessor as pr

model = load_model(save_to+run_name+'.h5', custom_objects={'iou_bce_loss': iou_bce_loss, 'mean_iou': mean_iou})

usr_coords, itemIds = get_data_from_test_file(data_file)

prep = pr.Preprocessor(usr_coords, delta, mask_size)

help_info = {}
result = []
result_row =[]

for itemId in itemIds:
    msk, koef, xmin, ymin, dw, dh = prep.get_mask(itemId)
    help_info[itemId] = [msk, koef, xmin, ymin, dw, dh]
    
for i in range(0, len(itemIds), batch_size):
    batch = []
    
    if i+batch_size > len(itemIds):
        endpoint = len(itemIds)
    else:
        endpoint = i+batch_size
        
    for itemId in itemIds[i:endpoint]:
        batch.append(help_info[itemId][0])
    
    preds = model.predict(np.array(batch))
    
    j=0
    for itemId in itemIds[i:endpoint]:
        x1, y1, x2, y2 = prep.get_coords_from_mask(preds[j], help_info[itemId][1],
                                                   help_info[itemId][2], help_info[itemId][3],
                                                   help_info[itemId][4], help_info[itemId][5],
                                                   trashhold = trashhold)
        result.append([itemId, x1, y1,x2, y2])
        result_row.append([itemId, preds[j]])
        j+=1

Визуализация предсказаний

In [0]:
f, axarr = plt.subplots(4, 4, figsize=(20,15))
axidx = 0
axarr = axarr.ravel()

for i in range(0, 16):
    itemId = result_row[i][0]
    axarr[axidx].imshow(help_info[itemId][0][:, :, 0].T, cmap='gray')
    comp = result_row[i][1][:, :, 0] > trashhold
    comp = measure.label(comp)
    for region in measure.regionprops(comp):
        x, y, x2, y2 = region.bbox
        height = y2 - y
        width = x2 - x
        axarr[axidx].add_patch(patches.Rectangle((x,y),width,height,linewidth=2,edgecolor='r',facecolor='none'))
    axidx +=1

Запись тестового сабмита

In [0]:
from csv import writer
print(len(result))

with open(result_file, mode='w', encoding='utf-8') as f:
    wrtr = writer(f, delimiter=',')
    for row in result:
        wrtr.writerow(row)