In [30]:
# Description: This script trains an AST whose input has been modified to take audio insteasd of patches of images 
# Original code is based off a tutorial by Brian Pulfer
# https://medium.com/@brianpulfer/vision-transformers-from-scratch-pytorch-a-step-by-step-guide-96c3313c2e0c
# Andrei Cartera -- Mar 2025

import sys
import heapq
import math
import numpy as np
from pathlib import Path
import torch
import torch.nn.functional as nnF
import torchaudio
from AudioTransformer import AudioTransformer
from Inference_process import vectorize_f
from Inference_process import process_repcycles as get_repcycles; 

np.random.seed(0) 
torch.manual_seed(0)

print(torch.__version__)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

2.6.0+cu126
cuda:0


In [31]:
def crop(waveform, target_length=16000):
  if waveform.shape[1] < target_length:
    padding = target_length - waveform.shape[1]
    waveform = nnF.pad(waveform, (0, padding))
  else:
    waveform = waveform[:, :target_length]
  return waveform

def get_repcycles_wav(waveform, repcycles, vec_size=48):

  # [0,0,0,0,0,(2025.2, 2125.3),(),(),0,0,0,0,0,0]
  #2D Array of the repcycles with size of the columns being the max repcycle size
  out = torch.zeros(len(repcycles), vec_size)

  for i, cycle in enumerate(repcycles):
    if not cycle:
      continue
    repc_wav = vectorize_f(waveform[0, math.floor(cycle[0]):math.floor(cycle[1])], vec_size)
    out[i] = torch.tensor(repc_wav) 
  
  return out

In [36]:
# Hyperparameters
classes = ['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine']
NUM_CLASSES = 10
N_SEGMENTS = 32
REPC_VEC_SIZE = 48

EPOCHS = 100
N_HEADS = 8
N_ENCODERS = 4
BATCH_SIZE = 64
HIDDEN_DIM = 32
DROPOUT = 0.15
ACTIVATION="gelu"
LR = 0.0009
CONTROL = False

MODEL_PATH = f'models/N_ATmodel_32SEG_48VEC_E100_8_4_B64_H32_3.pth'
  
model = AudioTransformer(N_SEGMENTS, REPC_VEC_SIZE, N_ENCODERS, HIDDEN_DIM, N_HEADS, NUM_CLASSES).to(device)
model.load_state_dict(torch.load(MODEL_PATH, weights_only=True,map_location=torch.device(device)))
print(f"Model loaded from {MODEL_PATH}")

model.eval()  # Set the model to evaluation mode

audio_path = Path('audio/inference/2_2.wav')

#input :
#0 0 0 0 0 ...
#0 0 0 0 0 ...
#0.1 0.2 0.5 0.2 ...
#0 0 0 0 0 ...
#0 0 0 0 0 ...

waveform_t , _ = torchaudio.load(audio_path) # load the audio file, returns a tensor of shape (1, T) where T is the number of samples
repcycles = get_repcycles(waveform_t)
input = get_repcycles_wav(waveform_t, repcycles, vec_size=REPC_VEC_SIZE).to(device) # get the repcycles for the waveform and vectorize them

# Add a batch dimension to the input tensor
input = input.unsqueeze(0)  # Shape: (1, sequence_length, feature_size)

with torch.no_grad():
  output = model(input)

np_output = (output.cpu()).numpy()[0]

for i, value in enumerate(np_output):
  print("%s:\t%.2f" % (classes[i], value))



Model loaded from models/N_ATmodel_32SEG_48VEC_E100_8_4_B64_H32_3.pth
zero:	0.00
one:	0.00
two:	1.00
three:	0.00
four:	0.00
five:	0.00
six:	0.00
seven:	0.00
eight:	0.00
nine:	0.00
