In [None]:
import timm
import torch
import wandb
import fastai
import dill
from fastai.callback.wandb import WandbCallback
from fastai.vision.all import *
from fastai.vision.core import *
from fastai.text.core import RegexLabeller
from fastai.vision.utils import get_image_files
from fastai.data.block import DataBlock
from fastai.data.core import *
from fastai.tabular.all import *
os.environ['WANDB_NOTEBOOK_NAME'] = 'Parameter_Optimization_Sweep.ipynb'

In [None]:
# define configs and parameters
# define static parameters
meta_config = SimpleNamespace(
    dataset_path = r"/blue/hulcr/gmarais/Beetle_data/selected_images/train_data",
    img_size=224,
    seed=42,
    project="Ambrosia_Symbiosis",
    # group="Beetle_classifier",
    # job_type="parameter_optimization"
    )

# define parameter optimization config
sweep_config = {
    'name': 'Beetle_Classifier_Sweep',
    'project':meta_config.project,
    'method': 'bayes',
    'run_cap':10,
    'metric': {
        'goal': 'minimize', 
        'name': 'validation_loss'
        },
    # 'early_terminate':{
    #     'type': 'hyperband',
    #     'min_iter': 1,
    #     'max_iter': 100,
    #     'eta': 3,
    #     's': 2
    # },
    'parameters': {
        'pretrained':{'values': [True, False]},
        'model_name':{'values': ["maxvit_rmlp_small_rw_224.sw_in1k"]},
        'batch_size': {'values': [128, 64, 256]},
        'epochs': {'values': [2, 5, 3]},
     }
}

In [None]:
# define functions required for sweep
def get_images(dataset_path, batch_size, img_size, seed):
    "The beetles dataset"
    files = get_image_files(path=dataset_path, recurse=True, folders=('train','valid'))
    dblock = DataBlock(blocks = (ImageBlock, CategoryBlock),
                       get_items = get_image_files,
                       splitter = GrandparentSplitter(train_name='train', valid_name='valid'),
                       get_y = parent_label,
                       item_tfms = Resize(img_size, ResizeMethod.Pad, pad_mode='zeros'))
    dls = dblock.dataloaders(dataset_path, bs = batch_size)
    return dls

# def train(meta_config):
#     "Train the model using the supplied configs"
#     run = wandb.init(project=meta_config.project) # , job_type=meta_config.job_type, group =meta_config.group, 
#     dls = get_images(dataset_path=meta_config.dataset_path, img_size=meta_config.img_size, seed=meta_config.seed, batch_size=wandb.config.batch_size)
#     cbs = [MixedPrecision(), ShowGraphCallback(), SaveModelCallback(), WandbCallback(log='all')] 
#     learn = vision_learner(dls, 
#                            wandb.config.model_name,
#                            loss_func=LabelSmoothingCrossEntropyFlat(),
#                            metrics=[error_rate, 
#                                     accuracy, 
#                                     top_k_accuracy], 
#                            cbs=cbs, 
#                            pretrained=wandb.config.pretrained)
#     learn.fine_tune(wandb.config.epochs)
#     run.finish()

def train(config, dataset_path, subfolders=('train','valid')):
    "Train the model using the supplied config"
    dls = get_images(dataset_path=dataset_path, batch_size=config.batch_size, img_size=config.img_size, seed=config.seed, subfolders=subfolders)
    labels = np.array([str(x).split('/')[-2] for x in dls.items])
    classes = np.unique(labels)
    weights = compute_class_weight(class_weight='balanced', classes=classes, y=labels)
    class_weights = {c: w for c, w in zip(classes, weights)}
    weights = tensor([class_weights[c] for c in dls.vocab]).to(dls.device)
    wandb.init(project=config.wandb_project, group=config.wandb_group, job_type=config.job_type, config=config) # it is a good idea to keep these functions out of the training function due to some exporting issues
    cbs = [MixedPrecision(), ShowGraphCallback(), SaveModelCallback(), WandbCallback(log='gradients')] # (all, parameters, gradients or None) parameters and all does nto work currently wandb needs to be updated
    learn = vision_learner(dls, 
                           config.model_name, 
                           loss_func=LabelSmoothingCrossEntropy(weight=weights), # this fucntion is used for class imbalance it is a regularization technique # LabelSmoothingCrossEntropyFlat is used for multi dimensional data
                           metrics=[error_rate, 
                                    accuracy, 
                                    top_k_accuracy], 
                           cbs=cbs, 
                           pretrained=config.pretrained)
    learn.fine_tune(config.epochs)
    interp = ClassificationInterpretation.from_learner(learn)
    interp.plot_confusion_matrix()
    interp.plot_top_losses(config.top_k_losses, nrows=config.top_k_losses)
    wandb.finish() # it is a good idea to keep these functions out of the training function due to some exporting issues
    # return learn
    
# Prepare training wrapper based on configs
def train_wrapper():
    train(meta_config = meta_config)

In [None]:
# Run sweep    
sweep_id = wandb.sweep(sweep_config)
wandb.agent(sweep_id, function=train_wrapper)