In [2]:
import sys
sys.path.append('/absolute-path/snakeCLEF/training_scripts')
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import src
from src.core import models, data
from src.utils import  io
import torch.nn as nn

DATA_DIR = '../../'


MODEL_ARCH = 'vit_small_384' 
MODEL_NAME ='ensemble-model-clef2023_vit_small_384_ensemble_focal_05-14-2023_02-18-03'


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

Device: cpu


In [3]:
# load test metadata
test_df = pd.read_csv('../../snake_csv_files/SnakeCLEF2023-PubTestMetadata.csv')

print(f'Test set length: {len(test_df):,d}')
test_df.head()

Test set length: 14,071


Unnamed: 0,observation_id,captive,endemic,code,image_path
0,5954638,False,False,US,pubtest/7437408.jpg
1,3074568,False,False,ID,pubtest/3543837.jpg
2,37207473,False,False,US,pubtest/58945588.jpg
3,37280098,False,False,TH,pubtest/59067915.jpg
4,116571,False,False,GY,pubtest/166608.JPG


In [4]:
#test dataset 

class SnakeInferenceDataset(Dataset):
    def __init__(self, data, transform = None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index):
        image = self.data.iloc[index]
        img = Image.open(DATA_DIR+image.image_path).convert("RGB")

        if transform is not None:
            img = self.transform(img)

        return img

In [5]:
model_b = models.get_model(MODEL_ARCH, 1784, pretrained=True)


class EnsembleNet(nn.Module):
    def __init__(self, num_classes = 1784):
        super(EnsembleNet, self).__init__()

        self.model_a = model_b
        # self.model_b = model_b
        
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.1),
            nn.Linear(3568, 2048, bias=True),
            nn.Tanh(),
            nn.Dropout(p=0.1),
            nn.Linear(2048, 1784, bias=True)
        )

    def forward(self, x):
        logits_a = self.model_a(x)
        # logits_b = self.model_b(x)
        concatenated_vectors = torch.cat((logits_a, logits_a), dim=1)
        output = self.classifier(concatenated_vectors)
        return output

model = EnsembleNet()
state_dict = torch.load(MODEL_NAME, map_location=torch.device('cpu'))
model.load_state_dict(state_dict)

<All keys matched successfully>

In [7]:
model_config = model_b.pretrained_config
batch_size = 128

# create transforms
_, test_tfms = data.get_transforms(
    size=model_config['input_size'], mean=model_config['image_mean'],
    std=model_config['image_std'])

from torchvision import transforms

transform = transforms.Compose([
    # you can add other transformations in this list
    transforms.Resize((384, 384)),
    transforms.ToTensor()
])

In [8]:
prediction_list = []

test_dataset = SnakeInferenceDataset(test_df, transform = transform) 
test_dataloader = DataLoader(test_dataset, batch_size = 32, shuffle = False)

loop=tqdm(test_dataloader)
for batch, X in enumerate(loop):
    device = torch.device(device)
    X = X.to(device)
    with torch.no_grad():
        preds = model(X)
        final_result = torch.argmax(preds, axis=1)
    prediction_list.append(final_result.tolist())


100%|██████████| 440/440 [41:00<00:00,  5.59s/it]


In [9]:
flat_list = [item for sublist in prediction_list for item in sublist]


In [10]:
df_prediction = pd.DataFrame(flat_list)
df_prediction.to_csv('test_prediction.csv')

In [11]:
df_prediction.columns = ['class_id']
df_prediction.shape


(14071, 1)

In [12]:
test_df.shape

(14071, 5)

In [13]:
final_df = pd.concat([ test_df['observation_id'], df_prediction], axis=1)

In [14]:
print(final_df.shape)
final_df.head()


(14071, 2)


Unnamed: 0,observation_id,class_id
0,5954638,861
1,3074568,1513
2,37207473,860
3,37280098,995
4,116571,144


In [15]:
df = final_df.drop_duplicates('observation_id', keep='last')


In [16]:
df.head()

Unnamed: 0,observation_id,class_id
0,5954638,861
1,3074568,1513
2,37207473,860
3,37280098,995
4,116571,144


In [17]:
df.shape

(7811, 2)

In [18]:
df.to_csv('snake_prediction.csv', index=False)