In [None]:
import numpy as np
import pandas as pd 

import torch.utils.data as data
import torchvision.transforms as transforms 
import torch

import torch.optim as optim 
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import os 

import torchvision.models as models
import torch.nn as nn
from PIL import Image, ImageOps
import sys
sys.path.append('../input/pytorchimagemodels/pytorch-image-models-master')
import timm
import torch.nn as nn
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2

import cv2
sys.path.append('../input/toolbelt/pytorch-toolbelt-develop')
from pytorch_toolbelt.inference import tta

In [None]:
class TimmBackbone(nn.Module):
    def __init__(self, model_name= 'mobilenetv3_large_100', pretrained=False, target_size = 5):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        # n_features = self.model.fc.in_features
        self.model.reset_classifier(target_size)

    def forward(self, x):
        x = self.model(x)
        return x

    
class EnsembleWeight(nn.Module):
    def __init__(self, model_size = 5, target_size = 5):
        super().__init__()
        self.w = nn.Parameter(torch.tensor([[0.6528, 0.7624, 0.8109, 0.7850, 0.8738],
            [0.0281, 0.0388, 0.0609, 0.0287, 0.0526],
            [0.0225, 0.0260, 0.0194, 0.0242, 0.0147],
            [0.2687, 0.1444, 0.0838, 0.1314, 0.0436],
            [0.0279, 0.0285, 0.0250, 0.0307, 0.0153]]))

    def forward(self, x): 
        # b, model, cls
        x = torch.sum(self.w * x, dim=1)
        x /= x.sum(dim=-1, keepdim=True).detach()
        return x

In [None]:
class TestDataset(Dataset):
    def __init__(self, df):
        self.df = df
        self.file_names = df['image_id'].values
        self.transform = A.Compose([
            A.Resize(IMAGE_SIZE,IMAGE_SIZE),
            A.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
            ])
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        file_path = f'{TEST_PATH}/{file_name}'

        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.transform(image=image)["image"]

        return image.float()

In [None]:
def inference(model_list, test_loader, device):
#     tk0 = tqdm(enumerate(test_loader), total=len(test_loader))
    probs = []
    for i, (images) in enumerate(test_loader):
        images = images.to(device)
        avg_preds = []
        
        print(model_list)
        # load models
        for backbone, filenames in model_list.items():
            backbone, model_base_path = backbone.split(';')
            model = TimmBackbone(model_name=backbone).to(device=device)
            model.eval()
            

            for filename in filenames:
                model_path = os.path.join(model_base_path, filename)
                model.load_state_dict(torch.load(model_path))

                with torch.no_grad():
                    if TTA:
                        y_preds = tta.TTAWrapper(model, tta.d4_image2label)(images)
                    else:
                        y_preds = model(images)
                avg_preds.append(y_preds.softmax(dim=-1).to('cpu'))

        # simple mean weights
        if METHOD == "mean":
            avg_preds = torch.mean(torch.stack(avg_preds), dim=0)
        else:
            avg_preds = WEIGHT_MODEL(torch.stack(avg_preds, dim=1).to(device)).to('cpu')
        probs.append(avg_preds)

    probs = torch.cat(probs, dim=0)
    return probs

In [None]:
TEST_PATH = '../input/cassava-leaf-disease-classification/test_images'
OUTPUT_DIR = './'
METHOD = "mean"
ENSEMBLE_MODELS ={
#     "mobilenetv3_large_100;../input/mobilenetv3-large-100/mobilenetv3_large_100"
#     : ['fold0_best.pth', 'fold1_best.pth', 'fold2_best.pth', 'fold3_best.pth', 'fold4_best.pth'],
    "tf_efficientnet_b3;../input/eb3-test2-label-smooth-strong-fix-09/tf_efficientnet_b3" :
    ['fold0_best.pth', 'fold1_best.pth', 'fold2_best.pth', 'fold3_best.pth', 'fold4_best.pth']
}

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_size= sum(len(ENSEMBLE_MODELS[b]) for b in ENSEMBLE_MODELS)
WEIGHT_MODEL = EnsembleWeight(model_size, 5).to(DEVICE)
TTA = False
IMAGE_SIZE = 512

test = pd.read_csv('../input/cassava-leaf-disease-classification/sample_submission.csv')
test.head()
test_dataset = TestDataset(test)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=8, pin_memory=True)

In [None]:
predictions = inference(ENSEMBLE_MODELS, test_loader, DEVICE)
test['label'] = predictions.argmax(-1).numpy()
test[['image_id', 'label']].to_csv(OUTPUT_DIR + 'submission.csv', index=False)
test.head()