In [None]:
import timm
import torch
from torchvision import transforms
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from training_utils import *
import warnings
warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

label_map = {
    'badger': 0,
    'bird': 1,
    'boar': 2,
    'butterfly': 3,
    'cat': 4,
    'dog': 5,
    'fox': 6,
    'lizard': 7,
    'podolic_cow': 8,
    'porcupine': 9,
    'weasel': 10,
    'wolf': 11,
    'other': 12
}

In [3]:
# Load ViT pre-trained model
model_feat = timm.create_model('vit_base_patch16_224', pretrained=True)
model_feat.reset_classifier(0)  # head removal for classification

model_feat.eval()
model_feat = model_feat.to(device)

In [4]:
clf = LogisticRegression(max_iter=1000)

clf_results = cross_validation("data/fold_split", clf, 5, device, transform, label_map, model_feat)

print("---------------------LogisticRegression-------------------------")
print_cross_validation_results(clf_results)

---------------------LogisticRegression-------------------------
Accuracy: 0.9389 ± 0.0153
Precision: 0.7918 ± 0.2899
Recall: 0.8090 ± 0.3025
f1score: 0.7929 ± 0.2894
Confusion matrix:
 38.8±1.2    0.0±0.0    0.0±0.0    0.6±0.8    0.0±0.0    0.0±0.0    0.6±0.8  
  0.2±0.4   18.8±0.7    0.0±0.0    0.2±0.4    0.2±0.4    0.4±0.5    0.4±0.5  
  0.0±0.0    0.0±0.0    2.4±0.5    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0  
  0.2±0.4    0.4±0.5    0.0±0.0    8.8±0.7    0.2±0.4    0.4±0.5    0.2±0.4  
  0.0±0.0    0.8±0.7    0.0±0.0    0.2±0.4   54.6±1.0    0.0±0.0    0.0±0.0  
  0.0±0.0    0.0±0.0    0.0±0.0    0.8±0.7    0.2±0.4    1.0±0.9    0.4±0.5  
  0.2±0.4    0.4±0.5    0.0±0.0    1.0±1.5    0.2±0.4    0.0±0.0    1.6±1.2  


In [5]:
decision_tree = DecisionTreeClassifier(random_state=42)
decision_tree_results = cross_validation("data/fold_split", decision_tree, 5, device, transform, label_map, model_feat)

print("---------------------Decision Tree-------------------------")
print_cross_validation_results(decision_tree_results)

---------------------Decision Tree-------------------------
Accuracy: 0.8286 ± 0.0230
Precision: 0.6048 ± 0.3386
Recall: 0.5814 ± 0.3005
f1score: 0.5846 ± 0.3103
Confusion matrix:
 36.4±1.4    0.8±0.4    0.0±0.0    1.2±1.2    0.4±0.5    0.2±0.4    1.0±0.9  
  0.8±0.7   15.6±1.0    0.0±0.0    1.4±0.8    1.6±0.5    0.2±0.4    0.6±0.5  
  0.6±0.5    0.0±0.0    1.0±0.6    0.4±0.5    0.2±0.4    0.0±0.0    0.2±0.4  
  1.2±0.7    0.4±0.5    0.2±0.4    5.4±0.5    1.2±0.4    0.2±0.4    1.6±1.0  
  0.4±0.5    2.4±1.9    0.2±0.4    0.2±0.4   51.4±2.2    0.4±0.8    0.6±0.8  
  0.2±0.4    0.4±0.5    0.0±0.0    0.8±0.7    0.2±0.4    0.8±0.4    0.0±0.0  
  0.2±0.4    1.4±1.0    0.0±0.0    0.8±0.7    0.0±0.0    0.4±0.8    0.6±0.5  


In [6]:
knn  = KNeighborsClassifier(n_neighbors=1)
knn_results = cross_validation("data/fold_split", knn, 5, device, transform, label_map, model_feat)

print("---------------------KNN k=1-------------------------")
print_cross_validation_results(knn_results)

---------------------KNN k=1-------------------------
Accuracy: 0.9419 ± 0.0099
Precision: 0.8512 ± 0.1917
Recall: 0.8357 ± 0.2060
f1score: 0.8317 ± 0.1906
Confusion matrix:
 39.4±0.8    0.0±0.0    0.0±0.0    0.4±0.5    0.0±0.0    0.0±0.0    0.2±0.4  
  0.0±0.0   19.4±0.5    0.0±0.0    0.2±0.4    0.0±0.0    0.2±0.4    0.4±0.5  
  0.0±0.0    0.0±0.0    2.4±0.5    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0  
  0.2±0.4    0.8±0.4    0.0±0.0    7.6±1.0    0.4±0.5    0.8±0.7    0.4±0.5  
  0.0±0.0    1.0±0.6    0.0±0.0    0.4±0.5   54.2±0.7    0.0±0.0    0.0±0.0  
  0.0±0.0    0.0±0.0    0.0±0.0    0.6±0.8    0.0±0.0    1.4±0.5    0.4±0.5  
  0.2±0.4    0.6±0.5    0.0±0.0    0.4±0.5    0.0±0.0    0.2±0.4    2.0±0.6  
