# Audio Classification with TART!!!

In [None]:
%load_ext autoreload
%reload_ext autoreload

In [None]:
%autoreload 2

In [None]:
import os
import sys

import torch

sys.path.append(f'{os.path.dirname(os.getcwd())}/')
import warnings
import yaml

from tart.tart_modules import Tart
from tart.registry import DATASET_REGISTRY

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning, module='tqdm')

### !! Running this notebook with a pre-trained TART head !!
* Download [pre-trained TART Reasoning module](https://github.com/HazyResearch/TART/releases/download/reasoning_module/tart_heads.zip)  --- see the cell below

* Set the location of the downloaded module to `path_tart_weights` in the cell below

In [None]:
! wget https://github.com/HazyResearch/TART/releases/download/reasoning_module/tart_heads.zip
! unzip tart_heads.zip

In [None]:
#### CUSTOMIZE AS NEEDED ####
path_tart_weights = '/u/scr/nlp/data/ic-fluffy-head-k-2/3e9724ed-5a49-4070-9b7d-4209a30e2392' # PATH to pretrained weights downloaded above
cache_dir = '/u/scr/nlp/data/neo/hub'
path_tart_config = 'tart_conf.yaml'  # if you are using the pre-trained module above, don't change this!
data_dir_path = None

### Step #1: Set-up TART
* To set-up tart, we need to first load in our TART reasoning module and then load in the base embedding model

In [None]:
BASE_EMBED_MODEL = "openai/whisper-large"
EMBED_METHOD = "stream"
PATH_TO_PRETRAINED_HEAD = f"{path_tart_weights}/model_24000.pt"
TART_CONFIG = yaml.load(open(path_tart_config, "r"), Loader=yaml.FullLoader)
TOTAL_TRAIN_SAMPLES = TART_CONFIG['n_positions'] - 2
PATH_TO_FINETUNED_EMBED_MODEL = None
CACHE_DIR = cache_dir
NUM_PCA_COMPONENTS = 8
DOMAIN = "audio"


In [None]:
#### Instantiate TartModule
tart_module = Tart(
    embed_method=EMBED_METHOD,
    embed_model_name=BASE_EMBED_MODEL,
    path_to_pretrained_head=PATH_TO_PRETRAINED_HEAD,
    tart_head_config=TART_CONFIG,
    path_to_finetuned_embed_model=PATH_TO_FINETUNED_EMBED_MODEL,
    cache_dir=CACHE_DIR,
    num_pca_components=NUM_PCA_COMPONENTS,
    domain="audio",
)


### Step #2: Load in data...

In [None]:
#### CUSTOMIZE AS NEEDED ####
DATASET_NAME = "speech_commands"
seed = 42
k_range = [18, 32, 64, 128] # k = number of in-context examples
pos_class = 0
neg_class = 1
max_eval_samples=1000 # total number of samples to evaluate on

Download data from HF datasets, and sample a class balanced "train" set of ICL examples. 

More concretely, indexing into `X_train_1` with $k$ -- `X_train_1[0:k]` -- returns a list of train samples where $k/2$ of the samples have a positive label and $k/2$ have a negative label.

In [None]:
dataset = DATASET_REGISTRY[DOMAIN][DATASET_NAME](
    total_train_samples=TOTAL_TRAIN_SAMPLES, 
    k_range=k_range,
    seed=seed,
    cache_dir=CACHE_DIR,
    max_eval_samples=max_eval_samples,
    pos_class=pos_class,
    neg_class=neg_class,
)

X_train, y_train, X_test, y_test = dataset.get_dataset

### Step #3: Evaluate!

In [None]:
results_at_k =  {}
with torch.no_grad():
    for k in k_range:
        result = tart_module.evaluate(
            X_train,
            y_train,
            X_test,
            y_test,
            k=k,
            seed=seed,
        )
        results_at_k[k] = result
        print(f"Accuracy at {k} samples: {result['accuracy']}")

        