# Description


In [None]:
import os
import gc
import cv2
import rasterio
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import tifffile as tiff
import matplotlib.pyplot as plt
from fastai.vision.all import *
from rasterio.windows import Window
from torch.utils.data import Dataset, DataLoader
import warnings; warnings.filterwarnings("ignore")

# Data

In [None]:
# functions to convert encoding to mask and mask to encoding
def enc2mask(encs, shape):
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for m,enc in enumerate(encs):
        if isinstance(enc,np.float) and np.isnan(enc): continue
        s = enc.split()
        for i in range(len(s)//2):
            start = int(s[2*i]) - 1
            length = int(s[2*i+1])
            img[start:start+length] = 1 + m
    return img.reshape(shape).T

def mask2enc(mask, n=1):
    pixels = mask.T.flatten()
    encs = []
    for i in range(1,n+1):
        p = (pixels == i).astype(np.int8)
        if p.sum() == 0: encs.append(np.nan)
        else:
            p = np.concatenate([[0], p, [0]])
            runs = np.where(p[1:] != p[:-1])[0] + 1
            runs[1::2] -= runs[::2]
            encs.append(' '.join(str(x) for x in runs))
    return encs

#https://www.kaggle.com/bguberfain/memory-aware-rle-encoding
#with transposed mask
def rle_encode_less_memory(img):
    #the image should be transposed
    pixels = img.T.flatten()
    
    # This simplified method requires first and last pixel to be zero
    pixels[0] = 0
    pixels[-1] = 0
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
    runs[1::2] -= runs[::2]
    
    return ' '.join(str(x) for x in runs)

# Model

In [None]:
class FPN(nn.Module):
    def __init__(self, input_channels:list, output_channels:list):
        super().__init__()
        self.convs = nn.ModuleList(
            [nn.Sequential(nn.Conv2d(in_ch, out_ch*2, kernel_size=3, padding=1),
             nn.ReLU(inplace=True), nn.BatchNorm2d(out_ch*2),
             nn.Conv2d(out_ch*2, out_ch, kernel_size=3, padding=1))
            for in_ch, out_ch in zip(input_channels, output_channels)])
        
    def forward(self, xs:list, last_layer):
        hcs = [F.interpolate(c(x),scale_factor=2**(len(self.convs)-i),mode='bilinear') 
               for i,(c,x) in enumerate(zip(self.convs, xs))]
        hcs.append(last_layer)
        return torch.cat(hcs, dim=1)

class UnetBlock(Module):
    def __init__(self, up_in_c:int, x_in_c:int, nf:int=None, blur:bool=False,
                 self_attention:bool=False, **kwargs):
        super().__init__()
        self.shuf = PixelShuffle_ICNR(up_in_c, up_in_c//2, blur=blur, **kwargs)
        self.bn = nn.BatchNorm2d(x_in_c)
        ni = up_in_c//2 + x_in_c
        nf = nf if nf is not None else max(up_in_c//2,32)
        self.conv1 = ConvLayer(ni, nf, norm_type=None, **kwargs)
        self.conv2 = ConvLayer(nf, nf, norm_type=None,
            xtra=SelfAttention(nf) if self_attention else None, **kwargs)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, up_in:Tensor, left_in:Tensor) -> Tensor:
        s = left_in
        up_out = self.shuf(up_in)
        cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
        return self.conv2(self.conv1(cat_x))
        
class _ASPPModule(nn.Module):
    def __init__(self, inplanes, planes, kernel_size, padding, dilation, groups=1):
        super().__init__()
        self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
                stride=1, padding=padding, dilation=dilation, bias=False, groups=groups)
        self.bn = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU()

        self._init_weight()

    def forward(self, x):
        x = self.atrous_conv(x)
        x = self.bn(x)

        return self.relu(x)

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

class ASPP(nn.Module):
    def __init__(self, inplanes=512, mid_c=256, dilations=[6, 12, 18, 24], out_c=None):
        super().__init__()
        self.aspps = [_ASPPModule(inplanes, mid_c, 1, padding=0, dilation=1)] + \
            [_ASPPModule(inplanes, mid_c, 3, padding=d, dilation=d,groups=4) for d in dilations]
        self.aspps = nn.ModuleList(self.aspps)
        self.global_pool = nn.Sequential(nn.AdaptiveMaxPool2d((1, 1)),
                        nn.Conv2d(inplanes, mid_c, 1, stride=1, bias=False),
                        nn.BatchNorm2d(mid_c), nn.ReLU())
        out_c = out_c if out_c is not None else mid_c
        self.out_conv = nn.Sequential(nn.Conv2d(mid_c*(2+len(dilations)), out_c, 1, bias=False),
                                    nn.BatchNorm2d(out_c), nn.ReLU(inplace=True))
        self.conv1 = nn.Conv2d(mid_c*(2+len(dilations)), out_c, 1, bias=False)
        self._init_weight()

    def forward(self, x):
        x0 = self.global_pool(x)
        xs = [aspp(x) for aspp in self.aspps]
        x0 = F.interpolate(x0, size=xs[0].size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat([x0] + xs, dim=1)
        return self.out_conv(x)
    
    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

In [None]:
from torchvision.models.resnet import ResNet, Bottleneck
sys.path.append("../input/resnest-package/resnest-0.0.6b20200701/resnest")
from resnest.torch import resnest50, resnest101, resnest200, resnest269

class UneXt50(nn.Module):
    def __init__(self, stride=1, **kwargs):  # stride 확인
        super().__init__()
        #encoder

        m = resnest269(pretrained=False)
        
        self.enc0 = nn.Sequential(m.conv1, m.bn1, nn.ReLU(inplace=True))
        self.enc1 = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1),
                            m.layer1) #256
        self.enc2 = m.layer2 #512
        self.enc3 = m.layer3 #1024
        self.enc4 = m.layer4 #2048
        #aspp with customized dilatations
        self.aspp = ASPP(2048,256,out_c=512,dilations=[stride*1,stride*2,stride*3,stride*4]) # mid-channel 바뀜(256->1024)
        self.drop_aspp = nn.Dropout2d(0.5)
        #decoder
        self.dec4 = UnetBlock(512,1024,256)
        self.dec3 = UnetBlock(256,512,128)
        self.dec2 = UnetBlock(128,256,64)
        self.dec1 = UnetBlock(64,128,32) # resneSt101~
#         self.dec1 = UnetBlock(64,64,32) # resneXt, resnest50
        self.fpn = FPN([512,256,128,64],[16]*4)
        self.drop = nn.Dropout2d(0.3)
        self.final_conv = ConvLayer(32+16*4, 1, ks=1, norm_type=None, act_cls=None)
        
    def forward(self, x):
        enc0 = self.enc0(x)
        enc1 = self.enc1(enc0)
        enc2 = self.enc2(enc1)
        enc3 = self.enc3(enc2)
        enc4 = self.enc4(enc3)
        enc5 = self.aspp(enc4)
        dec3 = self.dec4(self.drop_aspp(enc5),enc3)
        dec2 = self.dec3(dec3,enc2)
        dec1 = self.dec2(dec2,enc1)
        dec0 = self.dec1(dec1,enc0)
        x = self.fpn([enc5, dec3, dec2, dec1], dec0)
        x = self.final_conv(self.drop(x))
        x = F.interpolate(x,scale_factor=2,mode='bilinear')
        return x

In [None]:
class UneXt50_(nn.Module):
    def __init__(self, stride=1, **kwargs):  # stride 확인
        super().__init__()
        #encoder
        
        m = resnest269(pretrained=False)
        
        self.enc0 = nn.Sequential(m.conv1, m.bn1, nn.ReLU(inplace=True))
        self.enc1 = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1),
                            m.layer1) #256
        self.enc2 = m.layer2 #512
        self.enc3 = m.layer3 #1024
        self.enc4 = m.layer4 #2048
        #aspp with customized dilatations
        self.aspp = ASPP(2048,1024,out_c=512,dilations=[stride*1,stride*2,stride*3,stride*4]) # mid-channel 바뀜(256->1024)
        self.drop_aspp = nn.Dropout2d(0.5)
        #decoder
        self.dec4 = UnetBlock(512,1024,256)
        self.dec3 = UnetBlock(256,512,128)
        self.dec2 = UnetBlock(128,256,64)
        self.dec1 = UnetBlock(64,128,32) # resneSt101~
#         self.dec1 = UnetBlock(64,64,32) # resneXt, resnest50
        self.fpn = FPN([512,256,128,64],[16]*4)
        self.drop = nn.Dropout2d(0.3)
        self.final_conv = ConvLayer(32+16*4, 1, ks=1, norm_type=None, act_cls=None)
        
    def forward(self, x):
        enc0 = self.enc0(x)
        enc1 = self.enc1(enc0)
        enc2 = self.enc2(enc1)
        enc3 = self.enc3(enc2)
        enc4 = self.enc4(enc3)
        enc5 = self.aspp(enc4)
        dec3 = self.dec4(self.drop_aspp(enc5),enc3)
        dec2 = self.dec3(dec3,enc2)
        dec1 = self.dec2(dec2,enc1)
        dec0 = self.dec1(dec1,enc0)
        x = self.fpn([enc5, dec3, dec2, dec1], dec0)
        x = self.final_conv(self.drop(x))
        x = F.interpolate(x,scale_factor=2,mode='bilinear')
        return x

# 앙상블 코드

In [None]:
bs = 2
# TH = 0.3, 예측때 설정으로 변경

# mode = 'w'
# mode = 's'
mode = 'e'

DATA = '../input/hubmap-organ-segmentation/test_images/'

# MODELS = glob.glob("../input/ensemble/*.pth")\
# +glob.glob("../input/resnest101-256-4-multi-scale/*.pth")

# slide

MODELS_128_8 = ['../input/last256/ok128-8.pth']
MODELS_128_12 = ['../input/last256/ok128-12.pth']

MODELS_256_3 = ['../input/last256/ok256-3.pth']
MODELS_256_4 = ['../input/last1024/ok256-4.pth']
MODELS_256_6 = ['../input/last1024/ok256-6.pth',
                '../input/last1024/ok256-6s.pth']

MODELS_384_4 = ['../input/last256/ok384-4.pth']

MODELS_512_3 = ['../input/last256/ok512-3.pth']

MODELS_768_2 = ['../input/last256/ok768-2.pth']

# whole
MODELS_128_24 = ['../input/last256/ok128-24.pth']

MODELS_256_12 = ['../input/last1024/ok256-12.pth',
                 '../input/last1024/ok256-12s.pth']
MODELS_256_18 = ['../input/last256/ok256-18.pth']           

MODELS_384_8 = ['../input/last256/ok384-8.pth']
MODELS_384_12 = ['../input/last256/ok384-12.pth']

MODELS_512_6 = ['../input/last256/ok512-6.pth']
MODELS_512_9 = ['../input/last256/ok512-9.pth']

MODELS_640_5 = ['../input/last256/ok640-5.pth']

MODELS_768_4 = ['../input/last256/ok768-4.pth']

# hanbin whole

MODELS_256_bin = ['../input/lastbin/res256_ch1024.pth']
MODELS_256_bin_stain = ['../input/lastbin/res256_ch1024_stain.pth']

MODELS_384_bin = ['../input/lastbin/res384_ch1024.pth']
MODELS_384_bin_stain = ['../input/lastbin/res384_ch1024_stain_mdice0.80.pth']

MODELS_512_bin = ['../input/lastbin/res512_ch1024.pth']
MODELS_512_bin_stain = ['../input/lastbin/res512_ch1024_stain.pth']

MODELS_640_bin = ['../input/lastbin/res640_ch1024.pth']
MODELS_640_bin_stain = ['../input/lastbin/res640_ch1024_stain.pth']

MODELS_768_bin = ['../input/lastbin/res768_ch1024.pth']

MODELS_1024_bin = ['../input/lastbin/res1024_ch1024.pth']


# df_sample = pd.read_csv('../input/hubmap-organ-segmentation/sample_submission.csv')
df_sample = pd.read_csv('../input/hubmap-organ-segmentation/test.csv')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# jiseong
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

# hanbin whole
mean256 = np.array([0.73785705, 0.70808914, 0.72905603])
std256 = np.array([0.26174852, 0.27280234, 0.27136351])

mean384 = np.array([0.67643098, 0.64846221, 0.66760599])
std384 = np.array([0.31982929, 0.3231894, 0.32578941])

mean512 = np.array([0.63561364, 0.60888698, 0.62712141])
std512 = np.array([0.34821859, 0.34816096, 0.35253718])

mean640 = np.array([0.70691496, 0.67844708, 0.6984898])
std640 = np.array([0.29862373, 0.30497219, 0.30555107])

mean768 = np.array([0.6710477,  0.64848008, 0.66450543])
std768 = np.array([0.34888044, 0.35108358, 0.35293783])

mean1024 = np.array([0.69127252, 0.66775246, 0.68414704])
std1024 = np.array([0.33476045, 0.33858518, 0.33940973])


# hanbin slide

mean384s = np.array([0.71735743, 0.68700507, 0.70827026])
std384s = np.array([0.27728809, 0.28566926, 0.28593607])


In [None]:
# ratio에 따른 구분, 전체 한번에 예측
import math

identity = rasterio.Affine(1, 0, 0, 0, 1, 0)

def img2tensor(img,dtype:np.dtype=np.float32):
    if img.ndim==2 : img = np.expand_dims(img,2)
    img = np.transpose(img,(2,0,1))
    return torch.from_numpy(img.astype(dtype, copy=False))

class HuBMAPDataset_w(Dataset):
    def __init__(self, idx, sz=256, mean = mean, std = std):
        self.data = rasterio.open(os.path.join(DATA,idx+'.tiff'), transform = identity,
                                 num_threads='all_cpus')
        # some images have issues with their format 
        # and must be saved correctly before reading with rasterio
        if self.data.count != 3:
            subdatasets = self.data.subdatasets
            self.layers = []
            if len(subdatasets) > 0:
                for i, subdataset in enumerate(subdatasets, 0):
                    self.layers.append(rasterio.open(subdataset))
        self.shape = self.data.shape
        if self.shape[0] < sz:
            self.ratio = 1
        else:
            self.ratio = math.ceil(self.shape[0] / sz)
        # 이미지 전체를 sz 한장으로 만들어서 예측(패딩이 들어감)
            
        self.reduce = self.ratio
        self.sz = self.reduce*sz
        self.pad0 = (self.sz - self.shape[0]%self.sz)%self.sz
        self.pad1 = (self.sz - self.shape[1]%self.sz)%self.sz
        self.n0max = (self.shape[0] + self.pad0)//self.sz
        self.n1max = (self.shape[1] + self.pad1)//self.sz
        
    def __len__(self):
        return self.n0max*self.n1max
    
    def __getitem__(self, idx):

        n0,n1 = idx//self.n1max, idx%self.n1max

        x0,y0 = -self.pad0//2 + n0*self.sz, -self.pad1//2 + n1*self.sz

        p00,p01 = max(0,x0), min(x0+self.sz,self.shape[0])
        p10,p11 = max(0,y0), min(y0+self.sz,self.shape[1])
        img = np.zeros((self.sz,self.sz,3),np.uint8)
        
        if self.data.count == 3:
            img[(p00-x0):(p01-x0),(p10-y0):(p11-y0)] = np.moveaxis(self.data.read([1,2,3],
                window=Window.from_slices((p00,p01),(p10,p11))), 0, -1)
        else:
            for i,layer in enumerate(self.layers):
                img[(p00-x0):(p01-x0),(p10-y0):(p11-y0),i] =\
                  layer.read(1,window=Window.from_slices((p00,p01),(p10,p11)))
        
        if self.reduce != 1:
            img = cv2.resize(img,(self.sz//self.reduce,self.sz//self.reduce),
                             interpolation = cv2.INTER_AREA)

        return img2tensor((img/255.0 - mean)/std), idx

In [None]:
class HuBMAPDataset_s(Dataset):
    def __init__(self, idx, sz=256, reduce=4, stride=128):
        self.data = rasterio.open(os.path.join(DATA,idx+'.tiff'), transform = identity,
                                 num_threads='all_cpus')
        # some images have issues with their format 
        # and must be saved correctly before reading with rasterio
        if self.data.count != 3:
            subdatasets = self.data.subdatasets
            self.layers = []
            if len(subdatasets) > 0:
                for i, subdataset in enumerate(subdatasets, 0):
                    self.layers.append(rasterio.open(subdataset))
        self.shape = self.data.shape     
#         if self.shape[0] <= reduce*sz:
        if self.shape[0] <= (reduce*sz)//2: # 좀 더 세분화
            self.reduce = 1
        else:
            self.reduce = reduce
        self.sz = self.reduce*sz
        self.stride = self.reduce*stride

        self.pad0 = abs(self.shape[0] - (max(((self.shape[0]-self.sz-1)//self.stride)+1, 0)*self.stride + self.sz))
        self.pad1 = abs(self.shape[0] - (max(((self.shape[1]-self.sz-1)//self.stride)+1, 0)*self.stride + self.sz))

        self.n0max = (self.shape[0] + self.pad0 - self.sz)//self.stride + 1
        self.n1max = (self.shape[1] + self.pad1 - self.sz)//self.stride + 1
        
        
    def __len__(self):
        return self.n0max*self.n1max
    
    def __getitem__(self, idx):

        n0,n1 = idx//self.n1max, idx%self.n1max

#         x0,y0 = -self.pad0//2 + n0*self.sz, -self.pad1//2 + n1*self.sz
        x0,y0 = -self.pad0//2 + n0*self.stride, -self.pad1//2 + n1*self.stride
        # make sure that the region to read is within the image
        p00,p01 = max(0,x0), min(x0+self.sz,self.shape[0])
        p10,p11 = max(0,y0), min(y0+self.sz,self.shape[1])
        img = np.zeros((self.sz,self.sz,3),np.uint8)
        # mapping the loade region to the tile
        if self.data.count == 3:
            img[(p00-x0):(p01-x0),(p10-y0):(p11-y0)] = np.moveaxis(self.data.read([1,2,3],
                window=Window.from_slices((p00,p01),(p10,p11))), 0, -1)
        else:
            for i,layer in enumerate(self.layers):
                img[(p00-x0):(p01-x0),(p10-y0):(p11-y0),i] =\
                  layer.read(1,window=Window.from_slices((p00,p01),(p10,p11)))
        
        if self.reduce != 1:
            img = cv2.resize(img,(self.sz//self.reduce,self.sz//self.reduce),
                             interpolation = cv2.INTER_AREA)

        return img2tensor((img/255.0 - mean)/std), idx

In [None]:
#iterator like wrapper that returns predicted masks
class Model_pred:
    def __init__(self, models, dl, reduce, tta:bool=True, half:bool=False):
        self.models = models
        self.dl = dl
        self.tta = tta
        self.half = half
        self.reduce = reduce
        
    def __iter__(self):
        count=0
        with torch.no_grad():
            for x,y in iter(self.dl):
                if ((y>=0).sum() > 0): #exclude empty images
                    x = x[y>=0].to(device)
                    y = y[y>=0]
                    if self.half: x = x.half()
                    py = None
                    for model in self.models:
                        p = model(x)
                        p = torch.sigmoid(p).detach()
                        if py is None: py = p
                        else: py += p
                    if self.tta:
                        #x,y,xy flips as TTA
                        flips = [[-1],[-2],[-2,-1]]
                        for f in flips:
                            xf = torch.flip(x,f)
                            for model in self.models:
                                p = model(xf)
                                p = torch.flip(p,f)
                                py += torch.sigmoid(p).detach()
                        py /= (1+len(flips))        
                    py /= len(self.models)

                    py = F.upsample(py, scale_factor=self.reduce, mode="bilinear")
                    py = py.permute(0,2,3,1).float().cpu()
                    
                    batch_size = len(py)
                    for i in range(batch_size):
                        yield py[i],y[i]
                        count += 1
                    
    def __len__(self):
        return len(self.dl.dataset)

In [None]:
def create_models(paths, model_name):
    models = []
    for path in paths:
        state_dict = torch.load(path,map_location=torch.device('cpu'))
        model = model_name()
        model.load_state_dict(state_dict)
        model.float()
        model.eval()
        model.to(device)
        models.append(model)
    return models
    

In [None]:
if mode == 'w' or mode == 'e':
    if MODELS_256_12:
        models_256_12 = create_models(MODELS_256_12, UneXt50_)
    if MODELS_256_18:
        models_256_18 = create_models(MODELS_256_18, UneXt50)
    if MODELS_768_4:
        models_768_4 = create_models(MODELS_768_4, UneXt50)
    if MODELS_128_24:
        models_128_24 = create_models(MODELS_128_24, UneXt50)
    if MODELS_512_6:
        models_512_6 = create_models(MODELS_512_6, UneXt50)
    if MODELS_512_9:
        models_512_9 = create_models(MODELS_512_9, UneXt50)
    if MODELS_640_5:
        models_640_5 = create_models(MODELS_640_5, UneXt50)
    if MODELS_384_8:
        models_384_8 = create_models(MODELS_384_8, UneXt50)
    if MODELS_384_12:
        models_384_12 = create_models(MODELS_384_12, UneXt50)
        
    if MODELS_256_bin:
        models_256_bin = create_models(MODELS_256_bin, UneXt50_)
    if MODELS_256_bin:
        models_256_bin_stain = create_models(MODELS_256_bin_stain, UneXt50_)
    if MODELS_384_bin:
        models_384_bin = create_models(MODELS_384_bin, UneXt50_)
#     if MODELS_384_bin:
#         models_384_bin_stain = create_models(MODELS_384_bin_stain, UneXt50_)
#     if MODELS_512_bin:
#         models_512_bin = create_models(MODELS_512_bin, UneXt50_)
    if MODELS_512_bin:
        models_512_bin = create_models(MODELS_512_bin_stain, UneXt50_)
#     if MODELS_640_bin:
#         models_640_bin = create_models(MODELS_640_bin, UneXt50_)
    if MODELS_640_bin:
        models_640_bin_stain = create_models(MODELS_640_bin_stain, UneXt50_)
#     if MODELS_768_bin:
#         models_768_bin = create_models(MODELS_768_bin, UneXt50_)
#     if MODELS_1024_bin:
#         models_1024_bin = create_models(MODELS_1024_bin, UneXt50_)


if mode == 's' or mode == 'e':
    if MODELS_128_8:
        models_128_8 = create_models(MODELS_128_8, UneXt50)
    if MODELS_128_12:
        models_128_12 = create_models(MODELS_128_12, UneXt50)
    if MODELS_256_3:
        models_256_3 = create_models(MODELS_256_3, UneXt50)    
    if MODELS_256_6:
        models_256_6 = create_models(MODELS_256_6, UneXt50_)
    if MODELS_256_4:
        models_256_4 = create_models(MODELS_256_4, UneXt50_)
    if MODELS_384_4:
        models_384_4 = create_models(MODELS_384_4, UneXt50)
    if MODELS_512_3:
        models_512_3 = create_models(MODELS_512_3, UneXt50)
    if MODELS_768_2:
        models_768_2 = create_models(MODELS_768_2, UneXt50)    
    

In [None]:
def create_mask(idx, bs, sz, reduce, stride, models):
    ds = HuBMAPDataset_s(idx, sz, reduce, stride)
    dl = DataLoader(ds,bs,num_workers=0,shuffle=False,pin_memory=True)
    mp = Model_pred(models,dl, ds.reduce)

    mask = torch.zeros(ds.shape[0]+ds.pad0,ds.shape[1]+ds.pad1,dtype=torch.float32)
    count_map = torch.zeros(ds.shape[0]+ds.pad0,ds.shape[1]+ds.pad1,dtype=torch.float32)

    for p, i in iter(mp):
        x = i.item()//ds.n0max
        y = i.item()%ds.n0max
        x0 = x*ds.stride
        x1 = (x*ds.stride+ds.sz)
        y0 = y*ds.stride
        y1 = (y*ds.stride+ds.sz)
        mask[x0:x1, y0:y1] += p.squeeze(-1)
        count_map[x0:x1, y0:y1] += 1.

    mask = mask / count_map
    mask = mask[ds.pad0//2:-(ds.pad0-ds.pad0//2), ds.pad1//2:-(ds.pad1-ds.pad1//2)]
    return mask

In [None]:
def create_mask_w(idx, bs, sz, models, mean, std):
    ds = HuBMAPDataset_w(idx, sz, mean, std)
    dl = DataLoader(ds,bs,num_workers=0,shuffle=False,pin_memory=True)
    mp = Model_pred(models, dl, ds.reduce)
    mask = torch.zeros(len(ds),ds.sz,ds.sz,dtype=torch.float32)

    for p,i in iter(mp):
            mask[i.item()] = p.squeeze(-1)

    mask = mask.view(ds.n0max,ds.n1max,ds.sz,ds.sz).\
            permute(0,2,1,3).reshape(ds.n0max*ds.sz,ds.n1max*ds.sz)
    mask = mask[ds.pad0//2:-(ds.pad0-ds.pad0//2) if ds.pad0 > 0 else ds.n0max*ds.sz,
            ds.pad1//2:-(ds.pad1-ds.pad1//2) if ds.pad1 > 0 else ds.n1max*ds.sz]
    
    return mask

# mmseg

In [None]:
!pip install ../input/mmdetection/addict-2.4.0-py3-none-any.whl > /dev/null
!pip install ../input/mmdetection/yapf-0.31.0-py2.py3-none-any.whl > /dev/null
!pip install ../input/mmdetection/terminaltables-3.1.0-py3-none-any.whl > /dev/null

!pip install ../input/mmdetection/einops-0.4.1-py3-none-any.whl
!pip install ../input/mmsegmentation/mmcv-full/mmcv_full-1.5.3-cp37-cp37m-linux_x86_64.whl > /dev/null
!pip install ../input/openmmlab-essential-repositories/openmmlab-repos/src/mmcls-0.23.1-py2.py3-none-any.whl > /dev/null

!cp -r ../input/mmsegm/mmsegmentation-master /kaggle/working/ && cd /kaggle/working/mmsegmentation-master && pip install -e . && cd ..

import numpy as np
import pandas as pd
import os
from glob import glob
from tqdm.notebook import tqdm
import sys
import gc
sys.path.append('/kaggle/working/mmsegmentation-master')
from mmseg.apis import init_segmentor, inference_segmentor
from mmcv.utils import config

configs = [
    '../input/hanbinadd2/multiscale4.py',
]

ckpts = [
     '../input/hanbin100/f4_epoch_70.pth',    
]




models = []
for idx,(cfg, ckpt) in enumerate(zip(configs, ckpts)):
    cfg = config.Config.fromfile(cfg)
    model = init_segmentor(cfg, ckpt, device='cuda:0')
    models.append(model)
    
    
DATA = '../input/hubmap-organ-segmentation/test_images/'
df_sample = pd.read_csv('../input/hubmap-organ-segmentation/test.csv')

# inference

In [None]:
# slide로 마스크의 확률값을 얻고 th 지정해
# 전체 예측 -1
# 슬라이드 예측 -2
# mode = 'w'
# mode = 's'
# mode = 'e'

names,preds = [],[]
names_w, preds_w = [], []
names_s, preds_s = [], []
debug = False
if len(df_sample) == 1:
    debug = True
for idx,row in tqdm(df_sample.iterrows(),total=len(df_sample)):
    
    img  = cv2.imread(os.path.join(DATA,str(idx)+'.tiff'))    
    
    pred_w = []
    pred_s = []
    pred_e = []
    
    TH = 0.30 #여기서 바꿔 줘야함
    idx = str(row['id'])
    organ = str(row['organ'])
    width = row['img_height']

    for model in models :
        predict = inference_segmentor(model, img)[0]
        predict_f = predict.flatten()
        cnt = Counter(predict_f)

        if cnt.most_common(1)[0][0] != 0 :
            max_num = cnt.most_common(1)[0][0]
            predict = (predict == max_num).astype(np.uint8)
            pred_w.append(predict)
            pred_s.append(predict)
            pred_e.append(predict)
        else :
            if len(cnt.most_common()) > 1 :
                max_num = cnt.most_common(2)[1][0]
                predict = (predict == max_num).astype(np.uint8)
                pred_w.append(predict)
                pred_s.append(predict)
                pred_e.append(predict)
            
            else : 
                pred_w.append(predict)
                pred_s.append(predict)   
                pred_e.append(predict)



    if mode == 'w' or mode == 'e':
        wcnt = 2

        mask = create_mask_w(idx, bs=bs, sz=256, models=models_256_12, mean=mean, std=std)
        mask = create_mask_w(idx, bs=bs, sz=256, models=models_256_bin_stain, mean=mean256, std=std256)

        if width <= 256:
            mask += create_mask_w(idx, bs=bs, sz=128, models=models_128_24, mean=mean, std=std)
            mask += create_mask_w(idx, bs=bs, sz=256, models=models_256_18, mean=mean, std=std)

            mask += create_mask_w(idx, bs=bs, sz=256, models=models_256_bin, mean=mean256, std=std256)

            wcnt += 3

        if width > 256 and width <= 1536:
            mask += create_mask_w(idx, bs=bs, sz=384, models=models_384_8, mean=mean, std=std)
            mask += create_mask_w(idx, bs=bs, sz=512, models=models_512_6, mean=mean, std=std)

            mask += create_mask_w(idx, bs=bs, sz=384, models=models_384_bin, mean=mean384, std=std384)
            mask += create_mask_w(idx, bs=bs, sz=640, models=models_640_bin_stain, mean=mean640, std=std640)            

            wcnt += 4

        if width >= 1536:
            mask += create_mask_w(idx, bs=bs, sz=640, models=models_640_5, mean=mean, std=std)
            mask += create_mask_w(idx, bs=bs, sz=768, models=models_768_4, mean=mean, std=std)

            mask += create_mask_w(idx, bs=bs, sz=640, models=models_640_bin_stain, mean=mean640, std=std640)
            wcnt += 3

        mask_w = mask / wcnt
        pred_w.append(mask_w)
        pred_w = np.ceil((sum(pred_w) / len(pred_w))).astype(np.uint8)

        rle_w = rle_encode_less_memory((pred_w>TH).numpy().astype(np.uint8))

        if debug:
            names_w.append(idx)
            preds_w.append(rle_w)

        # e 가중치
        mask_w *= 1.2
    # slide 에측
    scnt = 3
    if mode == 's' or mode == 'e':
        mask = create_mask(idx, bs=bs, sz=256, reduce=4, stride=128, models=models_256_4)
        mask += create_mask(idx, bs=bs, sz=256, reduce=6, stride=128, models=models_256_6)
        mask += create_mask(idx, bs=bs, sz=256, reduce=3, stride=128, models=models_256_3)

        if width < 256:
            mask += create_mask(idx, bs=bs, sz=128, reduce=8, stride=64, models=models_128_8) # 64 오래걸림
            mask += create_mask(idx, bs=bs, sz=128, reduce=12, stride=64, models=models_128_12)
            scnt += 2
        if width >= 256 and width < 1536:
            mask += create_mask(idx, bs=bs, sz=384, reduce=4, stride=192, models=models_384_4)
            scnt += 1
        if width >= 1536:
            mask += create_mask(idx, bs=bs, sz=512, reduce=3, stride=256, models=models_512_3) #256이 좋은듯?
            mask += create_mask(idx, bs=bs, sz=768, reduce=2, stride=384, models=models_768_2)
            scnt += 2

        mask_s = mask / scnt
        pred_s.append(mask_s)
        pred_s = np.ceil((sum(pred_s) / len(pred_s))).astype(np.uint8)

        rle_s = rle_encode_less_memory((pred_s>TH).numpy().astype(np.uint8))
        if debug:
            names_s.append(idx)
            preds_s.append(rle_s)

        # s 가중치
        mask_s *= 0.8

    if mode == 'e':
        mask_e = (mask_w + mask_s) / 2
        pred_e.append(mask_e)

        pred_e = np.ceil((sum(pred_e) / len(pred_e))).astype(np.uint8)#########       

        rle = rle_encode_less_memory(pred_e)
        del mask


    ## 하나로 submit 하고 싶을 때
    if mode == 'w':
        rle = rle_w
    if mode == 's':
        rle = rle_s

    names.append(idx)
    preds.append(rle)


    gc.collect()

In [None]:
df = pd.DataFrame({'id':names,'rle':preds})
df.to_csv('submission.csv',index=False)
df

In [None]:
def rle2mask(mask_rle, shape=(1600,256)):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (width,height) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape).T

img_1 = tiff.imread('../input/hubmap-organ-segmentation/test_images/10078.tiff')
mask_1 = rle2mask(df["rle"][0], (img_1.shape[1], img_1.shape[0]))
if mode == 'w' or mode == 'e':
    mask_w = rle2mask(preds_w[0], (img_1.shape[1], img_1.shape[0]))
if mode == 's' or mode == 'e':
    mask_s = rle2mask(preds_s[0], (img_1.shape[1], img_1.shape[0]))

plt.figure(figsize=(15,15))
plt.subplot(1,4,1)
plt.imshow(img_1)


if mode == 'w' or mode == 'e':
    plt.subplot(1,4,2)
    plt.title('whole')
    plt.imshow(img_1)
    plt.imshow(mask_w, cmap='coolwarm', alpha=0.5)
    plt.axis("off")


if mode == 's' or mode == 'e':
    plt.subplot(1,4,3)
    plt.title('slide')
    plt.imshow(img_1)
    plt.imshow(mask_s, cmap='coolwarm', alpha=0.5)
    plt.axis("off")



plt.subplot(1,4,4)
plt.title('ensemble')
plt.imshow(img_1)
plt.imshow(mask_1, cmap='coolwarm', alpha=0.5)
plt.axis("off")

    

