In [1]:
import torch
from training_utils import *
from torchvision import transforms

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
label_map = {
    'badger': 0,
    'bird': 1,
    'boar': 2,
    'butterfly': 3,
    'cat': 4,
    'dog': 5,
    'fox': 6,
    'lizard': 7,
    'podolic_cow': 8,
    'porcupine': 9,
    'weasel': 10,
    'wolf': 11,
    'other': 12
}

### Trining of Vit model with frozen backbone

In [5]:
frozen_results = nn_cross_validation("data/augmented_fold_split", 5, device, transform, label_map, num_epochs=5, frozen=True, save_model_path="vitdet_frozen.pt")

print("---------------------VitDet finetuned frozen backbone-------------------------")
print_cross_validation_results(frozen_results)

Model saved at vitdet_frozen.pt with f1score 0.9596
Early stopping at epoch 2
Early stopping at epoch 1
Early stopping at epoch 1
Early stopping at epoch 1
---------------------VitDet finetuned frozen backbone-------------------------
Accuracy: 0.9492 ± 0.0114
Precision: 0.9390 ± 0.0876
Recall: 0.9367 ± 0.1087
f1score: 0.9322 ± 0.0177
Confusion matrix:
 14.2±1.2    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.2±0.4    0.0±0.0    0.6±1.2  
  0.0±0.0   39.2±1.2    0.0±0.0    0.2±0.4    0.0±0.0    0.2±0.4    0.2±0.4    0.2±0.4    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0  
  0.2±0.4    0.2±0.4   26.6±1.4    0.0±0.0    0.0±0.0    0.0±0.0    0.6±0.8    0.2±0.4    0.0±0.0    0.2±0.4    0.0±0.0    0.2±0.4  
  0.0±0.0    0.0±0.0    0.0±0.0   15.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0  
  0.0±0.0    0.0±0.0    0.0±0.0    0.0±0.0   14.8±0.4    0.0±0.0    0.0±0.0    0.2±0.4    0.0±0.0    0.0±0.0    0

### Trining of Vit model updating all weight

In [8]:
results = nn_cross_validation("data/augmented_fold_split", 5, device, transform, label_map, num_epochs=5, frozen=False)

print("---------------------VitDet finetuned-------------------------")
print_cross_validation_results(results)

Early stopping at epoch 3
Early stopping at epoch 1
Early stopping at epoch 1
Early stopping at epoch 1
---------------------VitDet finetuned-------------------------
Accuracy: 0.4614 ± 0.1222
Precision: 0.2852 ± 0.3199
Recall: 0.3490 ± 0.3897
f1score: 0.2811 ± 0.3043
Confusion matrix:
  8.0±6.8    0.0±0.0    0.0±0.0    6.0±7.3    0.0±0.0    0.0±0.0    0.4±0.8    0.0±0.0    0.0±0.0    0.0±0.0    0.6±1.2    0.0±0.0  
  0.4±0.5   35.4±2.6    0.0±0.0    0.4±0.5    0.2±0.4    0.0±0.0    0.4±0.8    0.8±1.0    2.0±2.8    0.2±0.4    0.2±0.4    0.0±0.0  
  1.8±2.4    4.6±2.6    2.8±5.6    2.0±2.3    0.0±0.0    0.4±0.8    3.6±7.2    0.0±0.0   11.6±6.7    0.2±0.4    1.2±1.2    0.0±0.0  
  3.0±6.0    0.0±0.0    0.0±0.0   10.2±5.4    0.0±0.0    0.0±0.0    0.0±0.0    1.2±1.2    0.0±0.0    0.6±1.2    0.0±0.0    0.0±0.0  
  1.0±2.0    7.8±4.1    0.0±0.0    1.0±1.3    1.4±2.8    0.8±1.6    0.4±0.8    1.4±2.0    1.0±0.9    0.0±0.0    0.2±0.4    0.0±0.0  
  0.0±0.0    7.6±6.3    0.0±0.0    0.0±0.0    0.