In [6]:
%load_ext autoreload
%autoreload 2
import os
os.environ["TERRA_CONFIG_PATH"] = "/home/sabri/code/spr-21/terra_config.json"

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
import robustnessgym as rg

In [8]:
import torch
import torchvision
import torch.nn as nn
from spr.data.iwildcam import iwildcam_task_config
def initialize_torchvision_model(name, d_out, **kwargs):
    # get constructor and last layer names
    if name == 'wideresnet50':
        constructor_name = 'wide_resnet50_2'
        last_layer_name = 'fc'
    elif name == 'densenet121':
        constructor_name = name
        last_layer_name = 'classifier'
    elif name in ('resnet50', 'resnet34'):
        constructor_name = name
        last_layer_name = 'fc'
    else:
        raise ValueError(f'Torchvision model {name} not recognized')
    # construct the default model, which has the default last layer
    constructor = getattr(torchvision.models, constructor_name)
    model = constructor(**kwargs)
    # adjust the last layer
    d_features = getattr(model, last_layer_name).in_features
    if d_out is None:  # want to initialize a featurizer model
        last_layer = Identity(d_features)
        model.d_out = d_features
    else: # want to initialize a classifier for a particular num_classes
        last_layer = nn.Linear(d_features, d_out)
        model.d_out = d_out
    setattr(model, last_layer_name, last_layer)
    return model

from spr.vision import Classifier
model = initialize_torchvision_model(
    "resnet50", 
    d_out=iwildcam_task_config["num_classes"]
)
classifier = Classifier(model=model, metrics=["accuracy"], config=iwildcam_task_config)
state_dict = torch.load("/home/common/datasets/iwildcam_v2.0/best_model.pth")
classifier.load_state_dict(state_dict["algorithm"])

<All keys matched successfully>

In [14]:
from spr.data.iwildcam import get_iwildcam_model
classifier = get_iwildcam_model()



In [15]:
from spr.data.iwildcam import iwildcam_task_config, build_iwildcam_df
from spr.vision import score
#iwildcam_task_config["img_transform"] = transform
out = score(
    model=classifier,
    data_df=build_iwildcam_df.out(),
    batch_size=256,
    split="id_valid",
    **iwildcam_task_config
)

task: score, run_id=39
Global seed set to 123
[2021-04-23 23:14:38,910][INFO][lightning:54] :: Global seed set to 123
GPU available: True, used: True
[2021-04-23 23:14:39,025][INFO][lightning:55] :: GPU available: True, used: True
TPU available: None, using: 0 TPU cores
[2021-04-23 23:14:39,026][INFO][lightning:55] :: TPU available: None, using: 0 TPU cores


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'valid_accuracy': 0.8341536521911621,
 'valid_loss': 1.0643936395645142,
 'valid_macro_f1': 0.20122647285461426,
 'valid_macro_recall': 0.20635363459587097}
--------------------------------------------------------------------------------


In [10]:
from spr.data.iwildcam import iwildcam_task_config, build_iwildcam_df
from spr.vision import score
#iwildcam_task_config["img_transform"] = transform
out = score(
    model=classifier,
    data_df=build_iwildcam_df.out(),
    batch_size=256,
    split="valid",
    **iwildcam_task_config
)

task: score, run_id=38
Global seed set to 123
[2021-04-23 22:59:24,451][INFO][lightning:54] :: Global seed set to 123
GPU available: True, used: True
[2021-04-23 22:59:24,667][INFO][lightning:55] :: GPU available: True, used: True
TPU available: None, using: 0 TPU cores
[2021-04-23 22:59:24,668][INFO][lightning:55] :: TPU available: None, using: 0 TPU cores


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'valid_accuracy': 0.6154000163078308,
 'valid_loss': 1.8163567781448364,
 'valid_macro_f1': 0.14952275156974792,
 'valid_macro_recall': 0.15714137256145477}
--------------------------------------------------------------------------------
