In [None]:
import timm
import torch
import wandb
import fastai
import dill
import numpy as np
from fastai.callback.wandb import WandbCallback
from fastai.vision.all import *
from fastai.vision.core import *
from fastai.text.core import RegexLabeller
from fastai.vision.utils import get_image_files
from fastai.data.block import DataBlock
from fastai.data.core import *
from fastai.tabular.all import *

In [None]:
config = SimpleNamespace(
    batch_size=256,  #16, #256,
    img_size=224,
    seed=42,
    pretrained=True,
    model_name="maxvit_rmlp_small_rw_224.sw_in1k", # try with maxvit_nano_rw_256.sw_in1k # regnetx_040 coatnet_bn_0_rw_224.sw_in1k
    epochs=2)

def get_images(dataset_path, batch_size, img_size, seed):
    "The beetles dataset"
    files = get_image_files(path=dataset_path, recurse=True, folders=('train','valid'))
    dblock = DataBlock(blocks = (ImageBlock, CategoryBlock),
                       get_items = get_image_files,
                       splitter = GrandparentSplitter(train_name='train', valid_name='valid'),
                       get_y = parent_label,
                       item_tfms = Resize(img_size, ResizeMethod.Pad, pad_mode='zeros'))
    dls = dblock.dataloaders(dataset_path, bs = batch_size)
    return dls

def train(config, dataset_path):
    "Train the model using the supplied config"
    dls = get_images(dataset_path=dataset_path, batch_size=config.batch_size, img_size=config.img_size, seed=config.seed)
    wandb.init(project="Beetle_classifier", group='ambrosia_symbiosis', job_type='training', config=config)
    cbs = [MixedPrecision(), WandbCallback(log='all')]
    learn = vision_learner(dls, 
                           config.model_name, 
                           loss_func=LabelSmoothingCrossEntropyFlat(),
                           metrics=[error_rate, 
                                    accuracy, 
                                    top_k_accuracy], 
                           cbs=cbs, 
                           pretrained=config.pretrained)
    learn.fine_tune(config.epochs)
    interp = ClassificationInterpretation.from_learner(learn)
    interp.plot_confusion_matrix()
    interp.plot_top_losses(5, nrows=5)
    wandb.finish()
    return learn

In [None]:
# Train Model
learn = train(config=config, dataset_path=r"/blue/hulcr/gmarais/Beetle_data/selected_images/train_data")
# save model to disk for inference
learn.path = Path("/blue/hulcr/gmarais/Beetle_classifier/Models")
learn.remove_cb(WandbCallback) # remove WandbCallbacks to allow prediction model to be applied without wandb 
learn.export('beetle_classifier.pkl', pickle_module=dill) # use learn.save to save model and continue training later

In [None]:
# import model again ot test
# learn = load_learner(Path("/blue/hulcr/gmarais/Beetle_classifier/Models") / 'beetle_classifier.pkl', cpu=False, pickle_module=dill)
print(learn.dls.vocab) # print all possible classes of model

# import testing data
dataset_path=r"/blue/hulcr/gmarais/Beetle_data/selected_images"
files = get_image_files(path=dataset_path, recurse=True, folders=('test_data')) # get files from directory
test_dl = learn.dls.test_dl(files, with_labels=True) # load data as a dataloader

preds, targets = learn.get_preds(dl=test_dl)
val_out = learn.validate(dl=test_dl)
print(" Loss: "+str(val_out[0])+"\n",
      "Error Rate: "+str(val_out[1])+"\n",
      "Accuracy: "+str(val_out[2])+"\n",
      "Top k(5) Accuracy: "+str(val_out[3])+"\n")

# this function only describes how much a singular value in al ist stands out.
# if all values in the lsit are high or low this is 1
# the smaller the proportiopn of number of disimilar vlaues are to other more similar values the lower this number
# the larger the gap between the dissimilar numbers and the simialr number the smaller this number
# only able to interpret probabilities or values between 0 and 1
# this function outputs an estimate an inverse of the classification confidence based on the probabilities of all the classes.
# the wedge threshold splits the data on a threshold with a magnitude of a positive int to force a ledge/peak in the data
def unkown_prob_calc(probs, wedge_threshold, wedge_magnitude=1, wedge=True):
    if wedge:
        probs = np.where(probs<=wedge_threshold , probs**(2*wedge_magnitude), probs)
        probs = np.where(probs>=wedge_threshold , probs**(1/(2*wedge_magnitude)), probs)
    diff_matrix = np.abs(probs[:, np.newaxis] - probs)
    diff_matrix_sum = np.sum(diff_matrix)
    probs_sum = np.sum(probs)
    class_val = (diff_matrix_sum/probs_sum)
    max_class_val = ((len(probs)-1)*2)
    kown_prob = class_val/max_class_val
    unknown_prob = 1-kown_prob
    return(unknown_prob)

In [None]:
preds