# SSAST for speech
All functions copied from this repo and then edited as indicated

https://github.com/YuanGongND/ssast
Additional authors: Matt, Daniela

Original ASTModel class was split into two - one for pretraining and one for finetuning

To begin, you will need access to google cloud storage bucket and the following packages must be installed on your system 

* opencv-python
* albumentations (may run into issues in AIF)
* librosa
* torch, torchvision, torchaudio

(can ignore the following if using AIF)
* google-cloud
* google-cloud-storage
* google-cloud-bigquery

If working on a local computer, you can run the following commands to gain access to the google storage bucket

```gcloud auth application-default login```
```gcloud auth application-defaul set-quota-project PROJECT_NAME```

In [39]:
# Import
import pandas as pd
import io
import numpy as np
import sys
import json
import torch
from google.cloud import storage, bigquery

In [40]:
import os
os.getcwd()

'/home/jupyter/ssast_mayo/ssast/src'

In [41]:
from models.ast_models import ASTModel_pretrain, ASTModel_finetune
from dataloader_gcs import AudioDataset

In [42]:
from utilities.ssast_utils import *
from utilities.speech_utils import *

In [43]:
# First, load data from google storage bucket

project_name = 'ml-mps-aif-afdgpet01-p-6827'
study = 'speech_poc_freeze_1'
bucket_name = 'ml-e107-phi-shared-aif-us-p'
gcs_prefix = f'speech_ai/speech_lake/{study}'

storage_client = storage.Client(project=project_name)
bq_client = bigquery.Client(project=project_name)
bucket = storage_client.bucket(bucket_name)

file_list=[]
for blob in storage_client.list_blobs(bucket_name, prefix='speech_ai/speech_lake/speech_poc_freeze_1'):
    file_list.append(blob.name)

    extensions=[f.split('.')[-1] for f in file_list]

data_split_root = 'gs://ml-e107-phi-shared-aif-us-p/speech_ai/share/data_splits/amr_subject_dedup_594_train_100_test_binarized_v20220620'
gcs_train_path = f'{data_split_root}/train.csv'
gcs_test_path = f'{data_split_root}/test.csv'

In [44]:

# (1) load the train and test files to a df
train_df = pd.read_csv(gcs_train_path, index_col = 'uid')
test_df = pd.read_csv(gcs_test_path, index_col = 'uid')

# (2) alter columns as necessary 
train_df["distortions"]=((train_df["distorted Cs"]+train_df["distorted V"])>0).astype(int)
test_df["distortions"]=((test_df["distorted Cs"]+test_df["distorted V"])>0).astype(int)

# (3) define target labels
target_labels=['breathy',
             'loudness decay',
             'slow rate',
             'high pitch',
             'hoarse / harsh',
             'irregular artic breakdowns',
             'rapid rate',
             'reduced OA loudness',
             'abn pitch variability',
             'strained',
             'hypernasal',
             'abn loudness variability',
              'distortions']

# (4) select only the target labels from train and test df
train_df=train_df[target_labels]
test_df=train_df[target_labels]

# (5) prep the data
prep_ssast_data(train_df,target_labels,'train_ssast',create_label_csv=True)
prep_ssast_data(test_df,target_labels,'test_ssast')

In [31]:
train_df.head()

Unnamed: 0_level_0,breathy,loudness decay,slow rate,high pitch,hoarse / harsh,irregular artic breakdowns,rapid rate,reduced OA loudness,abn pitch variability,strained,hypernasal,abn loudness variability,distortions
uid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
4d28f730-5814-48e1-bc29-3c0bf562e2fb,0,0,1,0,1,0,0,0,0,1,0,0,0
1e2dedd0-4f93-42ee-b0fb-c77fb7ba4cf4,0,0,1,0,0,0,0,0,1,0,0,1,0
f31c13e4-9f49-411e-b59f-f692244fb740,1,1,0,0,1,0,0,0,0,1,1,0,1
d917de91-c421-40bf-9d75-a0b5b0736c5b,0,0,0,0,0,0,1,1,0,0,0,0,1
9c4a9e77-3080-4591-8797-d712e42d6ed6,1,0,0,0,1,0,0,0,0,1,0,0,1


## Run SSAST
additional imports to support running SSAST

In [45]:
import argparse
import os
import ast
import pickle
import sys
import time
import torch
from torch.utils.data import WeightedRandomSampler
basepath = os.path.dirname(os.path.dirname(sys.path[0]))

In [14]:
basepath

'/home/jupyter/ssast_mayo'

In [46]:
sys.path.append(basepath)
import dataloader_gcs as dataloader
from models.ast_models import ASTModel_pretrain, ASTModel_finetune
import numpy as np
from traintest import train, validate
from traintest_mask import trainmask

Set arguments for running SSAST
#TODO: label csv

In [47]:
#set arguments for running pre-training/fine-tuning
parser = argparse.ArgumentParser()
parser.add_argument("--data-train", type=str, default='train_ssast.json', help="training data json")
parser.add_argument("--data-val", type=str, default='test_ssast.json', help="validation data json")
parser.add_argument("--data-eval", type=str, default=None, help="evaluation data json")
parser.add_argument("--label-csv", type=str, default='label_df.csv', help="csv with class labels")
parser.add_argument("--n_class", type=int, default=len(target_labels), help="number of classes")

parser.add_argument("--dataset", type=str, default='demo', help="the dataset used for training")
parser.add_argument("--dataset_mean", type=float, default= -4.2677393, help="the dataset mean, used for input normalization")
parser.add_argument("--dataset_std", type=float, default=4.5689974, help="the dataset std, used for input normalization")
parser.add_argument("--target_length", type=int, default=1024, help="the input length in frames")
parser.add_argument("--num_mel_bins", type=int, default=128, help="number of input mel bins")

parser.add_argument("--exp-dir", type=str, default="experiments", help="directory to dump experiments")
parser.add_argument('--lr', '--learning-rate', default=0.0001, type=float, metavar='LR', help='initial learning rate')
parser.add_argument('--warmup', help='if use warmup learning rate scheduler', type=ast.literal_eval, default='True')
parser.add_argument("--optim", type=str, default="adam", help="training optimizer", choices=["sgd", "adam"])
parser.add_argument('-b', '--batch-size', default=8, type=int, metavar='N', help='mini-batch size')
parser.add_argument('-w', '--num-workers', default=8, type=int, metavar='NW', help='# of workers for dataloading (default: 32)')
parser.add_argument("--n-epochs", type=int, default=80, help="number of maximum training epochs")
# only used in pretraining stage or from-scratch fine-tuning experiments
parser.add_argument("--lr_patience", type=int, default=2, help="how many epoch to wait to reduce lr if mAP doesn't improve")
parser.add_argument('--adaptschedule', help='if use adaptive scheduler ', type=ast.literal_eval, default='False')

parser.add_argument("--n-print-steps", type=int, default=100, help="number of steps to print statistics")
parser.add_argument('--save_model', help='save the models or not', type=ast.literal_eval, default='True')

parser.add_argument('--freqm', help='frequency mask max length', type=int, default=0)
parser.add_argument('--timem', help='time mask max length', type=int, default=0)
parser.add_argument("--mixup", type=float, default=0, help="how many (0-1) samples need to be mixup during training")
parser.add_argument("--bal", type=str, default=None, help="use balanced sampling or not")
# the stride used in patch spliting, e.g., for patch size 16*16, a stride of 16 means no overlapping, a stride of 10 means overlap of 6.
# during self-supervised pretraining stage, no patch split overlapping is used (to aviod shortcuts), i.e., fstride=fshape and tstride=tshape
# during fine-tuning, using patch split overlapping (i.e., smaller {f,t}stride than {f,t}shape) improves the performance.
# it is OK to use different {f,t} stride in pretraining and finetuning stages (though fstride is better to keep the same)
# but {f,t}stride in pretraining and finetuning stages must be consistent.
parser.add_argument("--fstride", type=int,default=128,help="soft split freq stride, overlap=patch_size-stride")
parser.add_argument("--tstride", type=int,default=2, help="soft split time stride, overlap=patch_size-stride")
parser.add_argument("--fshape", type=int, default=128, help="shape of patch on the frequency dimension")
parser.add_argument("--tshape", type=int, default=2, help="shape of patch on the time dimension")
parser.add_argument('--model_size', help='the size of AST models', type=str, default='base')

parser.add_argument("--task", type=str, default='ft_cls', help="pretraining or fine-tuning task", choices=["ft_avgtok", "ft_cls", "pretrain_mpc", "pretrain_mpg", "pretrain_joint"])

# pretraining augments
#parser.add_argument('--pretrain_stage', help='True for self-supervised pretraining stage, False for fine-tuning stage', type=ast.literal_eval, default='False')
parser.add_argument('--mask_patch', help='how many patches to mask (used only for ssl pretraining)', type=int, default=400)
parser.add_argument("--cluster_factor", type=int, default=3, help="mask clutering factor")
parser.add_argument("--epoch_iter", type=int, default=2000, help="for pretraining, how many iterations to verify and save models")

# fine-tuning arguments
parser.add_argument("--pretrained_mdl_path", type=str, default='SSAST-Base-Frame-400.pth', help="the ssl pretrained models path")
parser.add_argument("--head_lr", type=int, default=1, help="the factor of mlp-head_lr/lr, used in some fine-tuning experiments only")
parser.add_argument("--noise", help='if augment noise in finetuning', type=ast.literal_eval, default='False')
parser.add_argument("--metrics", type=str, default="mAP", help="the main evaluation metrics in finetuning", choices=["mAP", "acc"])
parser.add_argument("--lrscheduler_start", default=10, type=int, help="when to start decay in finetuning")
parser.add_argument("--lrscheduler_step", default=5, type=int, help="the number of step to decrease the learning rate in finetuning")
parser.add_argument("--lrscheduler_decay", default=0.5, type=float, help="the learning rate decay ratio in finetuning")
parser.add_argument("--wa", help='if do weight averaging in finetuning', type=ast.literal_eval, default='False')
parser.add_argument("--wa_start", type=int, default=16, help="which epoch to start weight averaging in finetuning")
parser.add_argument("--wa_end", type=int, default=30, help="which epoch to end weight averaging in finetuning")
parser.add_argument("--loss", type=str, default="BCE", help="the loss function for finetuning, depend on the task", choices=["BCE", "CE"])

parser.add_argument("-f", "--fff", help="a dummy argument to fool ipython", default="1")

args = parser.parse_args()

Further data prep

In [48]:
# # dataset spectrogram mean and std, used to normalize the input
# norm_stats = {'librispeech':[-4.2677393, 4.5689974], 'howto100m':[-4.2677393, 4.5689974], 'audioset':[-4.2677393, 4.5689974], 'esc50':[-6.6268077, 5.358466], 'speechcommands':[-6.845978, 5.5654526]}
# target_length = {'librispeech': 1024, 'howto100m':1024, 'audioset':1024, 'esc50':512, 'speechcommands':128}
# # if add noise for data augmentation, only use for speech commands
# noise = {'librispeech': False, 'howto100m': False, 'audioset': False, 'esc50': False, 'speechcommands':True}

audio_conf = {'num_mel_bins': args.num_mel_bins, 'target_length': args.target_length, 'freqm': args.freqm, 'timem': args.timem, 'mixup': args.mixup, 'dataset': args.dataset,
              'mode':'train', 'mean':args.dataset_mean, 'std':args.dataset_std, 'noise':args.noise}

val_audio_conf = {'num_mel_bins': args.num_mel_bins, 'target_length': args.target_length, 'freqm': 0, 'timem': 0, 'mixup': 0, 'dataset': args.dataset, 'mode': 'evaluation', 'mean': args.dataset_mean, 'std': args.dataset_std, 'noise': False}
  

In [49]:
train_dataset = dataloader.AudioDataset(args.data_train, label_csv=args.label_csv, audio_conf=audio_conf, bucket=bucket, gcs_prefix=gcs_prefix)

---------------the train dataloader---------------
now using following mask: 0 freq, 0 time
now using mix-up with rate 0.000000
now process demo
use dataset mean -4.268 and std 4.569 to normalize the input.
number of classes is 289


RuntimeError: Failed to load audio from <_io.BytesIO object at 0x7fa42ad2a950>

In [37]:
# if use balanced sampling, note - self-supervised pretraining should not use balance sampling as it implicitly leverages the label information.
if args.bal == 'bal':
    print('balanced sampler is being used')
    samples_weight = np.loadtxt(args.data_train[:-5]+'_weight.csv', delimiter=',')
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight), replacement=True)


    train_loader = torch.utils.data.DataLoader(
        dataloader.AudioDataset(args.data_train, label_csv=args.label_csv, audio_conf=audio_conf, bucket=bucket, gcs_prefix=gcs_prefix),
        batch_size=args.batch_size, sampler=sampler, num_workers=args.num_workers, pin_memory=False, drop_last=True)
else:
    print('balanced sampler is not used')
    train_loader = torch.utils.data.DataLoader(
        dataloader.AudioDataset(args.data_train, label_csv=args.label_csv, audio_conf=audio_conf, bucket=bucket, gcs_prefix=gcs_prefix),
        batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=False, drop_last=True)

val_loader = torch.utils.data.DataLoader(
    dataloader.AudioDataset(args.data_val, label_csv=args.label_csv, audio_conf=val_audio_conf,bucket=bucket, gcs_prefix=gcs_prefix),
    batch_size=args.batch_size * 2, shuffle=False, num_workers=args.num_workers, pin_memory=False)

balanced sampler is not used
---------------the train dataloader---------------
now using following mask: 0 freq, 0 time
now using mix-up with rate 0.000000
now process demo
use dataset mean -4.268 and std 4.569 to normalize the input.
number of classes is 289
---------------the evaluation dataloader---------------
now using following mask: 0 freq, 0 time
now using mix-up with rate 0.000000
now process demo
use dataset mean -4.268 and std 4.569 to normalize the input.
number of classes is 289


In [38]:
batch = next(iter(train_loader))

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/jupyter/.local/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/jupyter/.local/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/jupyter/.local/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/jupyter/ssast_mayo/ssast/src/dataloader_gcs.py", line 435, in __getitem__
    fbank, mix_lambda = self._wav2fbank(datum['wav'])
  File "/home/jupyter/ssast_mayo/ssast/src/dataloader_gcs.py", line 315, in _wav2fbank
    waveform, metadata = load_waveform_from_gcs(self.bucket,self.gcs_prefix,filename)
  File "/home/jupyter/ssast_mayo/ssast/src/dataloader_gcs.py", line 205, in load_waveform_from_gcs
    waveform, _ = torchaudio.load(wave_bytes, format = extension)
  File "/opt/conda/lib/python3.7/site-packages/torchaudio/backend/sox_io_backend.py", line 214, in load
    return _fallback_load_fileobj(filepath, frame_offset, num_frames, normalize, channels_first, format)
  File "/opt/conda/lib/python3.7/site-packages/torchaudio/backend/sox_io_backend.py", line 33, in _fail_load_fileobj
    raise RuntimeError(f"Failed to load audio from {fileobj}")
RuntimeError: Failed to load audio from <_io.BytesIO object at 0x7fa42994e8f0>


Initalize model

In [24]:
print('Now train with {:s} with {:d} training samples, evaluate with {:d} samples'.format(args.dataset, len(train_loader.dataset), len(val_loader.dataset)))

# in the pretraining stage
if 'pretrain' in args.task:
    cluster = (args.num_mel_bins != args.fshape)
    if cluster == True:
        print('The num_mel_bins {:d} and fshape {:d} are different, not masking a typical time frame, using cluster masking.'.format(args.num_mel_bins, args.fshape))
    else:
        print('The num_mel_bins {:d} and fshape {:d} are same, masking a typical time frame, not using cluster masking.'.format(args.num_mel_bins, args.fshape))
    # no label dimension needed as it is self-supervised, fshape=fstride and tshape=tstride
    audio_model = ASTModel_pretrain(fshape=args.fshape, tshape=args.tshape, fstride=args.fshape, tstride=args.tshape,
                       input_fdim=args.num_mel_bins, input_tdim=args.target_length, model_size=args.model_size, load_pretrained_mdl_path=args.pretrained_mdl_pth)
# in the fine-tuning stage
else:
    audio_model = ASTModel_finetune(task = args.task, label_dim=args.n_class, fshape=args.fshape, tshape=args.tshape, fstride=args.fstride, tstride=args.tstride,
                       input_fdim=args.num_mel_bins, input_tdim=args.target_length, model_size=args.model_size,
                       load_pretrained_mdl_path=args.pretrained_mdl_path)

if not isinstance(audio_model, torch.nn.DataParallel):
    audio_model = torch.nn.DataParallel(audio_model)


Now train with demo with 594 training samples, evaluate with 594 samples
now load a SSL pretrained models from SSAST-Base-Frame-400.pth
pretraining patch split stride: frequency=128, time=2
pretraining patch shape: frequency=128, time=2
pretraining patch array dimension: frequency=1, time=512
pretraining number of patches=512
fine-tuning patch split stride: frequncey=128, time=2
fine-tuning number of patches=512


Run trainings

In [28]:
print("\nCreating experiment directory: %s" % args.exp_dir)
if os.path.exists("%s/models" % args.exp_dir) == False:
    os.makedirs("%s/models" % args.exp_dir)
with open("%s/args.pkl" % args.exp_dir, "wb") as f:
    pickle.dump(args, f)


Creating experiment directory: experiments


In [32]:
batch = next(iter(train_loader))

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/jupyter/.local/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/jupyter/.local/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/jupyter/.local/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/jupyter/ssast_mayo/ssast/src/dataloader.py", line 177, in __getitem__
    fbank, mix_lambda = self._wav2fbank(datum['wav'])
  File "/home/jupyter/ssast_mayo/ssast/src/dataloader.py", line 98, in _wav2fbank
    waveform, sr = torchaudio.load(filename)
  File "/opt/conda/lib/python3.7/site-packages/torchaudio/backend/sox_io_backend.py", line 227, in load
    return _fallback_load(filepath, frame_offset, num_frames, normalize, channels_first, format)
  File "/opt/conda/lib/python3.7/site-packages/torchaudio/backend/sox_io_backend.py", line 29, in _fail_load
    raise RuntimeError("Failed to load audio from {}".format(filepath))
RuntimeError: Failed to load audio from ce8c1402-8f50-4580-92ed-48e67b0fa756


In [29]:
if 'pretrain' not in args.task:
    print('Now starting fine-tuning for {:d} epochs'.format(args.n_epochs))
    train(audio_model, train_loader, val_loader, args)
else:
    print('Now starting self-supervised pretraining for {:d} epochs'.format(args.n_epochs))
    trainmask(audio_model, train_loader, val_loader, args)

Now starting fine-tuning for 80 epochs
running on cpu
Total parameter number is : 87.199 million
Total trainable parameter number is : 87.199 million
The mlp header uses 1 x larger lr
Total mlp parameter number is : 0.012 million
Total base parameter number is : 87.188 million
now training with demo, main metrics: mAP, loss function: BCEWithLogitsLoss(), learning rate scheduler: <torch.optim.lr_scheduler.MultiStepLR object at 0x7fa42aeb2750>
The learning rate scheduler starts at 10 epoch with decay rate of 0.500 every 5 epoches
current #steps=0, #epochs=1
start training...
---------------
2023-04-20 15:34:40.012710
current #epochs=1, #steps=0


RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/jupyter/.local/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/jupyter/.local/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/jupyter/.local/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/jupyter/ssast_mayo/ssast/src/dataloader.py", line 177, in __getitem__
    fbank, mix_lambda = self._wav2fbank(datum['wav'])
  File "/home/jupyter/ssast_mayo/ssast/src/dataloader.py", line 98, in _wav2fbank
    waveform, sr = torchaudio.load(filename)
  File "/opt/conda/lib/python3.7/site-packages/torchaudio/backend/sox_io_backend.py", line 227, in load
    return _fallback_load(filepath, frame_offset, num_frames, normalize, channels_first, format)
  File "/opt/conda/lib/python3.7/site-packages/torchaudio/backend/sox_io_backend.py", line 29, in _fail_load
    raise RuntimeError("Failed to load audio from {}".format(filepath))
RuntimeError: Failed to load audio from 6822ad12-e1ae-4e3c-ba66-1ffe43401244


If fine-tuning, evaluate

In [None]:
# if the dataset has a seperate evaluation set (e.g., speechcommands), then select the model using the validation set and eval on the evaluation set.
# this is only for fine-tuning
if args.data_eval != None:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    sd = torch.load(args.exp_dir + '/models/best_audio_model.pth', map_location=device)
    if not isinstance(audio_model, torch.nn.DataParallel):
        audio_model = torch.nn.DataParallel(audio_model)
    audio_model.load_state_dict(sd, strict=False)

    # best models on the validation set
    args.loss_fn = torch.nn.BCEWithLogitsLoss()
    stats, _ = validate(audio_model, val_loader, args, 'valid_set')
    # note it is NOT mean of class-wise accuracy
    val_acc = stats[0]['acc']
    val_mAUC = np.mean([stat['auc'] for stat in stats])
    print('---------------evaluate on the validation set---------------')
    print("Accuracy: {:.6f}".format(val_acc))
    print("AUC: {:.6f}".format(val_mAUC))

    # test the models on the evaluation set
    eval_loader = torch.utils.data.DataLoader(
        dataloader.AudioDataset(args.data_eval, label_csv=args.label_csv, audio_conf=val_audio_conf),
        batch_size=args.batch_size*2, shuffle=False, num_workers=args.num_workers, pin_memory=True)
    stats, _ = validate(audio_model, eval_loader, args, 'eval_set')
    eval_acc = stats[0]['acc']
    eval_mAUC = np.mean([stat['auc'] for stat in stats])
    print('---------------evaluate on the test set---------------')
    print("Accuracy: {:.6f}".format(eval_acc))
    print("AUC: {:.6f}".format(eval_mAUC))
    np.savetxt(args.exp_dir + '/eval_result.csv', [val_acc, val_mAUC, eval_acc, eval_mAUC])


## Embeddings
Once you have a model pre-trained to your liking (fine-tuned on your data) you can extract embeddings

First prepare the dataframe you would like to get embeddings for

In [None]:
annotations_df = pd.read_csv(
    'gs://ml-e107-phi-shared-aif-us-p/speech_ai/share/data_splits/r01_prelim_161_amrs/test.csv', 
    index_col = 'uid'
)

Then prepare the bucket where the data is stored

In [None]:
study = 'r01_prelim'
bucket_name = 'ml-e107-phi-shared-aif-us-p'
gcs_prefix = f'speech_ai/speech_lake/{study}'

storage_client = storage.Client(project=project_name)
bucket = storage_client.bucket(bucket_name = bucket_name)

Get embeddings

In [None]:
#TODO get model_name
embeddings = get_ssast_embeddings(model_name, args, bucket, gcs_prefix)