# DESCRIPTION
Experiment with MIT's AST (Audio Spectrogram Transformer) for UAV Classification. 

In [2]:
from AST_helper.util import AudioDataset, train_test_split_custom, save_model
from AST_helper.engine import train, inference_loop
from AST_helper.model import auto_extractor, custom_AST
from AST_helper.util import save_model # noqa: F401

import torch
from torch.utils.data import DataLoader
import torch.optim
import torch.nn as nn
from torchinfo import summary

import wandb

device = "cuda" if torch.cuda.is_available() else "cpu"
display(device)

'cuda'

In [3]:
data_path = "C:/Users/Sidewinders/Research_notebooks/Drone_classification/Research/UAV_Dataset_9"
model_name = "MIT/ast-finetuned-audioset-10-10-0.4593"
BATCH_SIZE = 16
SEED = 42
EPOCHS = 10
NUM_CUDA_WORKERS = 0
PINNED_MEMORY = True
SHUFFLED = True
ACCUMULATION_STEPS = 3 # multiplies by batch size for large batch size effect.
OPTIM_LR = 0.0001
TRAIN_PATIENCE = 5
multiple_runs = False
wandb_init = False
SAVE_MODEL = False

torch.cuda.empty_cache()


config = {
        "learning_rate": OPTIM_LR,
        "batch_size": BATCH_SIZE,
        "num_epochs": EPOCHS,
        "random_seed" : SEED,
        "optimizer": "AdamW",
        "loss_function": "CrossEntropyLoss"
    }
wandb_params = {
        "project": "vanilla_AST",
        "name": "classifier_grad_true_lowerLR",
        "reinit": False,
        "notes" : "8457 trainable params",
        "tags": ["AST"],
        "config": config
    }

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

In [5]:
feature_extractor = auto_extractor(model_name)

dataset_0 = AudioDataset(data_path, feature_extractor)
shape = dataset_0[0][0].shape

train_subset, test_subset, inference_subset = train_test_split_custom(dataset_0, test_size=0.2, inference_size=0.1) 
num_classes = len(dataset_0.get_classes()) 

model = custom_AST(model_name, num_classes, device)

summary(model,
        col_names=["num_params","trainable"],
        col_width=20,
        row_settings=["var_names"])

Layer (type (var_name))                                                Param #              Trainable
ASTForAudioClassification (ASTForAudioClassification)                  --                   Partial
├─ASTModel (audio_spectrogram_transformer)                             --                   False
│    └─ASTEmbeddings (embeddings)                                      933,888              False
│    │    └─ASTPatchEmbeddings (patch_embeddings)                      (197,376)            False
│    │    └─Dropout (dropout)                                          --                   --
│    └─ASTEncoder (encoder)                                            --                   False
│    │    └─ModuleList (layer)                                         (85,054,464)         False
│    └─LayerNorm (layernorm)                                           (1,536)              False
├─ASTMLPHead (classifier)                                              --                   True
│    └─LayerNorm (

In [6]:
train_dataloader_custom = DataLoader(dataset=train_subset, 
                                     batch_size=BATCH_SIZE,
                                     num_workers=NUM_CUDA_WORKERS,
                                     pin_memory=PINNED_MEMORY,
                                     shuffle=SHUFFLED)

test_dataloader_custom = DataLoader(dataset=test_subset,
                                    batch_size=BATCH_SIZE, 
                                    num_workers=NUM_CUDA_WORKERS,
                                    pin_memory=PINNED_MEMORY,
                                    shuffle=SHUFFLED)

if inference_subset: # may not be defined
    inference_dataloader_custom = DataLoader(dataset=inference_subset,
                                    batch_size=BATCH_SIZE, 
                                    num_workers=NUM_CUDA_WORKERS,
                                    pin_memory=PINNED_MEMORY,
                                    shuffle=SHUFFLED) 

In [7]:
loss_fn = nn.CrossEntropyLoss()

optimizer = torch.optim.AdamW(model.parameters(), lr=OPTIM_LR)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3) #TODO experiment w/ diff hyperparams

In [8]:
if wandb_init:
    wandb.init(
            project=wandb_params.get("project"),
            config=wandb_params.get("config"),
            name=wandb_params.get("name"),
            reinit=wandb_params.get("reinit", True),
            tags=wandb_params.get("tags", []),
            notes=wandb_params.get("notes", ""),
            dir=wandb_params.get("dir", None)
        )

In [9]:
results = train(model,
                train_dataloader=train_dataloader_custom,
                test_dataloader=test_dataloader_custom,
                optimizer=optimizer,
                scheduler=scheduler,
                loss_fn=loss_fn,
                epochs=EPOCHS,
                device=device,
                accumulation_steps=ACCUMULATION_STEPS,
                patience=TRAIN_PATIENCE)


wandb: Agent Starting Run: o3bn04sd with config:
	batch_size: 64
	epochs: 15
	learning_rate: 0.05211007753744672
	optimizer: adam
wandb: Agent Starting Run: aggy3w4j with config:
	batch_size: 64
	epochs: 5
	learning_rate: 0.008980299649315186
	optimizer: adam
wandb: Agent Starting Run: m8l1vqf8 with config:
	batch_size: 16
	epochs: 15
	learning_rate: 0.03771697645406378
	optimizer: adamW
wandb: Agent Starting Run: 9jc0pnm0 with config:
	batch_size: 64
	epochs: 5
	learning_rate: 0.08420366782944151
	optimizer: adam
wandb: Agent Starting Run: z3ru84zf with config:
	batch_size: 16
	epochs: 5
	learning_rate: 0.009585436261313113
	optimizer: adam


2024-09-24 13:47:19,317 - wandb.wandb_agent - ERROR - Detected 5 failed runs in a row, shutting down.


  0%|          | 0/10 [00:00<?, ?it/s]

  context_layer = torch.nn.functional.scaled_dot_product_attention(


KeyboardInterrupt: 

In [None]:
inference_loop(model=model,
               device=device,
               loss_fn=loss_fn,
               inference_loader= inference_dataloader_custom)



if not multiple_runs and wandb_init:
    wandb.finish()

In [9]:
if SAVE_MODEL:
    save_model(model=model,
            target_dir="saved_models",
            model_name="AST_classifier_true.pt")

[INFO] Saving model to: saved_models\AST_classifier_true.pt
