# DESCRIPTION
Experiment with MIT's AST (Audio Spectrogram Transformer) for UAV Classification. 

# NOTES
- add inference metric to sweep loop
- change schedular 

In [None]:
from AST_helper.util import AudioDataset, train_test_split_custom, save_model
from AST_helper.engine import sweep_train, inference_loop
from AST_helper.model import auto_extractor, custom_AST
from AST_helper.util import save_model # noqa: F401

import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
from torchinfo import summary

import wandb

device = "cuda" if torch.cuda.is_available() else "cpu"
display(device)

In [2]:
data_path = "C:/Users/Sidewinders/Desktop/CODE/UAV_Classification_repo/UAV_Dataset_9"
model_name = "MIT/ast-finetuned-audioset-10-10-0.4593"
BATCH_SIZE = 16
SEED = 42
NUM_CUDA_WORKERS = 0
NUM_CLASSES =  9 
EPOCHS = 7
PINNED_MEMORY = True
SHUFFLED = True
ACCUMULATION_STEPS = 3 # multiplies by batch size for large batch size effect.
SAVE_MODEL = False
PROJECT_NAME = "AST_Sweeps"

torch.cuda.empty_cache()
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

In [None]:
sweep_config = {
    "name": "accumulation_steps",
    "method": "random",
    "metric": {"goal": "maximize", "name": "test_acc"},
    "parameters": {
        "learning_rate": {"distribution":"uniform","min": 0.0009, "max": 0.005},
        "batch_size": {"values": [4,8,16]},
        "epochs" : {"values" : [7]},
        "optimizer" : {"values" : ["adamW"]},
        "scheduler" : {"values" : ["ReduceLROnPlateau"]},
        "accumulation_steps" : {"values" : [1,2,3]}

        }    
}
sweep_id = wandb.sweep(sweep_config, project=PROJECT_NAME)
sweep_count = 100


In [4]:
def make(config):
    # Make the data
    feature_extractor = auto_extractor(model_name)

    dataset = AudioDataset(data_path, feature_extractor)
    train_subset, test_subset = train_test_split_custom(dataset, test_size=0.2)  # type: ignore


    train_loader = DataLoader(dataset=train_subset, 
                                         batch_size=config.batch_size,
                                         num_workers=NUM_CUDA_WORKERS,
                                         pin_memory=PINNED_MEMORY,
                                         shuffle=SHUFFLED)
    
    test_loader = DataLoader(dataset=test_subset,
                                        batch_size=config.batch_size, 
                                        num_workers=NUM_CUDA_WORKERS,
                                        pin_memory=PINNED_MEMORY,
                                        shuffle=SHUFFLED)
    
    # if inference_subset: # may not be defined
    #     inference_dataloader_custom = DataLoader(dataset=inference_subset,
    #                                     batch_size=config.batch_size, 
    #                                     num_workers=NUM_CUDA_WORKERS,
    #                                     pin_memory=PINNED_MEMORY,
    #                                     shuffle=SHUFFLED) 

    # Make the model
    model = custom_AST(model_name, NUM_CLASSES, device)

    # Make the loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(),
                              lr=config.learning_rate)
    
    
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2) 

    return model, train_loader, test_loader, criterion, optimizer, scheduler

In [5]:
def model_pipeline(config=None):

    # tell wandb to get started
    with wandb.init(config):
      # access all HPs through wandb.config, so logging matches execution!
      config = wandb.config
      # make the model, data, and optimization problem
      model, train_loader, test_loader, criterion, optimizer, scheduler = make(config)
      print(model)

      results = sweep_train(model,
                      train_dataloader=train_loader,
                      test_dataloader=test_loader,
                      optimizer=optimizer,
                      scheduler=scheduler,
                      loss_fn=criterion,
                      epochs=config.epochs, # type: ignore
                      device=device,
                      num_classes=NUM_CLASSES,
                      accumulation_steps= config.accumulation_steps # type: ignore #TODO change typing to sweeps format 
                      # patience=TRAIN_PATIENCE)
                      )
      
      inference_loop(model=model,
               device=device,
               loss_fn=criterion,
               inference_loader= train_loader)

    return model, results

In [None]:
# model,result = model_pipeline(config)
wandb.agent(sweep_id, model_pipeline, count=sweep_count)

In [7]:
# if SAVE_MODEL:
#     save_model(model=model,
#             target_dir="saved_models",
#             model_name="AST_classifier_true.pt")