# **Cassava EfficientNet&TTA prediction**
2021/01/20 written by T.Yonezu

In [24]:
%load_ext autoreload
%autoreload 2

import torch
from torch.utils.data import DataLoader, Dataset

import numpy as np 
import matplotlib.pyplot as plt
import pandas as pd
import glob 
import os
from tqdm import tqdm
import copy

from cassava_dataset import *
import gc

import warnings
warnings.simplefilter('ignore')

#!pip install "../input/efficientnet-pytorch-07/efficientnet_pytorch-0.7.0"
#!pip install "../input/timm-pytorch-image-models/pytorch-image-models-master"

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [25]:
#input_dir = "../input/cassava-leaf-disease-classification"
input_dir = "../../input/cassava-leaf-disease-classification"

In [26]:
x = np.atleast_2d(glob.glob(os.path.join(input_dir, 'test_images',"*")))
x = np.atleast_2d(glob.glob(os.path.join(input_dir, 'train_images',"*"))[:100])

x = x.T
x = pd.DataFrame(columns=["image_path"], data=x)
tmp = x["image_path"].str.split(os.path.sep,expand=True)
x["image_id"] = tmp[len(tmp.columns)-1]
x["label"] = np.nan

test_dict = dict( zip(x["image_path"],x["label"]) )

In [27]:
size = (512,512)
mean = [0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
BATCH_SIZE = 10

transform = ImageTransform(size,mean,std)
test_data = CassavaDataset(test_dict,transform=transform,phase="test")
test_dataloader = DataLoader(test_data,batch_size=BATCH_SIZE)

del test_data
gc.collect()

9175

In [28]:
import timm
import torch.nn as nn
class CustomResNext(nn.Module):
    def __init__(self, model_name='resnext50_32x4d', pretrained=True):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        n_features = self.model.fc.in_features
        self.model.fc = nn.Linear(n_features, 5, bias=True)

    def forward(self, x):
        x = self.model(x)
        return x
    
class CustomEfficientNet(nn.Module):
    def __init__(self, model_name='tf_efficientnet_b4_ns', pretrained=True):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        n_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(n_features, 5)
        '''
        self.model.classifier = nn.Sequential(
            nn.Dropout(0.3),
            #nn.Linear(n_features, hidden_size,bias=True), nn.ELU(),
            nn.Linear(n_features, n_class, bias=True)
        )
        '''
    def forward(self, x):
        x = self.model(x)
        return x

    
def get_network(model_name, use_pretrained=True):
    if model_name == 'vgg19':
        net = models.vgg19(pretrained=use_pretrained)
        net.classifier[6] = nn.Linear(in_features=4096, out_features=5, bias=True)
        update_param_names = ['classifier.6.weight', 'classifier.6.bias']
    elif model_name == 'resnet50':
        net = models.resnet50(pretrained=use_pretrained)
        net.fc = nn.Linear(in_features=2048, out_features=5, bias=True)
        update_param_names = ['fc.weight', 'fc.bias']
    elif model_name == 'resnext50_32x4d':
        net = CustomResNext(model_name=model_name, pretrained=use_pretrained)
        update_param_names = ['model.fc.weight', 'model.fc.bias']
    elif model_name == 'seresnext50_32x4d':
        net = CustomResNext(model_name=model_name, pretrained=use_pretrained)
        update_param_names = ['model.fc.weight', 'model.fc.bias']
    elif model_name == 'tf_efficientnet_b5_ns':
        net = CustomEfficientNet(model_name=model_name, pretrained=use_pretrained)
        update_param_names = ['model.classifier.weight', 'model.classifier.bias']
    elif model_name == 'tf_efficientnet_b4_ns':
        net = CustomEfficientNet(model_name=model_name, pretrained=use_pretrained)
        update_param_names = ['model.classifier.weight', 'model.classifier.bias']
    elif model_name == 'tf_efficientnet_b3_ns':
        net = CustomEfficientNet(model_name=model_name, pretrained=use_pretrained)
        update_param_names = ['model.classifier.weight', 'model.classifier.bias']
        
    return net, update_param_names

In [35]:
config = {
          "tta":{"num":8},
          "aug":{"size":512},
          "test_batch_size":1,
         }
def inference_test_data_tta(net, dataloader, device):
    net.eval()
    preds = []
#     bar = tqdm(dataloader)
    
    with torch.no_grad():
        for batch in dataloader:
            inputs = batch[0]
            inputs = inputs.to(device)
            if config['tta']['num'] == 8:
                inputs = torch.stack([inputs, inputs.flip(-1), inputs.flip(-2), inputs.flip(-1,-2),
                                      inputs.transpose(-1,-2), inputs.transpose(-1,-2).flip(-1), 
                                      inputs.transpose(-1,-2).flip(-2), inputs.transpose(-1,-2).flip(-1,-2)], 0)
            elif config['tta']['num'] == 5:
                inputs = torch.stack([inputs, inputs.flip(-1), inputs.flip(-2), inputs.flip(-1,-2),
                                      inputs.transpose(-1,-2)], 0)
            inputs = inputs.view(-1, 3, config['aug']['size'], config['aug']['size'])
            logits = net(inputs)
            logits = logits.view(config['test_batch_size'], config['tta']['num'], -1).mean(1)
            preds += [torch.softmax(logits, 1).detach().cpu()]
        preds = torch.cat(preds).cpu().numpy()
    return preds

In [36]:
efnetb1_list = np.sort(glob.glob(os.path.join('..\\..\\input\\stacking_ensemble_data\\tf_efficientnet_b1_ns',"*.pth")))
efnetb4_list = np.sort(glob.glob(os.path.join('..\\..\\input\\stacking_ensemble_data\\tf_efficientnet_b4_ns',"*.pth")))
efnetb5_list = np.sort(glob.glob(os.path.join('..\\..\\input\\stacking_ensemble_data\\tf_efficientnet_b5_ns',"*.pth")))

In [40]:
test_preds = []
device = "cuda"

### pred resnet50 ###
resnet_list = np.sort(glob.glob(os.path.join('..\\..\\input\\stacking_ensemble_data\\resnet50',"*.pth")))
for model_path in tqdm(resnet_list):
    
    model, _ = get_network("resnet50", False)
    model.load_state_dict(torch.load(model_path))
    model.to(device)
    
    print(inference_test_data_tta(model, test_dataloader, device))

[[9.76291358e-06 2.34830168e-05 4.17643241e-06 9.64368382e-07
  4.29539141e-05 1.89902827e-09 1.06963319e-06 1.98768657e-06
  3.60770673e-02 8.80454265e-10 1.65304732e-06 1.09254552e-06
  3.12038469e-06 9.38181347e-06 3.17444123e-04 3.56417416e-07
  1.08179648e-03 6.81946403e-07 7.70533370e-06 7.95258529e-07
  2.08597273e-09 2.85753607e-08 5.45941191e-07 6.74476206e-01
  9.88619853e-09 2.89002627e-10 2.54829260e-08 8.66556184e-06
  2.11110756e-01 1.43468981e-09 7.03812830e-08 8.30768130e-08
  8.62880726e-04 2.07936000e-06 1.95060449e-04 3.61430284e-05
  1.69233790e-05 7.44729448e-07 1.81374173e-06 7.60377443e-05
  2.66841147e-04 2.04424796e-05 1.61873243e-06 5.93363438e-07
  2.31662416e-05 4.57238247e-10 4.86213594e-06 5.99646626e-07
  7.53083304e-02 1.77872161e-10]
 [2.61289959e-13 2.36826447e-10 2.76976941e-10 1.21484417e-03
  4.73994054e-13 3.47443038e-12 7.03216319e-10 1.18085297e-09
  5.96629652e-05 3.84913889e-12 5.52367041e-11 1.62653960e-04
  2.45785198e-10 1.00365991e-08 1.266

In [39]:
np.shape(np.mean(test_preds,axis=0))

(10, 50)