In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from wilds import get_dataset
from wilds.common.data_loaders import get_train_loader
from wilds.common.grouper import CombinatorialGrouper
import torchvision
import torchvision.transforms as transforms
import torch
import torch.nn as nn
from tqdm import tqdm
from helpers import *
from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix
import sys
import os
import time
from Classifier import Classifier
from SubsampledDataset import SubsampledDataset


device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
BATCH_SIZE = 64

<font size="6">Load train and validation data</font>

In [None]:
dataset = get_dataset(dataset="fmow", download=False)
grouper = CombinatorialGrouper(dataset, ["region"])

ood_train_data = dataset.get_subset(
    "train",
    transform=transforms.Compose(
        [transforms.ToTensor()]
    ),
)

ood_val_data = dataset.get_subset(
    "val",
    transform=transforms.Compose(
        [transforms.ToTensor()]
    ),
)

id_val_data = dataset.get_subset(
    "id_val",
    transform=transforms.Compose(
        [transforms.ToTensor()]
    ),
)

ood_train_dataset = SubsampledDataset(ood_train_data, grouper)
ood_val_dataset = SubsampledDataset(ood_val_data, grouper)
id_val_dataset = SubsampledDataset(id_val_data, grouper)

ood_train_loader = get_train_loader("standard", ood_train_dataset, batch_size=BATCH_SIZE)
ood_val_loader   = get_train_loader("standard", ood_val_dataset, batch_size=BATCH_SIZE)
id_val_loader    = get_train_loader("standard", id_val_dataset, batch_size=BATCH_SIZE)


<font size="6">Train and accumulate evaluation per epoch</font>

In [None]:
NUM_EPOCHS = 2
LEARNING_RATE = 0.0001
WEIGHT_DECAY = 0.96
MODEL_PATH = "models"
NUM_CLASSES = 20


model = Classifier(NUM_CLASSES)
model.to(device)

model_name = f"ERM_{NUM_EPOCHS}_Adam_{LEARNING_RATE}_{WEIGHT_DECAY}_CrossEntropy.pth"
save_name = os.path.join(MODEL_PATH, model_name)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

train_evolution = []
val_evolution = []
id_val_evolution = []
best_loss = sys.float_info.max

for epoch in range(NUM_EPOCHS):
	print(f"EPOCH {epoch + 1}:")
	# train
	y_true, y_pred, metadata, loss = train_step(model, ood_train_loader, loss_fn, optimizer, device)
	train_evolution.append(build_metrics_dict(dataset, y_true, y_pred, metadata, loss))

	# validation
	y_true, y_pred, metadata, loss = val_step(model, ood_val_loader, loss_fn, device)
	val_evolution.append(build_metrics_dict(dataset, y_true, y_pred, metadata, loss))
	print(f"OOD Loss: {loss}")

	# # save by best ood loss
	# if loss < best_loss:
	# 	best_loss = loss
	# 	torch.save(model, save_name)

	# in distribution validation
	y_true, y_pred, metadata, loss = val_step(model, id_val_loader, loss_fn, device)
	id_val_evolution.append(build_metrics_dict(dataset, y_true, y_pred, metadata, loss))
	print(f"ID Loss: {loss}")



<font size="6">Plot loss and accuracy per region</font>

In [None]:
metrics = list(train_evolution[0].keys())

for metric in metrics:
	plot_graph(metric, train_evolution, val_evolution, id_val_evolution, NUM_EPOCHS)