## Description
This kernel performs inference for [PANDA concat tile pooling starter](https://www.kaggle.com/iafoss/panda-concat-fast-ai-starter) kernel with use of multiple models and 8 fold TTA. Check it for more training details. The image preprocessing pipline is provided [here](https://www.kaggle.com/iafoss/panda-16x128x128-tiles).

In [None]:
import cv2
from tqdm import tqdm_notebook as tqdm
import fastai
from fastai.vision import *
import os
import torch
from torchvision.transforms import Compose
import torchvision.transforms as transforms
import torch.nn as nn
from mish_activation import *
import warnings
import glob
warnings.filterwarnings("ignore")
import skimage.io
import numpy as np
import pandas as pd
from efficientnet_pytorch import EfficientNet

In [None]:
DATA = '../input/prostate-cancer-grade-assessment/test_images'
TEST = '../input/prostate-cancer-grade-assessment/test.csv'
SAMPLE = '../input/prostate-cancer-grade-assessment/sample_submission.csv'
MODELS=['../input/efficientnetb0medianv6/bestmodel_0 (4).pth',]
MODELS_EF=['../input/efficientnetb0medianv14/bestmodel_5.pth']
wpath = '../input/efficientnetb0-pretrained/efficientnet-b0-355c32eb.pth'
sz = 256
bs = 2
N = 25
nworkers = 2

In [None]:
data_transforms = Compose([
     transforms.ToPILImage(),
     transforms.RandomRotation(degrees=(-90, 90)),
     transforms.RandomAffine(0),
     transforms.ToTensor(),
    ])

# Model

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1,stride=(2,2)),
            nn.BatchNorm2d(mid_channels),
            Mish(),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1,stride=(2,2)),
            nn.BatchNorm2d(out_channels),
            Mish()
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)


    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
    def forward(self, x):
        return self.conv(x)


class Head(nn.Module):
    def __init__(self,n_classes,in_features):
        super(Head,self).__init__()
        self.head=nn.Sequential(AdaptiveConcatPool2d(),Flatten(),nn.Linear(in_features,n_classes),)
    def forward(self,x):
        return self.head(x)

# Resnet34 encoder

In [None]:
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.m=models.resnet34()
        nc=list(self.m.children())[-1].in_features
        self.enc=nn.Sequential(*list(self.m.children())[:-1])
        self.head1=Head(in_features=2*nc,n_classes=n_classes)
    def print(self,x):
        print(x[0].size())
        fig,ax=py.subplots(4,3)
        j=0
        for ax1 in ax:
            for ax2 in ax1:
                print(x[j,0,:,:].detach().numpy().size())
                ax2.imshow(x[j,0,:,:].detach().numpy())
                j+=1
    
    def forward(self, *x):
        shape=x[0].shape
        n_tiles=len(x)
        x=torch.stack(x,1).view(-1,shape[1],shape[2],shape[3])
        x1 = self.enc(x)
        shape1=x1.size()
        x1=x1.view(-1,shape1[1],n_tiles*shape1[2],shape1[2])
        img = self.head1(x1)
        shape1=img.size()
        return img

# EfficientNet B0 encoder

In [None]:
class ENet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(ENet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.m=EfficientNet.from_pretrained(model_name = 'efficientnet-b0',num_classes=n_classes)
        nc=list(self.m.children())[-2].in_features
        self.head1=Head(in_features=2*nc,n_classes=n_classes)
    def print(self,x):
        print(x[0].size())
        fig,ax=py.subplots(4,3)
        j=0
        for ax1 in ax:
            for ax2 in ax1:
                print(x[j,0,:,:].detach().numpy().size())
                ax2.imshow(x[j,0,:,:].detach().numpy())
                j+=1
    
    def forward(self, *x):
        shape=x[0].shape
        n_tiles=len(x)
        x=torch.stack(x,1).view(-1,shape[1],shape[2],shape[3])
        x1 = self.m.extract_features(x)
        shape1=x1.size()
        x1=x1.view(-1,shape1[1],n_tiles*shape1[2],shape1[2])#shape1[2]
        img = self.head1(x1)
        shape1=img.size()
        return img

In [None]:

if not os.path.exists('/root/.cache/torch/checkpoints/'):
        os.makedirs('/root/.cache/torch/checkpoints/')
!cp /kaggle/input/efficientnetb0-pretrained/efficientnet-b0-355c32eb.pth /root/.cache/torch/checkpoints/



In [None]:
models_ef = []
for path in MODELS_EF:
    state_dict = torch.load(path,map_location=torch.device('cpu'))
    model = ENet(3,6)
    model.load_state_dict(state_dict)
    model.float()
    model.eval()
    model.cuda()
    models_ef.append(model)

del state_dict

# Data

In [None]:
def tile(img):
    shape = img.shape
    pad0,pad1 = (sz - shape[0]%sz)%sz, (sz - shape[1]%sz)%sz
    img = np.pad(img,[[pad0//2,pad0-pad0//2],[pad1//2,pad1-pad1//2],[0,0]],
                 constant_values=255)
    img = img.reshape(img.shape[0]//sz,sz,img.shape[1]//sz,sz,3)
    img = img.transpose(0,2,1,3,4).reshape(-1,sz,sz,3)
    if len(img) < N:
        img = np.pad(img,[[0,N-len(img)],[0,0],[0,0],[0,0]],constant_values=255)
    idxs = np.argsort(img.reshape(img.shape[0],-1).sum(-1))[:N]
    img = img[idxs]
    return img

mean = torch.tensor([1.0-0.90949707, 1.0-0.8188697, 1.0-0.87795304])
std = torch.tensor([0.36357649, 0.49984502, 0.40477625])

class PandaDataset(Dataset):
    def __init__(self, path, test, transform=None):
        self.path = path
        self.names = list(pd.read_csv(test).image_id)
        self.transform = data_transforms
    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        name = self.names[idx]
        img = skimage.io.MultiImage(os.path.join(DATA,name+'.tiff'))[1]
        tiles = torch.Tensor(1.0 - tile(img)/255.0)
        tiles=tiles.permute(0,3,1,2)
        trans_img = []
        for i in tiles:
            i = self.transform(i)
            i = (i - mean[...,None,None])/std[...,None,None]
            trans_img.append(i)
        tiles = torch.stack(trans_img)
        return tiles, name

# Prediction

In [None]:
sub_df = pd.read_csv(TEST)
if os.path.exists(DATA):
    ds = PandaDataset(DATA,TEST, transform = data_transforms,)
    dl = DataLoader(ds, batch_size=bs, shuffle=False,num_workers=nworkers)
    names,preds = [],[]

    with torch.no_grad():
        for x,y in tqdm(dl):
            x = x.cuda()
            x=x.permute(1,0,2,3,4)#size after permutation =(n_tiles,bs,C,H,W) 
            p1= [model(*x) for model in models_ef]
            p1=torch.stack(p1,0)
            p1=torch.sum(p1,dim=0)/len(p1)
            p1=p1.argmax(-1).cpu()
            names.append(y)
            preds.append(p1)
    
    names = np.concatenate(names)
    preds = torch.cat(preds).cpu().numpy()
    sub_df = pd.DataFrame({'image_id': names, 'isup_grade': preds})
    sub_df['isup_grade']=sub_df['isup_grade'].apply(int)
    sub_df.to_csv('submission.csv', index=False)
    sub_df.head()

In [None]:
sub_df['isup_grade'].value_counts()

In [None]:
sub_df.to_csv("submission.csv", index=False)
sub_df.head()