In [1]:
import timm
import torch
from torchvision import transforms
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from training_utils import *
import warnings
warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

label_map = {
    'badger': 0,
    'bird': 1,
    'boar': 2,
    'butterfly': 3,
    'cat': 4,
    'dog': 5,
    'fox': 6,
    'lizard': 7,
    'podolic_cow': 8,
    'porcupine': 9,
    'weasel': 10,
    'wolf': 11,
    'other': 12
}

In [3]:
# Load ViT pre-trained model
model_feat = timm.create_model('vit_base_patch16_224', pretrained=True)
model_feat.reset_classifier(0)  # head removal for classification

model_feat.eval()
model_feat = model_feat.to(device)

In [4]:
clf = LogisticRegression(max_iter=1000)

clf_results = cross_validation("data/augmented_fold_split", clf, 5, device, transform, label_map, model_feat)

print("---------------------LogisticRegression-------------------------")
print_cross_validation_results(clf_results)

---------------------LogisticRegression-------------------------
Accuracy: 0.9757 ± 0.0061
Precision: 0.9718 ± 0.0569
Recall: 0.9766 ± 0.0577
f1score: 0.9724 ± 0.0451
Confusion matrix:
 10.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0  
  0.0±0.0   39.0±0.9    0.0±0.0    0.0±0.0    0.0±0.0    0.2±0.4    0.8±1.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0  
  0.2±0.4    0.0±0.0   19.0±1.4    0.0±0.0    0.0±0.0    0.0±0.0    0.4±0.8    0.2±0.4    0.0±0.0    0.0±0.0    0.0±0.0    0.4±0.5  
  0.0±0.0    0.0±0.0    0.0±0.0   10.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0  
  0.0±0.0    0.0±0.0    0.0±0.0    0.2±0.4    9.8±0.4    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0  
  0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0   10.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0  
  0.0±0.0    0.4±

In [5]:
decision_tree = DecisionTreeClassifier(random_state=42)
decision_tree_results = cross_validation("data/augmented_fold_split", decision_tree, 5, device, transform, label_map, model_feat)

print("---------------------Decision Tree-------------------------")
print_cross_validation_results(decision_tree_results)

---------------------Decision Tree-------------------------
Accuracy: 0.8194 ± 0.0180
Precision: 0.7729 ± 0.1679
Recall: 0.7718 ± 0.2062
f1score: 0.7623 ± 0.1786
Confusion matrix:
  8.0±1.7    0.0±0.0    0.6±0.8    0.0±0.0    0.0±0.0    0.2±0.4    0.0±0.0    0.4±0.5    0.0±0.0    0.6±0.8    0.0±0.0    0.2±0.4  
  0.2±0.4   36.0±2.2    0.4±0.5    0.4±0.8    0.6±0.8    0.2±0.4    0.4±0.5    0.4±0.8    1.0±0.9    0.0±0.0    0.4±0.8    0.0±0.0  
  0.0±0.0    0.4±0.5   15.4±1.4    0.0±0.0    0.6±0.5    0.2±0.4    1.6±0.8    0.8±0.7    0.4±0.5    0.6±0.8    0.0±0.0    0.2±0.4  
  0.0±0.0    0.0±0.0    0.0±0.0    9.6±0.5    0.0±0.0    0.0±0.0    0.0±0.0    0.4±0.5    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0  
  0.2±0.4    0.0±0.0    0.0±0.0    0.0±0.0    8.6±1.2    0.2±0.4    0.4±0.5    0.0±0.0    0.2±0.4    0.0±0.0    0.4±0.5    0.0±0.0  
  0.2±0.4    0.2±0.4    0.6±0.8    0.4±0.5    0.0±0.0    8.4±1.2    0.0±0.0    0.2±0.4    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0  
  0.8±0.4    1.2±0.4  

In [6]:
knn  = KNeighborsClassifier(n_neighbors=1)
knn_results = cross_validation("data/augmented_fold_split", knn, 5, device, transform, label_map, model_feat)

print("---------------------KNN k=1-------------------------")
print_cross_validation_results(knn_results)

---------------------KNN k=1-------------------------
Accuracy: 0.9631 ± 0.0050
Precision: 0.9493 ± 0.0651
Recall: 0.9622 ± 0.0877
f1score: 0.9531 ± 0.0676
Confusion matrix:
  9.6±0.5    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.4±0.5    0.0±0.0    0.0±0.0  
  0.0±0.0   39.0±0.9    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.4±0.5    0.6±0.5    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0  
  0.2±0.4    0.0±0.0   19.0±1.3    0.0±0.0    0.2±0.4    0.0±0.0    0.4±0.8    0.2±0.4    0.0±0.0    0.0±0.0    0.0±0.0    0.2±0.4  
  0.0±0.0    0.0±0.0    0.0±0.0   10.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0  
  0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0   10.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0  
  0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0   10.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0  
  1.2±0.4    0.2±0.4    0.4±