In [114]:
%load_ext autoreload
%autoreload 2
import os, time, copy
import numpy as np
import torch
import torchvision as tv
import matplotlib.pyplot as plt
from collections import Counter
from tqdm import tqdm
from utils import get_train_val_split, get_test_dataset

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [115]:
# Get data 2006-2014 from the following link: https://darchive.mblwhoilibrary.org/handle/1912/7341
# Unzip and merge the datasets in the following directory
data_dir = '~/mai_datasets/plankton/2014'
model_data_dir = '~/mai_datasets/plankton/merged-2006-2013'
model_path = './models/33ckpt.pth'

In [116]:
# Pct Val
pct_val = 0.2

print("Initializing Datasets and Dataloaders...")

# Create training and validation datasets
train_dataset, val_dataset = get_train_val_split(model_data_dir, pct_val)
test_dataset = get_test_dataset(data_dir)
num_classes = len(train_dataset.dataset.classes)
model_classes = train_dataset.dataset.classes
test_classes = test_dataset.classes

# Label correspondence
corr = torch.tensor([ train_dataset.dataset.class_to_idx[cls] for cls in test_dataset.classes ])

Initializing Datasets and Dataloaders...


In [117]:
# Detect if we have a GPU available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")

# Load model
print("Loading model")
model = tv.models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model = model.to(device)
model.eval();

Using device cpu
Loading model


In [118]:
# Initialize dataloaders
def collate_fn(batch):
  batch = list(filter(lambda x: x is not None, batch))
  return torch.utils.data.dataloader.default_collate(batch)
num_classes = len(train_dataset.dataset.classes)
train_dataloader = torch.utils.data.DataLoader(train_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=True, num_workers=0)
val_dataloader = torch.utils.data.DataLoader(val_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=True, num_workers=0)
test_dataloader = torch.utils.data.DataLoader(test_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=False, num_workers=0)

# Detect if we have a GPU available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")

Using device cpu


In [119]:
with torch.no_grad():
    test_preds = torch.zeros((len(test_dataset),num_classes))
    test_labels = torch.zeros((len(test_dataset),))
    counter = 0
    for batch in tqdm(test_dataloader):
        imgs, labels = batch
        test_preds[counter:counter+labels.shape[0],:] = model(imgs.to(device)).cpu()
        test_labels[counter:counter+labels.shape[0]] = labels
        counter += labels.shape[0]
        if counter > 5000:
            break

    test_preds = test_preds[:counter]
    test_labels = test_labels[:counter]

  1%|██                                                                                                                                                     | 9/645 [02:24<2:50:28, 16.08s/it]


In [122]:
uq_preds, counts_preds = np.unique(test_preds.argmax(dim=1).numpy(), return_counts=True)
uq_labels, counts_labels = np.unique(test_labels.numpy(), return_counts=True)
uq_labels = uq_labels.astype(int)
uq_preds = uq_preds.astype(int)

print("Preds in this batch: " + str({model_classes[uq_preds[i]] : counts_preds[i] for i in range(uq_preds.shape[0])}))
print("")
print("Labels in this batch: " + str({test_classes[uq_labels[i]] : counts_labels[i] for i in range(uq_labels.shape[0])}))
print("")
accuracy = (test_preds.argmax(dim=1) == corr[test_labels.long()]).float().mean()
print(f"Accuracy: {accuracy}")

Preds in this batch: {'Akashiwo': 12, 'Asterionellopsis': 186, 'Bidulphia': 1, 'Chaetoceros_didymus_flagellate': 15, 'Chaetoceros_flagellate': 1, 'Chrysochromulina': 182, 'Cochlodinium': 1, 'Corethron': 5, 'DactFragCerataul': 891, 'Dactyliosolen': 94, 'Delphineis': 1, 'Dictyocha': 988, 'Dinobryon': 106, 'Ditylum': 1, 'Ditylum_parasite': 3, 'Ephemera': 3, 'Euplotes_sp': 1, 'Guinardia_flaccida': 1, 'Guinardia_striata': 112, 'Gyrodinium': 231, 'Hemiaulus': 3, 'Katodinium_or_Torodinium': 777, 'Leptocylindrus': 1, 'Pleuronema_sp': 1, 'Prorocentrum': 110, 'Protoperidinium': 5, 'Pseudochattonella_farcimen': 3, 'Thalassionema': 17, 'Tiarina_fusus': 144, 'amoeba': 20, 'bead': 482, 'detritus': 2, 'dino_large1': 103, 'flagellate_sp3': 6, 'other_interaction': 606, 'pennate': 4, 'spore': 1}

Labels in this batch: {'Akashiwo': 1, 'Amphidinium_sp': 66, 'Asterionellopsis': 128, 'Cerataulina': 412, 'Cerataulina_flagellate': 5, 'Ceratium': 6, 'Chaetoceros': 1871, 'Chaetoceros_didymus': 11, 'Chaetoceros_