In [23]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [24]:
import IPython.display as ipd
import numpy as np, pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.file_utils import ModelOutput
from transformers import AutoConfig, Wav2Vec2Processor
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.models.wav2vec2.modeling_wav2vec2 import (
    Wav2Vec2PreTrainedModel,
    Wav2Vec2Model
)

from pathlib import Path
from tqdm import tqdm
import torchaudio, os, sys, json, pickle, librosa

from dataclasses import dataclass
from typing import Optional, Tuple
from tqdm import tqdm

In [25]:


@dataclass
class SpeechClassifierOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    hidden_rep: Optional[Tuple[torch.FloatTensor]] = None



class Wav2Vec2ClassificationHead(nn.Module):
    """Head for wav2vec classification task."""

    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.dropout = nn.Dropout(config.final_dropout)
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, features, **kwargs):
        x = features
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x0 = self.dropout(x)
#         print('----------------------------')
#         print(x0[:,-10:])
#         print(x0.shape)
#         print('----------------------------')
        x1 = self.out_proj(x0)
        return x0, x1

class Wav2Vec2ForSpeechClassification(Wav2Vec2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.pooling_mode = config.pooling_mode
        self.config = config

        self.wav2vec2 = Wav2Vec2Model(config)
        self.classifier = Wav2Vec2ClassificationHead(config)

        self.init_weights()
        
    def freeze_feature_extractor(self):
        self.wav2vec2.feature_extractor._freeze_parameters()
        for module in self.wav2vec2.encoder.layers[:10]:
            for param in module.parameters():
                param.requires_grad = False

    def merged_strategy(self, hidden_states, mode="mean"):
        if mode == "mean":
            outputs = torch.mean(hidden_states, dim=1)
        elif mode == "sum":
            outputs = torch.sum(hidden_states, dim=1)
        elif mode == "max":
            outputs = torch.max(hidden_states, dim=1)[0]
        else:
            raise Exception("The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']")

        return outputs

    def forward(
            self,
            input_values,
            attention_mask=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
            labels=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        outputs = self.wav2vec2(
            input_values,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = outputs[0]
        hidden_states = self.merged_strategy(hidden_states, mode=self.pooling_mode)
        hidden_rep, logits = self.classifier(hidden_states)

        loss = None
        if labels is not None:
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)

        if not return_dict:
            output = (hidden_rep + logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SpeechClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            hidden_rep=hidden_rep
        )

In [26]:
!export CUDA_VISIBLE_DEVICES=0
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.cuda.current_device()

0

In [27]:
# model_name_or_path = "/mnt/data/aman/mayank/MTP/mount_points/jan_19/Error-Driven-ASR-Personalization/MCV_accent/data/dristi_accent-recognition/checkpoint-6400/"
model_name_or_path = "/home/mayank/MTP/begin_again/Error-Driven-ASR-Personalization/mz-isca/classifier-data/training_data/8acc_10freeze/checkpoint-4000/"
config = AutoConfig.from_pretrained(model_name_or_path)
processor = Wav2Vec2Processor.from_pretrained(model_name_or_path)
sampling_rate = processor.feature_extractor.sampling_rate
model = Wav2Vec2ForSpeechClassification.from_pretrained(model_name_or_path).to(device)

In [28]:
!gpustat

[1m[37mswara[m  Thu Mar 24 13:05:28 2022
[36m[0][m [34mGeForce GTX 1080 Ti[m |[1m[31m 79'C[m, [32m 12 %[m | [36m[1m[33m 5176[m / [33m11178[m MB | [1m[30mpiyush[m([33m781M[m) [1m[30mpiyush[m([33m779M[m) [1m[30mmayank[m([33m2837M[m) [1m[30mpiyush[m([33m773M[m)
[36m[1][m [34mGeForce GTX 1080 Ti[m |[1m[31m 60'C[m, [32m 11 %[m | [36m[1m[33m 7697[m / [33m11178[m MB | [1m[30mmayank[m([33m3139M[m) [1m[30mmayank[m([33m2345M[m) [1m[30mmayank[m([33m2209M[m)
[36m[2][m [34mGeForce GTX 1080 Ti[m |[1m[31m 78'C[m, [32m  0 %[m | [36m[1m[33m10363[m / [33m11178[m MB | [1m[30mmayank[m([33m10359M[m)
[36m[3][m [34mGeForce GTX 1080 Ti[m |[1m[31m 77'C[m, [1m[32m 35 %[m | [36m[1m[33m 5747[m / [33m11177[m MB | [1m[30mmayank[m([33m5743M[m)


In [29]:
def speech_file_to_array_fn(path, sampling_rate):
    speech_array, _sampling_rate = torchaudio.load(path)
    resampler = torchaudio.transforms.Resample(_sampling_rate)
    speech = resampler(speech_array).squeeze().numpy()
    return speech

def predict(path, sampling_rate):
#     print(path)
    speech = speech_file_to_array_fn(path, sampling_rate)
    features = processor(speech, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
    input_values = features.input_values.to(device)
    attention_mask = None

    with torch.no_grad():
        op = model(input_values, attention_mask=attention_mask)
        logits = op.logits
        hidden_rep = op.hidden_rep
        
    scores = F.softmax(logits, dim=1).detach().cpu().numpy()[0]
    outputs = [{"Accent": config.id2label[i], "Score": f"{round(score * 100, 3):.1f}%"} for i, score in enumerate(scores)]
    return outputs, hidden_rep
#     return outputs, logits

def prediction(df_row):
    if 'path' in df_row: path = df_row["path"]
    else: path = df_row["audio_filepath"]
    speech, sr = torchaudio.load(path)
    speech = speech[0].numpy().squeeze()
    speech = librosa.resample(np.asarray(speech), sr, sampling_rate)
    outputs, hidden_rep = predict(path, sampling_rate)
#     print(hidden_rep[:,-10:])
    return hidden_rep

def extract_features(file_list, file_dir):
    with open(file_dir.replace('.json', '_wv8.file'), 'wb') as f:
        for file in tqdm(file_list):
            w2v2_features = prediction(file).cpu().detach().numpy()
            pickle.dump(w2v2_features, f)

In [30]:
def extract_features_to_dir(file_list, file_path, feature):
    file_dir, accent = '/'.join(file_path.split('/')[:-1]), file_path.split('/')[-1].split('.json')[0]
    print(file_dir, accent)
    with open("{}/{}/{}_{}.file".format(file_dir, feature, accent, feature), 'wb') as f:
        for file in tqdm(file_list):
            file['audio_filepath'] = file['audio_filepath'].replace('/wav/', '/clips/').replace('.wav', '.mp3')
            features = prediction(file).cpu().detach().numpy()
            pickle.dump(features, f)

In [31]:
base_dir = '../mz-isca/classifier-data/inval'


In [32]:
base_dir = '../mz-isca/classifier-data/inval/'
feature = 'wv10'
os.makedirs(base_dir+feature, exist_ok = True)
jsons = [f.name for f in os.scandir(base_dir) if not(f.is_dir())]
print(jsons)

['philippines.json', 'scotland.json', 'hongkong.json', 'indian.json', 'us.json', 'england.json', 'ireland.json', 'african.json']


In [33]:
for json_file in jsons[4:6]:

#     seed_file_dir = manifests_path + 'seed.json'
#     seed_file = open(seed_file_dir)
#     seed_list = [json.loads(line.strip()) for line in seed_file]

#     print('seed_file_starting')
#     print(seed_file_dir)
#     extract_features(seed_list, seed_file_dir)
#     print(len(seed_list))
#     print('seed_file_ending ...\n')
    
    
#     selection_file_dir = manifests_path + 'selection.json'
#     selection_file = open(selection_file_dir)
#     selection_list = [json.loads(line.strip()) for line in selection_file]
    
#     print('selection_file_starting')
#     extract_features(selection_list, selection_file_dir)
#     print(len(selection_list))
#     print('selection_file_ending ...\n\n')
    
    
    test_file_name = base_dir+json_file
    test_file = open(test_file_name)
    test_list = [json.loads(line.strip()) for line in test_file]

    print('test_file_starting')
    extract_features_to_dir(test_list[:350], test_file_name, feature)
    print(len(test_list))
    print('test_file_ending ...\n\n')
    
    
#     dev_file_dir = manifests_path + 'dev.json'
#     dev_file = open(dev_file_dir)
#     dev_list = [json.loads(line.strip()) for line in dev_file]

#     print('dev_file_starting')
#     print(dev_file_dir)
#     extract_features(dev_list, dev_file_dir)
#     print(len(dev_list))
#     print('dev_file_ending ...\n')

test_file_starting
../mz-isca/classifier-data/inval us


100%|██████████| 350/350 [02:54<00:00,  2.01it/s]


2500
test_file_ending ...


test_file_starting
../mz-isca/classifier-data/inval england


100%|██████████| 350/350 [02:44<00:00,  2.12it/s]

2500
test_file_ending ...







In [13]:
# base_dir = 'accent-without/'

# accents = ['hindi', 'chinese', 'spanish', 'arabic', 'korean', 'vietnamese']

# for accent in accents:
#     manifests_path = base_dir + accent + '/manifests/'
#     print('_'*20)
#     print(accent)

# #     seed_file_dir = manifests_path + 'seed.json'
# #     seed_file = open(seed_file_dir)
# #     seed_list = [json.loads(line.strip()) for line in seed_file]

# #     print('seed_file_starting')
# #     print(seed_file_dir)
# #     extract_features(seed_list, seed_file_dir)
# #     print(len(seed_list))
# #     print('seed_file_ending ...\n')
    
    
# #     selection_file_dir = manifests_path + 'selection.json'
# #     selection_file = open(selection_file_dir)
# #     selection_list = [json.loads(line.strip()) for line in selection_file]
    
# #     print('selection_file_starting')
# #     extract_features(selection_list, selection_file_dir)
# #     print(len(selection_list))
# #     print('selection_file_ending ...\n\n')
    
    
#     test_file_dir = manifests_path + 'test.json'
#     test_file = open(test_file_dir)
#     test_list = [json.loads(line.strip()) for line in test_file]

#     print('test_file_starting')
#     extract_features(test_list, test_file_dir)
#     print(len(test_list))
#     print('test_file_ending ...\n\n')
    
    
# #     dev_file_dir = manifests_path + 'dev.json'
# #     dev_file = open(dev_file_dir)
# #     dev_list = [json.loads(line.strip()) for line in dev_file]

# #     print('dev_file_starting')
# #     print(dev_file_dir)
# #     extract_features(dev_list, dev_file_dir)
# #     print(len(dev_list))
# #     print('dev_file_ending ...\n')

____________________
hindi
test_file_starting


100%|██████████| 1224/1224 [12:39<00:00,  1.61it/s]


1224
test_file_ending ...


____________________
chinese
test_file_starting


100%|██████████| 1224/1224 [13:32<00:00,  1.51it/s]


1224
test_file_ending ...


____________________
spanish
test_file_starting


100%|██████████| 1191/1191 [17:50<00:00,  1.11it/s]


1191
test_file_ending ...


____________________
arabic
test_file_starting


100%|██████████| 1182/1182 [13:59<00:00,  1.41it/s]


1182
test_file_ending ...


____________________
korean
test_file_starting


100%|██████████| 1224/1224 [13:52<00:00,  1.47it/s]


1224
test_file_ending ...


____________________
vietnamese
test_file_starting


100%|██████████| 1224/1224 [10:01<00:00,  2.03it/s]

1224
test_file_ending ...







In [14]:
# jsons_path = '.json'
# jsons = [f.name for f in os.scandir('../MCV_accent/invalidated/') if '.json' in f.name and f.name.split('.')[0] not in ['unlabelled',
# def path_proc(pth):
#     return pth.replace('./', '~/MTP/begin_again/Error-Driven-ASR-Personalization/mz-expts/')                                                                                                            'other']]
# print(jsons)

# for file in tqdm(jsons):
#     print('_'*20)
    
    
#     json_file_path = '../MCV_accent/invalidated/' + file 
#     json_file = open(json_file_path)
#     json_list = [json.loads(line.strip()) for line in json_file]
#     print(json_file_path)
    
#     extract_features(json_list, json_file_path)
#     print(len(json_list))
#     print('_'*20, '\n\n')