In [1]:
import timm
import torch
from torchvision import transforms
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from training_utils import *
import warnings
warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

label_map = {
    'badger': 0,
    'bird': 1,
    'boar': 2,
    'butterfly': 3,
    'cat': 4,
    'dog': 5,
    'fox': 6,
    'lizard': 7,
    'podolic_cow': 8,
    'porcupine': 9,
    'weasel': 10,
    'wolf': 11,
    'other': 12
}

In [3]:
# Load ViT pre-trained model
model_feat = timm.create_model('vit_base_patch16_224', pretrained=True)
model_feat.reset_classifier(0)  # head removal for classification

model_feat.eval()
model_feat = model_feat.to(device)

In [4]:
clf = LogisticRegression(max_iter=1000)

clf_results = cross_validation("data/augmented_fold_split", clf, 5, device, transform, label_map, model_feat)

print("---------------------LogisticRegression-------------------------")
print_cross_validation_results(clf_results)

---------------------LogisticRegression-------------------------
Accuracy: 0.9811 ± 0.0115
Precision: 0.9779 ± 0.0462
Recall: 0.9823 ± 0.0441
f1score: 0.9793 ± 0.0382
Confusion matrix:
 15.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0  
  0.0±0.0   39.2±1.0    0.0±0.0    0.0±0.0    0.0±0.0    0.2±0.4    0.4±0.5    0.2±0.4    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0  
  0.2±0.4    0.0±0.0   27.0±0.9    0.0±0.0    0.0±0.0    0.0±0.0    0.2±0.4    0.2±0.4    0.2±0.4    0.0±0.0    0.0±0.0    0.4±0.5  
  0.0±0.0    0.0±0.0    0.0±0.0   15.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0  
  0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0   15.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0  
  0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0   15.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0  
  0.0±0.0    0.4±

In [5]:
decision_tree = DecisionTreeClassifier(random_state=42)
decision_tree_results = cross_validation("data/augmented_fold_split", decision_tree, 5, device, transform, label_map, model_feat)

print("---------------------Decision Tree-------------------------")
print_cross_validation_results(decision_tree_results)

---------------------Decision Tree-------------------------
Accuracy: 0.8303 ± 0.0138
Precision: 0.8072 ± 0.1314
Recall: 0.8056 ± 0.1538
f1score: 0.8015 ± 0.1334
Confusion matrix:
 11.6±2.1    0.0±0.0    0.2±0.4    0.2±0.4    0.0±0.0    0.2±0.4    0.0±0.0    0.0±0.0    0.2±0.4    1.6±1.0    0.0±0.0    1.0±1.5  
  0.0±0.0   34.6±1.6    0.8±0.7    0.0±0.0    1.0±1.1    0.4±0.8    0.8±1.2    0.4±0.5    0.8±0.7    0.4±0.8    0.0±0.0    0.8±0.7  
  0.0±0.0    1.2±1.0   22.8±1.5    0.0±0.0    0.6±0.5    0.0±0.0    2.0±0.6    0.0±0.0    1.4±0.5    0.2±0.4    0.0±0.0    0.0±0.0  
  0.2±0.4    0.2±0.4    0.0±0.0   13.2±0.7    0.0±0.0    0.0±0.0    0.0±0.0    1.4±0.5    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0  
  0.0±0.0    0.4±0.5    0.0±0.0    0.2±0.4   13.2±1.3    0.0±0.0    0.4±0.8    0.0±0.0    0.4±0.5    0.0±0.0    0.0±0.0    0.4±0.5  
  0.0±0.0    0.2±0.4    0.4±0.5    0.4±0.5    0.0±0.0   12.4±2.3    0.0±0.0    1.0±0.6    0.2±0.4    0.0±0.0    0.0±0.0    0.4±0.5  
  0.4±0.5    1.2±1.2  

In [6]:
knn  = KNeighborsClassifier(n_neighbors=1)
knn_results = cross_validation("data/augmented_fold_split", knn, 5, device, transform, label_map, model_feat)

print("---------------------KNN k=1-------------------------")
print_cross_validation_results(knn_results)

---------------------KNN k=1-------------------------
Accuracy: 0.9705 ± 0.0147
Precision: 0.9635 ± 0.0584
Recall: 0.9728 ± 0.0681
f1score: 0.9665 ± 0.0533
Confusion matrix:
 14.8±0.4    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.2±0.4    0.0±0.0    0.0±0.0  
  0.0±0.0   39.2±1.2    0.0±0.0    0.2±0.4    0.0±0.0    0.0±0.0    0.2±0.4    0.4±0.5    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0  
  0.2±0.4    0.0±0.0   26.6±1.2    0.0±0.0    0.2±0.4    0.0±0.0    0.4±0.5    0.2±0.4    0.0±0.0    0.2±0.4    0.0±0.0    0.4±0.5  
  0.0±0.0    0.0±0.0    0.0±0.0   15.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0  
  0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0   15.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0  
  0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0   15.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0  
  1.0±0.9    0.2±0.4    0.2±