# Description
This kernel performs inference for [PANDA concat tile pooling starter](https://www.kaggle.com/iafoss/panda-concat-fast-ai-starter) kernel with use of multiple models and 8 fold TTA. Check it for more training details. The image preprocessing pipline is provided [here](https://www.kaggle.com/iafoss/panda-16x128x128-tiles).

In [1]:
#!pip install hubconf

In [2]:
import cv2
from tqdm import tqdm_notebook as tqdm
import fastai
from fastai.vision import *
import os
from mish_activation import *
import warnings
warnings.filterwarnings("ignore")
import skimage.io
import numpy as np
import pandas as pd
#sys.path.insert(0, '../input/semisupervised-imagenet-models/semi-supervised-ImageNet1K-models-master/')
# from hubconf import *

In [3]:
DATA = '../input/prostate-cancer-grade-assessment/test_images'
TEST = '../input/prostate-cancer-grade-assessment/test.csv'
SAMPLE = '../input/prostate-cancer-grade-assessment/sample_submission.csv'
MODELS = [f'../input/panda-starter-models/RNXT50_{i}.pth' for i in range(4)]

sz = 128
bs = 2
N = 9
nworkers = 2

In [None]:
############DATA

In [10]:
import os
import cv2
import skimage.io
from tqdm.notebook import tqdm
import zipfile
import numpy as np

In [11]:
TRAIN = '../input/prostate-cancer-grade-assessment/train_images/'
MASKS = '../input/prostate-cancer-grade-assessment/train_label_masks/'
OUT_TRAIN = 'train.zip'
OUT_MASKS = 'masks.zip'
sz = 128
N = 16

In [12]:
def tile(img, mask):
    result = []
    shape = img.shape
    pad0,pad1 = (sz - shape[0]%sz)%sz, (sz - shape[1]%sz)%sz
    img = np.pad(img,[[pad0//2,pad0-pad0//2],[pad1//2,pad1-pad1//2],[0,0]],
                constant_values=255)
    mask = np.pad(mask,[[pad0//2,pad0-pad0//2],[pad1//2,pad1-pad1//2],[0,0]],
                constant_values=0)
    img = img.reshape(img.shape[0]//sz,sz,img.shape[1]//sz,sz,3)
    img = img.transpose(0,2,1,3,4).reshape(-1,sz,sz,3)
    mask = mask.reshape(mask.shape[0]//sz,sz,mask.shape[1]//sz,sz,3)
    mask = mask.transpose(0,2,1,3,4).reshape(-1,sz,sz,3)
    if len(img) < N:
        mask = np.pad(mask,[[0,N-len(img)],[0,0],[0,0],[0,0]],constant_values=0)
        img = np.pad(img,[[0,N-len(img)],[0,0],[0,0],[0,0]],constant_values=255)
    idxs = np.argsort(img.reshape(img.shape[0],-1).sum(-1))[:N]
    img = img[idxs]
    mask = mask[idxs]
    for i in range(len(img)):
        result.append({'img':img[i], 'mask':mask[i], 'idx':i})
    return result

In [13]:
x_tot,x2_tot = [],[]
names = [name[:-10] for name in os.listdir(MASKS)]
with zipfile.ZipFile(OUT_TRAIN, 'w') as img_out,\
 zipfile.ZipFile(OUT_MASKS, 'w') as mask_out:
    for name in tqdm(names):
        img = skimage.io.MultiImage(os.path.join(TRAIN,name+'.tiff'))[-1]
        mask = skimage.io.MultiImage(os.path.join(MASKS,name+'_mask.tiff'))[-1]
        tiles = tile(img,mask)
        for t in tiles:
            img,mask,idx = t['img'],t['mask'],t['idx']
            x_tot.append((img/255.0).reshape(-1,3).mean(0))
            x2_tot.append(((img/255.0)**2).reshape(-1,3).mean(0)) 
            #if read with PIL RGB turns into BGR
            img = cv2.imencode('.png',cv2.cvtColor(img, cv2.COLOR_RGB2BGR))[1]
            img_out.writestr(f'{name}_{idx}.png', img)
            mask = cv2.imencode('.png',mask[:,:,0])[1]
            mask_out.writestr(f'{name}_{idx}.png', mask)

  0%|          | 0/10516 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
img_avr =  np.array(x_tot).mean(0)
img_std =  np.sqrt(np.array(x2_tot).mean(0) - img_avr**2)
print('mean:',img_avr, ', std:', np.sqrt(img_std))

In [None]:
############DATA

# Model

In [4]:
# !pip install timm

In [5]:
import timm
from torch.nn import AdaptiveAvgPool2d, Flatten, Linear
from mish_activation import Mish 

class AdaptiveConcatPool2d(nn.Module):
    def __init__(self, output_size=(1, 1)):
        super(AdaptiveConcatPool2d, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(output_size)
        self.max_pool = nn.AdaptiveMaxPool2d(output_size)

    def forward(self, x):
        return torch.cat([self.avg_pool(x), self.max_pool(x)], 1)

class Model(nn.Module):
    def __init__(self, arch='resnext50_32x4d', n=6, pre=True):
        super().__init__()
        # Load ResNet model using timm
        m = timm.create_model(arch, pretrained=pre)
        self.enc = nn.Sequential(*list(m.children())[:-2])
        nc = list(m.children())[-1].in_features
        self.head = nn.Sequential(AdaptiveConcatPool2d(), Flatten(), nn.Linear(2 * nc, 512),
                                  Mish(), nn.BatchNorm1d(512), nn.Dropout(0.5), nn.Linear(512, n))
        
    def forward(self, *x):
        shape = x[0].shape
        n = len(x)
        x = torch.stack(x, 1).view(-1, shape[1], shape[2], shape[3])
        # x: bs*N x 3 x 128 x 128
        x = self.enc(x)
        # x: bs*N x C x 4 x 4
        shape = x.shape
        # concatenate the output for tiles into a single map
        x = x.view(-1, n, shape[1], shape[2], shape[3]).permute(0, 2, 1, 3, 4).contiguous() \
            .view(-1, shape[1], shape[2] * n, shape[3])
        # x: bs x C x N*4 x 4
        x = self.head(x)
        # x: bs x n
        return x

In [6]:
print("modelbefore")
model = torch.load('/kaggle/input/modelagain/resnext50_32x4d_final (1).pth')
models = [model]
print("modelafter")

modelbefore
modelafter


# Data

In [7]:
def tile(img):
    shape = img.shape
    pad0,pad1 = (sz - shape[0]%sz)%sz, (sz - shape[1]%sz)%sz
    img = np.pad(img,[[pad0//2,pad0-pad0//2],[pad1//2,pad1-pad1//2],[0,0]],
                 constant_values=255)
    img = img.reshape(img.shape[0]//sz,sz,img.shape[1]//sz,sz,3)
    img = img.transpose(0,2,1,3,4).reshape(-1,sz,sz,3)
    if len(img) < N:
        img = np.pad(img,[[0,N-len(img)],[0,0],[0,0],[0,0]],constant_values=255)
    idxs = np.argsort(img.reshape(img.shape[0],-1).sum(-1))[:N]
    img = img[idxs]
    return img

mean = torch.tensor([1.0-0.90949707, 1.0-0.8188697, 1.0-0.87795304])
std = torch.tensor([0.36357649, 0.49984502, 0.40477625])

class PandaDataset():
    def __init__(self, path, test):
        self.path = path
        self.names = list(pd.read_csv(test).image_id)

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        name = self.names[idx]
        img = skimage.io.MultiImage(os.path.join(DATA,name+'.tiff'))[-1]
        tiles = torch.Tensor(1.0 - tile(img)/255.0)
        tiles = (tiles - mean)/std
        return tiles.permute(0,3,1,2), name

# Prediction

In [9]:
sub_df = pd.read_csv(SAMPLE)

sub_df

Unnamed: 0,image_id,isup_grade
0,005700be7e06878e6605e7a5a39de1b2,0
1,005c6e8877caf724c600fdce5d417d40,0
2,0104f76634ff89bfff1ef0804a95c380,0


In [None]:
if os.path.exists(DATA):
    ds = PandaDataset(DATA,TEST)
    dl = DataLoader(ds, batch_size=bs, num_workers=nworkers, shuffle=False)
    names,preds = [],[]

    with torch.no_grad():
        for x,y in tqdm(dl):
            print(x)
            x = x.cuda()
            #dihedral TTA
            x = torch.stack([x,x.flip(-1),x.flip(-2),x.flip(-1,-2),
              x.transpose(-1,-2),x.transpose(-1,-2).flip(-1),
              x.transpose(-1,-2).flip(-2),x.transpose(-1,-2).flip(-1,-2)],1)
            x = x.view(-1,N,3,sz,sz)
            p = [model(x) for model in models]
            p = torch.stack(p,1)
            p = p.view(bs,8*len(models),-1).mean(1).argmax(-1).cpu()
            names.append(y)
            preds.append(p)
    
    names = np.concatenate(names)
    preds = torch.cat(preds).numpy()
    sub_df = pd.DataFrame({'image_id': names, 'isup_grade': preds})
    sub_df.to_csv('submission.csv', index=False)
    sub_df.head()

In [9]:
sub_df.to_csv("submission.csv", index=False)
sub_df.head()

Unnamed: 0,image_id,isup_grade
0,005700be7e06878e6605e7a5a39de1b2,0
1,005c6e8877caf724c600fdce5d417d40,0
2,0104f76634ff89bfff1ef0804a95c380,0
