In [5]:
from wilds import get_dataset
from wilds.common.data_loaders import get_eval_loader
from wilds.common.grouper import CombinatorialGrouper
import torchvision.transforms as transforms
import torch
# import torch.nn as nn
# from tqdm import tqdm
from erm_helpers import val_step, build_metrics_dict
# import sys
# import os
# import time
from Classifier import Classifier
from SubsampledDataset import SubsampledDataset, NUM_CLASSES


device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
BATCH_SIZE = 16

cuda


In [3]:
dataset = get_dataset(dataset="fmow", download=False)
grouper = CombinatorialGrouper(dataset, ["region"])


ood_val_data = dataset.get_subset(
    "val",
    transform=transforms.Compose(
        [transforms.ToTensor()]
    ),
)

id_val_data = dataset.get_subset(
    "id_val",
    transform=transforms.Compose(
        [transforms.ToTensor()]
    ),
)

ood_val_dataset = SubsampledDataset(ood_val_data, grouper)
id_val_dataset = SubsampledDataset(id_val_data, grouper)

ood_val_loader   = get_eval_loader("standard", ood_val_dataset, batch_size=BATCH_SIZE)
id_val_loader    = get_eval_loader("standard", id_val_dataset, batch_size=BATCH_SIZE)

In [14]:
def eval_model(model, loss_fn):
  # ood val
  y_true, y_pred, metadata, loss = val_step(model, ood_val_loader, loss_fn, device)
  ood_metrics_dict = build_metrics_dict(dataset, y_true, y_pred, metadata, loss)

  # id val
  y_true, y_pred, metadata, loss = val_step(model, id_val_loader, loss_fn, device)
  id_metrics_dict = build_metrics_dict(dataset, y_true, y_pred, metadata, loss)

  print('ood validation')
  print(', '.join([f'{key} - {value:.8f}' for (key, value) in ood_metrics_dict.items()]))
  print('id validation')
  print(', '.join([f'{key} - {value:.8f}' for (key, value) in id_metrics_dict.items()]))

In [7]:
model = Classifier(NUM_CLASSES)
model = model.to(device)

In [15]:
model = torch.load("models/ERM_10_15_SGD_0.0001_0.9_CrossEntropy.pth")
print('\nERM')
eval_model(model, loss_fn=torch.nn.CrossEntropyLoss())

model = torch.load("models/ERM_10_10_SGD_0.0001_0.9_CrossEntropy_pretrained_unet.pth")
print('\nERM with pretrained Unet')
eval_model(model, loss_fn=torch.nn.CrossEntropyLoss())


ERM


100%|██████████| 363/363 [00:33<00:00, 10.70it/s]
100%|██████████| 209/209 [00:18<00:00, 11.15it/s]


ood validation
acc_region:Europe - 0.59820282, acc_region:Americas - 0.49238202, loss - 1.40786018
id validation
acc_region:Europe - 0.52923834, acc_region:Americas - 0.55657494, loss - 1.40919449
ERM with pretrained Unet


100%|██████████| 363/363 [00:30<00:00, 11.84it/s]
100%|██████████| 209/209 [00:18<00:00, 11.48it/s]

ood validation
acc_region:Europe - 0.57445443, acc_region:Americas - 0.46302488, loss - 1.79384711
id validation
acc_region:Europe - 0.51744473, acc_region:Americas - 0.49770641, loss - 1.84223977



