In [1]:
import os
os.chdir("..")

In [2]:
import sys
sys.path.insert(0, os.getcwd())
print(sys.path)

['/storage/ice1/5/4/rso31/meadow', '/storage/ice1/5/4/rso31/meadow/notebooks', '/storage/ice1/5/4/rso31/miniforge3/envs/dml_env/lib/python312.zip', '/storage/ice1/5/4/rso31/miniforge3/envs/dml_env/lib/python3.12', '/storage/ice1/5/4/rso31/miniforge3/envs/dml_env/lib/python3.12/lib-dynload', '', '/storage/ice1/5/4/rso31/miniforge3/envs/dml_env/lib/python3.12/site-packages']


In [3]:
from torchensemble import SnapshotEnsembleClassifier, BaggingClassifier
from torchensemble.utils import io
import torch

In [4]:
from models.base_pretrained import PreTrainedResNet

In [5]:
out_classes = 182
model_variant = 34
freeze_backbone = True

In [6]:
ensemble_model = SnapshotEnsembleClassifier(
    estimator=PreTrainedResNet,
    estimator_args={"out_classes": out_classes, "variant": model_variant, "freeze_backbone": freeze_backbone},
    n_estimators=7,
    cuda=True # run this notebook with a GPU!
)

In [7]:
io.load(ensemble_model, "checkpoints/resnet-34_20241111-221323_lr2.00e-01_bs128_snapshot7")

  state = torch.load(save_dir, map_location=map_location)


In [8]:
ensemble_model.estimators_

ModuleList(
  (0-6): 7 x PreTrainedResNet(
    (accuracy): MulticlassAccuracy()
    (f1_score): MulticlassF1Score()
    (resnet_feat_extractor): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel

Yields 7 trained models from `torchensemble`'s `SnapshotEnsembleClassifier`

---

In [9]:
EVAL_SPLIT_TYPES = [
    "val",
    "test",
    "id_val",
    "id_test",
]

In [10]:
from utils.data import get_iwildcam_datasets, create_loader
from utils.mappings import TFMS_MAP
from tqdm import tqdm
import pandas as pd

In [11]:
labeled_data, _ = get_iwildcam_datasets()



In [12]:
model_tfms = TFMS_MAP["resnet"]
agg_res = {}

In [13]:
for split in EVAL_SPLIT_TYPES:
    loader = create_loader(
        labeled_data,
        subset_type=split,
        tfms=model_tfms,
        batch_size=128
    )
    assert loader is not None

    all_y_pred = []
    all_y_true = []
    all_metadata = []
    for X, y_true, m in tqdm(loader, desc=f"Collecting predictions for split {split}"):
        all_y_pred.append(ensemble_model.predict(X))
        all_y_true.append(y_true)
        all_metadata.append(m)
    all_y_pred = torch.vstack(all_y_pred).argmax(dim=-1).flatten()
    all_y_true = torch.hstack(all_y_true)
    all_metadata = torch.vstack(all_metadata)

    print(f"==={split}===")
    res, _ = labeled_data.eval(all_y_pred, all_y_true, all_metadata)
    print(res)
    agg_res[split] = res

Collecting predictions for split val: 100%|██████████| 117/117 [01:13<00:00,  1.59it/s]


===val===
{'acc_avg': 0.4035158157348633, 'recall-macro_all': 0.1906242998149338, 'F1-macro_all': 0.19060184811571088}


Collecting predictions for split test: 100%|██████████| 335/335 [02:51<00:00,  1.95it/s]


===test===
{'acc_avg': 0.5993316173553467, 'recall-macro_all': 0.16462148596415108, 'F1-macro_all': 0.15080931897809782}


Collecting predictions for split id_val: 100%|██████████| 58/58 [00:31<00:00,  1.86it/s]


===id_val===
{'acc_avg': 0.6707683801651001, 'recall-macro_all': 0.3311155057677788, 'F1-macro_all': 0.3041295430455254}


Collecting predictions for split id_test: 100%|██████████| 64/64 [00:34<00:00,  1.87it/s]

===id_test===
{'acc_avg': 0.608290433883667, 'recall-macro_all': 0.3266538278595561, 'F1-macro_all': 0.32122721048855024}





In [14]:
print("====SUMMARY====")
df = pd.DataFrame.from_dict(agg_res)
print(df)

====SUMMARY====
                       val      test    id_val   id_test
acc_avg           0.413742  0.616321  0.695789  0.628281
recall-macro_all  0.180431  0.149897  0.299366  0.295248
F1-macro_all      0.183849  0.150922  0.301406  0.299228
