In [1]:
"""
Author         : Aditya Jain
Last modified  : May 30th, 2023
About          : Test notebook to debug and test changes to the main training file
"""

import wandb
import torchvision.models as torchmodels
import torch
from torch import nn
import json
import os
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.nn.functional as F
import torch.optim as optim
import datetime
import pandas as pd
import tensorflow as tf
import time
from sklearn.metrics import confusion_matrix
import argparse
import random

%load_ext autoreload
%autoreload 2

from data import dataloader
from models.build_model import build_model
from training_params.loss import Loss
from training_params.optimizer import Optimizer
from training_params.lr_scheduler import LRScheduler
from evaluation.micro_accuracy import (
    MicroAccuracyBatch,
    add_batch_microacc,
    final_micro_accuracy,
)
from evaluation.macro_accuracy import (
    MacroAccuracyBatch,
    add_batch_macroacc,
    final_macro_accuracy,
)
from evaluation.taxon_accuracy import taxon_accuracy, add_taxon_accuracy_to_species_checklist
from evaluation.confusion_matrix_data import confusion_matrix_data
from evaluation.confusion_data_conversion import ConfusionDataConvert

#### Define variables

In [7]:
config_file = "config/01-config_uk-denmark_efficientnet.json"
config_data = json.load(open(config_file))
image_resize = config_data["training"]["image_resize"]
batch_size = config_data["training"]["batch_size"]
label_list = config_data["dataset"]["label_info"]
epochs = config_data["training"]["epochs"]
loss_name = config_data["training"]["loss"]["name"]
early_stop = config_data["training"]["early_stopping"]
start_val_loss = config_data["training"]["start_val_loss"]
label_read = json.load(open(label_list))
num_classes = len(label_read["species"])
model_type = config_data["model"]["type"]
preprocess_mode = config_data["model"]["preprocess_mode"]
test_webdataset_url = "/home/mila/a/aditya.jain/scratch/GBIF_Data/webdataset_moths_uk-denmark/test/test-500-{000000..000179}.tar"

test_dataloader = dataloader.build_webdataset_pipeline(
        sharedurl=test_webdataset_url,
        input_size=image_resize,
        batch_size=batch_size,
        set_type="test",
        num_workers=2,
        preprocess_mode=preprocess_mode,
        test_set_num=4,
    )

In [4]:
wandb.init(
        project=config_data["training"]["wandb"]["project"],
        entity=config_data["training"]["wandb"]["entity"],
    )
wandb.init(settings=wandb.Settings(start_method="fork"))

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33madityajain07[0m ([33mmoth-ai[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
# Read other dataset specific files
taxon_hierarchy = config_data["dataset"]["taxon_hierarchy"]
label_info = config_data["dataset"]["label_info"]
species_checklist = pd.read_csv(config_data["dataset"]["species_checklist"])

# Loading model
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
model = build_model(num_classes, model_type)

cpu


In [8]:
global_micro_acc_data = None
global_macro_acc_data = None

with torch.no_grad():
    for image_batch, label_batch in test_dataloader:
        image_batch, label_batch = image_batch.to(device), label_batch.to(device)
        predictions = model(image_batch)

        # micro-accuracy calculation
        micro_accuracy = MicroAccuracyBatch(
            predictions, label_batch, label_info, taxon_hierarchy
        ).batch_accuracy()
        global_micro_acc_data = add_batch_microacc(
            global_micro_acc_data, micro_accuracy
        )

        # macro-accuracy calculation
        macro_accuracy = MacroAccuracyBatch(
            predictions, label_batch, label_info, taxon_hierarchy
        ).batch_accuracy()
        global_macro_acc_data = add_batch_macroacc(
            global_macro_acc_data, macro_accuracy
        )

        break

: 

: 

In [None]:
final_micro_acc = final_micro_accuracy(global_micro_acc_data)
final_macro_acc, taxon_acc = final_macro_accuracy(global_macro_acc_data)
taxa_accuracy = taxon_accuracy(taxon_acc, label_read)
species_checklist_w_accuracy = add_taxon_accuracy_to_species_checklist(species_checklist, taxa_accuracy)

In [7]:
species_checklist_w_accuracy

Unnamed: 0,accepted_taxon_key,order_name,family_name,genus_name,search_species_name,gbif_species_name,confidence,status,match_type,rank,source,accuracy,num_of_train_images,num_of_test_images
0,1845962,Lepidoptera,Autostichidae,Oegoconia,Oegoconia quadripuncta,Oegoconia quadripuncta,99,ACCEPTED,EXACT,SPECIES,uksi_09May2022 denmark_Dec2022,-1.0,-1,-1
1,10055273,Lepidoptera,Tineidae,Oinophila,Oinophila v-flava,Oinophila v-flava,99,ACCEPTED,EXACT,SPECIES,uksi_09May2022,-1.0,-1,-1
2,1742185,Lepidoptera,Tortricidae,Olethreutes,Olethreutes arcuella,Olethreutes arcuella,99,ACCEPTED,EXACT,SPECIES,uksi_09May2022 denmark_Dec2022,0.0,5,1
3,1741545,Lepidoptera,Tortricidae,Olindia,Olindia schumacherana,Olindia schumacherana,99,ACCEPTED,EXACT,SPECIES,uksi_09May2022 denmark_Dec2022,-1.0,-1,-1
4,1875120,Lepidoptera,Pyralidae,Oncocera,Oncocera semirubella,Oncocera semirubella,99,ACCEPTED,EXACT,SPECIES,uksi_09May2022 denmark_Dec2022,-1.0,-1,-1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3018,4532050,Lepidoptera,Crambidae,Udea,Udea hamalis Thnbg.,Udea hamalis,99,ACCEPTED,EXACT,SPECIES,denmark_Dec2022,-1.0,-1,-1
3019,1882158,Lepidoptera,Crambidae,Loxostege,Loxostege turbidalis Tr.,Loxostege turbidalis,99,ACCEPTED,EXACT,SPECIES,denmark_Dec2022,-1.0,-1,-1
3020,1892242,Lepidoptera,Crambidae,Ecpyrrhorrhoe,Ecpyrrhorrhoe rubiginalis Hb.,Ecpyrrhorrhoe rubiginalis,99,ACCEPTED,EXACT,SPECIES,denmark_Dec2022,-1.0,-1,-1
3021,1890699,Lepidoptera,Crambidae,Pyrausta,Pyrausta porphyralis D.& S.,Pyrausta porphyralis,98,ACCEPTED,EXACT,SPECIES,denmark_Dec2022,-1.0,-1,-1
