# EFN+ResNext+ViT Ensemble Inference
if this helps, please do Upvote this code and the original 👍🏼

< Reference Code > <br>
-[Pytorch Efficientnet Baseline [Inference] TTA](https://www.kaggle.com/khyeh0719/pytorch-efficientnet-baseline-inference-tta)

# Import Library

In [None]:
from glob import glob
from sklearn.model_selection import GroupKFold, StratifiedKFold
import cv2
from skimage import io
import torch
from torch import nn
import os
from datetime import datetime
import time
import random
import cv2
import torchvision
from torchvision import transforms
import pandas as pd
import numpy as np
from tqdm import tqdm

import matplotlib.pyplot as plt
from torch.utils.data import Dataset,DataLoader
from  torch.cuda.amp import autocast, GradScaler

import sklearn
import joblib
from sklearn.metrics import roc_auc_score, log_loss
from sklearn import metrics
import cv2
from sklearn.metrics import log_loss

!pip install ../input/timmpackagelatestwhl/timm-0.3.4-py3-none-any.whl
import timm
import warnings 
warnings.filterwarnings('ignore')

# CFG Setting

In [None]:
CFG = {
    'seed': 719,
    'model_arch': ['tf_efficientnet_b4_ns','resnext50_32x4d','vit_base_patch16_384'],
    'weight_path': sorted(os.listdir('../input/cassava-ensemble-model')),
    'img_size': 512,
    'train_bs': 64,
    'valid_bs': 64,
    'lr': 1e-4,
    'num_workers': 4,
    'device': 'cuda',
    'tta': 2,
    'weights': [1,1,1,1,1,1]
}

# Import Test Set

In [None]:
train = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv')
train.head()

# Helper Functions

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

def get_img(path):
    im_bgr = cv2.imread(path)
    im_rgb = im_bgr[:, :, ::-1]
    return im_rgb

# DataSet for Loader

In [None]:
class CassavaDataset(Dataset):
    def __init__(
        self, df, data_root, transforms=None, output_label=True):
        super().__init__()
        self.df = df.reset_index(drop=True).copy()
        self.transforms = transforms
        self.data_root = data_root
        self.output_label = output_label
    
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, index: int):
        if self.output_label:
            target = self.df.iloc[index]['label']
        path = "{}/{}".format(self.data_root, self.df.iloc[index]['image_id'])
        img  = get_img(path)
        if self.transforms:
            img = self.transforms(image=img)['image']
        if self.output_label == True:
            return img, target
        else:
            return img

# Data Augmentations

In [None]:
from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize)
from albumentations.pytorch import ToTensorV2

def get_inference_transforms():
    return Compose([
            RandomResizedCrop(CFG['img_size'], CFG['img_size']),
            Transpose(p=0.5),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
            RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.)

def get_inference_transforms_vit():
    return Compose([
            RandomResizedCrop(CFG['img_size'], CFG['img_size']),
            Resize(384, 384),
            Transpose(p=0.5),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
            RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.)

# Create Model

In [None]:
class EnsembleModel(nn.Module):
    def __init__(self, model_arch, n_class, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_arch, pretrained=pretrained)
        if model_arch == 'tf_efficientnet_b4_ns':
            n_features = self.model.classifier.in_features
            self.model.classifier = nn.Linear(n_features, n_class)
        if model_arch == 'resnext50_32x4d':
            n_features = self.model.fc.in_features
            self.model.fc = nn.Linear(n_features, n_class)
        if model_arch == 'vit_base_patch16_384':
            n_features = self.model.head.in_features
            self.model.head = nn.Linear(n_features, n_class)
        
    def forward(self, x):
        x = self.model(x)
        return x

# Inference Function

In [None]:
def inference_one_epoch(model, data_loader, device):
    model.eval()
    image_preds_all = []
    pbar = tqdm(enumerate(data_loader), total=len(data_loader))
    for step, (imgs) in pbar:
        imgs = imgs.to(device).float()
        image_preds = model(imgs)
        image_preds_all += [torch.softmax(image_preds, 1).detach().cpu().numpy()]
    image_preds_all = np.concatenate(image_preds_all, axis=0)
    return image_preds_all

# Main - Inference

In [None]:
if __name__ == '__main__':
    seed_everything(CFG['seed'])
    for ix, model_arch in enumerate(CFG['model_arch']):
        TEST_DIR = '../input/cassava-leaf-disease-classification/test_images/'
        test = pd.DataFrame(); test['image_id'] = list(os.listdir(TEST_DIR))
        if model_arch=='vit_base_patch16_384':
            testset= CassavaDataset(test, TEST_DIR, transforms=get_inference_transforms_vit(), output_label=False)
        else: 
            testset= CassavaDataset(test, TEST_DIR, transforms=get_inference_transforms(), output_label=False)
        tst_loader = DataLoader(testset, batch_size=CFG['valid_bs'],num_workers=CFG['num_workers'],shuffle=False,pin_memory=False,)
        
        device = torch.device(CFG['device'])
        model = EnsembleModel(model_arch, train.label.nunique()).to(device)
        tst_preds = []
        for i,weight in enumerate(CFG['weight_path'][ix*2:ix*2+2]):    
            model.load_state_dict(torch.load(os.path.join('../input/cassava-ensemble-model',weight))['model'])
            with torch.no_grad():
                for _ in range(CFG['tta']):
                    tst_preds += [CFG['weights'][i]/sum(CFG['weights'])/CFG['tta']*inference_one_epoch(model, tst_loader, device)]
        avg_tst_preds = np.mean(tst_preds, axis=0)

        if not (os.path.isdir('./total_preds')): os.mkdir('./total_preds')
        np.save('./total_preds/total_preds.npy', tst_preds)
        if not (os.path.isdir('./mean_preds')): os.mkdir('./mean_preds')
        np.save('./mean_preds/mean_preds.npy', avg_tst_preds)

        del model
        torch.cuda.empty_cache()

# Make Submission File

In [None]:
test['label'] = np.argmax(avg_tst_preds, axis=1)
test.to_csv('submission.csv', index=False)
test.head()