In [1]:
from models import parameters
params = parameters.paramStore.default

In [2]:
import torch
import torchvision
import torchvision.models as models
import torch.nn as nn
import numpy as np
from torchvision import transforms, datasets
from sklearn import metrics
import time
import copy
import os
import random
from matplotlib import pyplot as plt
import glob
from pathlib import Path
import PIL
import math
import sys
from wilds import get_dataset
from wilds.datasets.wilds_dataset import WILDSSubset
from wilds.common.data_loaders import get_train_loader, get_eval_loader
import torchvision.transforms as transforms
from wilds.common.grouper import CombinatorialGrouper
from sklearn.model_selection import train_test_split
from collections import Counter

In [3]:
params.device

device(type='cuda')

In [5]:
load_pretrained_model = False
img_size          = params.img_size
location          = None
batch_size        = params.batch_size
epochs            = params.epocs
eval_per_epochs   = 1
device            = params.device
device

device(type='cuda')

In [4]:
iwildcam = get_dataset(dataset="iwildcam", download=False)
grouper = CombinatorialGrouper(iwildcam, ['location'])
iwildcam.split_dict

{'train': 0, 'val': 1, 'test': 2, 'id_val': 3, 'id_test': 4}

In [26]:
trainset= iwildcam.get_subset(
  "train",
  transform=transforms.Compose(
      [transforms.Resize((img_size, img_size)), transforms.ToTensor()]
  ),
)

testset= iwildcam.get_subset(
  "test",
  transform=transforms.Compose(
      [transforms.Resize((img_size, img_size)), transforms.ToTensor()]
  ),
)

valset = iwildcam.get_subset(
  "val",
  transform=transforms.Compose(
      [transforms.Resize((img_size, img_size)), transforms.ToTensor()]
  ),
)

In [20]:
def train(device, model, dataloader, criterion, optimizer):
    model.train()
    for i, data in enumerate(dataloader):
        inputs, labels, _ = data
        inputs = inputs.to(device).float()
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

In [21]:
def evaluate(device, model, data_loader, dataset, text=None):
  model.eval()
  predictions = []
  true_labels = []
  metadatas = []

  with torch.no_grad():
      for data in data_loader:
          inputs, labels, metadata = data
          inputs = inputs.to(device).float()
          true_labels += list(labels.to(device).cpu().numpy())
          metadatas += list(metadata.to(device).cpu().numpy())
          outputs = model(inputs)
          _, predicted = torch.max(outputs.data, 1)
          predictions += list(predicted.to(device).cpu().numpy())

  eval = dataset.eval(torch.tensor(predictions), torch.tensor(true_labels), torch.tensor(metadatas))
  if text != None:
    print(text + eval[1])
  return eval[0]

In [22]:
def get_targets(dataset):
  targets = {}
  for i in dataset.indices:
    targets[i] = int(dataset.dataset[i][1])
  return targets

In [23]:
# targets = get_targets(trainset)

In [28]:
resnet50 = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_resnet50', pretrained=True)

Downloading: "https://github.com/NVIDIA/DeepLearningExamples/archive/torchhub.zip" to C:\Users\jakob/.cache\torch\hub\torchhub.zip
Downloading: "https://api.ngc.nvidia.com/v2/models/nvidia/resnet50_pyt_amp/versions/20.06.0/files/nvidia_resnet50_200821.pth.tar" to C:\Users\jakob/.cache\torch\hub\checkpoints\nvidia_resnet50_200821.pth.tar


  0%|          | 0.00/97.7M [00:00<?, ?B/s]

In [None]:
dataset = testset
location_indices = []
for i in dataset.indices:
  location_i = int(grouper.metadata_to_group(torch.stack([dataset.dataset[i][2]]))[0])
  if location_i == location or location is None:
    location_indices.append(i)
dataset.indices = location_indices
unfiltered_targets = get_targets(dataset)
class_freq = Counter(unfiltered_targets.values())
filter_classes = []
for class_, freq in class_freq.items():
  if freq == 1:
    filter_classes.append(class_)
indices = []
for i in dataset.indices:
  if unfiltered_targets[i] not in filter_classes:
    indices.append(i)
dataset.indices = indices
targets = get_targets(dataset)




In [38]:
train_indices, valid_indices = train_test_split(testset.indices, test_size=0.5, stratify=list(targets.values()))
train_data = WILDSSubset(dataset.dataset, train_indices, dataset.transform)
valid_data = WILDSSubset(dataset.dataset, valid_indices, dataset.transform)
train_loader = get_train_loader('standard', train_data, batch_size)
valid_loader = get_eval_loader('standard', valid_data, batch_size)

In [40]:
model = torchvision.models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, dataset.n_classes)
model.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer_conv = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [None]:
epoch = 0
while epoch < epochs + 1:
    # Train
    train(device, model, train_loader, criterion, optimizer_conv)
    
    # Evaluation
    if epoch % eval_per_epochs == 0 or epoch == 1:
      evaluate(device, model, train_loader, dataset, f"Epoch {epoch} train set\n")
      f1 = evaluate(device, model, valid_loader, dataset, f"Epoch {epoch} valid set\n")['F1-macro_all']
      print("--------------")

      if f1 > best_f1[0]: # store best model so far, for later, based on best val auc
          best_model = copy.deepcopy(model)
          best_f1 = [f1, epoch]
    
    epoch += 1

In [46]:
train_data.metadata_array

tensor([[  188, 14022,  2013,  ...,     1,     0,     0],
        [  287, 13916,  2013,  ...,    26,    79,     0],
        [  120,  8505,  2013,  ...,    25,    15,     0],
        ...,
        [  163, 21053,  2015,  ...,    42,    27,     0],
        [  187, 34468,  2013,  ...,    53,     0,     0],
        [  288, 12740,  2013,  ...,    12,    36,     0]])

In [144]:
train_data.dataset[50]

(<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=560x448>,
 tensor(146),
 tensor([   2, 2157, 2013,    6,   10,    6,   53,   18,  146,    0]))

In [143]:
len(train_data.dataset)

203029