In [6]:
from torch import nn
import torch
from torchvision import models
from typing import Union
import cv2
from tqdm.auto import tqdm
import numpy as np
import pandas as pd
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2
from inference import InferenceDs, InferenceModel, Predictor
import numpy as np

In [7]:
test_df = pd.read_csv("../input/cassava-leaf-disease-classification/sample_submission.csv")
test_dir = "../input/cassava-leaf-disease-classification/test_images"
test_df["filePath"] = [os.path.join(test_dir,test_df["image_id"][n]) for n in range(len(test_df))]


train_df = pd.read_csv("../input/cassava-leaf-disease-classification/train.csv")

weights = {}
for i in range(5):
    weights[i] = 1 - (len(train_df.loc[train_df.label == i]) / len(train_df))
weights = list(weights.values())
weights = np.array(weights)
weights /= sum(weights)
print(weights)

[0.23729962 0.22442398 0.22212226 0.09626349 0.21989064]


In [12]:
# path to saved weights file
WEIGHTS_PATH = "../input/leafdiseaseclassificationmodelweights/weights_fold0.pt"

image_list = list(os.listdir('../input/cassava-leaf-disease-classification/test_images/'))

# model feature extractor
classifier = models.resnext50_32x4d()

# dims for the base model
num_ftrs = classifier.fc.out_features
h1 = 512 
h2 = int(h1/2)

# base of the feature extractor
base_model = nn.Sequential(
    nn.BatchNorm1d(num_ftrs),
    nn.ReLU(inplace=True),
    nn.Dropout(0.25),
    nn.Linear(num_ftrs, h1),
    nn.BatchNorm1d(h1),
    nn.ReLU(inplace=True),
    nn.Dropout(0.25),
    nn.Linear(h1, h2),
    nn.BatchNorm1d(h2),
    nn.ReLU(inplace=True),
    nn.Dropout(0.5),
    nn.Linear(h2, 5)
)

# test time augmentations
test_augs = A.Compose([
    A.RandomResizedCrop(224, 224),
    A.Transpose(p=0.5),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.HueSaturationValue(p=0.5),
    A.RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
    A.Normalize(max_pixel_value=255.0, p=1.0),
    ToTensorV2(p=1.0),
],p=1.0)


def inference_one_epoch(model, data_loader, device):
    """
    Performs one evaluations step
    """
    model.eval()
    
    image_preds_all = []
    
    pbar = tqdm(enumerate(data_loader), total=len(data_loader))
    
    for step, (imgs, ids) in pbar:
        imgs = imgs.to(device)
        image_preds = model(imgs)
        image_preds_all += [torch.softmax(image_preds, 1).detach().cpu().numpy()] 
    
    image_preds_all = np.concatenate(image_preds_all, axis=0)
    
    return image_preds_all

In [16]:
if __name__ == '__main__':
    tta_folds = 5
    # device
    device  = torch.device("cuda:0") if torch.cuda.is_available() else "cpu"
    
    # test dataframe
    test = pd.DataFrame()
    test["image_id"] = image_list
    test["filePath"] = [os.path.join(test_dir, test["image_id"][n]) for n in range(len(test))]
    
    # dataloader(s)
    tst_loader = torch.utils.data.DataLoader(InferenceDs(test, test_augs), shuffle=False, batch_size=128)
    
    # init model
    model = InferenceModel(classifier=classifier, base=base_model)
    model.load_state_dict(torch.load(WEIGHTS_PATH))
    model.to(device)
        
    # store predictions
    tst_preds = []
    
    
    # generate predictions
    with torch.no_grad():
        for _ in tqdm(range(tta_folds), desc="Eval"):
            tst_preds += [(weights / tta_folds)*inference_one_epoch(model, tst_loader, device)]
        
        tst_preds = np.mean(tst_preds, axis=0) 

    del model
    torch.cuda.empty_cache()

HBox(children=(FloatProgress(value=0.0, description='Evaluating', max=5.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))





In [5]:
test["label"] = np.argmax(tst_preds, axis=1)
test.drop(columns=["filePath"], inplace=True)
test.head()

Unnamed: 0,image_id,label
0,2216849948.jpg,2


In [None]:
test.to_csv('submission.csv', index=False)