In [1]:
import sys
sys.path.append('/absolute-pathtofolder/snakeCLEF/training_scripts')
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T

from src.core import models, training, data

DATA_DIR = '../../'

MODEL_ARCH_V_NV='vit_small_384'
MODEL_NAME_V_NV='vornot-clef2023_vit_small_384_efocal_05-22-2023_16-25-34'

MODEL_ARCH_V = 'vit_small_384'
MODEL_NAME_V ='venom-clef2023_vit_small_384_ensemble_focal_05-22-2023_05-09-10'

MODEL_ARCH_NV='vit_small_384'
MODEL_NAME_NV ='non_v-clef2023_vit_small_384_ensemble_focal_05-22-2023_12-38-44'


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')


  from .autonotebook import tqdm as notebook_tqdm


Device: cpu


In [2]:
# load test metadata
test_df = pd.read_csv('../../snake_csv_files/SnakeCLEF2023-PubTestMetadata.csv')

print(f'Test set length: {len(test_df):,d}')
test_df.head()


Test set length: 14,071


Unnamed: 0,observation_id,captive,endemic,code,image_path
0,5954638,False,False,US,pubtest/7437408.jpg
1,3074568,False,False,ID,pubtest/3543837.jpg
2,37207473,False,False,US,pubtest/58945588.jpg
3,37280098,False,False,TH,pubtest/59067915.jpg
4,116571,False,False,GY,pubtest/166608.JPG


In [3]:
#test dataset 

class SnakeInferenceDataset(Dataset):
    def __init__(self, data, transform = None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index):
        image = self.data.iloc[index]
        img = Image.open(DATA_DIR+image.image_path).convert("RGB")

        if transform is not None:
            img = self.transform(img)

        return img

In [4]:
# create fine-tuned network
model_nv = models.get_model(MODEL_ARCH_NV, 1457, pretrained=False)
training.load_model(model_nv, MODEL_NAME_NV, path='../results/models/')

model_v = models.get_model(MODEL_ARCH_V, 327, pretrained=False)
training.load_model(model_v, MODEL_NAME_V, path='../results/models/')

model_v_nv = models.get_model(MODEL_ARCH_V_NV, 2, pretrained=False)
training.load_model(model_v_nv, MODEL_NAME_V_NV, path='../results/models/')

'''
All three models are same so config is same
'''

model_config = model_v_nv.pretrained_config

batch_size = 1


def get_transforms(*, size=224, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
    """Create basic image transforms for training or validation dataset."""
    train_tfms = T.Compose([
        T.RandomResizedCrop((size, size), scale=(0.8, 1.0)),
        T.RandomHorizontalFlip(p=0.5),
        T.RandomVerticalFlip(p=0.5),
        T.RandomApply(torch.nn.ModuleList([
            T.ColorJitter(brightness=0.2, contrast=0.2)
        ]), p=0.2),  # random brightness contrast
        T.ToTensor(),
        T.Normalize(mean, std)])
    valid_tfms = T.Compose([
        T.Resize((size, size)),
        T.ToTensor(),
        T.Normalize(mean, std)])
    return train_tfms, valid_tfms


# create transforms
_, test_tfms = get_transforms(
    size=model_config['input_size'], mean=model_config['image_mean'],
    std=model_config['image_std'])

from torchvision import transforms

transform = transforms.Compose([
    # you can add other transformations in this list
    transforms.Resize((384, 384)),
    transforms.ToTensor()
])

In [17]:
'''
CPU run
'''
prediction_list = []

test_dataset = SnakeInferenceDataset(test_df, transform = transform) 
test_dataloader = DataLoader(test_dataset, batch_size = batch_size, shuffle = False)

loop=tqdm(test_dataloader)
for batch, X in enumerate(loop):

  device = torch.device(device)
  X = X.to(device)
  
  with torch.no_grad():
    preds = model_v_nv(X)
    final_result = torch.argmax(preds, axis=1)

    if final_result==1:
      pred = model_v(X)
      res = torch.argmax(pred, axis=1)

    elif final_result==0:
      pred = model_nv(X)
      res = torch.argmax(pred, axis=1)

    prediction_list.append(res.tolist())


100%|██████████| 14071/14071 [1:11:44<00:00,  3.27it/s]


In [18]:
flat_list = [item for sublist in prediction_list for item in sublist]


In [19]:
df_prediction = pd.DataFrame(flat_list)
df_prediction.to_csv('test_prediction.csv')

In [20]:
df_prediction.columns = ['class_id']
df_prediction.shape


(14071, 1)

In [21]:
test_df.shape

(14071, 5)

In [22]:
final_df = pd.concat([ test_df['observation_id'], df_prediction], axis=1)

In [23]:
print(final_df.shape)
final_df.head()


(14071, 2)


Unnamed: 0,observation_id,class_id
0,5954638,725
1,3074568,1249
2,37207473,709
3,37280098,230
4,116571,118


In [24]:
df = final_df.drop_duplicates('observation_id', keep='last')


In [25]:
df.head()

Unnamed: 0,observation_id,class_id
0,5954638,725
1,3074568,1249
2,37207473,709
3,37280098,230
4,116571,118


In [26]:
df.shape

(7811, 2)

In [27]:
df.to_csv('snake_prediction.csv', index=False)