In [1]:
from MyDataset import TGS_Dataset
from Model import Res34Unetv4
import os
import torch
from torch.autograd import Variable
import pandas as pd
import numpy as np
from Evaluation import do_length_decode, do_length_encode
import cv2
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
print('load finish')

load finish


In [2]:
TEST_PATH = './Data/test'
test_dataset = TGS_Dataset(TEST_PATH)
print('finish')

finish


In [3]:
def average_fold_prediction(path_list, H=101, W=101, fill_value=255, threshold=0.5):
    '''load rle from df, average them and return a new rle'''
    folds = []
    # decode
    for p in path_list:
        df = pd.read_csv(p)
        im = []
        for i in range(len(df)):
            im.append(do_length_decode(str(df.rle_mask.iloc[i]), H, W, fill_value))
        folds.append(im)
    # average
    avg = np.mean(folds, axis=0)
    avg = avg > threshold
    # encode
    rle = []
    for i in range(len(avg)):
        rle.append(do_length_encode(avg[i]))
    # create sub
    df = pd.DataFrame(dict(id=df.id, rle_mask=rle))
    
    return df

def tta_transform(images, mode):
    out = []
    if mode == 'out':
        images = images[0]
    images = images.transpose((0, 2, 3, 1))
    tta = []
    for i in range(len(images)):
        t = np.fliplr(images[i])
        tta.append(t)
    tta = np.transpose(tta, (0, 3, 1, 2))
    out.append(tta)
    return np.asarray(out)

def load_net_and_predict(net, test_path, load_paths, batch_size=32, tta_transform=None, threshold=0.5, min_size=0):
    test_dataset = TGS_Dataset(test_path)
    test_loader, test_ids = test_dataset.selfdefine_dataloader(data='test', num_workers=11, batch_size=batch_size)
    # predict
    for i in range(len(load_paths)):
        net.load_state_dict(torch.load(load_paths[i]))
        p = net.predict(test_loader, threshold=threshold, tta_transform=tta_transform, return_rle=False)
        if not i:
            avg = np.zeros_like(p['pred'])
        avg = (i * avg + p['pred']) / (i + 1)
    avg = avg > threshold
    # encode
    rle = []
    for i in range(len(avg)):
        rle.append(do_length_encode(avg[i]))
    # create sub
    df = pd.DataFrame(dict(id=p['id'], rle_mask=rle))
    
    return df


In [4]:
if __name__ == '__main__':
    TEST_PATH = './Data/test'
    net = Res34Unetv4()
    NET_NAME = type(net).__name__
    THRESHOLD = 0.5
    MIN_SIZE = 0
    BATCH_SIZE = 32
    
    LOAD_PATHS = [
        './Saves/Res34Unetv4_two_rounds_training/Fold0_Epoch58_Val0.851',
        './Saves/Res34Unetv4_two_rounds_training/Fold1_Epoch7_Val0.843',
        './Saves/Res34Unetv4_two_rounds_training/Fold2_Epoch47_Val0.849',
        './Saves/Res34Unetv4_two_rounds_training/Fold3_Epoch57_Val0.874',
        './Saves/Res34Unetv4_two_rounds_training/Fold4_Epoch44_Val0.851'
    ]
    
    df = load_net_and_predict(net, TEST_PATH, LOAD_PATHS, tta_transform=tta_transform, batch_size=BATCH_SIZE, threshold=THRESHOLD)
    df.to_csv(os.path.join('./Saves',NET_NAME+'_two_rounds_training','{}_5foldAvg.csv'.format(NET_NAME)),
        index=False)
    print('finish')

finish
