In [None]:
import sys
import os
import glob
import cv2
import random
import numpy as np
import pandas as pd
import torch
from torch import nn
import torch.cuda.amp as amp
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import gc
import json

from matplotlib import pyplot as plt
from tqdm.notebook import tqdm

sys.path.append('/kaggle/input/timm-clone/pytorch-image-models/')
device = "cuda" if torch.cuda.is_available() else 'cpu'

import timm

In [None]:
!pip install -q /kaggle/input/tlvmc-src/transformers-4.29.2-py3-none-any.whl

In [None]:

#-- configure ---------------------------------------------

organ_threshold = {
    'Hubmap': {
        'kidney'        : 0.40,
        'prostate'      : 0.40,
        'largeintestine': 0.40,
        'spleen'        : 0.40,
        'lung'          : 0.10,
    },
    'HPA': {
        'kidney'        : 0.50,
        'prostate'      : 0.50,
        'largeintestine': 0.50,
        'spleen'        : 0.50,
        'lung'          : 0.10,
    },
}

args = {
    'batch_size' : 1,
    'image_size' : 768
}

submit_type = 'submission'#'cv'

## Model

In [None]:
# class UNetDecoder(nn.Module):
#     def __init__(self, dim):
#         super().__init__()
#         self.upsample = nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners=True)
#         self.block1 = nn.Sequential(
#             nn.Conv2d(in_channels = dim, out_channels = dim // 2, kernel_size = 3, padding = "same"),
#             nn.ReLU(),
#             nn.Conv2d(in_channels = dim // 2 , out_channels = dim // 2, kernel_size = 3, padding = "same"),
#             nn.ReLU()
#         )
#         self.block2 = nn.Sequential(
#             nn.Conv2d(in_channels = dim // 2, out_channels= dim // 4, kernel_size = 3, padding = "same"),
#             nn.ReLU(),
#             nn.Conv2d(in_channels = dim // 4, out_channels= dim // 4, kernel_size = 3, padding = "same"),
#             nn.ReLU()
#         )
#         self.block3 = nn.Sequential(
#             nn.Conv2d(in_channels = dim // 4, out_channels = dim // 8, kernel_size = 3, padding = "same"),
#             nn.ReLU(),
#             nn.Conv2d(in_channels = dim // 8, out_channels= dim // 8, kernel_size = 3, padding = "same"),
#             nn.ReLU()
#         )
#         self.last_conv = nn.Conv2d(in_channels = dim // 8, out_channels = 1, kernel_size = 1)
#     def forward(self, x):
#     #TODO Skip Connection
#         x = self.upsample(x)
#         x = self.block1(x)
#         x = self.upsample(x)
#         x = self.block2(x)
#         x = self.upsample(x)
#         x = self.block3(x)
#         x = self.last_conv(x)
#         #x = F.interpolate(x, size=(720,720))
#         return x

In [None]:
# class Net(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.encoder = timm.create_model('tf_efficientnet_b6',
#                                          pretrained = False, 
#                                          num_classes = 0,   
#                                          global_pool = '')  
        
#         dim = self.encoder.conv_head.out_channels # effnet_b4 = 1792
#         self.decoder = UNetDecoder(dim = dim)

#     def forward(self, batch):
#         x = self.encoder(batch['image'])
#         logit = self.decoder(x)
        
#         out = {}
        
#         if self.training :
#             out['bce_loss'] = F.binary_cross_entropy_with_logits(input=logit, target = batch['mask'])

#         else :
#             #out['bce_loss'] = F.binary_cross_entropy_with_logits(input=logit, target = batch['mask'])
#             out['probability'] = torch.sigmoid(logit)
        
#         return out

In [None]:
class Net(nn.Module):
    def __init__(self, model, image_size):
        super().__init__()
        self.model = model
        self.image_size = image_size
    def forward(self, batch):
        logit = self.model(batch['image']).logits
        logit = F.interpolate(logit, (self.image_size, self.image_size))
        out = {}
        
        if self.training :
            out['bce_loss'] = F.binary_cross_entropy_with_logits(input=logit, target = batch['mask'])

        else :
            #out['bce_loss'] = F.binary_cross_entropy_with_logits(input=logit, target = batch['mask'])
            out['probability'] = torch.sigmoid(logit)
        
        return out

In [None]:
if submit_type == 'cv':
    valid_file = '../input/hubmap-organ-segmentation/train.csv'

if submit_type == 'submission':
    valid_file = '../input/hubmap-organ-segmentation/test.csv'

valid_df = pd.read_csv(valid_file)
valid_df  = valid_df.sort_values('id')
valid_df.head()

In [None]:
def do_local_validation():
    submit_df = pd.read_csv('./submission.csv').fillna('')
    submit_df = submit_df.sort_values('id')
    truth_df  = valid_df.sort_values('id')
    
    lb_score = []
    num = len(submit_df)
    for i in range(num):
        t_df = truth_df.iloc[i]
        p_df = submit_df.iloc[i]
        t = rle_decode(t_df.rle, t_df.img_height, t_df.img_width, 1)
        p = rle_decode(p_df.rle, t_df.img_height, t_df.img_width, 1)
        
        dice = 2*(t*p).sum()/(p.sum()+t.sum())
        lb_score.append(dice)
        
    truth_df.loc[:,'lb_score']=lb_score
    for organ in ['all', 'kidney', 'prostate', 'largeintestine', 'spleen', 'lung']:
        if organ != 'all':
            d = truth_df[truth_df.organ == organ]
        else:
            d = truth_df
        print('\t%f\t%s\t%f' % (len(d) / len(truth_df), organ, d.lb_score.mean()))

## Dataset

In [None]:
class segDataset(Dataset):   
    def __init__(self, df, augment, submit_type):
        self.df = df              
        self.augment = augment    
        self.submit_type = submit_type

    def __len__(self): 
        return len(self.df)

    def __getitem__(self, index): 
        d = self.df.iloc[index]
        id = d['id']
        height = d['img_height']
        width = d['img_width']
        if self.submit_type == 'cv':
            tiff_dir   = '../input/hubmap-organ-segmentation/train_images'
            
        if self.submit_type == 'submission':
            tiff_dir   = '../input/hubmap-organ-segmentation/test_images'
         
        image = cv2.imread(f'{tiff_dir}/{id}.tiff') 
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, dsize=(args['image_size'], args['image_size']))
        image = image / 255.  

        out = {}
        out['id']    = d['id']
        out['image'] = torch.tensor(image).permute(2,0,1).float() # h, w, c -> c, h, w
        out['data_source'] = d['data_source']
        out['organ']  = d['organ']
        out['height'] = d['img_height']
        out['width']  = d['img_width']
        return out

In [None]:
valid_ds = segDataset(df = valid_df, augment = None, submit_type = submit_type)

valid_dl = DataLoader(valid_ds,
                batch_size = args['batch_size'],
                shuffle = False,
                pin_memory = True,
                drop_last = False
                )

## Model

In [None]:
from transformers import SegformerForSemanticSegmentation

In [None]:
_model = SegformerForSemanticSegmentation.from_pretrained("/kaggle/input/d/methyl/hubmap-src/segformer-b4-finetuned-ade-512-512",
                                                        num_labels = 1,
                                                        ignore_mismatched_sizes=True).to(device)

In [None]:
model = Net(_model)

In [None]:
model_path = [
    '/kaggle/input/d/hubmap-src/ep_19_segformer_model.pt'
]

In [None]:
models = []
for i in range(len(model_path)):
    model.load_state_dict(torch.load(model_path[i]))
    model.eval()
    model.to(device)
    models.append(model)

In [None]:
def rle_encode(img):
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

def rle_decode(rle, width, height, fill=1, dtype=np.float32):
    s = rle.split()
    start  = np.asarray(s[0::2], dtype=int)-1
    length = np.asarray(s[1::2], dtype=int)
    end = start + length
    image = np.zeros(height * width, dtype=dtype)
    for s, e in zip(start, end):
        image[s:e] = fill
    image = image.reshape(height, width) #.T
    return image

## Inference

In [None]:
results = []

for i,d in tqdm(enumerate(valid_dl), total=len(valid_dl)):    
    cnt  = 0
    prob = 0
    TTA  = False 
    with torch.no_grad():
        with amp.autocast(enabled = True):
            d['image'] = d['image'].to(device)
            
            for i in range(len(models)):
                cnt += 1
                output = models[i](d)

                prob += \
                F.interpolate(output['probability'], size=(d['height'], d['width']),
                              mode='bilinear',align_corners=False, antialias=True)
                
                # if TTA == True:
                #    d['image'] = torch.fliplr(d['image'])
                #    output = models[i](d)
                #    cnt += 1
                #    prob += \
                #    F.interpolate(output['probability'], size=(d['height'], d['width']),
                #    mode='bilinear',align_corners=False, antialias=True)
                
            prob /= cnt
    
    prob = prob.detach().cpu().numpy() > organ_threshold[d['data_source'][0]][d['organ'][0]]
    rle  = rle_encode(prob.T)
    results.append({'id':d['id'].detach().cpu().numpy()[0], 'rle':rle})
    
submit_df = pd.DataFrame(results)
submit_df.to_csv('submission.csv',index=False)

submit_df.head()

## Local Validation

In [None]:
def do_local_validation():
    submit_df = pd.read_csv('./submission.csv').fillna('')
    submit_df = submit_df.sort_values('id')
    truth_df  = valid_df.sort_values('id')
    
    lb_score = []
    num = len(submit_df)
    for i in tqdm(range(num)):
        t_df = truth_df.iloc[i]
        p_df = submit_df.iloc[i]
        t = rle_decode(t_df.rle, t_df.img_height, t_df.img_width, 1)
        p = rle_decode(p_df.rle, t_df.img_height, t_df.img_width, 1)

        dice = 2*(t*p).sum()/(p.sum()+t.sum())
        lb_score.append(dice)
    
    truth_df.loc[:,'lb_score']=lb_score
    for organ in ['all', 'kidney', 'prostate', 'largeintestine', 'spleen', 'lung']:
        if organ != 'all':
            d = truth_df[truth_df.organ == organ]
        else:
            d = truth_df
        print('\t%f\t%s\t%f' % (len(d) / len(truth_df), organ, d.lb_score.mean()))
    
    return t, p

In [None]:
if submit_type == 'cv':
    t, p = do_local_validation()
    