
# WavLM Feature Extractor

##### https://github.com/microsoft/unilm/tree/master/wavlm

In [1]:
%run WavLM

In [2]:
import pandas as pd
import numpy as np

In [3]:
import torch
from WavLM import WavLM, WavLMConfig

device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(torch.cuda.get_device_name(0))

# load the pre-trained checkpoints
#checkpoint = torch.load('model/WavLM-Base+.pt')
checkpoint = torch.load('model/WavLM-Large.pt')
cfg = WavLMConfig(checkpoint['cfg'])
model = WavLM(cfg)
#model = model.to(device) #, dtype=torch.float32)
model.load_state_dict(checkpoint['model'])
model.eval()

Using device: cuda:2
NVIDIA A100-SXM4-40GB




WavLM(
  (feature_extractor): ConvFeatureExtractionModel(
    (conv_layers): ModuleList(
      (0): Sequential(
        (0): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
        (1): Dropout(p=0.0, inplace=False)
        (2): Sequential(
          (0): TransposeLast()
          (1): Fp32LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (2): TransposeLast()
        )
        (3): GELU(approximate='none')
      )
      (1-4): 4 x Sequential(
        (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (1): Dropout(p=0.0, inplace=False)
        (2): Sequential(
          (0): TransposeLast()
          (1): Fp32LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (2): TransposeLast()
        )
        (3): GELU(approximate='none')
      )
      (5-6): 2 x Sequential(
        (0): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
        (1): Dropout(p=0.0, inplace=False)
        (2): Sequential(
          (0): TransposeLa

### Fine-Tuning Layers

In [4]:
# Path to source files
path2 = 'GTzan_16k_Wav'

In [5]:
import os

# Read a directory and put all files in a list
file_list = []
i = 0 
for path, subdirs, files in os.walk( path2 ):
    for name in files:
        file_list.append( os.path.join( path, name) )
        i += 1
print("Files processed: "+str(i) )

Files processed: 1000


In [6]:
import soundfile as sf

# sample rate = 16,000
#  1s = 16,000 x 1 =  16,000
# 60s = 16,000 x 60 = 960,500

length = list()
i      = 0
avg    = 0

for file in file_list:
    data, samplerate = sf.read( file )
    
    if len(data) <= 960500:                         
            print("Audio length: "+str(len(data))+" with less than 30s: "+str(file) )
    #computer average lenght of files
    avg = avg + len(data)
    length.append(len(data))
    i += 1

print( "Files processed: "+str(i) )
print( "Average file length: "+str(avg/i) + " samples   "+str(avg/i/samplerate)+" s   "+str(avg/i/samplerate/60)+" min" )
print( "Max length: "+str(max(length))+ " samples   "+str(max(length)/samplerate)+" s   "+str(max(length)/samplerate/60)+" min" )
print( "Min length: "+str(min(length))+ " samples   "+str(min(length)/samplerate)+" s   "+str(min(length)/samplerate/60)+" min" )

Audio length: 480214 with less than 30s: GTzan_16k_Wav/country/country.00060.au.wav
Audio length: 480214 with less than 30s: GTzan_16k_Wav/country/country.00020.au.wav
Audio length: 480214 with less than 30s: GTzan_16k_Wav/country/country.00076.au.wav
Audio length: 480214 with less than 30s: GTzan_16k_Wav/country/country.00036.au.wav
Audio length: 480214 with less than 30s: GTzan_16k_Wav/country/country.00081.au.wav
Audio length: 480189 with less than 30s: GTzan_16k_Wav/country/country.00001.au.wav
Audio length: 480214 with less than 30s: GTzan_16k_Wav/country/country.00041.au.wav
Audio length: 480214 with less than 30s: GTzan_16k_Wav/country/country.00017.au.wav
Audio length: 480214 with less than 30s: GTzan_16k_Wav/country/country.00057.au.wav
Audio length: 480214 with less than 30s: GTzan_16k_Wav/country/country.00097.au.wav
Audio length: 480214 with less than 30s: GTzan_16k_Wav/country/country.00022.au.wav
Audio length: 480214 with less than 30s: GTzan_16k_Wav/country/country.00062

In [8]:
sampling_rate = 16000
track_count = 0

for file in file_list:
    data , samplerate = sf.read( file )
    print ("--------------")
    print ("Sample Rate: " + str(samplerate) + " Length: " + str(data.shape) + " " + str( file ) )

    # extract the representation of last layer
    wav_input = torch.from_numpy(data).float()
    wav_input_16khz = torch.unsqueeze(wav_input,0)
    # wav_input_16khz = torch.randn(1,48000)
    # wav_input_16khz = torch.randn(1,2786987)
    if cfg.normalize:
        wav_input_16khz = torch.nn.functional.layer_norm(wav_input_16khz , wav_input_16khz.shape)

    # last layer
    #rep = model.extract_features(wav_input_16khz.to(device, dtype=torch.float32))[0]
    rep = model.extract_features(wav_input_16khz)[0]

    # intermediate layers
    #rep, layer_results = model.extract_features(wav_input_16khz, output_layer=model.cfg.encoder_layers, ret_layer_results=True)[0]
    # Layers 0 to 12 layer_reps[0] .... layer_reps[12] 
    #layer_reps = [x.transpose(0, 1) for x, _ in layer_results]

    # rep = layer_reps[11] # 12th layer
    # rep = layer_reps[7] # 8th layer
    # rep = layer_reps[3] # 4th layer

    rep[0].shape
    rep_np = rep[0].detach().numpy()
    rep_df = pd.DataFrame(rep_np)

    # file_id = 'features/'+file.split('/')[2].split('.')[0]+'.wavlmbasefeat'  

    file_id = 'features/'+file.split('/')[2].split('.')[0]+'.'+file.split('/')[2].split('.')[1]+'.wavlmlargefeat'
    
    rep_df.to_csv(file_id)
    
    # Pooling 
    #df_temp = pd.DataFrame(rep[0].detach().numpy())
    #df_pool = df_temp.rolling(2, step=2).mean().drop(index=0) 
    
    #file_id = 'Recola2018_16k/features/'+file.split('/')[3].split('.')[0]+'.wavlmbasefeatpool'  
    #df_pool.to_csv(file_id)
    
    print( file, " ", str(track_count) )
    
    track_count += 1

--------------
Sample Rate: 16000 Length: (480214,) GTzan_16k_Wav/country/country.00060.au.wav
GTzan_16k_Wav/country/country.00060.au.wav   0
--------------
Sample Rate: 16000 Length: (480214,) GTzan_16k_Wav/country/country.00020.au.wav
GTzan_16k_Wav/country/country.00020.au.wav   1
--------------
Sample Rate: 16000 Length: (480214,) GTzan_16k_Wav/country/country.00076.au.wav
GTzan_16k_Wav/country/country.00076.au.wav   2
--------------
Sample Rate: 16000 Length: (480214,) GTzan_16k_Wav/country/country.00036.au.wav
GTzan_16k_Wav/country/country.00036.au.wav   3
--------------
Sample Rate: 16000 Length: (480214,) GTzan_16k_Wav/country/country.00081.au.wav
GTzan_16k_Wav/country/country.00081.au.wav   4
--------------
Sample Rate: 16000 Length: (480189,) GTzan_16k_Wav/country/country.00001.au.wav
GTzan_16k_Wav/country/country.00001.au.wav   5
--------------
Sample Rate: 16000 Length: (480214,) GTzan_16k_Wav/country/country.00041.au.wav
GTzan_16k_Wav/country/country.00041.au.wav   6
------

In [9]:
file_id

'features/reggae.00033.wavlmlargefeat'