In [None]:
!cp -r ../input/pytorch-segmentation-models-lib/ ./
!pip config set global.disable-pip-version-check true
!pip install -q ./pytorch-segmentation-models-lib/timm-0.4.12-py3-none-any.whl

In [None]:
!cp -r ../input/einops-041-wheel/ ./
!pip config set global.disable-pip-version-check true
!pip install -q ../input/einops-041-wheel/einops-0.4.1-py3-none-any.whl

In [None]:
import timm
import importlib
from timeit import default_timer as timer

import torch
import torch.cuda.amp as amp
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import tifffile 
import cv2
import os
import gc
from tqdm.notebook import tqdm
import rasterio
from rasterio.windows import Window

from fastai.vision.all import *
from torch.utils.data import Dataset, DataLoader
import albumentations as A # Augmentations
import warnings
warnings.filterwarnings("ignore")
import glob
import copy

In [None]:
import numpy as np
import cv2

class dotdict(dict):
	__setattr__ = dict.__setitem__
	__delattr__ = dict.__delitem__
	
	def __getattr__(self, name):
		try:
			return self[name]
		except KeyError:
			raise AttributeError(name)


#--- helper ----------
def time_to_str(t, mode='min'):
	if mode=='min':
		t  = int(t)/60
		hr = t//60
		min = t%60
		return '%2d hr %02d min'%(hr,min)
	
	elif mode=='sec':
		t   = int(t)
		min = t//60
		sec = t%60
		return '%2d min %02d sec'%(min,sec)
	
	else:
		raise NotImplementedError
	
def image_show(name, image, type='bgr', resize=1):
	if type == 'rgb': image = np.ascontiguousarray(image[:,:,::-1])
	H,W = image.shape[0:2]
	
	cv2.namedWindow(name, cv2.WINDOW_GUI_NORMAL)  #WINDOW_NORMAL #WINDOW_GUI_EXPANDED
	cv2.imshow(name, image) #.astype(np.uint8))
	cv2.resizeWindow(name, round(resize*W), round(resize*H))

In [None]:
import sys, os

import tifffile as tiff
import json
import cv2
import pandas as pd
import math
import numpy as np



##--------------------------------------------------------------------------------------
organ_meta = dotdict(
	kidney = dotdict(
		label = 1,
		um    = 0.5000,
		ftu   ='glomeruli',
	),
	prostate = dotdict(
		label = 2,
		um    = 6.2630,
		ftu   ='glandular acinus',
	),
	largeintestine = dotdict(
		label = 3,
		um    = 0.2290,
		ftu   ='crypt',
	),
	spleen = dotdict(
		label = 4,
		um    = 0.4945,
		ftu   ='white pulp',
	),
	lung = dotdict(
		label = 5,
		um    = 0.7562,
		ftu   ='alveolus',
	),
)



organ_to_label = {k: organ_meta[k].label for k in organ_meta.keys()}
label_to_organ = {v:k for k,v in organ_to_label.items()}
num_organ=5
#['kidney', 'prostate', 'largeintestine', 'spleen', 'lung']


# def read_tiff(image_file, mode='rgb'):
# 	image = tiff.imread(image_file)
# 	image = image.squeeze()
# 	if image.shape[0] == 3:
# 		image = image.transpose(1, 2, 0)
# 	if mode=='bgr':
# 		image = image[:,:,::-1]
# 	image = np.ascontiguousarray(image)
# 	return image

def read_json_as_list(json_file):
	with open(json_file) as f:
		j = json.load(f)
	return j


# # --- rle ---------------------------------
def rle_decode(rle, height, width , fill=255, dtype=np.uint8):
	s = rle.split()
	start, length = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
	start -= 1
	mask = np.zeros(height*width, dtype=dtype)
	for i, l in zip(start, length):
		mask[i:i+l] = fill
	mask = mask.reshape(width,height).T
	mask = np.ascontiguousarray(mask)
	return mask


def rle_encode(mask):
	m = mask.T.flatten()
	m = np.concatenate([[0], m, [0]])
	run = np.where(m[1:] != m[:-1])[0] + 1
	run[1::2] -= run[::2]
	rle =  ' '.join(str(r) for r in run)
	return rle


#
# # --draw ------------------------------------------
def mask_to_inner_contour(mask):
	mask = mask>0.5
	pad = np.lib.pad(mask, ((1, 1), (1, 1)), 'reflect')
	contour = mask & (
			(pad[1:-1,1:-1] != pad[:-2,1:-1]) \
			| (pad[1:-1,1:-1] != pad[2:,1:-1]) \
			| (pad[1:-1,1:-1] != pad[1:-1,:-2]) \
			| (pad[1:-1,1:-1] != pad[1:-1,2:])
	)
	return contour


def draw_contour_overlay(image, mask, color=(0,0,255), thickness=1):
	contour =  mask_to_inner_contour(mask)
	if thickness==1:
		image[contour] = color
	else:
		r = max(1,thickness//2)
		for y,x in np.stack(np.where(contour)).T:
			cv2.circle(image, (x,y), r, color, lineType=cv2.LINE_4 )
	return image

def result_to_overlay(image, mask, probability=None, **kwargs):
 
	
	H,W,C= image.shape
	if mask is None:
		mask = np.zeros((H,W),np.float32)
	if probability is None:
		probability = np.zeros((H,W),np.float32)
		
	o1 = np.zeros((H,W,3),np.float32)
	o1[:,:,2] = mask
	o1[:,:,1] = probability
	
	o2 = image.copy()
	o2 = o2*0.5
	o2[:,:,1] += 0.5*probability
	o2 = draw_contour_overlay(o2, mask, color=(0,0,1), thickness=max(3,int(7*H/1500)))
	
	#---
	o2,image,o1 = [(m*255).astype(np.uint8) for m in [o2,image,o1]]
	if kwargs.get('dice_score',-1)>=0:
		draw_shadow_text(o2,'dice=%0.5f'%kwargs.get('dice_score'),(20,80),2.5,(255,255,255),5)
	if kwargs.get('d',None) is not None:
		d = kwargs.get('d')
		draw_shadow_text(o2,d['id'],(20,140),1.5,(255,255,255),3)
		draw_shadow_text(o2,d.organ+'(%s)'%(organ_meta[d.organ].ftu),(20,190),1.5,(255,255,255),3)
		draw_shadow_text(o2,'%0.1f um'%(d.pixel_size),(20,240),1.5,(255,255,255),3)
		s100 = int(100/d.pixel_size)
		sx,sy = W-s100-40,H-80
		cv2.rectangle(o2,(sx,sy),(sx+s100,sy+s100//2),(0,0,0),-1)
		draw_shadow_text(o2,'100um',(sx+8,sy+40),1,(255,255,255),2)
		pass
	
	#draw_shadow_text(image,'input',(5,15),0.6,(1,1,1),1)
	#draw_shadow_text(im_paste,'predict',(5,15),0.6,(1,1,1),1)

	overlay = np.hstack([o2,image,o1])
	return overlay

# --lb metric ------------------------------------------
# https://www.kaggle.com/competitions/hubmap-organ-segmentation/overview/supervised-ml-evaluation

def compute_dice_score(probability, mask):
	N = len(probability)
	p = probability.reshape(N,-1)
	t = mask.reshape(N,-1)
	
	p = p>0.5
	t = t>0.5
	uion = p.sum(-1) + t.sum(-1)
	overlap = (p*t).sum(-1)
	dice = 2*overlap/(uion+0.0001)
	return dice

In [None]:
# CoAT-5level

# https://github.com/mlpc-ucsd/CoaT/blob/main/src/models/coat.py

"""
CoaT architecture.

Modified from timm/models/vision_transformer.py
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model

from einops import rearrange
from functools import partial
from torch import nn, einsum
import pdb

def init_weight(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('Batch') != -1:
        m.weight.data.normal_(1,0.02)
        m.bias.data.zero_()
    elif classname.find('Linear') != -1:
        nn.init.orthogonal_(m.weight, gain=1)
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('Embedding') != -1:
        nn.init.orthogonal_(m.weight, gain=1)
        
class LayerNorm2d(nn.Module):
	def __init__(self, dim, eps=1e-6):
		super().__init__()
		self.dim = dim
		self.weight = nn.Parameter(torch.ones(dim))
		self.bias = nn.Parameter(torch.zeros(dim))
		self.eps = eps
	
	def forward(self, x):
		batch_size,C,H,W = x.shape
		u = x.mean(1, keepdim=True)
		s = (x - u).pow(2).mean(1, keepdim=True)
		x = (x - u) / torch.sqrt(s + self.eps)
		x = self.weight[:, None, None] * x + self.bias[:, None, None]
		return x
#---------------------------------------------

def _cfg_coat(url='', **kwargs):
	return {
		'url': url,
		'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
		'crop_pct': .9, 'interpolation': 'bicubic',
		'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
		'first_conv': 'patch_embed.proj', 'classifier': 'head',
		**kwargs
	}


class Mlp(nn.Module):
	""" Feed-forward network (FFN, a.k.a. MLP) class. """
	def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
		super().__init__()
		out_features = out_features or in_features
		hidden_features = hidden_features or in_features
		self.fc1 = nn.Linear(in_features, hidden_features)
		self.act = act_layer()
		self.fc2 = nn.Linear(hidden_features, out_features)
		self.drop = nn.Dropout(drop)
	
	def forward(self, x):
		x = self.fc1(x)
		x = self.act(x)
		x = self.drop(x)
		x = self.fc2(x)
		x = self.drop(x)
		return x


class ConvRelPosEnc(nn.Module):
	""" Convolutional relative position encoding. """
	def __init__(self, Ch, h, window):
		"""
		Initialization.
			Ch: Channels per head.
			h: Number of heads.
			window: Window size(s) in convolutional relative positional encoding. It can have two forms:
					1. An integer of window size, which assigns all attention heads with the same window size in ConvRelPosEnc.
					2. A dict mapping window size to #attention head splits (e.g. {window size 1: #attention head split 1, window size 2: #attention head split 2})
					   It will apply different window size to the attention head splits.
		"""
		super().__init__()
		
		if isinstance(window, int):
			window = {window: h}                                                         # Set the same window size for all attention heads.
			self.window = window
		elif isinstance(window, dict):
			self.window = window
		else:
			raise ValueError()
		
		self.conv_list = nn.ModuleList()
		self.head_splits = []
		for cur_window, cur_head_split in window.items():
			dilation = 1                                                                 # Use dilation=1 at default.
			padding_size = (cur_window + (cur_window - 1) * (dilation - 1)) // 2         # Determine padding size. Ref: https://discuss.pytorch.org/t/how-to-keep-the-shape-of-input-and-output-same-when-dilation-conv/14338
			cur_conv = nn.Conv2d(cur_head_split*Ch, cur_head_split*Ch,
			                     kernel_size=(cur_window, cur_window),
			                     padding=(padding_size, padding_size),
			                     dilation=(dilation, dilation),
			                     groups=cur_head_split*Ch,
			                     )
			self.conv_list.append(cur_conv)
			self.head_splits.append(cur_head_split)
		self.channel_splits = [x*Ch for x in self.head_splits]
	
	def forward(self, q, v, size):
		B, h, N, Ch = q.shape
		H, W = size
		assert N == 1 + H * W
		
		# Convolutional relative position encoding.
		q_img = q[:,:,1:,:]                                                              # Shape: [B, h, H*W, Ch].
		v_img = v[:,:,1:,:]                                                              # Shape: [B, h, H*W, Ch].
		
		v_img = rearrange(v_img, 'B h (H W) Ch -> B (h Ch) H W', H=H, W=W)               # Shape: [B, h, H*W, Ch] -> [B, h*Ch, H, W].
		v_img_list = torch.split(v_img, self.channel_splits, dim=1)                      # Split according to channels.
		conv_v_img_list = [conv(x) for conv, x in zip(self.conv_list, v_img_list)]
		conv_v_img = torch.cat(conv_v_img_list, dim=1)
		conv_v_img = rearrange(conv_v_img, 'B (h Ch) H W -> B h (H W) Ch', h=h)          # Shape: [B, h*Ch, H, W] -> [B, h, H*W, Ch].
		
		EV_hat_img = q_img * conv_v_img
		zero = torch.zeros((B, h, 1, Ch), dtype=q.dtype, layout=q.layout, device=q.device)
		EV_hat = torch.cat((zero, EV_hat_img), dim=2)                                # Shape: [B, h, N, Ch].
		
		return EV_hat


class FactorAtt_ConvRelPosEnc(nn.Module):
	""" Factorized attention with convolutional relative position encoding class. """
	def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., shared_crpe=None):
		super().__init__()
		self.num_heads = num_heads
		head_dim = dim // num_heads
		self.scale = qk_scale or head_dim ** -0.5
		
		self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
		self.attn_drop = nn.Dropout(attn_drop)                                           # Note: attn_drop is actually not used.
		self.proj = nn.Linear(dim, dim)
		self.proj_drop = nn.Dropout(proj_drop)
		
		# Shared convolutional relative position encoding.
		self.crpe = shared_crpe
	
	def forward(self, x, size):
		B, N, C = x.shape
		
		# Generate Q, K, V.
		qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)  # Shape: [3, B, h, N, Ch].
		q, k, v = qkv[0], qkv[1], qkv[2]                                                 # Shape: [B, h, N, Ch].
		
		# Factorized attention.
		k_softmax = k.softmax(dim=2)                                                     # Softmax on dim N.
		k_softmax_T_dot_v = einsum('b h n k, b h n v -> b h k v', k_softmax, v)          # Shape: [B, h, Ch, Ch].
		factor_att        = einsum('b h n k, b h k v -> b h n v', q, k_softmax_T_dot_v)  # Shape: [B, h, N, Ch].
		
		# Convolutional relative position encoding.
		crpe = self.crpe(q, v, size=size)                                                # Shape: [B, h, N, Ch].
		
		# Merge and reshape.
		x = self.scale * factor_att + crpe
		x = x.transpose(1, 2).reshape(B, N, C)                                           # Shape: [B, h, N, Ch] -> [B, N, h, Ch] -> [B, N, C].
		
		# Output projection.
		x = self.proj(x)
		x = self.proj_drop(x)
		
		return x                                                                         # Shape: [B, N, C].


class ConvPosEnc(nn.Module):
	""" Convolutional Position Encoding.
		Note: This module is similar to the conditional position encoding in CPVT.
	"""
	def __init__(self, dim, k=3):
		super(ConvPosEnc, self).__init__()
		self.proj = nn.Conv2d(dim, dim, k, 1, k//2, groups=dim)
	
	def forward(self, x, size):
		B, N, C = x.shape
		H, W = size
		assert N == 1 + H * W
		
		# Extract CLS token and image tokens.
		cls_token, img_tokens = x[:, :1], x[:, 1:]                                       # Shape: [B, 1, C], [B, H*W, C].
		
		# Depthwise convolution.
		feat = img_tokens.transpose(1, 2).view(B, C, H, W)
		x = self.proj(feat) + feat
		x = x.flatten(2).transpose(1, 2)
		
		# Combine with CLS token.
		x = torch.cat((cls_token, x), dim=1)
		
		return x


class SerialBlock(nn.Module):
	""" Serial block class.
		Note: In this implementation, each serial block only contains a conv-attention and a FFN (MLP) module. """
	def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
	             drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
	             shared_cpe=None, shared_crpe=None):
		super().__init__()
		
		# Conv-Attention.
		self.cpe = shared_cpe
		
		self.norm1 = norm_layer(dim)
		self.factoratt_crpe = FactorAtt_ConvRelPosEnc(
			dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
			shared_crpe=shared_crpe)
		self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
		
		# MLP.
		self.norm2 = norm_layer(dim)
		mlp_hidden_dim = int(dim * mlp_ratio)
		self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
	
	def forward(self, x, size):
		# Conv-Attention.
		x = self.cpe(x, size)                  # Apply convolutional position encoding.
		cur = self.norm1(x)
		cur = self.factoratt_crpe(cur, size)   # Apply factorized attention and convolutional relative position encoding.
		x = x + self.drop_path(cur)
		
		# MLP.
		cur = self.norm2(x)
		cur = self.mlp(cur)
		x = x + self.drop_path(cur)
		
		return x


class ParallelBlock(nn.Module):
	""" Parallel block class. """
	def __init__(self, dims, num_heads, mlp_ratios=[], qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
	             drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
	             shared_cpes=None, shared_crpes=None):
		super().__init__()
		
		# Conv-Attention.
		self.cpes = shared_cpes
		
		self.norm12 = norm_layer(dims[1])
		self.norm13 = norm_layer(dims[2])
		self.norm14 = norm_layer(dims[3])
		self.norm15 = norm_layer(dims[4])


		self.factoratt_crpe2 = FactorAtt_ConvRelPosEnc(
			dims[1], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
			shared_crpe=shared_crpes[1]
		)
		self.factoratt_crpe3 = FactorAtt_ConvRelPosEnc(
			dims[2], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
			shared_crpe=shared_crpes[2]
		)
		self.factoratt_crpe4 = FactorAtt_ConvRelPosEnc(
			dims[3], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
			shared_crpe=shared_crpes[3]
		)
		self.factoratt_crpe5 = FactorAtt_ConvRelPosEnc(
			dims[4], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
			shared_crpe=shared_crpes[4]
		)


		self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
		
		# MLP.
		self.norm22 = norm_layer(dims[1])
		self.norm23 = norm_layer(dims[2])
		self.norm24 = norm_layer(dims[3])
		self.norm25 = norm_layer(dims[4])

		assert dims[1] == dims[2] == dims[3] ==dims[4]                             # In parallel block, we assume dimensions are the same and share the linear transformation.
		assert mlp_ratios[1] == mlp_ratios[2] == mlp_ratios[3]
		mlp_hidden_dim = int(dims[1] * mlp_ratios[1])
		self.mlp2 = self.mlp3 = self.mlp4 =self.mlp5= Mlp(in_features=dims[1], hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
	
	def upsample(self, x, output_size, size):
		""" Feature map up-sampling. """
		return self.interpolate(x, output_size=output_size, size=size)
	
	def downsample(self, x, output_size, size):
		""" Feature map down-sampling. """
		return self.interpolate(x, output_size=output_size, size=size)
	
	def interpolate(self, x, output_size, size):
		""" Feature map interpolation. """
		B, N, C = x.shape
		H, W = size
		assert N == 1 + H * W
		
		cls_token  = x[:, :1, :]
		img_tokens = x[:, 1:, :]
		
		img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W)
		img_tokens = F.interpolate(img_tokens, size=output_size, mode='bilinear')  # FIXME: May have alignment issue.
		img_tokens = img_tokens.reshape(B, C, -1).transpose(1, 2)
		
		out = torch.cat((cls_token, img_tokens), dim=1)
		
		return out
	
	def forward(self, x1, x2, x3, x4, x5,sizes):
		_, (H2, W2), (H3, W3), (H4, W4),(H5,W5) = sizes
		
		# Conv-Attention.
		x2 = self.cpes[1](x2, size=(H2, W2))  # Note: x1 is ignored.
		x3 = self.cpes[2](x3, size=(H3, W3))
		x4 = self.cpes[3](x4, size=(H4, W4))
		x5 = self.cpes[4](x5, size=(H5, W5))
		
		cur2 = self.norm12(x2)
		cur3 = self.norm13(x3)
		cur4 = self.norm14(x4)
		cur5 = self.norm15(x5)

		cur2 = self.factoratt_crpe2(cur2, size=(H2,W2))
		cur3 = self.factoratt_crpe3(cur3, size=(H3,W3))
		cur4 = self.factoratt_crpe4(cur4, size=(H4,W4))
		cur5 = self.factoratt_crpe4(cur5, size=(H5,W5))


		upsample3_2 = self.upsample(cur3, output_size=(H2,W2), size=(H3,W3))
		upsample4_3 = self.upsample(cur4, output_size=(H3,W3), size=(H4,W4))
		upsample4_2 = self.upsample(cur4, output_size=(H2,W2), size=(H4,W4))
		downsample2_3 = self.downsample(cur2, output_size=(H3,W3), size=(H2,W2))
		downsample3_4 = self.downsample(cur3, output_size=(H4,W4), size=(H3,W3))
		downsample2_4 = self.downsample(cur2, output_size=(H4,W4), size=(H2,W2))
		upsample5_2 = self.upsample(cur5, output_size=(H2,W2), size=(H5,W5))
		upsample5_3 = self.upsample(cur5, output_size=(H3,W3), size=(H5,W5))
		downsample3_5 = self.downsample(cur3, output_size=(H5,W5), size=(H3,W3))
		upsample5_4 = self.upsample(cur5, output_size=(H4,W4), size=(H5,W5))
		downsample2_5 = self.downsample(cur2, output_size=(H5,W5), size=(H2,W2))
		downsample4_5 = self.downsample(cur4, output_size=(H5,W5), size=(H4,W4))

		cur2 = cur2  + upsample3_2   + upsample4_2
		cur3 = cur3  + upsample4_3   + downsample2_3
		cur4 = cur4  + upsample5_4   + downsample2_4
		cur5 = cur5  + downsample4_5 + downsample2_5


		x2 = x2 + self.drop_path(cur2)
		x3 = x3 + self.drop_path(cur3)
		x4 = x4 + self.drop_path(cur4)
		x5 = x5 + self.drop_path(cur5)
		
		# MLP.
		cur2 = self.norm22(x2)
		cur3 = self.norm23(x3)
		cur4 = self.norm24(x4)
		cur5 = self.norm25(x5)

		cur2 = self.mlp2(cur2)
		cur3 = self.mlp3(cur3)
		cur4 = self.mlp4(cur4)
		cur5 = self.mlp5(cur5)

		x2 = x2 + self.drop_path(cur2)
		x3 = x3 + self.drop_path(cur3)
		x4 = x4 + self.drop_path(cur4)
		x5 = x5 + self.drop_path(cur5)
		
		return x1, x2, x3, x4, x5


class PatchEmbed(nn.Module):
	""" Image to Patch Embedding """
	def __init__(self, patch_size=16, in_chans=3, embed_dim=768):
		super().__init__()
		patch_size = to_2tuple(patch_size)
		
		self.patch_size = patch_size
		self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
		self.norm = nn.LayerNorm(embed_dim)
	
	def forward(self, x):
		_, _, H, W = x.shape
		out_H, out_W = H // self.patch_size[0], W // self.patch_size[1]
		
		x = self.proj(x).flatten(2).transpose(1, 2)
		out = self.norm(x)
		
		return out, (out_H, out_W)


class CoaT(nn.Module):
	""" CoaT class. """
	def __init__(self, patch_size=16, in_chans=3, embed_dims=[0, 0, 0, 0],
	             serial_depths=[0, 0, 0, 0], parallel_depth=0,
	             num_heads=0, mlp_ratios=[0, 0, 0, 0], qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
	             drop_path_rate=0.,
	             norm_layer=partial(nn.LayerNorm, eps=1e-6),
	             return_interm_layers=True,
	             out_features=['x1_nocls','x2_nocls','x3_nocls','x4_nocls','x5_nocls'],
	             crpe_window={3:2, 5:3, 7:3},
	             pretrain=None,
	             out_norm = nn.Identity, #use nn.Identity, nn.BatchNorm2d, LayerNorm2d
	             **kwargs):
		super().__init__()
		self.return_interm_layers = return_interm_layers
		self.pretrain     = pretrain
		self.embed_dims   = embed_dims
		self.out_features = out_features
		#self.num_classes  = num_classes
		
		# Patch embeddings.
		self.patch_embed1 = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0])
		self.patch_embed2 = PatchEmbed(patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])
		self.patch_embed3 = PatchEmbed(patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])
		self.patch_embed4 = PatchEmbed(patch_size=2, in_chans=embed_dims[2], embed_dim=embed_dims[3])
		self.patch_embed5 = PatchEmbed(patch_size=2, in_chans=embed_dims[3], embed_dim=embed_dims[4])
		
		# Class tokens.
		self.cls_token1 = nn.Parameter(torch.zeros(1, 1, embed_dims[0]))
		self.cls_token2 = nn.Parameter(torch.zeros(1, 1, embed_dims[1]))
		self.cls_token3 = nn.Parameter(torch.zeros(1, 1, embed_dims[2]))
		self.cls_token4 = nn.Parameter(torch.zeros(1, 1, embed_dims[3]))
		self.cls_token5 = nn.Parameter(torch.zeros(1, 1, embed_dims[4]))
		
		# Convolutional position encodings.
		self.cpe1 = ConvPosEnc(dim=embed_dims[0], k=3)
		self.cpe2 = ConvPosEnc(dim=embed_dims[1], k=3)
		self.cpe3 = ConvPosEnc(dim=embed_dims[2], k=3)
		self.cpe4 = ConvPosEnc(dim=embed_dims[3], k=3)
		self.cpe5 = ConvPosEnc(dim=embed_dims[4], k=3)
		
		# Convolutional relative position encodings.
		self.crpe1 = ConvRelPosEnc(Ch=embed_dims[0] // num_heads, h=num_heads, window=crpe_window)
		self.crpe2 = ConvRelPosEnc(Ch=embed_dims[1] // num_heads, h=num_heads, window=crpe_window)
		self.crpe3 = ConvRelPosEnc(Ch=embed_dims[2] // num_heads, h=num_heads, window=crpe_window)
		self.crpe4 = ConvRelPosEnc(Ch=embed_dims[3] // num_heads, h=num_heads, window=crpe_window)
		self.crpe5 = ConvRelPosEnc(Ch=embed_dims[4] // num_heads, h=num_heads, window=crpe_window)

		# Enable stochastic depth.
		dpr = drop_path_rate
		
		# Serial blocks 1.
		self.serial_blocks1 = nn.ModuleList([
			SerialBlock(
				dim=embed_dims[0], num_heads=num_heads, mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
				drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
				shared_cpe=self.cpe1, shared_crpe=self.crpe1
			)
			for _ in range(serial_depths[0])]
		)
		
		# Serial blocks 2.
		self.serial_blocks2 = nn.ModuleList([
			SerialBlock(
				dim=embed_dims[1], num_heads=num_heads, mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
				drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
				shared_cpe=self.cpe2, shared_crpe=self.crpe2
			)
			for _ in range(serial_depths[1])]
		)
		
		# Serial blocks 3.
		self.serial_blocks3 = nn.ModuleList([
			SerialBlock(
				dim=embed_dims[2], num_heads=num_heads, mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
				drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
				shared_cpe=self.cpe3, shared_crpe=self.crpe3
			)
			for _ in range(serial_depths[2])]
		)
		
		# Serial blocks 4.
		self.serial_blocks4 = nn.ModuleList([
			SerialBlock(
				dim=embed_dims[3], num_heads=num_heads, mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
				drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
				shared_cpe=self.cpe4, shared_crpe=self.crpe4
			)
			for _ in range(serial_depths[3])]
		)

		# Serial blocks 5.
		self.serial_blocks5 = nn.ModuleList([
			SerialBlock(
				dim=embed_dims[4], num_heads=num_heads, mlp_ratio=mlp_ratios[4], qkv_bias=qkv_bias, qk_scale=qk_scale,
				drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
				shared_cpe=self.cpe4, shared_crpe=self.crpe4
			)
			for _ in range(serial_depths[4])]
		)
		
		# Parallel blocks.
		self.parallel_depth = parallel_depth
		if self.parallel_depth > 0:
			self.parallel_blocks = nn.ModuleList([
				ParallelBlock(
					dims=embed_dims, num_heads=num_heads, mlp_ratios=mlp_ratios, qkv_bias=qkv_bias, qk_scale=qk_scale,
					drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
					shared_cpes=[self.cpe1, self.cpe2, self.cpe3, self.cpe4,  self.cpe5],
					shared_crpes=[self.crpe1, self.crpe2, self.crpe3, self.crpe4,  self.cpe5]
				)
				for _ in range(parallel_depth)]
			)

		self.out_norm = nn.ModuleList(
			[ out_norm(embed_dims[i]) for i in range(4)]
		)
		
		# Initialize weights.
		trunc_normal_(self.cls_token1, std=.02)
		trunc_normal_(self.cls_token2, std=.02)
		trunc_normal_(self.cls_token3, std=.02)
		trunc_normal_(self.cls_token4, std=.02)
		trunc_normal_(self.cls_token5, std=.02)
		self.apply(self._init_weights)
	
	def _init_weights(self, m):
		if isinstance(m, nn.Linear):
			trunc_normal_(m.weight, std=.02)
			if isinstance(m, nn.Linear) and m.bias is not None:
				nn.init.constant_(m.bias, 0)
		elif isinstance(m, nn.LayerNorm):
			nn.init.constant_(m.bias, 0)
			nn.init.constant_(m.weight, 1.0)
	
	@torch.jit.ignore
	def no_weight_decay(self):
		return {'cls_token1', 'cls_token2', 'cls_token3', 'cls_token4'}

	def insert_cls(self, x, cls_token):
		""" Insert CLS token. """
		cls_tokens = cls_token.expand(x.shape[0], -1, -1)
		x = torch.cat((cls_tokens, x), dim=1)
		return x
	
	def remove_cls(self, x):
		""" Remove CLS token. """
		return x[:, 1:, :]
	
	def forward(self, x0):
		B = x0.shape[0]
		
		
		# Serial blocks 1.
		x1, (H1, W1) = self.patch_embed1(x0)
		cls = self.cls_token1#torch.zeros_like(self.cls_token1)#self.cls_token1
		x1 = self.insert_cls(x1, cls)
		for blk in self.serial_blocks1:
			x1 = blk(x1, size=(H1, W1))
		x1_nocls = self.remove_cls(x1)
		x1_nocls = x1_nocls.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
		
		# Serial blocks 2.
		x2, (H2, W2) = self.patch_embed2(x1_nocls)
		cls = self.cls_token2# torch.zeros_like(self.cls_token2)#self.cls_token2#
		x2 = self.insert_cls(x2,cls)
		for blk in self.serial_blocks2:
			x2 = blk(x2, size=(H2, W2))
		x2_nocls = self.remove_cls(x2)
		x2_nocls = x2_nocls.reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous()
		
		# Serial blocks 3.
		x3, (H3, W3) = self.patch_embed3(x2_nocls)
		cls = self.cls_token3#torch.zeros_like(self.cls_token3)# self.cls_token3
		x3 = self.insert_cls(x3, cls)
		for blk in self.serial_blocks3:
			x3 = blk(x3, size=(H3, W3))
		x3_nocls = self.remove_cls(x3)
		x3_nocls = x3_nocls.reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous()
		
		# Serial blocks 4.
		x4, (H4, W4) = self.patch_embed4(x3_nocls)
		cls = self.cls_token5#torch.zeros_like(self.cls_token4)#self.cls_token4
		x4 = self.insert_cls(x4, cls)
		for blk in self.serial_blocks4:
			x4 = blk(x4, size=(H4, W4))
		x4_nocls = self.remove_cls(x4)
		x4_nocls = x4_nocls.reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous()

		# Serial blocks 5.
		x5, (H5, W5) = self.patch_embed4(x4_nocls)
		cls = self.cls_token4#torch.zeros_like(self.cls_token4)#self.cls_token4
		x5 = self.insert_cls(x5, cls)
		for blk in self.serial_blocks5:
			x5 = blk(x5, size=(H5, W5))
		x5_nocls = self.remove_cls(x5)
		x5_nocls = x5_nocls.reshape(B, H5, W5, -1).permute(0, 3, 1, 2).contiguous()
		
		# Only serial blocks: Early return.
		if self.parallel_depth == 0:
			x1_nocls = self.out_norm[0](x1_nocls)
			x2_nocls = self.out_norm[1](x2_nocls)
			x3_nocls = self.out_norm[2](x3_nocls)
			x4_nocls = self.out_norm[3](x4_nocls)
			return [x1_nocls,x2_nocls,x3_nocls,x4_nocls]
		 
	 
			
		
		# Parallel blocks.
		for blk in self.parallel_blocks:
			x1, x2, x3, x4,x5 = blk(x1, x2, x3, x4,x5, sizes=[(H1, W1), (H2, W2), (H3, W3), (H4, W4), (H5, W5)])
		# pdb.set_trace()
		# remove cls and return feature for seg
		if self.return_interm_layers:       # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2).
			feat_out = {}
			if 'x1_nocls' in self.out_features:
				x1_nocls = self.remove_cls(x1)
				x1_nocls = x1_nocls.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
				feat_out['x1_nocls'] = x1_nocls
			if 'x2_nocls' in self.out_features:
				x2_nocls = self.remove_cls(x2)
				x2_nocls = x2_nocls.reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous()
				feat_out['x2_nocls'] = x2_nocls
			if 'x3_nocls' in self.out_features:
				x3_nocls = self.remove_cls(x3)
				x3_nocls = x3_nocls.reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous()
				feat_out['x3_nocls'] = x3_nocls
			if 'x4_nocls' in self.out_features:
				x4_nocls = self.remove_cls(x4)
				x4_nocls = x4_nocls.reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous()
				feat_out['x4_nocls'] = x4_nocls
			if 'x5_nocls' in self.out_features:
				x5_nocls = self.remove_cls(x5)
				x5_nocls = x5_nocls.reshape(B, H5, W5, -1).permute(0, 3, 1, 2).contiguous()
				feat_out['x5_nocls'] = x5_nocls
			feat_out = list(feat_out.values())
			return feat_out
		else:
			x2 = self.norm2(x2)
			x3 = self.norm3(x3)
			x4 = self.norm4(x4)
			x2_cls = x2[:, :1]              # Shape: [B, 1, C].
			x3_cls = x3[:, :1]
			x4_cls = x4[:, :1]
			merged_cls = torch.cat((x2_cls, x3_cls, x4_cls), dim=1)       # Shape: [B, 3, C].
			merged_cls = self.aggregate(merged_cls).squeeze(dim=1)        # Shape: [B, C].
			return merged_cls

class coat_parallel_small_plus (CoaT):
	def __init__(self, **kwargs):
		super(coat_parallel_small_plus, self).__init__(
			patch_size=4, embed_dims=[152, 320, 320, 320, 320],
			serial_depths=[2, 2, 2, 2, 2],
			parallel_depth=6, num_heads=8, mlp_ratios=[4, 4, 4, 4, 4], 
			pretrain = 'coat_small_7479cf9b.pth',
			**kwargs)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class MixUpSample(nn.Module):
	def __init__( self, scale_factor=2):
		super().__init__()
		assert(scale_factor!=1)
		
		self.mixing = nn.Parameter(torch.tensor(0.5))
		self.scale_factor = scale_factor
	
	def forward(self, x):
		x = self.mixing *F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False) \
			+ (1-self.mixing )*F.interpolate(x, scale_factor=self.scale_factor, mode='nearest')
		return x

#https://github.com/lhoyer/DAFormer/blob/master/mmseg/models/decode_heads/daformer_head.py
def Conv2dBnReLU(in_channel, out_channel, kernel_size=3, padding=1,stride=1, dilation=1):
	return nn.Sequential(
		nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, padding=padding, stride=stride, dilation=dilation, bias=False),
		nn.BatchNorm2d(out_channel),
		nn.ReLU(inplace=True),
	)

class ASPP(nn.Module):
	
	def __init__(self,
				 in_channel,
				 channel,
				 dilation,
				 ):
		super(ASPP, self).__init__()
		
		self.conv = nn.ModuleList()
		for d in dilation:
			self.conv.append(
				Conv2dBnReLU(
					in_channel,
					channel,
					kernel_size=1 if d == 1 else 3,
					dilation=d,
					padding=0 if d == 1 else d,
				)
			)
		
		self.out = Conv2dBnReLU(
			len(dilation) * channel,
			channel,
			kernel_size=3,
			padding=1,
			)
	
	def forward(self, x):
		aspp = []
		for conv in self.conv:
			aspp.append(conv(x))
		aspp = torch.cat(aspp, dim=1)
		out = self.out(aspp)
		return out

#DepthwiseSeparable
class DSConv2d(nn.Module):
	def __init__(self,
				 in_channel,
				 out_channel,
				 kernel_size,
				 stride   = 1,
				 padding  = 0,
				 dilation = 1
		):
		super().__init__()
		
		self.depthwise = nn.Sequential(
			nn.Conv2d( in_channel, in_channel, kernel_size, stride=stride, padding=padding, dilation=dilation),
			nn.BatchNorm2d(in_channel),
			nn.ReLU(inplace=True)
		)
		
		self.pointwise = nn.Sequential(
			nn.Conv2d( in_channel, out_channel, kernel_size=1, stride=1, padding=0),
			nn.BatchNorm2d(out_channel),
			nn.ReLU(inplace=True)
		)
	
	def forward(self, x):
		x = self.depthwise(x)
		x = self.pointwise(x)
		return x

class DSASPP(nn.Module):
	
	def __init__(self,
				 in_channel,
				 channel,
				 dilation,
				 ):
		super(DSASPP, self).__init__()
		
		self.conv = nn.ModuleList()
		for d in dilation:
			if d == 1:
				self.conv.append(
					Conv2dBnReLU(
						in_channel,
						channel,
						kernel_size=1 if d == 1 else 3,
						dilation=d,
						padding=0 if d == 1 else d,
					)
				)
			else:
				self.conv.append(
					DSConv2d(
						in_channel,
						channel,
						kernel_size=3,
						dilation=d,
						padding=d,
					)
				)
		
		self.out = Conv2dBnReLU(
			len(dilation) * channel,
			channel,
			kernel_size=3,
			padding=1,
		)
	 
	def forward(self, x):
		aspp = []
		for conv in self.conv:
			aspp.append(conv(x))
		aspp = torch.cat(aspp, dim=1)
		out = self.out(aspp)
		return out

	
##############################################################################################33

class DaformerDecoder(nn.Module):
	def __init__(
			self,
			encoder_dim = [32, 64, 160, 256],
			decoder_dim = 256,
			dilation = [1, 6, 12, 18],
			use_bn_mlp  = True,
			fuse = 'conv3x3',
	):
		super().__init__()
		self.mlp = nn.ModuleList([
			nn.Sequential(
				# Conv2dBnReLU(dim, decoder_dim, 1, padding=0), #follow mmseg to use conv-bn-relu
				*(
				  ( nn.Conv2d(dim, decoder_dim, 1, padding= 0,  bias=False),
					nn.BatchNorm2d(decoder_dim),
					nn.ReLU(inplace=True),
				)if use_bn_mlp else
				  ( nn.Conv2d(dim, decoder_dim, 1, padding= 0,  bias=True),)
				),
				
				MixUpSample(2**i) if i!=0 else nn.Identity(),
			) for i, dim in enumerate(encoder_dim)])
	  
		if fuse=='conv1x1':
			self.fuse = nn.Sequential(
				nn.Conv2d(len(encoder_dim) * decoder_dim, decoder_dim, 1, padding=0, bias=False),
				nn.BatchNorm2d(decoder_dim),
				nn.ReLU(inplace=True),
			)
		
		if fuse=='conv3x3':
			self.fuse = nn.Sequential(
				nn.Conv2d(len(encoder_dim) * decoder_dim, decoder_dim, 3, padding=1, bias=False),
				nn.BatchNorm2d(decoder_dim),
				nn.ReLU(inplace=True),
			)
		
		if fuse=='aspp':
			self.fuse = ASPP(
				decoder_dim*len(encoder_dim),
				decoder_dim,
				dilation,
			)
			
		if fuse=='ds-aspp':
			self.fuse = DSASPP(
				decoder_dim*len(encoder_dim),
				decoder_dim,
				dilation,
			)
		
	
	def forward(self, feature):
		
		out = []
		for i,f in enumerate(feature):
			f = self.mlp[i](f)
			out.append(f)
			#print(f.shape)
		x = self.fuse(torch.cat(out, dim = 1))
		return x, out


class daformer_conv3x3 (DaformerDecoder):
	def __init__(self, **kwargs):
		super(daformer_conv3x3, self).__init__(
			fuse = 'conv3x3',
			**kwargs
		)
class daformer_conv1x1 (DaformerDecoder):
	def __init__(self, **kwargs):
		super(daformer_conv1x1, self).__init__(
			fuse = 'conv1x1',
			**kwargs
		)

class daformer_aspp (DaformerDecoder):
	def __init__(self, **kwargs):
		super(daformer_aspp, self).__init__(
			fuse = 'aspp',
			**kwargs
		)

In [None]:
import torch
import numpy as np
import tifffile as tiff
from torch import nn

class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__


def rle_encode(mask):
	m = mask.T.flatten()
	m = np.concatenate([[0], m, [0]])
	run = np.where(m[1:] != m[:-1])[0] + 1
	run[1::2] -= run[::2]
	rle =  ' '.join(str(r) for r in run)
	return rle

def read_tiff(image_file, mode='rgb'):
	image = tiff.imread(image_file)
	image = image.squeeze()
	if image.shape[0] == 3:
		image = image.transpose(1, 2, 0)
	if mode=='bgr':
		image = image[:,:,::-1]
	image = np.ascontiguousarray(image)
	return image


organ_meta = dotdict(
	kidney = dotdict(
		label = 1,
		um    = 0.5000,
		ftu   ='glomeruli',
	),
	prostate = dotdict(
		label = 2,
		um    = 6.2630,
		ftu   ='glandular acinus',
	),
	largeintestine = dotdict(
		label = 3,
		um    = 0.2290,
		ftu   ='crypt',
	),
	spleen = dotdict(
		label = 4,
		um    = 0.4945,
		ftu   ='white pulp',
	),
	lung = dotdict(
		label = 5,
		um    = 0.7562,
		ftu   ='alveolus',
	),
)

organ_to_label = {k: organ_meta[k].label for k in organ_meta.keys()}
label_to_organ = {v:k for k,v in organ_to_label.items()}

def image_to_tensor(image, mode='rgb'):
    if  mode=='bgr' :
        image = image[:,:,::-1]
    
    x = image.transpose(2,0,1)
    x = np.ascontiguousarray(x)
    x = torch.tensor(x)
    return x

In [None]:
class RGB(nn.Module):
    IMAGE_RGB_MEAN_4 = [0.485, 0.456, 0.406]  # [0.5, 0.5, 0.5]
    IMAGE_RGB_STD_4 = [0.229, 0.224, 0.225]  # [0.5, 0.5, 0.5]
    IMAGE_RGB_MEAN_7 = [0.7720342, 0.74582646, 0.76392896]
    IMAGE_RGB_STD_7 = [0.24745085, 0.26182273, 0.25782376]
    
    def __init__(self, nor=True):
        super(RGB, self).__init__()
        if nor:
            self.IMAGE_RGB_MEAN = self.IMAGE_RGB_MEAN_7
            self.IMAGE_RGB_STD = self.IMAGE_RGB_STD_7
        else:
            self.IMAGE_RGB_MEAN = self.IMAGE_RGB_MEAN_4
            self.IMAGE_RGB_STD = self.IMAGE_RGB_STD_4
        self.register_buffer('mean', torch.zeros(1, 3, 1, 1))
        self.register_buffer('std', torch.ones(1, 3, 1, 1))
        self.mean.data = torch.FloatTensor(self.IMAGE_RGB_MEAN).view(self.mean.shape)
        self.std.data = torch.FloatTensor(self.IMAGE_RGB_STD).view(self.std.shape)

    def forward(self, x):
        x = (x - self.mean) / self.std
        return x



def init_weight(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
        #nn.init.orthogonal_(m.weight, gain=1)
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('Batch') != -1:
        m.weight.data.normal_(1,0.02)
        m.bias.data.zero_()
    elif classname.find('Linear') != -1:
        nn.init.orthogonal_(m.weight, gain=1)
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('Embedding') != -1:
        nn.init.orthogonal_(m.weight, gain=1)



class Net(nn.Module):
    def __init__(self,
                 encoder=coat_parallel_small_plus,
                 decoder=daformer_conv1x1,
                 encoder_cfg={},
                 decoder_cfg={},
                 nor=True,
                 ):
        
        super(Net, self).__init__()
        decoder_dim = decoder_cfg.get('decoder_dim', 320)

        self.encoder = encoder

        self.rgb = RGB()

        encoder_dim = self.encoder.embed_dims
        # [64, 128, 320, 512]

        self.decoder = decoder(
            encoder_dim=encoder_dim,
            decoder_dim=decoder_dim,
        )
        self.logit = nn.Sequential(
            nn.Conv2d(decoder_dim, 1, kernel_size=1),
            nn.Upsample(scale_factor = 4, mode='bilinear', align_corners=False),
        )

    def forward(self, batch):

        x = self.rgb(batch)

        B, C, H, W = x.shape
        encoder = self.encoder(x)

        last, decoder = self.decoder(encoder)
        logit = self.logit(last)

        output = {}
        probability_from_logit = torch.sigmoid(logit)
        output['probability'] = probability_from_logit

        return output

def criterion_aux_loss(logit, mask):
    mask = F.interpolate(mask,size=logit.shape[-2:], mode='nearest')
    loss = F.binary_cross_entropy_with_logits(logit,mask)
    return loss

In [None]:
### import glob
import copy
class CFG:
    # step1: hyper-parameter
    seed = 42 
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    num_worker = 0 # if debug
    data_path = "../input/hubmap-organ-segmentation" 
    ckpt_paths =  "../input/coat-model/inference-cfg/65model.pth"

    n_fold = 5
    img_size = [768, 768]
    train_bs = 1
    valid_bs = train_bs * 2

    backbone = 'swin+upernet'
    num_classes = 1

    epoch = 12
    lr = 1e-3
    wd = 1e-5
    lr_drop = 8

    thr = {
        "Hubmap":{
            "kidney" : 0.3,
            "prostate":0.3,
            "largeintestine":0.3,
            "spleen":0.3,
            "lung":0.04,  
        },
        
         "HPA":{
            "kidney" : 0.4,
            "prostate":0.4,
            "largeintestine":0.4,
            "spleen":0.4,
            "lung":0.1,
             
         },}
    tta = True

def rle_encode_less_memory(img):
    pixels = img.T.flatten()

    pixels[0] = 0
    pixels[-1] = 0
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
    runs[1::2] -= runs[::2]
    
    return ' '.join(str(x) for x in runs)

def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.T.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

# Ref: https://www.kaggle.com/code/paulorzp/rle-functions-run-lenght-encode-decode/script
def rle_decode(mask_rle, shape):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background
    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape).T  # Needed to align to RLE direction


def build_transforms(CFG):
    data_transforms = {
        "valid_test": A.Compose([
            A.Resize(*CFG.img_size, interpolation=cv2.INTER_NEAREST),
            ], p=1.0)
        }
    return data_transforms


class build_dataset(Dataset):
    def __init__(self, df, label=True, transforms=None):
        self.df = df
        self.label = label
        self.transforms = transforms
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        
        img_path = self.df.loc[index, 'image_path']
        img_height = self.df.loc[index, 'img_height']
        img_width = self.df.loc[index, 'img_width']
        organs = self.df.loc[index, 'organ']
        id_ = self.df.loc[index, 'id']
        img = read_tiff(img_path)
        sours = self.df.loc[index,'data_source']
        
        if self.label:
            rle_mask = self.df.loc[index, 'rle']
            mask = rle_decode(rle_mask, (img_height, img_width))
            # pdb.set_trace() 
            if self.transforms:
                data = self.transforms(image=img, mask=mask)
                img  = data['image']/255
                mask  = data['mask']
            
            mask = np.expand_dims(mask, axis=0)
            img = np.transpose(img, (2, 0, 1))
            # mask = np.transpose(mask, (2, 0, 1))
            
            return torch.tensor(img), torch.tensor(mask)
        
        else:    # resize for infer
            if self.transforms:
                data = self.transforms(image=img)
                img  = data['image']
                
            img = np.transpose(img, (2, 0, 1))/255   #(c, h, w)
            return torch.tensor(img), img_height, img_width,id_,organs,sours

def build_model(CFG, test_flag = True, nor=True, number):
    encoder = coat_parallel_small_plus()
    checkpoint = '../input/hubmapsmall/coat_small_7479cf9b.pth'
    checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage)
    state_dict = checkpoint['model']
    encoder.load_state_dict(state_dict,strict=False)
    

    model = Net(encoder=encoder).cuda()

        # insert in the checkpoint the folder destination with your folds
    if number == 0:
        checkpoint = '../input/coat-model/Fifth/00000280.model.pth'
    elif number == 1:
        checkpoint = '../input/coat-model/Fourth/00000770.model.pth'
    elif number == 2:
        checkpoint = '../input/coat-model/ky/55.model.pth'
        
    model.load_state_dict(torch.load(checkpoint)['state_dict'],strict=False)
    
    return model

@torch.no_grad()
def test_one_epoch(test_loader,CFG):
    pred_ids = []
    pred_rles = []
    
    pbar = tqdm(enumerate(test_loader), total=len(test_loader), desc='Test: ')
    for _, (images, heights, widths, ids,organs,sours) in pbar:
        images  = images.to(CFG.device, dtype=torch.float) # [b, c, w, h]
        size = images.size()
        masks = torch.zeros((size[0], CFG.num_classes, size[2], size[3]), device=CFG.device, dtype=torch.float32) # [b, c, w, h]

        model = build_model(CFG, test_flag=True, nor=False,0)
        model2 = build_model(CFG, test_flag=True, nor=False,1)
        model3 = build_model(CFG, test_flag=True, nor=False, 2)
        
        output = model(images) # [b, c, w, h]
        output2 = model2(images)
        output3 = model3(images)
        
        y_pred = (output["probability"] + output2["probability"] + output3["probability"])/3
     
        masks = y_pred
        
        organ = organs[0]
        sour = sours[0]
        
        thr =  CFG.thr[sour][organ]
        
        masks = (masks.permute((0, 2, 3, 1))>thr).to(torch.uint8).cpu().detach().numpy() # [n, h, w, c]
        
        for idx in range(masks.shape[0]):
            height = heights[idx].item()
            width = widths[idx].item()
            id_ = ids[idx].item()
            msk = cv2.resize(masks[idx].squeeze(), dsize=(width, height), interpolation=cv2.INTER_NEAREST)
            rle = rle_encode_less_memory(msk)
            pred_rles.append(rle)
            pred_ids.append(id_)
    
    return pred_ids, pred_rles,msk

if __name__ == '__main__':
    df = pd.read_csv(os.path.join(CFG.data_path, "test.csv"))
    df['image_path'] = df['id'].apply(lambda x: os.path.join(CFG.data_path, 'test_images', str(x) + '.tiff'))

    data_transforms = build_transforms(CFG)
    test_dataset = build_dataset(df, label=False, transforms=data_transforms['valid_test'])
    
    test_loader  = DataLoader(test_dataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False)

    pred_ids, pred_rles,msk = test_one_epoch(test_loader,CFG)

    msk.astype(np.float32)
    plt.imshow(msk)
    pred_df = pd.DataFrame({
        "id":pred_ids,
        "rle":pred_rles
    })
    pred_df.to_csv('submission.csv',index=False)