# Description
This kernel performs inference for PANDAS dataset

In [1]:
!cp -r /kaggle/input/customtimm/timm-0.5.4 /kaggle
!pip install /kaggle/timm-0.5.4
!pip install --no-index /kaggle/input/nystrom-attn/einops-0.7.0-py3-none-any.whl
!pip install --no-index /kaggle/input/nystrom-attn/nystrom_attention-0.0.11-py3-none-any.whl

Processing /kaggle/timm-0.5.4
  Preparing metadata (setup.py) ... [?25l- done
Building wheels for collected packages: timm
  Building wheel for timm (setup.py) ... [?25l- done
[?25h  Created wheel for timm: filename=timm-0.5.4-py3-none-any.whl size=431521 sha256=ca99ab16c3d014897f580d21c49a0d77f57e00fdc22eb8c1d0d0b7ec45c4e868
  Stored in directory: /root/.cache/pip/wheels/cd/16/62/903e9b342497a308c0b62f74fe47e09f2315c459cd4e39e86e
Successfully built timm
Installing collected packages: timm
  Attempting uninstall: timm
    Found existing installation: timm 0.9.12
    Uninstalling timm-0.9.12:
      Successfully uninstalled timm-0.9.12
Successfully installed timm-0.5.4
Processing /kaggle/input/nystrom-attn/einops-0.7.0-py3-none-any.whl
Installing collected packages: einops
Successfully installed einops-0.7.0
Processing /kaggle/input/nystrom-attn/nystrom_attention-0.0.11-py3-none-any.whl
Installing collected packages: nystrom-attention
Successfully installed nystr

# Install Packages

In [2]:
import sys
import cv2
from tqdm import tqdm_notebook as tqdm
# import fastai
# from fastai.vision import *
import os
from mish_activation import *
import warnings
warnings.filterwarnings("ignore")
import skimage.io
from skimage import color
from scipy.ndimage.morphology import binary_fill_holes
from skimage.transform import rescale,resize
# sys.path.insert(0, '../input/semisupervised-imagenet-models/semi-supervised-ImageNet1K-models-master/')
# from hubconf import *
import numpy as np
import pandas as pd
import torch, torchvision
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset
from timm.models.layers.helpers import to_2tuple
import timm
import torch.nn as nn
from nystrom_attention import NystromAttention
from openslide import OpenSlide

# Model

### CtransPath

In [3]:
class ConvStem(nn.Module):

    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
        super().__init__()

        assert patch_size == 4
        assert embed_dim % 8 == 0

        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.flatten = flatten


        stem = []
        input_dim, output_dim = 3, embed_dim // 8
        for l in range(2):
            stem.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False))
            stem.append(nn.BatchNorm2d(output_dim))
            stem.append(nn.ReLU(inplace=True))
            input_dim = output_dim
            output_dim *= 2
        stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1))
        self.proj = nn.Sequential(*stem)

        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x)
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC
        x = self.norm(x)
        return x

def ctranspath():
    model = timm.create_model('swin_tiny_patch4_window7_224', embed_layer=ConvStem, pretrained=False)
    return model


featextractor = ctranspath()
featextractor.head = nn.Identity()
td = torch.load(r'/kaggle/input/model-weights/ctranspath.pth')
featextractor.load_state_dict(td['model'], strict=True)


<All keys matched successfully>

### MIL model

In [4]:
class TransLayer(nn.Module):

    def __init__(self, norm_layer=nn.LayerNorm, dim=128):
        super().__init__()
        self.norm = norm_layer(dim)
        self.attn = NystromAttention(
            dim = dim,
            dim_head = dim//8,
            heads = 8,
            num_landmarks = dim//2,    # number of landmarks
            pinv_iterations = 6,    # number of moore-penrose iterations for approximating pinverse. 6 was recommended by the paper
            residual = True,         # whether to do an extra residual with the value or not. supposedly faster convergence if turned on
            dropout=0.1
        )

    def forward(self, x, return_attn=False):
        if return_attn:
            attn_out, attn_vals = self.attn(self.norm(x),return_attn=return_attn)
            x = x + attn_out
            return x, attn_vals
        else:
            attn_out = self.attn(self.norm(x),return_attn=return_attn)
            x = x + attn_out
            return x


class PEG(nn.Module):
    def __init__(self, dim=256, k=7):
        super(PEG, self).__init__()
        self.proj = nn.Conv2d(dim, dim, k, 1, k//2, groups=dim)
        self.proj1 = nn.Conv2d(dim, dim, 5, 1, 5//2, groups=dim)
        self.proj2 = nn.Conv2d(dim, dim, 3, 1, 3//2, groups=dim)


    def forward(self, x, H, W):
        B, N, C = x.shape
        cls_token, feat_token = x[:, 0], x[:, 1:]
        cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W)
        x = self.proj(cnn_feat)+cnn_feat+self.proj1(cnn_feat)+self.proj2(cnn_feat)
        x = x.flatten(2).transpose(1, 2)
        x = torch.cat((cls_token.unsqueeze(1), x), dim=1)
        return x

'''
Code copied from https://github.com/szc19990412/HVTSurv/blob/main/models/TransMIL.py and slightly modified for this codebase
'''
class TransMIL_peg(nn.Module):
    def __init__(self, n_classes,dim=128):
        super(TransMIL_peg, self).__init__()
        self.pos_layer = PEG(dim)
        self._fc1 = nn.Sequential(nn.Linear(768, dim), nn.ReLU())
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.n_classes = n_classes
        self.layer1 = TransLayer(dim=dim)
        self.layer2 = TransLayer(dim=dim)
        self.norm = nn.LayerNorm(dim)
        # self._fc2 = nn.Linear(dim, self.n_classes)
        # self._fc2 = nn.Sequential(*[nn.Linear(dim,dim), nn.LayerNorm(dim), nn.ReLU(), nn.Linear(dim, n_classes)])
        # self._fc2 = nn.Sequential(*[nn.Linear(dim,dim), nn.ReLU(), nn.Linear(dim, n_classes)])
        self._fc2 = nn.Sequential(*[nn.Linear(dim,dim), nn.ReLU(), nn.Dropout(0.25), nn.Linear(dim, n_classes)])


    def forward(self, x, return_attn=False):

        h = x.float().unsqueeze(0) #[B, n, 768]
        
        #---->Dimensionality reduction first
        h = self._fc1(h) #[B, n, 128]
        
        #---->padding
        H = h.shape[1]
        _H, _W = int(np.ceil(np.sqrt(H))), int(np.ceil(np.sqrt(H)))
        add_length = _H * _W - H
        h = torch.cat([h, h[:,:add_length,:]],dim = 1) #[B, N, 128]

        #---->Add position encoding, after a transformer
        B = h.shape[0]
        cls_tokens = self.cls_token.expand(B, -1, -1).cuda()
        h = torch.cat((cls_tokens, h), dim=1)

        #---->Translayer x1
        h = self.layer1(h) #[B, N, 128]

        #---->PPEG
        h = self.pos_layer(h, _H, _W) #[B, N, 128]
        
        #---->Translayer x2
        if return_attn:
            h, attn_vals = self.layer2(h,return_attn) #[B, N, 128]
        else:
            h = self.layer2(h)

        h = self.norm(h)[:,0]

        #---->predict output
        logits = self._fc2(h)
        if return_attn:
            return logits, attn_vals
        else:
            return logits

mil_model = TransMIL_peg(n_classes=5, dim=512)

In [5]:
mode = 'test'
# mode = 'train'
DATA = f'../input/prostate-cancer-grade-assessment/{mode}_images'
TEST = f'../input/prostate-cancer-grade-assessment/{mode}.csv'
# TEST = "/kaggle/input/testing/testsplit.csv"
SAMPLE = '../input/prostate-cancer-grade-assessment/sample_submission.csv'

MODEL_WEIGHT = "fastrecov_trained_model.pt"


#at 0.5mpp, at 1mpp, sz*2
sz = 224
nworkers = 4

In [6]:
mil_model.load_state_dict(torch.load(MODEL_WEIGHT))
featextractor.cuda()
mil_model.cuda()
featextractor.eval()
mil_model.eval()

TransMIL_peg(
  (pos_layer): PEG(
    (proj): Conv2d(512, 512, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=512)
    (proj1): Conv2d(512, 512, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=512)
    (proj2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512)
  )
  (_fc1): Sequential(
    (0): Linear(in_features=768, out_features=512, bias=True)
    (1): ReLU()
  )
  (layer1): TransLayer(
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (attn): NystromAttention(
      (to_qkv): Linear(in_features=512, out_features=1536, bias=False)
      (to_out): Sequential(
        (0): Linear(in_features=512, out_features=512, bias=True)
        (1): Dropout(p=0.1, inplace=False)
      )
      (res_conv): Conv2d(8, 8, kernel_size=(33, 1), stride=(1, 1), padding=(16, 0), groups=8, bias=False)
    )
  )
  (layer2): TransLayer(
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (attn): NystromAttention(
      (t

# Data

In [7]:
def tile(path):
    scan = OpenSlide(path)
    img = skimage.io.MultiImage(path)[-1]
    level_dimensions = scan.level_dimensions
    image_array = np.asarray(scan.read_region((0, 0), len(level_dimensions)-1, level_dimensions[-1]).convert('RGB'))
    shape = img.shape
    
    #get mask from image
    threshold = 0.1
    lab = color.rgb2lab(image_array)
    mean = np.mean(lab[..., 1])
    lab = lab[..., 1] > (1 + threshold ) * mean
    mask = lab.astype(np.uint8)
    fill_mask_kernel_size=9
    mask = binary_fill_holes(mask)
    kernel = np.ones((fill_mask_kernel_size, fill_mask_kernel_size), np.uint8)
    mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
#     mask = resize(mask,img.shape[:-1])
    mask = (mask>0)*1
    downsample_factor = int(level_dimensions[0][0]/level_dimensions[-1][0])
    
    sz_big = sz*2 
    lim0,lim1 = shape[0]-shape[0]%sz_big,shape[1]-shape[1]%sz_big 
    sz_mask = int(sz_big/downsample_factor)
    img = img[:lim0,:lim1,:]
    mask = mask[:int(lim0//downsample_factor),:int(lim1//downsample_factor)]
    img = img.reshape(img.shape[0]//sz_big,sz_big,img.shape[1]//sz_big,sz_big,3)
    mask = mask.reshape(mask.shape[0]//sz_mask,sz_mask,mask.shape[1]//sz_mask,sz_mask,1)
    img = img.transpose(0,2,1,3,4).reshape(-1,sz_big,sz_big,3)
    mask = mask.transpose(0,2,1,3,4).reshape(-1,sz_mask,sz_mask,1)
    idxs = np.where(mask.reshape(mask.shape[0],-1).sum(-1)/float(sz_mask*sz_mask)>=0.8)[0]
    assert mask.shape[0]==img.shape[0]
    img = img[idxs]
    #For 1MPP extraction
    temp = []
    for i in range(len(img)):
        temp.append(rescale(img[i],0.5,channel_axis=-1))
    temp = np.stack(temp)
    return temp

class PandaDataset(Dataset):
    def __init__(self, path, test):
        self.path = path
        self.names = list(pd.read_csv(test).image_id)
        self.mean = torch.tensor([0.485, 0.456, 0.406])
        self.std = torch.tensor([0.229, 0.224, 0.225])

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        name = self.names[idx]
#         img = skimage.io.MultiImage(os.path.join(DATA,name+'.tiff'))[-1]
        path = os.path.join(self.path,name+'.tiff')
        #Can make it faster at this stage will decide accordingly
#         tiles = torch.Tensor(tile(path)/255.0)
        tiles = torch.Tensor(tile(path))
        tiles = (tiles - self.mean)/self.std
        return tiles.permute(0,3,1,2), name


In [8]:
def inference(tiles):
    dataloader = torch.utils.data.DataLoader(tiles, batch_size=256)
    all_feats = []
    with torch.no_grad():
        for data in dataloader:
            # print(data)
            img = data.cuda()
            feats = featextractor(img)
            all_feats.extend(feats)
#         all_feats = torch.stack(all_feats,dim=0)
        logits = mil_model(feats)
        pred_sig = torch.sigmoid(logits)    
#         predicted = pred_sig.sum(dim=1).cpu()
        predicted = pred_sig.sum(dim=1).round().cpu()
    return predicted

In [9]:
sub_df = pd.read_csv(SAMPLE)
if os.path.exists(DATA):
    ds = PandaDataset(DATA,TEST)
    names,preds = [],[]

    with torch.no_grad():
        for idx in tqdm(range(len(ds))):
            name = ds[idx][1]
            tiles = ds[idx][0]
            prediction = inference(tiles)
            names.append(name)
            preds.append(prediction)
    
    preds = np.asarray(torch.cat(preds).numpy(),int)
    sub_df = pd.DataFrame({'image_id': names, 'isup_grade': preds})
    sub_df.to_csv('submission.csv', index=False)