In [35]:
#IMPORTS
#built-in
import argparse
import ast
import os
import pickle
import sys
import time

#third party
import numpy as np
import pandas as pd
import torch

from google.cloud import storage, bigquery
from torch.utils.data import WeightedRandomSampler

#local
from dataloader_mayo import AudioDataset
from models import ASTModel_pretrain, ASTModel_finetune
#from traintest_mayo import train, validate
#from traintest_mask_mayo import trainmask


## Load data from google cloud storage

In [15]:
project_name = 'ml-mps-aif-afdgpet01-p-6827'
study = 'speech_poc_freeze_1'
bucket_name = 'ml-e107-phi-shared-aif-us-p'
gcs_prefix = f'speech_ai/speech_lake/{study}'

storage_client = storage.Client(project=project_name)
bq_client = bigquery.Client(project=project_name)
bucket = storage_client.bucket(bucket_name)

file_list=[]
for blob in storage_client.list_blobs(bucket_name, prefix=gcs_prefix):
    file_list.append(blob.name)

    extensions=[f.split('.')[-1] for f in file_list]

data_split_root = 'gs://ml-e107-phi-shared-aif-us-p/speech_ai/share/data_splits/amr_subject_dedup_594_train_100_test_binarized_v20220620'
gcs_train_path = f'{data_split_root}/train.csv'
gcs_test_path = f'{data_split_root}/test.csv'

In [16]:
# (1) load the train and test files to a df
train_df = pd.read_csv(gcs_train_path, index_col = 'uid')
test_df = pd.read_csv(gcs_test_path, index_col = 'uid')

# (2) alter columns as necessary 
train_df["distortions"]=((train_df["distorted Cs"]+train_df["distorted V"])>0).astype(int)
test_df["distortions"]=((test_df["distorted Cs"]+test_df["distorted V"])>0).astype(int)

# (3) define target labels
target_labels=['breathy',
             'loudness decay',
             'slow rate',
             'high pitch',
             'hoarse / harsh',
             'irregular artic breakdowns',
             'rapid rate',
             'reduced OA loudness',
             'abn pitch variability',
             'strained',
             'hypernasal',
             'abn loudness variability',
              'distortions']

## Set additional variables for running SSAST and set up audio configurations

In [18]:
#audio configuration
dataset = 'retrospeech'
#audio
resample_rate = 16000
reduce = True
clip_length = 0
#audio augmentations
tshift = 0 #time shift
speed = 0
gauss = 0 #amt noise
pshift = 0 #pitch shift
pshiftn = 0 #pitch shift n steps
gain = 0
stretch = 0
#spectrogram
dataset_mean = -4.2677393
dataset_std = 4.5689974
target_length = 1024
num_mel_bins = 128
freqm = 0
timem = 0
mixup = 0
noise = False
#new_audio_conf = {'resample_rate':16000, 'reduce': True, 'clip_length':0, 'tshift':0, 'speed':0, 'gauss_noise':0, 'pshift':0, 'pshiftn':0, 'gain':0, 'stretch': 0, 'num_mel_bins': 128, 'target_length': 1024, 'freqm': 0, 'timem': 0, 'mixup': 0, 'dataset': 'demo',
#              'mode':'train', 'mean':dataset_mean, 'std':dataset_std, 'noise':False}
#new_audio_conf = {'resample_rate':16000, 'reduce': True, 'clip_length':0, 'tshift':0.9, 'speed':0, 'gauss_noise':0.8, 'pshift':0, 'pshiftn':0, 'gain':0.9, 'stretch': 0, 'num_mel_bins': 128, 'target_length': 1024, 'freqm': 0, 'timem': 0, 'mixup': 0, 'dataset': 'demo','mode':'train', 'mean':dataset_mean, 'std':dataset_std, 'noise':False}

#train_data = AudioDataset(train_df, target_labels, new_audio_conf, gcs_prefix, bucket)



#train_loader = torch.utils.data.DataLoader(
 #   train_data,
  #  batch_size=1, shuffle=True, num_workers=0)


train_audio_conf = {'dataset': dataset, 'mode': 'train', 'resample_rate': resample_rate, 'reduce': reduce, 'clip_length': 0,
                    'tshift':tshift, 'speed':speed, 'gauss_noise':gauss, 'pshift':pshift, 'pshiftn':pshiftn, 'gain':gain, 'stretch': stretch,
                    'num_mel_bins': num_mel_bins, 'target_length': target_length, 'freqm': freqm, 'timem': timem, 'mixup': mixup, 'noise':noise,
                    'mean':dataset_mean, 'std':dataset_std}

eval_audio_conf = {'dataset': dataset, 'mode': 'evaluation', 'resample_rate': resample_rate, 'reduce': reduce, 'clip_length': 0,
                    'tshift':tshift, 'speed':speed, 'gauss_noise':gauss, 'pshift':pshift, 'pshiftn':pshiftn, 'gain':gain, 'stretch': stretch,
                    'num_mel_bins': num_mel_bins, 'target_length': target_length, 'freqm': freqm, 'timem': timem, 'mixup': mixup, 'noise':noise,
                    'mean':dataset_mean, 'std':dataset_std}

## Initialize dataset

In [36]:
train_dataset = AudioDataset(train_df, target_labels, train_audio_conf, gcs_prefix, bucket) #librosa = True (might need to debug this one)

test_dataset = AudioDataset(test_df, target_labels, eval_audio_conf, gcs_prefix, bucket)
#optional validation set

---------------the train dataloader---------------
now process retrospeech
now using following mask: 0 freq, 0 time
MIXUP NOT CURRENTLY AVAILABLE
use dataset mean -4.268 and std 4.569 to normalize the input.
number of classes is 13
---------------the evaluation dataloader---------------
now process retrospeech
now using following mask: 0 freq, 0 time
MIXUP NOT CURRENTLY AVAILABLE
use dataset mean -4.268 and std 4.569 to normalize the input.
number of classes is 13


# RUNNING SSAST

In [39]:
batch_size = 1
num_workers = 0

In [49]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=False, drop_last=True)

#EVENTUALLY ADD IN OPTIONAL VALIDATION
eval_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

In [28]:
batch = next(iter(train_loader))

In [23]:
batch

{'waveform': tensor([[[ 0.0086,  0.0086,  0.0086,  ..., -0.0047, -0.0044,  0.0023]]]),
 'targets': tensor([[0., 0., 1., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0.]]),
 'sample_rate': tensor([16000]),
 'fbank': tensor([[[-1.2776, -1.2776, -1.2776,  ..., -1.2776, -1.2776, -1.2776],
          [-1.2776, -1.2776, -1.2776,  ..., -1.2776, -1.2776, -1.2776],
          [-1.2776, -1.2776, -1.2776,  ..., -1.0430, -1.0535, -1.0200],
          ...,
          [-0.7442, -1.1240, -0.7471,  ..., -0.4353, -0.5055, -0.3359],
          [-1.2776, -1.2776, -1.1366,  ..., -0.3679, -0.5114, -0.4016],
          [-1.0614, -1.2776, -1.1437,  ..., -0.4080, -0.4991, -0.4675]]])}

# FINETUNING THE WAY WE WANT
take in a variable saying if we want to freeze the model? 

In [44]:
ast_mdl = ASTModel_finetune(
    task='ft_cls', label_dim=len(target_labels), fshape=128, tshape=2, fstride=128, tstride=2, input_fdim=128, input_tdim=target_length, 
    model_size='base', load_pretrained_mdl_path='/Users/m144443/Documents/mayo_ssast/pretrained_model/SSAST-Base-Frame-400.pth')

now load a SSL pretrained models from /Users/m144443/Documents/mayo_ssast/pretrained_model/SSAST-Base-Frame-400.pth
pretraining patch split stride: frequency=128, time=2
pretraining patch shape: frequency=128, time=2
pretraining patch array dimension: frequency=1, time=512
pretraining number of patches=512
fine-tuning patch split stride: frequncey=128, time=2
fine-tuning number of patches=512




In [45]:
#FREEZE THE MODEL (only finetuning classifier head)
for param in ast_mdl.v.parameters():
    param.requires_grad = False
    
model_parameters = filter(lambda p: p.requires_grad, ast_mdl.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(f'Number of trainable parameters: {params}')

Number of trainable parameters: 11533


Basic running model (no scheduler or anything, no validation)

In [31]:
criterion = torch.nn.BCEWithLogitsLoss()
optim = torch.optim.AdamW([p for p in ast_mdl.parameters() if p.requires_grad])

In [46]:
epochs = 1 #for now just 1 epoch
for e in range(epochs):
    running_loss = 0
    for i, batch in enumerate(train_loader):
        x = batch['fbank']
        targets = batch['targets'] #have to change to select targets like this
        optim.zero_grad()
        o =  ast_mdl(x) #no need for task + give just fbank
        loss = criterion(o, targets)
        loss.backward()
        optim.step()
        running_loss += loss.item()
        print(f'Progress: {round(i/len(train_loader)*100)}%    ',end='\r')
        
    print(e, running_loss/len(train_loader))

0 0.6259236772253056


In [47]:
torch.save(ast_mdl.state_dict(), 'ast_mdl_base_frame_400_speechfeat_13_adamw_1epoch.pt')

In [50]:
ast_mdl.eval()
all_preds=[]
all_targets=[]
for i, batch in enumerate(eval_loader):
    x = batch['fbank']
    targets = batch['targets']
    optim.zero_grad()
    o=ast_mdl(x)
    all_preds.append(o)
    all_targets.append(targets)
    print(f'Progress: {round(i/len(eval_loader)*100)}%    ',end='\r')

Progress: 99%    

In [51]:
#simple metrics
pred_mat=torch.sigmoid(torch.cat(all_preds)).detach().numpy()
target_mat=torch.cat(all_targets).detach().numpy()
from sklearn.metrics import roc_auc_score, roc_curve
import matplotlib.pyplot as plt

aucs=roc_auc_score(target_mat, pred_mat, average = None)
print(aucs)
data = [
('Label', target_labels),
('AUC', target_labels)]
pd.DataFrame({'Label':target_labels, 'AUC':aucs})


[0.39930556 0.51215686 0.60547504 0.38705739 0.5703125  0.53494505
 0.57765152 0.66666667 0.44106667 0.47769622 0.27272727 0.42365967
 0.57014157]


Unnamed: 0,Label,AUC
0,breathy,0.399306
1,loudness decay,0.512157
2,slow rate,0.605475
3,high pitch,0.387057
4,hoarse / harsh,0.570312
5,irregular artic breakdowns,0.534945
6,rapid rate,0.577652
7,reduced OA loudness,0.666667
8,abn pitch variability,0.441067
9,strained,0.477696


In [None]:
import sys
import os
import datetime
sys.path.append(os.path.dirname(os.path.dirname(sys.path[0])))
from src.utilities import *
import time
import torch
from torch import nn
import numpy as np
import pickle
from torch.cuda.amp import autocast,GradScaler

In [None]:
def train(audio_model, train_loader, test_loader, args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print('running on ' + str(device))
    torch.set_grad_enabled(True)

    # Initialize all of the statistics we want to keep track of
    batch_time = AverageMeter()
    per_sample_time = AverageMeter()
    data_time = AverageMeter()
    per_sample_data_time = AverageMeter()
    loss_meter = AverageMeter()
    per_sample_dnn_time = AverageMeter()
    progress = []
    # best_cum_mAP is checkpoint ensemble from the first epoch to the best epoch
    best_epoch, best_cum_epoch, best_mAP, best_acc, best_cum_mAP = 0, 0, -np.inf, -np.inf, -np.inf
    global_step, epoch = 0, 0
    start_time = time.time()
    exp_dir = args.exp_dir

    def _save_progress():
        progress.append([epoch, global_step, best_epoch, best_mAP,
                time.time() - start_time])
        with open("%s/progress.pkl" % exp_dir, "wb") as f:
            pickle.dump(progress, f)

    if not isinstance(audio_model, nn.DataParallel):
        audio_model = nn.DataParallel(audio_model)

    audio_model = audio_model.to(device)
    # Set up the optimizer
    trainables = [p for p in audio_model.parameters() if p.requires_grad]
    print('Total parameter number is : {:.3f} million'.format(sum(p.numel() for p in audio_model.parameters()) / 1e6))
    print('Total trainable parameter number is : {:.3f} million'.format(sum(p.numel() for p in trainables) / 1e6))

    # diff lr optimizer
    mlp_list = ['mlp_head.0.weight', 'mlp_head.0.bias', 'mlp_head.1.weight', 'mlp_head.1.bias']
    mlp_params = list(filter(lambda kv: kv[0] in mlp_list, audio_model.module.named_parameters()))
    base_params = list(filter(lambda kv: kv[0] not in mlp_list, audio_model.module.named_parameters()))
    mlp_params = [i[1] for i in mlp_params]
    base_params = [i[1] for i in base_params]
    # only finetuning small/tiny models on balanced audioset uses different learning rate for mlp head
    print('The mlp header uses {:d} x larger lr'.format(args.head_lr))
    optimizer = torch.optim.Adam([{'params': base_params, 'lr': args.lr}, {'params': mlp_params, 'lr': args.lr * args.head_lr}], weight_decay=5e-7, betas=(0.95, 0.999))
    mlp_lr = optimizer.param_groups[1]['lr']
    lr_list = [args.lr, mlp_lr]

    print('Total mlp parameter number is : {:.3f} million'.format(sum(p.numel() for p in mlp_params) / 1e6))
    print('Total base parameter number is : {:.3f} million'.format(sum(p.numel() for p in base_params) / 1e6))

    # # dataset specific settings
    # if args.dataset == 'audioset':
    #     if len(train_loader.dataset) > 2e5:
    #         print('scheduler for full audioset is used')
    #         scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2,3,4,5], gamma=0.5, last_epoch=-1)
    #     else:
    #         print('scheduler for balanced audioset is used')
    #         scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [10, 15, 20, 25], gamma=0.5, last_epoch=-1)
    #     main_metrics = 'mAP'
    #     loss_fn = nn.BCEWithLogitsLoss()
    #     warmup = True
    # elif args.dataset == 'esc50':
    #     print('scheduler for esc-50 is used')
    #     scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, list(range(5,26)), gamma=0.85)
    #     main_metrics = 'acc'
    #     loss_fn = nn.CrossEntropyLoss()
    #     warmup = False
    # elif args.dataset == 'speechcommands':
    #     print('scheduler for speech commands is used')
    #     scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, list(range(5,26)), gamma=0.85)
    #     main_metrics = 'acc'
    #     loss_fn = nn.BCEWithLogitsLoss()
    #     warmup = False
    # else:
    #     raise ValueError('unknown dataset, dataset should be in [audioset, speechcommands, esc50]')
    # print('now training with {:s}, main metrics: {:s}, loss function: {:s}, learning rate scheduler: {:s}'.format(str(args.dataset), str(main_metrics), str(loss_fn), str(scheduler)))

    if args.adaptschedule == True:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=args.lr_patience, verbose=True)
        print('now use adaptive learning rate scheduler.')
    else:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, list(range(args.lrscheduler_start, 1000, args.lrscheduler_step)),gamma=args.lrscheduler_decay)
    main_metrics = args.metrics
    if args.loss == 'BCE':
        loss_fn = nn.BCEWithLogitsLoss()
    elif args.loss == 'CE':
        loss_fn = nn.CrossEntropyLoss()
    args.loss_fn = loss_fn

    print('now training with {:s}, main metrics: {:s}, loss function: {:s}, learning rate scheduler: {:s}'.format(str(args.dataset), str(main_metrics), str(loss_fn), str(scheduler)))
    print('The learning rate scheduler starts at {:d} epoch with decay rate of {:.3f} every {:d} epoches'.format(args.lrscheduler_start, args.lrscheduler_decay, args.lrscheduler_step))

    epoch += 1

    print("current #steps=%s, #epochs=%s" % (global_step, epoch))
    print("start training...")
    result = np.zeros([args.n_epochs, 10])
    audio_model.train()
    while epoch < args.n_epochs + 1:
        begin_time = time.time()
        end_time = time.time()
        audio_model.train()
        print('---------------')
        print(datetime.datetime.now())
        print("current #epochs=%s, #steps=%s" % (epoch, global_step))

        for i, (audio_input, labels) in enumerate(train_loader):

            B = audio_input.size(0)
            audio_input = audio_input.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            data_time.update(time.time() - end_time)
            per_sample_data_time.update((time.time() - end_time) / audio_input.shape[0])
            dnn_start_time = time.time()

            # first several steps for warm-up
            if global_step <= 1000 and global_step % 50 == 0 and args.warmup == True:
                for group_id, param_group in enumerate(optimizer.param_groups):
                    warm_lr = (global_step / 1000) * lr_list[group_id]
                    param_group['lr'] = warm_lr
                    print('warm-up learning rate is {:f}'.format(param_group['lr']))

            audio_output = audio_model(audio_input)
            if isinstance(loss_fn, torch.nn.CrossEntropyLoss):
                loss = loss_fn(audio_output, torch.argmax(labels.long(), axis=1))
            else:
                loss = loss_fn(audio_output, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # record loss
            loss_meter.update(loss.item(), B)
            batch_time.update(time.time() - end_time)
            per_sample_time.update((time.time() - end_time)/audio_input.shape[0])
            per_sample_dnn_time.update((time.time() - dnn_start_time)/audio_input.shape[0])

            print_step = global_step % args.n_print_steps == 0
            early_print_step = epoch == 0 and global_step % (args.n_print_steps/10) == 0
            print_step = print_step or early_print_step

            if print_step and global_step != 0:
                print('Epoch: [{0}][{1}/{2}]\t'
                  'Per Sample Total Time {per_sample_time.avg:.5f}\t'
                  'Per Sample Data Time {per_sample_data_time.avg:.5f}\t'
                  'Per Sample DNN Time {per_sample_dnn_time.avg:.5f}\t'
                  'Train Loss {loss_meter.avg:.4f}\t'.format(
                   epoch, i, len(train_loader), per_sample_time=per_sample_time, per_sample_data_time=per_sample_data_time,
                      per_sample_dnn_time=per_sample_dnn_time, loss_meter=loss_meter), flush=True)
                if np.isnan(loss_meter.avg):
                    print("yuan training diverged...")
                    torch.save(audio_model.state_dict(), "%s/models/nan_audio_model.pth" % (exp_dir))
                    torch.save(optimizer.state_dict(), "%s/models/nan_optim_state.pth" % (exp_dir))
                    with open(exp_dir + '/audio_input.npy', 'wb') as f:
                        np.save(f, audio_input.cpu().detach().numpy())
                    np.savetxt(exp_dir + '/audio_output.csv', audio_output.cpu().detach().numpy(), delimiter=',')
                    np.savetxt(exp_dir + '/labels.csv', labels.cpu().detach().numpy(), delimiter=',')
                    print('audio output and label saved for debugging.')
                    #return

            end_time = time.time()
            global_step += 1

        '''
        if not skip_validation:
            print('start validation')
            stats, valid_loss = validate(audio_model, test_loader, args, epoch)

            # ensemble results
            cum_stats = validate_ensemble(args, epoch)
            cum_mAP = np.mean([stat['AP'] for stat in cum_stats])
            cum_mAUC = np.mean([stat['auc'] for stat in cum_stats])
            cum_acc = cum_stats[0]['acc']

            mAP = np.mean([stat['AP'] for stat in stats])
            mAUC = np.mean([stat['auc'] for stat in stats])
            acc = stats[0]['acc']

            middle_ps = [stat['precisions'][int(len(stat['precisions'])/2)] for stat in stats]
            middle_rs = [stat['recalls'][int(len(stat['recalls'])/2)] for stat in stats]
            average_precision = np.mean(middle_ps)
            average_recall = np.mean(middle_rs)

            if main_metrics == 'mAP':
                print("mAP: {:.6f}".format(mAP))
            else:
                print("acc: {:.6f}".format(acc))
            print("AUC: {:.6f}".format(mAUC))
            print("Avg Precision: {:.6f}".format(average_precision))
            print("Avg Recall: {:.6f}".format(average_recall))
            print("d_prime: {:.6f}".format(d_prime(mAUC)))
            print("train_loss: {:.6f}".format(loss_meter.avg))
            print("valid_loss: {:.6f}".format(valid_loss))
        '''
        
        if main_metrics == 'mAP':
            result[epoch-1, :] = [mAP, mAUC, average_precision, average_recall, d_prime(mAUC), loss_meter.avg, valid_loss, cum_mAP, cum_mAUC, optimizer.param_groups[0]['lr']]
        else:
            result[epoch-1, :] = [acc, mAUC, average_precision, average_recall, d_prime(mAUC), loss_meter.avg, valid_loss, cum_acc, cum_mAUC, optimizer.param_groups[0]['lr']]
        np.savetxt(exp_dir + '/result.csv', result, delimiter=',')
        print('validation finished')

        if mAP > best_mAP:
            best_mAP = mAP
            if main_metrics == 'mAP':
                best_epoch = epoch

        if acc > best_acc:
            best_acc = acc
            if main_metrics == 'acc':
                best_epoch = epoch

        if cum_mAP > best_cum_mAP:
            best_cum_epoch = epoch
            best_cum_mAP = cum_mAP

        if best_epoch == epoch:
            torch.save(audio_model.state_dict(), "%s/models/best_audio_model.pth" % (exp_dir))
            torch.save(optimizer.state_dict(), "%s/models/best_optim_state.pth" % (exp_dir))

        # save every models
        torch.save(audio_model.state_dict(), "%s/models/audio_model.%d.pth" % (exp_dir, epoch))
        if len(train_loader.dataset) > 2e5:
            torch.save(optimizer.state_dict(), "%s/models/optim_state.%d.pth" % (exp_dir, epoch))

        if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            print('adaptive learning rate scheduler step')
            scheduler.step(mAP)
        else:
            print('normal learning rate scheduler step')
            scheduler.step()

        print('Epoch-{0} lr: {1}'.format(epoch, optimizer.param_groups[0]['lr']))
        print('Epoch-{0} lr: {1}'.format(epoch, optimizer.param_groups[1]['lr']))

        with open(exp_dir + '/stats_' + str(epoch) +'.pickle', 'wb') as handle:
            pickle.dump(stats, handle, protocol=pickle.HIGHEST_PROTOCOL)
        _save_progress()

        finish_time = time.time()
        print('epoch {:d} training time: {:.3f}'.format(epoch, finish_time-begin_time))

        epoch += 1

        # break if lr too small
        # if optimizer.param_groups[0]['lr'] < args.lr/64 and epoch > 10:
        #     break

        batch_time.reset()
        per_sample_time.reset()
        data_time.reset()
        per_sample_data_time.reset()
        loss_meter.reset()
        per_sample_dnn_time.reset()
    '''
    if args.wa == True:
        stats = validate_wa(audio_model, test_loader, args, args.wa_start, args.wa_end)
        mAP = np.mean([stat['AP'] for stat in stats])
        mAUC = np.mean([stat['auc'] for stat in stats])
        middle_ps = [stat['precisions'][int(len(stat['precisions'])/2)] for stat in stats]
        middle_rs = [stat['recalls'][int(len(stat['recalls'])/2)] for stat in stats]
        average_precision = np.mean(middle_ps)
        average_recall = np.mean(middle_rs)
        wa_result = [mAP, mAUC, average_precision, average_recall, d_prime(mAUC)]
        print('---------------Training Finished---------------')
        print('weighted averaged models results')
        print("mAP: {:.6f}".format(mAP))
        print("AUC: {:.6f}".format(mAUC))
        print("Avg Precision: {:.6f}".format(average_precision))
        print("Avg Recall: {:.6f}".format(average_recall))
        print("d_prime: {:.6f}".format(d_prime(mAUC)))
        print("train_loss: {:.6f}".format(loss_meter.avg))
        print("valid_loss: {:.6f}".format(valid_loss))
        np.savetxt(exp_dir + '/wa_result.csv', wa_result)
    '''

def validate_ensemble(args, epoch):
    exp_dir = args.exp_dir
    target = np.loadtxt(exp_dir+'/predictions/target.csv', delimiter=',')
    if epoch == 1:
        cum_predictions = np.loadtxt(exp_dir + '/predictions/predictions_1.csv', delimiter=',')
    else:
        cum_predictions = np.loadtxt(exp_dir + '/predictions/cum_predictions.csv', delimiter=',') * (epoch - 1)
        predictions = np.loadtxt(exp_dir+'/predictions/predictions_' + str(epoch) + '.csv', delimiter=',')
        cum_predictions = cum_predictions + predictions
        # remove the prediction file to save storage space
        os.remove(exp_dir+'/predictions/predictions_' + str(epoch-1) + '.csv')

    cum_predictions = cum_predictions / epoch
    np.savetxt(exp_dir+'/predictions/cum_predictions.csv', cum_predictions, delimiter=',')

    stats = calculate_stats(cum_predictions, target)
    return stats


In [None]:
#set arguments for running pre-training/fine-tuning
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--data-train", type=str, default='train_ssast.json', help="training data json")
parser.add_argument("--data-val", type=str, default=None, help="validation data json")
parser.add_argument("--data-eval", type=str, default='test_ssast.json', help="evaluation data json")
parser.add_argument("--label-csv", type=str, default='./label_df.csv', help="csv with class labels")
parser.add_argument("--n_class", type=int, default=len(target_labels), help="number of classes")

parser.add_argument("--dataset", type=str, default='demo', help="the dataset used for training")
parser.add_argument("--dataset_mean", type=float, default= -4.2677393, help="the dataset mean, used for input normalization")
parser.add_argument("--dataset_std", type=float, default=4.5689974, help="the dataset std, used for input normalization")
parser.add_argument("--target_length", type=int, default=1024, help="the input length in frames")
parser.add_argument("--num_mel_bins", type=int, default=128, help="number of input mel bins")

parser.add_argument("--exp-dir", type=str, default="", help="directory to dump experiments")
parser.add_argument('--lr', '--learning-rate', default=0.0001, type=float, metavar='LR', help='initial learning rate')
parser.add_argument('--warmup', help='if use warmup learning rate scheduler', type=ast.literal_eval, default='True')
parser.add_argument("--optim", type=str, default="adam", help="training optimizer", choices=["sgd", "adam"])
parser.add_argument('-b', '--batch-size', default=8, type=int, metavar='N', help='mini-batch size')
parser.add_argument('-w', '--num-workers', default=8, type=int, metavar='NW', help='# of workers for dataloading (default: 32)')
parser.add_argument("--n-epochs", type=int, default=80, help="number of maximum training epochs")
# only used in pretraining stage or from-scratch fine-tuning experiments
parser.add_argument("--lr_patience", type=int, default=2, help="how many epoch to wait to reduce lr if mAP doesn't improve")
parser.add_argument('--adaptschedule', help='if use adaptive scheduler ', type=ast.literal_eval, default='False')

parser.add_argument("--n-print-steps", type=int, default=100, help="number of steps to print statistics")
parser.add_argument('--save_model', help='save the models or not', type=ast.literal_eval, default='True')

parser.add_argument('--freqm', help='frequency mask max length', type=int, default=0)
parser.add_argument('--timem', help='time mask max length', type=int, default=0)
parser.add_argument("--mixup", type=float, default=0, help="how many (0-1) samples need to be mixup during training")
parser.add_argument("--bal", type=str, default=None, help="use balanced sampling or not")
# the stride used in patch spliting, e.g., for patch size 16*16, a stride of 16 means no overlapping, a stride of 10 means overlap of 6.
# during self-supervised pretraining stage, no patch split overlapping is used (to aviod shortcuts), i.e., fstride=fshape and tstride=tshape
# during fine-tuning, using patch split overlapping (i.e., smaller {f,t}stride than {f,t}shape) improves the performance.
# it is OK to use different {f,t} stride in pretraining and finetuning stages (though fstride is better to keep the same)
# but {f,t}stride in pretraining and finetuning stages must be consistent.
parser.add_argument("--fstride", type=int,default=128,help="soft split freq stride, overlap=patch_size-stride")
parser.add_argument("--tstride", type=int,default=2, help="soft split time stride, overlap=patch_size-stride")
parser.add_argument("--fshape", type=int, defaut=128, help="shape of patch on the frequency dimension")
parser.add_argument("--tshape", type=int, default=2, help="shape of patch on the time dimension")
parser.add_argument('--model_size', help='the size of AST models', type=str, default='base')

parser.add_argument("--task", type=str, default='ft_cls', help="pretraining or fine-tuning task", choices=["ft_avgtok", "ft_cls", "pretrain_mpc", "pretrain_mpg", "pretrain_joint"])

# pretraining augments
#parser.add_argument('--pretrain_stage', help='True for self-supervised pretraining stage, False for fine-tuning stage', type=ast.literal_eval, default='False')
parser.add_argument('--mask_patch', help='how many patches to mask (used only for ssl pretraining)', type=int, default=400)
parser.add_argument("--cluster_factor", type=int, default=3, help="mask clutering factor")
parser.add_argument("--epoch_iter", type=int, default=2000, help="for pretraining, how many iterations to verify and save models")

# fine-tuning arguments
parser.add_argument("--pretrained_mdl_path", type=str, default='./pretrained_model/SSAST-Base-Frame-400.pth', help="the ssl pretrained models path")
parser.add_argument("--head_lr", type=int, default=1, help="the factor of mlp-head_lr/lr, used in some fine-tuning experiments only")
parser.add_argument("--noise", help='if augment noise in finetuning', type=ast.literal_eval, default='False')
parser.add_argument("--metrics", type=str, default="mAP", help="the main evaluation metrics in finetuning", choices=["mAP", "acc"])
parser.add_argument("--lrscheduler_start", default=10, type=int, help="when to start decay in finetuning")
parser.add_argument("--lrscheduler_step", default=5, type=int, help="the number of step to decrease the learning rate in finetuning")
parser.add_argument("--lrscheduler_decay", default=0.5, type=float, help="the learning rate decay ratio in finetuning")
parser.add_argument("--wa", help='if do weight averaging in finetuning', type=ast.literal_eval, default='False')
parser.add_argument("--wa_start", type=int, default=16, help="which epoch to start weight averaging in finetuning")
parser.add_argument("--wa_end", type=int, default=30, help="which epoch to end weight averaging in finetuning")
parser.add_argument("--loss", type=str, default="BCE", help="the loss function for finetuning, depend on the task", choices=["BCE", "CE"])

args = parser.parse_args()