 # Notebook used for evaluate the wav2vec with the huggingface model

In [1]:
from datasets import load_dataset, load_metric
from datasets import load_from_disk
import argparse
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML
import re
import json
import numpy as np
import random
import torchaudio
import librosa
import os
import json
import torch
import gc
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from pytorch_lightning.loggers.neptune import NeptuneLogger

## Parameters

In [2]:
version = "wav2vec2_fairseq_target_librilabels_test1" #@param {type: "string"}
lr = 1e-4#@param {type: "number"}
w_decay = 0#@param {type: "number"}
bs = 5#@param {type: "integer"}
accum_grads = 4#@param {type: "integer"}
patience = 20#@param {type: "integer"}
max_epochs = 300#@param {type: "integer"}
# warmup_steps = 1000#@param {type: "integer"}
hold_epochs = 20#@param {type: "integer"}

# Define hyperparameters
hparams = {"version": version,
          "lr": lr,
          "w_decay": w_decay,
          "bs": bs,
          "patience": patience,
          "hold_epochs":hold_epochs,
          "accum_grads": accum_grads,
          "max_epochs": max_epochs}
hparams

{'version': 'wav2vec2_fairseq_target_librilabels_test1',
 'lr': 0.0001,
 'w_decay': 0,
 'bs': 5,
 'patience': 20,
 'hold_epochs': 20,
 'accum_grads': 4,
 'max_epochs': 300}

In [3]:
from fairseq.data.audio.raw_audio_dataset import RawAudioDataset
from fairseq.tasks.audio_pretraining import AudioPretrainingTask, AudioPretrainingConfig
from fairseq.dataclass import FairseqDataclass
from fairseq.data import (
    AddTargetDataset,
    BinarizedAudioDataset,
    Dictionary,
    FileAudioDataset,
    encoders,
)
from fairseq.criterions.ctc import CtcCriterion, CtcCriterionConfig
from fairseq.tasks import FairseqTask
from fairseq.data.data_utils import post_process
import fairseq
from fairseq.models.wav2vec.wav2vec2_asr import Wav2VecEncoder, Wav2VecCtc, Wav2Vec2AsrConfig, Wav2Vec2CtcConfig

In [4]:
class AudioDataset(RawAudioDataset):

  def __init__(
        self,
        data,
        split,
        sample_rate,
        max_sample_size=None,
        min_sample_size=0,
        shuffle=True,
        pad=False,
        normalize=False,
        num_buckets=0,
        compute_mask_indices=False,
        **mask_compute_kwargs,
    ):
        super().__init__(
            sample_rate=sample_rate,
            max_sample_size=max_sample_size,
            min_sample_size=min_sample_size,
            shuffle=shuffle,
            pad=pad,
            normalize=normalize,
            compute_mask_indices=compute_mask_indices,
            **mask_compute_kwargs,
        )
        self.data = data

  def __len__(self):
    return len(self.data)

  def __getitem__(self, index):
    
    wav = self.data[index]["speech"]
    feats = torch.from_numpy(np.asarray(wav)).float()
    feats = self.postprocess(feats, curr_sample_rate=self.data[index]["sampling_rate"])
    return {"id": index, "source": feats}

In [5]:
class LabelEncoder(object):
    def __init__(self, dictionary):
        self.dictionary = dictionary

    def __call__(self, label):
        return self.dictionary.encode_line(
            label,
            append_eos=False,
            add_if_not_exist=False
        )

class AudioFinetuneTask(AudioPretrainingTask):
  def __init__(
        self,
        cfg: AudioPretrainingConfig,
    ):
        super().__init__(cfg)

  def load_dataset(self, data, split: str, task_cfg: FairseqDataclass = None, **kwargs):
        task_cfg = task_cfg or self.cfg

        self.datasets[split] = AudioDataset(
            data,
            split=split,
            sample_rate=task_cfg.get("sample_rate", self.cfg.sample_rate),
            max_sample_size=self.cfg.max_sample_size,
            min_sample_size=self.cfg.min_sample_size,
            pad=task_cfg.labels is not None or task_cfg.enable_padding,
            normalize=task_cfg.normalize,
        )
      
        process_label = LabelEncoder(self.target_dictionary)

        self.datasets[split] = AddTargetDataset(
            dataset=self.datasets[split],
            labels=data["target_text"],
            pad=self.target_dictionary.pad(),
            eos=self.target_dictionary.eos(),
            batch_targets=True,
            process_label=process_label,
            # add_to_input=task_cfg.get("autoregressive", False),
        )      

In [6]:
import torch.nn.functional as F
from fairseq import metrics, utils

class CtcCriterionFinetuning(CtcCriterion):
  def __init__(self, cfg: CtcCriterionConfig, task: FairseqTask):
    super().__init__(cfg, task)
  
  def batch_decode(self, targets):
    target_sentences = list()
    with torch.no_grad():
      for t in targets:
        # print('t', t)
        p = (t != self.task.target_dictionary.pad()) & (t != self.task.target_dictionary.eos())
        targ = t[p]
        targ_units = self.task.target_dictionary.string(targ)
        # print(targ_units)

        targ_words = post_process(targ_units, self.post_process)  
        target_sentences.append(targ_words)
    return target_sentences   
  
  def batch_decode_pred(self, model, sample, reduce=True):
    net_output = model(**sample["net_input"])
    lprobs = model.get_normalized_probs(
        net_output, log_probs=True
    ).contiguous()  # (T, B, C) from the encoder

    if "src_lengths" in sample["net_input"]:
        input_lengths = sample["net_input"]["src_lengths"]
    else:
        if net_output["padding_mask"] is not None:
            non_padding_mask = ~net_output["padding_mask"]
            input_lengths = non_padding_mask.long().sum(-1)
        else:
            input_lengths = lprobs.new_full(
                (lprobs.size(1),), lprobs.size(0), dtype=torch.long
            )

    if not model.training:

        with torch.no_grad():
            lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()

            pred_lp_words = []
            for lp, t, inp_l in zip(
                lprobs_t,
                sample["target_label"]
                if "target_label" in sample
                else sample["target"],
                input_lengths,
            ):
                lp = lp[:inp_l].unsqueeze(0)

                toks = lp.argmax(dim=-1).unique_consecutive()
                pred_units_arr = toks[toks != self.blank_idx].tolist()

                pred_units = self.task.target_dictionary.string(pred_units_arr)
                pred_words = post_process(pred_units, self.post_process)
                pred_lp_words.append(pred_words)

    return pred_lp_words 

In [7]:
# !mkdir models
# !wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr_53_56k.pt
# !mv xlsr_53_56k.pt models

In [8]:
cp_path = './models/xlsr_53_56k.pt'
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])

model = model[0]
# model.eval()

# modelo tem 24 camadas transformers e embed_dim de 1024
print(model)
del model
gc.collect()
torch.cuda.empty_cache()

Wav2Vec2Model(
  (feature_extractor): ConvFeatureExtractionModel(
    (conv_layers): ModuleList(
      (0): Sequential(
        (0): Conv1d(1, 512, kernel_size=(10,), stride=(5,))
        (1): Dropout(p=0.0, inplace=False)
        (2): Sequential(
          (0): TransposeLast()
          (1): Fp32LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (2): TransposeLast()
        )
        (3): GELU()
      )
      (1): Sequential(
        (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
        (1): Dropout(p=0.0, inplace=False)
        (2): Sequential(
          (0): TransposeLast()
          (1): Fp32LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (2): TransposeLast()
        )
        (3): GELU()
      )
      (2): Sequential(
        (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
        (1): Dropout(p=0.0, inplace=False)
        (2): Sequential(
          (0): TransposeLast()
          (1): Fp32LayerNorm((512,), eps=1e-05, elementwise_affine=True

In [9]:
# print(cfg['task'].enable_padding) 

cfg['task'].enable_padding = True

print('enable_padding:', cfg['task'].enable_padding)

cfg['task'].eval_wer = True

print('eval_wer:', cfg['task'].eval_wer) 

cfg['task'].labels = 'keys'

print('labels:', cfg['task'].labels) 

finetune_task = AudioFinetuneTask(cfg=cfg['task'])

finetune_task.cfg.data = 'data'
# finetune_task.cfg.labels = 'keys'
print(finetune_task.cfg.labels)

enable_padding: True
eval_wer: True
labels: keys
keys


In [10]:
wav2vec2asr_config = Wav2Vec2AsrConfig
# wav2vec2asr_config.w2v_path = './models/wav2vec_small.pt'
wav2vec2asr_config.w2v_path = './models/xlsr_53_56k.pt'
wav2vec2asr_config.normalize = True

wav2vec2asr_config.feature_grad_mult = 0
wav2vec2asr_config.apply_mask=True
wav2vec2asr_config.mask_prob=0.05
wav2vec2asr_config.mask_channel_prob=0.05
wav2vec2asr_config.layerdrop=0.01
wav2vec2asr_config.activation_dropout=0.01
wav2vec2asr_config.dropout=0.01
wav2vec2asr_config.attention_dropout=0.01
wav2vec2asr_config.final_dropout=0.01
wav2vec2asr_config.mask_length=10
wav2vec2asr_config.mask_channel_length=10


print("normalize:", wav2vec2asr_config.normalize)

wav2vec2ctc_config = Wav2Vec2CtcConfig
wav2vec2ctc_config.w2v_path = './models/xlsr_53_56k.pt'
wav2vec2ctc_config.normalize = cfg['task']['normalize']
wav2vec2ctc_config.conv_feature_layers = cfg['model']['conv_feature_layers']
wav2vec2ctc_config.encoder_embed_dim = cfg['model']['encoder_embed_dim']
print(wav2vec2ctc_config.w2v_args)


normalize: True
None


In [11]:
ctc_config = CtcCriterionConfig

In [12]:

finetune_task.state.target_dictionary = finetune_task.load_target_dictionary()
print(finetune_task.target_dictionary.symbols)

print(finetune_task.target_dictionary.indices)

task_dict = finetune_task.target_dictionary.indices

print(len(task_dict))

w2v_encoder = Wav2VecEncoder(cfg=wav2vec2asr_config, output_size=len(task_dict))

['<s>', '<pad>', '</s>', '<unk>', '|', 'a', 'e', 'o', 's', 'r', 'm', 'i', 'n', 'u', 'd', 't', 'c', 'l', 'p', 'v', 'h', 'g', 'b', 'f', 'q', 'ã', 'á', 'é', 'ç', 'ê', 'z', 'j', 'x', 'ó', 'í', 'ú', 'õ', 'k', 'à', 'y', 'w', 'ô', 'â', 'ü', "'", 'ñ']
{'<s>': 0, '<pad>': 1, '</s>': 2, '<unk>': 3, '|': 4, 'a': 5, 'e': 6, 'o': 7, 's': 8, 'r': 9, 'm': 10, 'i': 11, 'n': 12, 'u': 13, 'd': 14, 't': 15, 'c': 16, 'l': 17, 'p': 18, 'v': 19, 'h': 20, 'g': 21, 'b': 22, 'f': 23, 'q': 24, 'ã': 25, 'á': 26, 'é': 27, 'ç': 28, 'ê': 29, 'z': 30, 'j': 31, 'x': 32, 'ó': 33, 'í': 34, 'ú': 35, 'õ': 36, 'k': 37, 'à': 38, 'y': 39, 'w': 40, 'ô': 41, 'â': 42, 'ü': 43, "'": 44, 'ñ': 45}
46


In [13]:
ctc_criterion = CtcCriterionFinetuning(cfg=ctc_config, task=finetune_task).cuda()

In [14]:
model_ctc = Wav2VecCtc(cfg=wav2vec2ctc_config, w2v_encoder=w2v_encoder)

print(model_ctc)

Wav2VecCtc(
  (w2v_encoder): Wav2VecEncoder(
    (w2v_model): Wav2Vec2Model(
      (feature_extractor): ConvFeatureExtractionModel(
        (conv_layers): ModuleList(
          (0): Sequential(
            (0): Conv1d(1, 512, kernel_size=(10,), stride=(5,))
            (1): Dropout(p=0.0, inplace=False)
            (2): Sequential(
              (0): TransposeLast()
              (1): Fp32LayerNorm((512,), eps=1e-05, elementwise_affine=True)
              (2): TransposeLast()
            )
            (3): GELU()
          )
          (1): Sequential(
            (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
            (1): Dropout(p=0.0, inplace=False)
            (2): Sequential(
              (0): TransposeLast()
              (1): Fp32LayerNorm((512,), eps=1e-05, elementwise_affine=True)
              (2): TransposeLast()
            )
            (3): GELU()
          )
          (2): Sequential(
            (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
            

In [15]:
class Wav2VecNet(pl.LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__()

        self.hparams = hparams
        
        self.model = model_ctc.cuda()

        # freeze feature_extractor
        for param in self.model.w2v_encoder.w2v_model.feature_extractor.parameters():
            param.requires_grad = False

        self.criterion = ctc_criterion

    def forward(self, sample):
        
        net_output = self.model(**sample["net_input"])
        logits = self.model.get_logits(net_output)
        
        return logits

    def predict_step(self, batch, batch_idx, dataloader_idx=None):
        pred_strings = self.criterion.batch_decode_pred(model=self.model.eval(),
                                                        sample=batch)
        return pred_strings    

    def training_step(self, train_batch, batch_idx):

        # loss ctc compute
        loss, sample_size, logging_output = self.criterion.forward(model=self.model,
                                                                  sample=train_batch)

        self.log('ctc_loss_step', loss, on_step=True, prog_bar=True)
        
        return loss

    def training_epoch_end(self, outputs):
        loss = torch.stack([x['loss'] for x in outputs]).mean()       

        self.log("train_loss", loss, prog_bar=True)
  
    def validation_step(self, val_batch, batch_idx):

        # predict 
        val_loss, sample_size, logging_output = self.criterion.forward(model=self.model,
                                                                  sample=val_batch)

        # logits = self.forward(val_batch)
        
        with metrics.aggregate("val"):
          self.criterion.reduce_metrics([logging_output])

        wer = metrics.get_smoothed_values('val')['wer']

        self.log('val_loss_step', val_loss, prog_bar=True)
        self.log('val_wer_step', wer, prog_bar=True)

        return {"val_loss_step": val_loss, "val_wer_step": wer}

    def validation_epoch_end(self, outputs):
        val_loss = torch.stack([x['val_loss_step'] for x in outputs]).mean()
        val_wer = np.stack([x['val_wer_step'] for x in outputs]).mean()

        self.log("val_loss", val_loss, prog_bar=True)
        self.log("val_wer", val_wer, prog_bar=True)
  
    def test_step(self, test_batch, batch_idx):
        
        test_loss, sample_size, logging_output = self.criterion.forward(model=self.model,
                                                                  sample=test_batch)

        # logits = self.forward(test_batch)
        
        with metrics.aggregate("test"):
          self.criterion.reduce_metrics([logging_output])
        
        wer = metrics.get_smoothed_values('test')['wer']

        self.log("test_loss_step", test_loss, prog_bar=True)
        self.log("test_wer_step", wer, prog_bar=True)
        
        return {"test_loss_step": test_loss, "test_wer_step": wer}

    def test_epoch_end(self, outputs):
        loss = torch.stack([x['test_loss_step'] for x in outputs]).mean()
        wer = np.stack([x['test_wer_step'] for x in outputs]).mean()

        self.log("test_loss", loss, prog_bar=True)
        self.log("test_wer", wer, prog_bar=True)  

    def configure_optimizers(self):

        optimizer = torch.optim.Adam(self.parameters(),
                         lr=self.hparams["lr"],
                         betas=(0.9,0.98),
                         eps=1e-6,
                         weight_decay=self.hparams["w_decay"])
        
        scheduler = LinearWarmupCosineAnnealingLR(optimizer, 
                                                  eta_min=0, # final-lr
                                                  warmup_start_lr=self.hparams["lr"],
                                                  warmup_epochs=self.hparams["hold_epochs"], # hold_epochs
                                                  max_epochs=self.hparams["max_epochs"])
        
        return {'optimizer': optimizer, 'lr_scheduler': scheduler}
        # return optimizer

In [16]:
best_model = '/home/nm/phd-wav2vec2-xlsr-53/notebooks/Wav2Vec2_Fairseq/wav2vec2_fairseq_target_librilabels_test1-epoch=299-step=149999.ckpt'

print(best_model)

model = Wav2VecNet.load_from_checkpoint(best_model, hparams=hparams)
model.to("cuda")

/home/nm/phd-wav2vec2-xlsr-53/notebooks/Wav2Vec2_Fairseq/wav2vec2_fairseq_target_librilabels_test1-epoch=299-step=149999.ckpt


Wav2VecNet(
  (model): Wav2VecCtc(
    (w2v_encoder): Wav2VecEncoder(
      (w2v_model): Wav2Vec2Model(
        (feature_extractor): ConvFeatureExtractionModel(
          (conv_layers): ModuleList(
            (0): Sequential(
              (0): Conv1d(1, 512, kernel_size=(10,), stride=(5,))
              (1): Dropout(p=0.0, inplace=False)
              (2): Sequential(
                (0): TransposeLast()
                (1): Fp32LayerNorm((512,), eps=1e-05, elementwise_affine=True)
                (2): TransposeLast()
              )
              (3): GELU()
            )
            (1): Sequential(
              (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
              (1): Dropout(p=0.0, inplace=False)
              (2): Sequential(
                (0): TransposeLast()
                (1): Fp32LayerNorm((512,), eps=1e-05, elementwise_affine=True)
                (2): TransposeLast()
              )
              (3): GELU()
            )
            (2): Sequential(
    

In [17]:
from enelvo import normaliser

norm = normaliser.Normaliser()

In [18]:
test_dataset = load_dataset("common_voice", "pt", split="test")
wer = load_metric("wer")


chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\'\�]'
resampler = torchaudio.transforms.Resample(48_000, 16_000)
norm = normaliser.Normaliser()

Reusing dataset common_voice (/home/nm/.cache/huggingface/datasets/common_voice/pt/6.1.0/bb59ce0bb532485ab64b5d488a8dd2addc3104f694e06bcd2c272dc608bb1112)


# Evaluate using enelvo (4-gram)

In [19]:
 # Preprocessing the datasets.
# We need to read the aduio files as arrays
def speech_file_to_array_fn(batch):
    batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower()
    
    batch["target_text"] = " ".join(list(batch["sentence"].replace(" ", "|"))) + " |"
    
    batch["sampling_rate"] = 16_000
    
    speech_array, sampling_rate = torchaudio.load(batch["path"])
    batch["speech"] = resampler(speech_array).squeeze().numpy()
    return batch

test_dataset = test_dataset.map(speech_file_to_array_fn)

HBox(children=(FloatProgress(value=0.0, max=4641.0), HTML(value='')))




In [20]:
finetune_task.load_dataset(test_dataset, split='test')
test_dataset_process = finetune_task.datasets['test']

In [21]:
def data_collator(samples):
    # print(samples)
    return test_dataset_process.collater(samples)

In [22]:
batch_size = hparams["bs"]

test_dataloader = DataLoader(test_dataset_process, batch_size=batch_size,
                              collate_fn = data_collator,
                              shuffle=True, num_workers=8)

In [23]:
batch = next(iter(test_dataloader))
print(batch)

{'id': tensor([3779, 3877, 2803, 4199, 1920]), 'net_input': {'source': tensor([[ 0.0003,  0.0003,  0.0003,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0003,  0.0003,  0.0003,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0740, -0.0740, -0.0740,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0031, -0.0031, -0.0031,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0012,  0.0012,  0.0012,  ...,  0.1398,  0.1341,  0.1411]]), 'padding_mask': tensor([[False, False, False,  ...,  True,  True,  True],
        [False, False, False,  ...,  True,  True,  True],
        [False, False, False,  ...,  True,  True,  True],
        [False, False, False,  ...,  True,  True,  True],
        [False, False, False,  ..., False, False, False]])}, 'target_lengths': tensor([34, 48, 46, 47, 60]), 'ntokens': 235, 'target': tensor([[ 6,  4, 19,  7, 16, 29,  4, 10,  6,  4, 16,  7, 12, 15,  7, 13,  4,  8,
          7, 22,  9,  6,  4, 18,  9,  6,  8,  8, 26, 21, 11,  7,  8,  4,  1,  1,
          1,  1,  1,  1,  1,  1,

In [24]:
y_true = list()
y_pred = list()
y_pred_norm = list()

for i, batch in enumerate(test_dataloader):
  batch['net_input']['source'] = batch['net_input']['source'].cuda()
  batch['net_input']['padding_mask'] = batch['net_input']['padding_mask'].cuda()

  labels = ctc_criterion.batch_decode(batch["target"])
  print('labels:', labels[0])
  
  pred_strings = model.predict_step(batch, i)
  print('pred:', pred_strings[0])

  pred_strings_norm = [norm.normalise(i) for i in pred_strings]

  print('pred norm:', pred_strings_norm[0])
  print('---------------------')
  
  y_true.append(labels)
  y_pred.append(pred_strings)
  y_pred_norm.append(pred_strings_norm)

y_true = np.concatenate(y_true)
y_pred = np.concatenate(y_pred)  
y_pred_norm = np.concatenate(y_pred_norm)

labels: ou então vai para quintafeira
pred: o então vai para quintafeira
pred norm: o então vai para tarantela
---------------------
labels: um homem sem camisa pinta o teto
pred: um homem sem camisa pinta o teto
pred norm: um homem sem camisa pinta o teto
---------------------
labels: quanto tempo duraria a entrevista
pred: quanto tempo duraria a entrevista
pred norm: quanto tempo duraria a entrevista
---------------------
labels: isso não é tudo
pred: isso não é tudo
pred norm: isso não é tudo
---------------------
labels: quais são os fatos
pred: quais são os patos
pred norm: quais são os patos
---------------------
labels: uma mulher aponta para o interior de um computador enquanto uma menina olha
pred: uma mulher aponta para o enterver de um computador enquanto uma menina ola
pred norm: uma mulher aponta para o enterter de um computador enquanto uma menina ola
---------------------
labels: tomando o ferry foi uma escolha sábia
pred: tomando o ferre foi uma escolha sabea
pred norm:

## WER with no-LM

In [25]:
print("WER: {:2f}".format(100 * wer.compute(predictions=y_pred, references=y_true)))

WER: 17.787671


## WER with enelvo (4-gram)

In [26]:
print("WER: {:2f}".format(100 * wer.compute(predictions=y_pred_norm, references=y_true)))

WER: 16.457778


In [27]:
del y_pred, y_pred_norm, y_true 
torch.cuda.empty_cache()

# Using PTT5 LM

In [28]:
class PTT5Net(pl.LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__()

        self.hparams = hparams

        # Note como a arquitetura esta dependente dos hiperparâmetros salvos.
        self.model = model_pt

        # for param in self.model.bert.parameters():
        #     param.requires_grad = False

        self.tokenizer = tokenizer

    def forward(self, input_values):
        logits = self.model(**input_values).prediction_logits
        return logits

    def predict_step(self, batch, batch_idx, dataloader_idx=None, gen_type=None):

        if gen_type==0:
          # normal beam search generation  
          pred_tokens = self._generate_tokens(batch["input_ids"])
        elif gen_type==1:
          # fast beam search generation
          pred_tokens = self._fast_generate_tokens(batch["input_ids"])  
        elif gen_type==2:
          # Top-K Sampling  generation
          pred_tokens = self._topK_generate_tokens(batch["input_ids"])  
        elif gen_type==3:   
          # Top-p generation
          pred_tokens = self._topp_generate_tokens(batch["input_ids"]) 
        elif gen_type==4:  
          # # Top-p e Top-K Sampling generation
          pred_tokens = self._toppK_generate_tokens(batch["input_ids"]) 
        else:
            # greedy decoding
          pred_tokens = self._greedy_generate_tokens(batch["input_ids"])

        # Tokens -> String
        decoded_pred = self.tokenizer.batch_decode(pred_tokens, skip_special_tokens=True)
        
        if 'labels' in batch:
            target = batch["labels"]
            decoded_target = [self.tokenizer.decode(tokens[tokens!=-100], skip_special_tokens=True) for tokens in target] 
            print(f"\nSample Target: {decoded_target[0]}\nPrediction: {decoded_pred[0]}\n")
            return decoded_pred, decoded_target   
        else:
            return decoded_pred
    
    def _greedy_generate_tokens(self, input_ids):
        
        decoded_ids = self.model.generate(
            input_ids, 
            # max_length=512,
            max_length=input_ids.shape[1]+1, 
            # pad_token_id=0,
            # eos_token_id=1,
            # early_stopping=True
        )


        return decoded_ids

    def _fast_generate_tokens(self, input_ids):
        '''
        Token generation
        '''
          
        # print(input_ids.shape)
        decoded_ids = self.model.generate(
            input_ids, 
            max_length=input_ids.shape[1]+1, 
            num_beams=50, 
            temperature=0.1,
            no_repeat_ngram_size=2, 
            num_return_sequences=1,
            length_penalty=0.8,
            repetition_penalty=0.8,
            num_beam_groups=5,
#             pad_token_id=0,
#             eos_token_id=1,
#             early_stopping=True
        )

        return decoded_ids   
    
    def _generate_tokens(self, input_ids):
        '''
        Token generation
        '''
          
        # print(input_ids.shape)
        decoded_ids = self.model.generate(
            input_ids, 
            max_length=input_ids.shape[1]+1, 
            num_beams=300, 
            temperature=0.1,
            no_repeat_ngram_size=2, 
            num_return_sequences=1,
            length_penalty=0.8,
            repetition_penalty=0.8,
            num_beam_groups=5,
#             pad_token_id=0,
#             eos_token_id=1,
#             early_stopping=True
        )

        return decoded_ids   
    
    def _topK_generate_tokens(self, input_ids):
        '''
        Token TopK generation
        '''
          
        # print(input_ids.shape)
        decoded_ids = self.model.generate(
            input_ids, 
            max_length=input_ids.shape[1]+1, 
            do_sample=True,  
            top_k=500, 
            temperature=0.1,
            num_return_sequences=1,
            length_penalty=0.8,
            repetition_penalty=0.8,
#             pad_token_id=0,
#             eos_token_id=1,
#             early_stopping=True
        )

        return decoded_ids
    
    def _topp_generate_tokens(self, input_ids):
        '''
        Token TopK generation
        '''
          
        # print(input_ids.shape)
        decoded_ids = self.model.generate(
            input_ids, 
            max_length=input_ids.shape[1]+1, 
            do_sample=True,  
            top_p=0.92, 
            top_k=0,
            temperature=0.1,
            num_return_sequences=1,
            length_penalty=0.8,
            repetition_penalty=0.8,
            pad_token_id=0,
            eos_token_id=1,
            early_stopping=True
        )

        return decoded_ids
    
    def _toppK_generate_tokens(self, input_ids):
        '''
        Token TopK generation
        '''
          
        # print(input_ids.shape)
        decoded_ids = self.model.generate(
            input_ids, 
            max_length=input_ids.shape[1]+1, 
            do_sample=True,  
            top_p=0.95, 
            top_k=1500,
            temperature=0.1,
            num_return_sequences=1,
            length_penalty=0.5,
            repetition_penalty=0.8,
#             pad_token_id=0,
#             eos_token_id=1,
#             early_stopping=True
        )

        return decoded_ids

    def training_step(self, train_batch, batch_idx):
        # loss compute
        loss = self.model(**train_batch).loss
        # print('loss', loss)

        self.log('cross_loss_step', loss, on_step=True, prog_bar=True)
        
        return loss

    def training_epoch_end(self, outputs):
        loss = torch.stack([x['loss'] for x in outputs]).mean()       

        self.log("train_loss", loss, prog_bar=True)
  
    def validation_step(self, val_batch, batch_idx):

        # predict 
        target = val_batch["labels"]
        val_loss = self.model(**val_batch).loss

        # pred_tokens = self._fast_generate_tokens(val_batch["input_ids"])
        pred_tokens = self._greedy_generate_tokens(val_batch["input_ids"])
        # Tokens -> String
        # decoded_pred = [self.tokenizer.decode(tokens, skip_special_tokens=True) for tokens in pred_tokens]
        decoded_pred = self.tokenizer.batch_decode(pred_tokens, skip_special_tokens=True)
        
        decoded_target = [self.tokenizer.decode(tokens[tokens!=-100], skip_special_tokens=True) for tokens in target]
        return {"val_loss_step": val_loss, "pred": decoded_pred, "target": decoded_target}

    def validation_epoch_end(self, outputs):
        val_loss = torch.stack([x['val_loss_step'] for x in outputs]).mean()
        trues = sum([list(x['target']) for x in outputs], [])
        preds = sum([list(x['pred']) for x in outputs], [])
        
        n = random.choices(range(len(trues)), k=2)
        for i in n:
          print(f"\nSample Target: {trues[i]}\nPrediction: {preds[i]}\n")
        
        f1 = []
        for true, pred in zip(trues, preds):
            f1.append(compute_f1(a_gold=true, a_pred=pred))
        f1_val = np.mean(f1)


        self.log("val_loss", val_loss, prog_bar=True)
        self.log("val_f1", f1_val, prog_bar=True)
  
    def test_step(self, test_batch, batch_idx):
        
        # input_values, labels = test_batch['input_values'], test_batch['labels']
        target = test_batch["labels"]
        test_loss = self.model(**test_batch).loss
        
        pred_tokens = self._greedy_generate_tokens(test_batch["input_ids"])
        decoded_pred = self.tokenizer.batch_decode(pred_tokens, skip_special_tokens=True)

        decoded_target = [self.tokenizer.decode(tokens[tokens!=-100], skip_special_tokens=True) for tokens in target]
        return {"test_loss_step": test_loss, "pred": decoded_pred, "target": decoded_target}

    def test_epoch_end(self, outputs):
        loss = torch.stack([x['test_loss_step'] for x in outputs]).mean()
        trues = sum([list(x['target']) for x in outputs], [])
        preds = sum([list(x['pred']) for x in outputs], [])

        n = random.choices(range(len(trues)), k=2)
        for i in n:
          print(f"\nSample Target: {trues[i]}\nPrediction: {preds[i]}\n")
        
        f1 = []
        for true, pred in zip(trues, preds):
            f1.append(compute_f1(a_gold=true, a_pred=pred))
        f1_test = np.mean(f1)

        self.log("test_loss", loss, prog_bar=True)
        self.log("test_f1", f1_test, prog_bar=True)

    def configure_optimizers(self):

        optimizer = torch.optim.Adam(self.parameters(),
                         lr=self.hparams["lr"],
                         weight_decay=self.hparams["w_decay"])
        
        scheduler = LinearWarmupCosineAnnealingLR(optimizer, 
                                                  eta_min=0, # final-lr
                                                  warmup_start_lr=self.hparams["lr"],
                                                  warmup_epochs=self.hparams["hold_epochs"], # hold_epochs
                                                  max_epochs=self.hparams["max_epochs"])
        
        return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor':'val_f1'}  

In [29]:
def evaluate_ptt5(batch, i, gen_type=None):
    batch['net_input']['source'] = batch['net_input']['source'].cuda()
    batch['net_input']['padding_mask'] = batch['net_input']['padding_mask'].cuda()

    labels = ctc_criterion.batch_decode(batch["target"])
    print('labels:', labels[0])

    preds_str = model.predict_step(batch, i)
    print('pred:', preds_str[0])
    
    inputs = tokenizer(preds_str,
                       padding=True,
                       return_tensors='pt')
    # print(inputs)
    inputs['input_ids'] = inputs['input_ids'].cuda()
    inputs['attention_mask'] = inputs['attention_mask'].cuda()
    with torch.no_grad():
        pred_strings = model_ptt5.predict_step(inputs, i, gen_type=gen_type)
    print('pred_strings:', pred_strings[0])
    print('-----------')
    return labels, pred_strings

In [30]:
from transformers import T5Tokenizer
from transformers import T5ForConditionalGeneration

In [31]:
tokenizer = T5Tokenizer.from_pretrained('unicamp-dl/ptt5-base-portuguese-vocab')
model_pt = T5ForConditionalGeneration.from_pretrained('unicamp-dl/ptt5-base-portuguese-vocab')

sentence = 'a garoa fria vai parar a unidade de ligação'
input_tokens = tokenizer(sentence)
print(input_tokens)

{'input_ids': [7, 6367, 43, 44, 10139, 1057, 20, 33, 7, 1589, 4, 2496, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}


In [32]:
PATH = "PTT5_LM/PTT5_LM/PTT5_noise_oldvocab_adafactor_test7-epoch=13-step=2071.ckpt"

model_ptt5 = PTT5Net.load_from_checkpoint(PATH, hparams=hparams).cuda().eval()

In [33]:
y_true = list()
y_pred = list()

for i, batch in enumerate(test_dataloader):
  
  labels, pred_strings = evaluate_ptt5(batch, i, gen_type=None) # greedy -> gen_type=None
  
  y_true.append(labels)
  y_pred.append(pred_strings)

y_true = np.concatenate(y_true)
y_pred = np.concatenate(y_pred)


labels: a indústria do entretenimento é enorme
pred: a indústria do entretenimento é enorme
pred_strings: a indústria do entretenimento é enorme
-----------
labels: um cão do exército sendo treinado por um soldado
pred: um cão do exerto sendo treinado por um soldado
pred_strings: um cão do deserto sendo treinado por um soldado
-----------
labels: o som está ficando menor
pred: o som está ficando menor
pred_strings: o som está ficando menor
-----------
labels: volte para assistir a caravana ele disse
pred: volte para assistir a caravana ele disse
pred_strings: volte para assistir a caravana ele disse
-----------
labels: ela é a primeira noiva da aldeia
pred: ela raprimentas loira da odela
pred_strings: ela rapidamente saiu da sala
-----------
labels: tem amigos que participaram da palestra de hoje
pred: tem amigos que participaram da palestra de hoje
pred_strings: tem amigos que participaram da palestra de hoje
-----------
labels: franquias de zoológicos são uma desgraça para o moviment

# WER with transformer LM (PTT5)

## Greedy decoding

In [34]:
print("WER: {:2f}".format(100 * wer.compute(predictions=y_pred, references=y_true)))

WER: 16.242273


In [35]:
del y_pred, y_true, test_dataloader 
torch.cuda.empty_cache()

## Beam-decoder

In [36]:
batch_size = 2

test_dataloader = DataLoader(test_dataset_process, batch_size=batch_size,
                              collate_fn = data_collator,
                              shuffle=True, num_workers=8)

In [37]:
y_true = list()
y_pred = list()

for i, batch in enumerate(test_dataloader):
  
  labels, pred_strings = evaluate_ptt5(batch, i, gen_type=0) # beam-search -> gen_type=0
  
  y_true.append(labels)
  y_pred.append(pred_strings)

y_true = np.concatenate(y_true)
y_pred = np.concatenate(y_pred)

labels: os netos também são bons filhos
pred: os netos tomem são bons filhos


  "Passing `max_length` to BeamSearchScorer is deprecated and has no effect."
To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  /pytorch/aten/src/ATen/native/BinaryOps.cpp:467.)
  return torch.floor_divide(self, other)


pred_strings: os netos também são bons filhos
-----------
labels: a furadeira sem fio da jade se tornou defeituosa
pred: a fura degra semfio da ejade se tornou de feituosa
pred_strings: a parede do subsolo da cidade se tornou deliciosa
-----------
labels: para poder atravessar décadas
pred: para poder atravessar decadas
pred_strings: para poder atravessar fronteiras
-----------
labels: um grupo de pessoas está olhando para a frente enquanto bebe copos de vinho
pred: um grupo de pessoas está olhando para a frente enquanto bebe copas de vinho
pred_strings: um grupo de pessoas está olhando para a frente enquanto bebe garrafas de vinho
-----------
labels: valgrind detectou vários vazamentos de memória
pred: algde detectou vários vazamentos memora
pred_strings: alguém descobriu vários vazamentos depois
-----------
labels: a condição está piorando
pred: a condição está piorando
pred_strings: a condição está piorando
-----------
labels: dentro das restrições de sua mente não havia como voltar

In [38]:
print("WER: {:2f}".format(100 * wer.compute(predictions=y_pred, references=y_true)))

WER: 16.222424


In [39]:
del y_pred, y_true, test_dataloader 
torch.cuda.empty_cache()

## Top-p and Top-K  decoding

In [37]:
batch_size = 2

test_dataloader = DataLoader(test_dataset_process, batch_size=batch_size,
                              collate_fn = data_collator,
                              shuffle=True, num_workers=8)

In [38]:
y_true = list()
y_pred = list()

for i, batch in enumerate(test_dataloader):
  
  labels, pred_strings = evaluate_ptt5(batch, i, gen_type=4) # Topp - TopK -> gen_type=4
  
  y_true.append(labels)
  y_pred.append(pred_strings)

y_true = np.concatenate(y_true)
y_pred = np.concatenate(y_pred)

labels: antes que morresse ele seria saciado
pred: antes que morresse ele seria sassiado
pred_strings: antes que morresse ele seria sacudido
-----------
labels: ainda depois daquele primeiro milhão antes dos trinta
pred: ainda depois daquele primeiro milhão antes dostrintar
pred_strings: ainda depois daquele primeiro milhão antes de explodir
-----------
labels: uma jovem mulher está de pé sob uma placa de rua
pred: uma jovem mulher está de pé sobe uma placa de rua
pred_strings: uma jovem mulher está de pé sobre uma placa de rua
-----------
labels: vibração tremulosa de uma nota
pred: vibração tremulosa de uma nota
pred_strings: vibração tremenda de uma nota
-----------
labels: acho melhor eu ir embora
pred: acho melhor filê embora
pred_strings: acho melhor você embora
-----------
labels: ele se irrita por besteira
pred: ele serrita por bespera
pred_strings: ele chorou por socorro
-----------
labels: meninas brincando na neve no capô de um carro
pred: meninas brincando na neve no capo d

In [39]:
print("WER: {:2f}".format(100 * wer.compute(predictions=y_pred, references=y_true)))

WER: 16.279136


In [40]:
del y_pred, y_true 
torch.cuda.empty_cache()