In [1]:
import timm
import torch
from torchvision import transforms
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from training_utils import *
import warnings
warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

label_map = {
    'badger': 0,
    'bird': 1,
    'boar': 2,
    'butterfly': 3,
    'cat': 4,
    'dog': 5,
    'fox': 6,
    'lizard': 7,
    'podolic_cow': 8,
    'porcupine': 9,
    'weasel': 10,
    'wolf': 11,
    'other': 12
}

In [3]:
# Load ViT pre-trained model
model_feat = timm.create_model('vit_base_patch16_224', pretrained=True)
model_feat.reset_classifier(0)  # head removal for classification

model_feat.eval()
model_feat = model_feat.to(device)

In [4]:
clf = LogisticRegression(max_iter=1000)

clf_results = cross_validation("data/fold_split", clf, 5, device, transform, label_map, model_feat)

print("---------------------LogisticRegression-------------------------")
print_cross_validation_results(clf_results)

---------------------LogisticRegression-------------------------
Accuracy: 0.9603 ± 0.0083
Precision: 0.9196 ± 0.1261
Recall: 0.9009 ± 0.1705
f1score: 0.8960 ± 0.1385
Confusion matrix:
 39.2±1.0    0.0±0.0    0.0±0.0    0.6±0.8    0.0±0.0    0.0±0.0    0.2±0.4  
  0.0±0.0   27.4±1.0    0.0±0.0    0.4±0.5    0.0±0.0    0.2±0.4    0.2±0.4  
  0.0±0.0    0.0±0.0    2.4±0.5    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0  
  0.4±0.8    0.6±0.8    0.0±0.0   10.6±1.0    0.0±0.0    0.6±0.5    0.0±0.0  
  0.0±0.0    0.4±0.8    0.0±0.0    0.6±0.8   59.8±1.0    0.0±0.0    0.0±0.0  
  0.0±0.0    0.0±0.0    0.0±0.0    0.4±0.5    0.0±0.0    3.4±0.5    0.4±0.8  
  0.0±0.0    0.0±0.0    0.0±0.0    0.8±0.7    0.2±0.4    0.0±0.0    2.4±1.2  


In [5]:
decision_tree = DecisionTreeClassifier(random_state=42)
decision_tree_results = cross_validation("data/fold_split", decision_tree, 5, device, transform, label_map, model_feat)

print("---------------------Decision Tree-------------------------")
print_cross_validation_results(decision_tree_results)

---------------------Decision Tree-------------------------
Accuracy: 0.8400 ± 0.0182
Precision: 0.6282 ± 0.3325
Recall: 0.6570 ± 0.3268
f1score: 0.6245 ± 0.3163
Confusion matrix:
 35.2±1.8    0.8±1.0    0.4±0.8    1.0±0.6    1.8±1.7    0.2±0.4    0.6±0.8  
  1.2±0.7   24.0±1.4    0.2±0.4    1.0±1.1    0.8±0.7    0.6±0.8    0.4±0.5  
  0.2±0.4    0.2±0.4    1.2±1.2    0.2±0.4    0.2±0.4    0.2±0.4    0.2±0.4  
  0.2±0.4    0.8±0.4    0.2±0.4    7.4±1.6    1.2±1.5    1.6±1.0    0.8±0.7  
  1.0±2.0    1.8±1.7    0.8±0.7    0.8±0.4   55.6±2.6    0.0±0.0    0.8±1.2  
  0.0±0.0    0.2±0.4    0.2±0.4    0.2±0.4    0.4±0.5    3.2±0.4    0.0±0.0  
  0.6±0.8    0.8±0.7    0.0±0.0    1.0±0.6    0.0±0.0    0.6±0.5    0.4±0.5  


In [6]:
knn  = KNeighborsClassifier(n_neighbors=1)
knn_results = cross_validation("data/fold_split", knn, 5, device, transform, label_map, model_feat)

print("---------------------KNN k=1-------------------------")
print_cross_validation_results(knn_results)

---------------------KNN k=1-------------------------
Accuracy: 0.9550 ± 0.0128
Precision: 0.9094 ± 0.1396
Recall: 0.8914 ± 0.1519
f1score: 0.8895 ± 0.1308
Confusion matrix:
 39.4±0.8    0.0±0.0    0.0±0.0    0.2±0.4    0.0±0.0    0.0±0.0    0.4±0.5  
  0.0±0.0   27.6±1.0    0.0±0.0    0.2±0.4    0.0±0.0    0.2±0.4    0.2±0.4  
  0.0±0.0    0.0±0.0    2.4±0.5    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0  
  0.4±0.5    0.6±0.8    0.0±0.0   10.4±0.8    0.0±0.0    0.6±0.8    0.2±0.4  
  0.2±0.4    1.4±1.2    0.0±0.0    0.2±0.4   59.0±1.1    0.0±0.0    0.0±0.0  
  0.0±0.0    0.0±0.0    0.0±0.0    0.6±0.5    0.0±0.0    3.2±0.4    0.4±0.8  
  0.2±0.4    0.4±0.5    0.0±0.0    0.2±0.4    0.0±0.0    0.2±0.4    2.4±1.0  
