In [1]:
import os
import sys
sys.path.append("accent-embeddings")

%env DATASET_PATH=/content/VCTK-Corpus-0.92

env: DATASET_PATH=/content/VCTK-Corpus-0.92


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
from torchaudio.transforms import GriffinLim
import IPython.display as ipd
import numpy as np
import librosa
import matplotlib.pyplot as plt
import pytorch_lightning as pl

from text import text_to_sequence, sequence_to_text

from hyper_params import *
from train import load_data

from models.tacotron2 import Tacotron2
from models.wav2vec_id import Wav2VecID
from models.wav2vec_asr import Wav2VecASR
from multitask import AccentedMultiTaskNetwork, Task
from metrics import *

from transformers import Wav2Vec2Processor 

import wandb
from tqdm import tqdm

In [4]:
train_loader, val_loader = load_data(TrainingParams(val_size=0.1, batch_size=1), DataParams(filter_length=800, sample_rate=16000, win_length=800, hop_length=200))

INFO: Loading Audio Lengths
Number of samples:  37372


# Load the Model

In [5]:
run = wandb.init()
artifact = run.use_artifact('g-luo/accent_embeddings/model-3lllyvm6:v3', type='model')
artifact_dir = artifact.download()

[34m[1mwandb[0m: Currently logged in as: [33maparande[0m (use `wandb login --relogin` to force relogin)


[34m[1mwandb[0m: Downloading large artifact model-3lllyvm6:v3, 3901.45MB. 1 files... Done. 0:0:0


In [6]:
tacotron = Tacotron2(TacotronParams())
tts_task = Task(model=tacotron, loss=None, learning_rate=1e-3, weight_decay=1e-6, name='TTS', loss_weight=1, metrics=[MSE()])

In [7]:
asr = Wav2VecASR(Wav2VecASRParams())
asr_task = Task(model=asr, loss=None, learning_rate=1e-5, weight_decay=0, name='ASR', loss_weight=0.5, metrics=[WERAccuracy()])

In [8]:
accent_id = Wav2VecID(Wav2VecIDParams())
accent_id_task = Task(model=accent_id, loss=None, learning_rate=1e-5, name='ID', weight_decay=0, loss_weight=1, metrics=[SoftmaxAccuracy()])

In [9]:
mp = MultiTaskParams(hidden_dim=[13], in_dim=1024, alternate_epoch_interval = 2)
model = AccentedMultiTaskNetwork.load_from_checkpoint(f"{artifact_dir}/model.ckpt", params=mp, tasks=[accent_id_task, asr_task, tts_task]).eval().cuda()

Some weights of the model checkpoint at facebook/wav2vec2-large-960h were not used when initializing Wav2Vec2Model: ['lm_head.bias', 'lm_head.weight']
- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-large-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
def predict_batch(batch, predict=False, accent_embed=None):
  batch["wav2vec_input"] = batch["wav2vec_input"].cuda()
  wav2vec_feats = model.get_wav2vec_features(batch)

  if accent_embed is None:
    accent_embed = model.bottleneck(wav2vec_feats)

  outs = dict()

  if predict:
    batch["text_tensor"] = batch["text_tensor"].cuda()
    for task in model.tasks:
      x = task.model.parse_batch(batch)
      outs[task.name] = task.model(x, accent_embed)

  return accent_embed, outs

# Data Collection

In [11]:
label_map = {
  4: 'English',
  9: 'Scottish',
  8: 'NorthernIrish',
  6: 'Irish',
  5: 'Indian',
  12: 'Welsh',
  11: 'Unknown',
  0: 'American',
  3: 'Canadian',
  10: 'SouthAfrican',
  1: 'Australian',
  7: 'NewZealand',
  2: 'British'
}

In [21]:
accents = []
embeddings = []
for i, batch in enumerate(tqdm(val_loader)):
  embedding, x = predict_batch(batch)
  
  accents.append(label_map[batch["accents"].data.cpu().numpy()[0]])
  embeddings.append(embedding)

  0%|          | 11/3737 [00:00<03:09, 19.64it/s]


In [22]:
avg_embed = {}
accent_counts = {}
for embed, accent in tqdm(zip(embeddings, accents)):
  if accent in avg_embed:
    avg_embed[accent] += embed
  else:
    avg_embed[accent] = embed

  accent_counts[accent] = accent_counts.get(accent, 0) + 1

12it [00:00, 27822.91it/s]


In [23]:
accent_embeds = { accent: avg_embed[accent] / accent_counts[accent] for accent in accent_counts }

In [24]:
accented_batches = { accent: [] for accent in accent_embeds }

for i, batch in enumerate(val_loader):
  accent = label_map[batch["accents"].data.numpy()[0]]
  if len(accented_batches[accent]) < 25:
    accented_batches[accent].append(batch)

In [None]:
results = []

for orig_accent in accented_batches:
  for batch in accented_batches[orig_accent]:
    embedding, out = predict_batch(batch, predict=True)

    for task in out:
      if task == "TTS":
        out[task]["mfcc"] = out[task].pop("mel_out_postnet")
      out[task] = [out[task]]

    asr_targets = [model.tasks[1].model.get_targets(batch)]
    orig_asr_wer = model.tasks[1].metrics[0](out["ASR"], asr_targets)

    for target_accent in accent_embeds:
      accent_scores = np.zeros(len(model.tasks))
      embedding, accented_out = predict_batch(batch, predict=True, accent_embed=accent_embeds[target_accent])

      for task in accented_out:
        accented_out[task] = [accented_out[task]]

      for i, task in enumerate(model.tasks):
        if task.name == "ASR":
          accent_scores[i] = orig_asr_wer - task.metrics[0](accented_out[task.name], asr_targets)
        else:
          accent_scores[i] = task.metrics[0](accented_out[task.name], out[task.name])

      results.append((orig_accent, target_accent, *accent_scores))

In [None]:
import pickle

In [None]:
with open("results.pkl") as f:
  pickle.dump(results, f)