# Write Chunked wav2vec encodings to disk

In [1]:
import os
import glob
import torch
import torch.nn as nn
F = nn.functional
import pandas as pd
import torchaudio
from fairseq.models.wav2vec import Wav2VecModel
from sklearn.preprocessing import OneHotEncoder
from torch.utils.data import DataLoader, Dataset
import gc

In [2]:
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
# torch.cuda.current_device()

In [3]:
dev = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

In [4]:
cp = torch.load('/home/karmanya/wav2vec_large.pt');
model = Wav2VecModel.build_model(cp['args'], task=None);
model.load_state_dict(cp['model']);
model.to(dev);

Wav2VecModel(
  (feature_extractor): ConvFeatureExtractionModel(
    (conv_layers): ModuleList(
      (0): Sequential(
        (0): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
        (1): Dropout(p=0.0, inplace=False)
        (2): Fp32GroupNorm(1, 512, eps=1e-05, affine=True)
        (3): ReLU()
      )
      (1): Sequential(
        (0): Conv1d(512, 512, kernel_size=(8,), stride=(4,), bias=False)
        (1): Dropout(p=0.0, inplace=False)
        (2): Fp32GroupNorm(1, 512, eps=1e-05, affine=True)
        (3): ReLU()
      )
      (2): Sequential(
        (0): Conv1d(512, 512, kernel_size=(4,), stride=(2,), bias=False)
        (1): Dropout(p=0.0, inplace=False)
        (2): Fp32GroupNorm(1, 512, eps=1e-05, affine=True)
        (3): ReLU()
      )
      (3): Sequential(
        (0): Conv1d(512, 512, kernel_size=(4,), stride=(2,), bias=False)
        (1): Dropout(p=0.0, inplace=False)
        (2): Fp32GroupNorm(1, 512, eps=1e-05, affine=True)
        (3): ReLU()
      )
 

In [5]:
data_files_path = '/media/nas_mount/Sarthak/ijcai_acl/prompt_wise_16k_audios/prompt3/'
data_files_list = glob.glob(data_files_path+'*.wav')
labels_path = '/media/nas_mount/Sarthak/ques_wise_models/p3q1/labels.csv'
test_path = '/media/nas_mount/Sarthak/ques_wise_models/p3q1/test.csv'

In [6]:
class AudioDataset(Dataset):
    def __init__(self, df, audio_dir, file_col):
        '''
        Passed Dataset a dataframe of the filenames to train/validate on
        Pass one hot encoded numpy array for labels
        Pass a string for the directory of audio files
        '''
        self.items = df
        self.items['path'] = self.items[file_col].apply(lambda x: x.split('.jpeg')[0]) # Label files are *.jpeg
        self.items['path'] = self.items['path'].apply(lambda x: os.path.join(audio_dir, f'{x}.wav'))
        self.audio_dir = audio_dir
    def __len__(self):
        return len(self.items)
    
    def __getitem__(self, idx):
        '''
        Load the audio waveform of idx from dataframe
        returns a touple of torch tensor and label
        '''
        file = self.items['path'].iloc[idx]
        audio, _ = torchaudio.load(file)
        samples = audio.shape[1]
        # Trim and pad to 60 seconds
        if audio.shape[0] > 1:
            audio = torch.mean(audio, axis=0).unsqueeze(dim=0)
        if samples < 60*16000:
            p1d = (60*16000 - samples, 0)
            audio = F.pad(audio, p1d, "constant", 0)
            print('pad')
        elif samples > 60*16000 :
            print('trim')
            audio = torch.narrow(audio, 1, 0, 60*16000)
        return audio, torch.tensor([idx])

In [8]:
train_dataset = AudioDataset(pd.read_csv(labels_path), data_files_path, 'name')

In [9]:
train_dl = DataLoader(train_dataset, batch_size=8, num_workers=10, shuffle=False)

In [12]:
train_dataset[0][0].shape

torch.Size([1, 960000])

In [22]:
train_dataset[610][0].shape

torch.Size([2, 960000])

In [10]:
from pathlib import Path
path = Path('/media/nas_mount/Karmanya/wav2vec_chunked')

In [11]:
for xb, yb in train_dl:
    xb = xb.view(-1, 30*16000).to(dev)
    model.eval()
    with torch.no_grad():
        xb = model.feature_extractor(xb)
        xb = model.feature_aggregator(xb)
        torch.save(xb, path/f'{yb[0].item()}_{yb[-1].item()}.pt')
        print(xb.shape)
        del(xb)
        gc.collect()

torch.Size([16, 512, 2998])
torch.Size([16, 512, 2998])
torch.Size([16, 512, 2998])
torch.Size([16, 512, 2998])
torch.Size([16, 512, 2998])
torch.Size([16, 512, 2998])
torch.Size([16, 512, 2998])
torch.Size([16, 512, 2998])
torch.Size([16, 512, 2998])
torch.Size([16, 512, 2998])
torch.Size([16, 512, 2998])
torch.Size([16, 512, 2998])
torch.Size([16, 512, 2998])
torch.Size([16, 512, 2998])
torch.Size([16, 512, 2998])
torch.Size([16, 512, 2998])
torch.Size([16, 512, 2998])
torch.Size([16, 512, 2998])
torch.Size([16, 512, 2998])
torch.Size([16, 512, 2998])
torch.Size([16, 512, 2998])
torch.Size([16, 512, 2998])
torch.Size([16, 512, 2998])
torch.Size([16, 512, 2998])


KeyboardInterrupt: 