In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms as pth_transforms

import numpy as np
from PIL import Image

from tqdm import *

import time
import timm

from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

from GPR1200 import GPR1200

### Define models that should be evaluated with timm

In [2]:
model_list = [
             "resnetv2_101x1_bitm",
             "resnetv2_101x1_bitm_in21k",
             "resnetv2_101x3_bitm",
             "resnetv2_101x3_bitm_in21k",
             "tf_efficientnetv2_l",
             "tf_efficientnetv2_l_in21ft1k",
             "tf_efficientnetv2_l_in21k",
             "vit_base_patch16_224",
             "vit_base_patch16_224_in21k",
             "vit_large_patch16_224",
             "vit_large_patch16_224_in21k",
             "deit_base_patch16_224",
             "deit_base_distilled_patch16_224",
             "swin_base_patch4_window7_224",
             "swin_base_patch4_window7_224_in22k",
             "swin_large_patch4_window7_224",
             "swin_large_patch4_window7_224_in22k"
            ]

### Create Dataset Class and GPR1200 Dataset Object

In [3]:
class TestDataset(torch.utils.data.Dataset):
  'Characterizes a dataset for PyTorch'
  def __init__(self, file_paths):
        'Initialization'
        self.file_paths = file_paths
        
  def __len__(self):
        'Denotes the total number of samples'
        return len(self.file_paths)

  def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        return ppc_image(self.file_paths[index])

In [4]:
GPR1200_dataset = GPR1200("/media/Data/images/GPR10x1200/images")
image_filepaths = GPR1200_dataset.image_files

### Start Evaluation of selected models

In [None]:
# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")


for m_name in model_list:
    
    # create models and their respective preprocessing chain
    bb_model = timm.create_model(m_name, pretrained=True, num_classes=0)
    data_config = resolve_data_config({}, model=bb_model)
    transform = create_transform(**data_config)
    
    bb_model.to(device)
    bb_model.eval()
    
    
    # Preprocessing that will be run on each individuall test image
    def ppc_image(path):
    
        with open(path, 'rb') as f:
            img = Image.open(f)
            img = img.convert('RGB')

        img = transform(img)

        return img
    
    # dataloader parameters
    batch_size = 32
    params = {'batch_size': batch_size,
          'shuffle': False,
          'num_workers': 6}

    gpr1200_loader = torch.utils.data.DataLoader(TestDataset(image_filepaths), **params)
    
    
    # some addtional info
    time_start = time.time()
    fv_list = []
    
    pbar = tqdm(enumerate(gpr1200_loader), position=0, leave=True, total=(int(len(image_filepaths) / batch_size)))
    
    with torch.set_grad_enabled(False):
        for i, local_batch in pbar:

            local_batch = local_batch.to(device)
            fv = bb_model(local_batch)
                
            fv = fv / torch.norm(fv, dim=-1, keepdim=True)
           
            fv_list += list(fv.cpu().numpy())
            pbar.update()
    
        print(fv.shape)
    
    # display some addtional info
    fv_list = np.array(fv_list).astype(float)
    print("---------name: {} -- dim: {}---------".format(m_name, fv_list.shape))
    time_needed = np.round((time.time() - time_start) / len(image_filepaths) * 1000, 2)
    dim = fv_list.shape[-1]
    input_size = data_config["input_size"]
    
    
    # run this line to evaluate dataset embeddings
    gpr, lm, iNat, ims, instre, sop, faces = GPR1200_dataset.evaluate(fv_list, compute_partial=True)
    print("GPR1200 mAP: {}".format(gpr))
    print("Landmarks: {}, IMSketch: {}, iNat: {}, Instre: {}, SOP: {}, faces: {}".format(lm, ims, iNat, instre, sop, faces))
    print()
    
    del bb_model

100%|██████████| 375/375 [01:41<00:00,  3.70it/s]


torch.Size([32, 2048])
---------name: resnetv2_101x1_bitm -- dim: (12000, 2048)---------
GPR1200 mAP: 0.5559
Landmarks: 0.8221, IMSketch: 0.4709, iNat: 0.4298, Instre: 0.5292, SOP: 0.861, faces: 0.2227



100%|██████████| 375/375 [00:24<00:00, 15.01it/s]


torch.Size([32, 2048])
---------name: resnetv2_101x1_bitm_in21k -- dim: (12000, 2048)---------
GPR1200 mAP: 0.5494
Landmarks: 0.8112, IMSketch: 0.4113, iNat: 0.4197, Instre: 0.5181, SOP: 0.8695, faces: 0.2668



100%|██████████| 375/375 [08:01<00:00,  1.28s/it]


torch.Size([32, 6144])
---------name: resnetv2_101x3_bitm -- dim: (12000, 6144)---------
GPR1200 mAP: 0.5694
Landmarks: 0.8297, IMSketch: 0.5292, iNat: 0.4012, Instre: 0.5564, SOP: 0.8722, faces: 0.2273



100%|██████████| 375/375 [02:14<00:00,  2.79it/s]


torch.Size([32, 6144])
---------name: resnetv2_101x3_bitm_in21k -- dim: (12000, 6144)---------
