## Ensemble

In [1]:
# python native
import os
import json
import random
import datetime
from functools import partial

# external library
import cv2
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from sklearn.model_selection import GroupKFold
import albumentations as A

# torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models

# visualization
import matplotlib.pyplot as plt

In [2]:
def decode_rle_to_mask(rle, height, width):
    s = rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(height * width, dtype=np.uint8)
    
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    
    return img.reshape(height, width)

In [3]:
# mask map으로 나오는 인퍼런스 결과를 RLE로 인코딩 합니다.

def encode_mask_to_rle(mask):
    '''
    mask: numpy array binary mask 
    1 - mask 
    0 - background
    Returns encoded run length 
    '''
    pixels = mask.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

In [6]:
import pandas as pd

# ensemble_base = pd.read_csv("/opt/ml/mmsegmentation/ensemble_sample.csv")
data_a = pd.read_csv("/opt/ml/input/inference/317_deeplabv3plus_r152_dice_bright_resize1024_output.csv")
data_b = pd.read_csv("/opt/ml/input/inference/344_unet2plus_r152_Adam_dicefocal_bright_1e-3_CosineAnnealingLR_resized1024_output.csv")
data_c = pd.read_csv("/opt/ml/input/inference/TTAclahe_319_deeplabv3plus_r101_Adam_dice_clahe1_None_resize1024_output.csv")


In [7]:
for i in tqdm(range(8700)):
	a_rle = data_a.iloc[i]['rle']
	b_rle = data_b.iloc[i]['rle']
	c_rle = data_c.iloc[i]['rle']

	if type(a_rle) == str and type(b_rle) == str and type(c_rle) == str:
		a_mask = decode_rle_to_mask(a_rle, 2048, 2048)
		b_mask = decode_rle_to_mask(b_rle, 2048, 2048)
		c_mask = decode_rle_to_mask(c_rle, 2048, 2048)
		mask_sum = a_mask + b_mask + c_mask
	else:
		print('-' * 10)
		print(type(a_rle))
		print(type(b_rle))
		print(type(c_rle))
		print('-' * 10)
		continue
	
	en = np.where(mask_sum >= 2, 1, 0)
	res = encode_mask_to_rle(en)
	data_a.iloc[i]['rle'] = res

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=8700.0), HTML(value='')))




In [8]:
data_a.head(30)

Unnamed: 0,image_name,class,rle
0,image1661319116107.png,finger-1,1814994 7 1817040 14 1819086 20 1821132 24 182...
1,image1661319116107.png,finger-2,2079293 3 2081336 12 2083382 17 2085429 19 208...
2,image1661319116107.png,finger-3,2536091 9 2538135 15 2540180 19 2542226 22 254...
3,image1661319116107.png,finger-4,826069 6 828114 14 830160 18 832206 22 834253 ...
4,image1661319116107.png,finger-5,1082113 4 1084157 9 1086201 14 1088247 17 1090...
5,image1661319116107.png,finger-6,1409821 12 1411865 18 1413911 20 1415957 24 14...
6,image1661319116107.png,finger-7,2016059 13 2018102 24 2020148 29 2022194 33 20...
7,image1661319116107.png,finger-8,637913 9 639957 17 642003 20 644050 22 646097 ...
8,image1661319116107.png,finger-9,916411 13 918456 22 920503 28 920557 9 922550 ...
9,image1661319116107.png,finger-10,1319910 8 1321952 22 1323960 16 1323996 29 132...


In [9]:
data_a.tail(30)

Unnamed: 0,image_name,class,rle
8670,image1667354405140.png,Ulna,3220582 1 3222628 8 3224675 11 3226723 12 3228...
8671,image1667354424553.png,finger-1,1584692 9 1586737 15 1588782 20 1590828 23 159...
8672,image1667354424553.png,finger-2,1861093 4 1863136 14 1865182 18 1867228 22 186...
8673,image1667354424553.png,finger-3,2268510 10 2270555 15 2272602 19 2274649 21 22...
8674,image1667354424553.png,finger-4,550144 6 552188 13 554233 18 556280 21 558327 ...
8675,image1667354424553.png,finger-5,822501 6 824548 11 826596 15 826637 8 828643 2...
8676,image1667354424553.png,finger-6,1133778 10 1135824 15 1137870 19 1139917 23 11...
8677,image1667354424553.png,finger-7,1739956 16 1742001 23 1744046 29 1746093 33 17...
8678,image1667354424553.png,finger-8,388110 10 390155 17 392202 20 394249 23 396296...
8679,image1667354424553.png,finger-9,672815 11 674859 17 676905 20 678949 25 680966...


In [10]:
data_a.to_csv("/opt/ml/input/ensemble.csv", index=False)