In [1]:
%cd ../..

/Users/alexanderpolok/PycharmProjects/IS24_DeCRED


In [2]:
!pip install datasets
!pip install transformers
!pip install torch
!pip install numpy
!pip install tqdm



In [3]:
import datasets
from transformers import pipeline
import torch
import numpy as np
import tqdm
from torch.nn import CrossEntropyLoss


In [4]:

def compute_ppl(pipe, test_dataset, batch_size: int = 16, add_start_token: bool = True, device=None,
                max_length=None, use_encoder=False
                ):
    # if batch_size > 1 (which generally leads to padding being required), and
    # if there is not an already assigned pad_token, assign an existing
    # special token to also be the padding token
    if pipe.tokenizer.pad_token is None and batch_size > 1:
        existing_special_tokens = list(pipe.tokenizer.special_tokens_map_extended.values())
        # check that the model already has at least one special token defined
        assert (
                len(existing_special_tokens) > 0
        ), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1."
        # assign one of the special tokens to also be the pad token
        pipe.tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})

    if add_start_token and max_length:
        # leave room for <BOS> token to be added:
        assert (
                pipe.tokenizer.bos_token is not None
        ), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False"
        max_tokenized_len = max_length - 1
    else:
        max_tokenized_len = max_length

    encodings = pipe.tokenizer(
        test_dataset["text"],
        add_special_tokens=True,
        padding=True,
        truncation=True if max_tokenized_len else False,
        max_length=max_tokenized_len,
        return_tensors="pt",
        return_attention_mask=True,
    ).to(device)

    encoded_texts = encodings["input_ids"]
    attn_masks = encodings["attention_mask"]

    # check that each input is long enough:
    if add_start_token:
        assert torch.all(torch.ge(attn_masks.sum(1), 1)), "Each input text must be at least one token long."
    else:
        assert torch.all(
            torch.ge(attn_masks.sum(1), 2)
        ), "When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings."

    ppls = []
    loss_fct = CrossEntropyLoss(reduction="none")

    for start_index in tqdm.tqdm(range(0, len(encoded_texts), batch_size)):
        end_index = min(start_index + batch_size, len(encoded_texts))
        encoded_batch = encoded_texts[start_index:end_index]
        attn_mask = attn_masks[start_index:end_index]

        if use_encoder:
            features = pipe.feature_extractor(
                ([array["array"] for array in test_dataset["audio"][start_index:end_index]]), sampling_rate=16_000,
                return_tensors="pt", padding=True).to(device)
            features["input_values"] = features["input_features"]
            del features["input_features"]
            encoder_outputs = pipe.model.encoder(**features)
            encoder_outputs = encoder_outputs.last_hidden_state
        else:
            encoder_outputs = None

        if add_start_token:
            bos_tokens_tensor = torch.tensor([[pipe.tokenizer.bos_token_id]] * encoded_batch.size(dim=0)).to(device)
            encoded_batch = torch.cat([bos_tokens_tensor, encoded_batch], dim=1)
            attn_mask = torch.cat(
                [torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(device), attn_mask], dim=1
            )

        labels = encoded_batch

        with torch.no_grad():
            out_logits = pipe.model.decoder(encoded_batch, attention_mask=attn_mask,
                                            encoder_hidden_states=encoder_outputs).logits

        shift_logits = out_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        shift_attention_mask_batch = attn_mask[..., 1:].contiguous()

        perplexity_batch = torch.exp(
            (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1)
            / shift_attention_mask_batch.sum(1)
        )

        ppls += perplexity_batch.tolist()

    return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)}

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
test_datasets = datasets.load_from_disk("data/test_transcriptions")

/Users/alexanderpolok/PycharmProjects/IS24_DeCRED/data/test_transcriptions/librispeech_test.clean/data-00000-of-00001.arrow
done
/Users/alexanderpolok/PycharmProjects/IS24_DeCRED/data/test_transcriptions/fleurs_test/data-00000-of-00001.arrow
done
/Users/alexanderpolok/PycharmProjects/IS24_DeCRED/data/test_transcriptions/tedlium3_test/data-00000-of-00001.arrow
done
/Users/alexanderpolok/PycharmProjects/IS24_DeCRED/data/test_transcriptions/commonvoice_en_test/data-00000-of-00001.arrow
done
/Users/alexanderpolok/PycharmProjects/IS24_DeCRED/data/test_transcriptions/voxpopuli_test/data-00000-of-00001.arrow
done
/Users/alexanderpolok/PycharmProjects/IS24_DeCRED/data/test_transcriptions/librispeech_test.other/data-00000-of-00001.arrow
done
/Users/alexanderpolok/PycharmProjects/IS24_DeCRED/data/test_transcriptions/gigaspeech_test/data-00000-of-00001.arrow
done
/Users/alexanderpolok/PycharmProjects/IS24_DeCRED/data/test_transcriptions/ami_corpus_test/data-00000-of-00001.arrow
done


In [6]:
for model_id, use_encoder in [
    ("BUT-FIT/DeCRED-small", False),
    ("BUT-FIT/ED-small", False),
]:
    pipe = pipeline("automatic-speech-recognition", model=model_id, feature_extractor=model_id,
                    trust_remote_code=True, device=device)
    pipe.type = "seq2seq"

    if torch.cuda.is_available():
        pipe.model = pipe.model.to(device)

    tokenizer = pipe.tokenizer
    print(f"{model_id}_{use_encoder}")
    for test_set in test_datasets:
        test_data = test_datasets[test_set]
        ppl = compute_ppl(pipe=pipe, test_dataset=test_data, add_start_token=True, device=device,
                          use_encoder=use_encoder)
        print(test_set, round(ppl["mean_perplexity"],1))



BUT-FIT/DeCRED-small_False


100%|██████████| 164/164 [00:36<00:00,  4.49it/s]


librispeech_test.clean 129.1


100%|██████████| 41/41 [00:05<00:00,  7.05it/s]


fleurs_test 111.5


100%|██████████| 73/73 [00:18<00:00,  3.90it/s]


tedlium3_test 89.0


100%|██████████| 1024/1024 [01:24<00:00, 12.11it/s]


commonvoice_en_test 141.0


100%|██████████| 116/116 [00:38<00:00,  2.98it/s]


voxpopuli_test 101.4


100%|██████████| 184/184 [00:45<00:00,  4.01it/s]


librispeech_test.other 140.4


100%|██████████| 1584/1584 [08:26<00:00,  3.13it/s]


gigaspeech_test 66.3


100%|██████████| 789/789 [02:21<00:00,  5.57it/s]


ami_corpus_test 136.6
BUT-FIT/ED-small_False


100%|██████████| 164/164 [00:38<00:00,  4.25it/s]


librispeech_test.clean 206.1


100%|██████████| 41/41 [00:06<00:00,  6.19it/s]


fleurs_test 159.8


100%|██████████| 73/73 [00:18<00:00,  3.97it/s]


tedlium3_test 134.6


100%|██████████| 1024/1024 [01:29<00:00, 11.45it/s]


commonvoice_en_test 232.4


100%|██████████| 116/116 [00:41<00:00,  2.82it/s]


voxpopuli_test 142.7


100%|██████████| 184/184 [00:48<00:00,  3.78it/s]


librispeech_test.other 199.9


100%|██████████| 1584/1584 [08:55<00:00,  2.96it/s]


gigaspeech_test 84.1


100%|██████████| 789/789 [02:25<00:00,  5.44it/s]

ami_corpus_test 308.3



