In [24]:
# !pip install torch-audiomentations

In [60]:
import torch
T=50
C=20
target_lengths = torch.randint(low=1, high=50, size=(), dtype=torch.long)
target = torch.randint(low=1, high=20, size=(target_lengths,), dtype=torch.long)

In [62]:
target_lengths, target

(tensor(10), tensor([ 6,  9,  5,  5,  6, 15,  7,  3,  6,  1]))

In [39]:
from src.text_encoder import CTCTextEncoder
from src.datasets import LibrispeechDataset
from src.transforms.wav_augs import Gain
from src.datasets.collate import collate_fn
import torch
import torchaudio 

part="dev-clean"
instance_transforms= {
    'get_spectrogram': torchaudio.transforms.MelSpectrogram(sample_rate=16000),
    'audio': Gain()
}

text_encoder = CTCTextEncoder()
dataset = LibrispeechDataset(
    text_encoder=text_encoder,
    part= part,
    max_audio_length= 20.0,
    max_text_length= 200,
    limit= 10,
    instance_transforms= instance_transforms
)
# dataloaders, batch_transforms = get_dataloaders(config, text_encoder, device)

  >>> augment = Gain(..., output_type='dict')
  >>> augmented_samples = augment(samples).samples


In [40]:
dataloader = torch.utils.data.DataLoader(
    batch_size=10,
    num_workers=2,
    pin_memory=True,
    dataset=dataset,
    collate_fn=collate_fn,
    drop_last=True,
    shuffle=True,
)

In [41]:
for batch in dataloader:
    break

In [44]:
batch.keys()

dict_keys(['audio', 'spectrogram', 'text_encoded', 'text', 'audio_path', 'spectrogram_length', 'text_encoded_length'])

In [50]:
batch['text_encoded_length']

[176, 83, 161, 37, 65, 52, 136, 36, 147, 51]

In [None]:
from torch.nn import CTCLoss
criterion = CTCLoss()


In [56]:
torch.nn.CTCLoss?

[0;31mInit signature:[0m
[0mtorch[0m[0;34m.[0m[0mnn[0m[0;34m.[0m[0mCTCLoss[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mblank[0m[0;34m:[0m [0mint[0m [0;34m=[0m [0;36m0[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mreduction[0m[0;34m:[0m [0mstr[0m [0;34m=[0m [0;34m'mean'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mzero_infinity[0m[0;34m:[0m [0mbool[0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
The Connectionist Temporal Classification loss.

Calculates loss between a continuous (unsegmented) time series and a target sequence. CTCLoss sums over the
probability of possible alignments of input to target, producing a loss value which is differentiable
with respect to each input node. The alignment of input to target is assumed to be "many-to-one", which
limits the length of the target sequence such that it must be :math:`\leq` the input length.

Args:
    blank (int, opt

In [14]:
from src.utils.io_utils import ROOT_PATH
import json

data_dir = ROOT_PATH / "data" / "datasets" / "librispeech"
index_path = data_dir / f"{part}_index.json"
with index_path.open() as f:
    index = json.load(f)

In [15]:
import numpy as np

max_audio_length = 20.0
max_text_length = 300

initial_size = len(index)
if max_audio_length is not None:
    exceeds_audio_length = (
        np.array([el["audio_len"] for el in index]) >= max_audio_length
    )
    _total = exceeds_audio_length.sum()
    print(
        f"{_total} ({_total / initial_size:.1%}) records are longer then "
        f"{max_audio_length} seconds. Excluding them."
    )
else:
    exceeds_audio_length = False

initial_size = len(index)
if max_text_length is not None:
    exceeds_text_length = (
        np.array(
            [len(CTCTextEncoder.normalize_text(el["text"])) for el in index]
        )
        >= max_text_length
    )
    _total = exceeds_text_length.sum()
    print(
        f"{_total} ({_total / initial_size:.1%}) records are longer then "
        f"{max_text_length} characters. Excluding them."
    )
else:
    exceeds_text_length = False

records_to_filter = exceeds_text_length | exceeds_audio_length

if records_to_filter is not False and records_to_filter.any():
    _total = records_to_filter.sum()
    index = [el for el, exclude in zip(index, records_to_filter) if not exclude]
    print(
        f"Filtered {_total} ({_total / initial_size:.1%}) records  from dataset"
    )

61 (2.3%) records are longer then 20.0 seconds. Excluding them.
59 (2.2%) records are longer then 300 characters. Excluding them.
Filtered 75 (2.8%) records  from dataset


In [16]:
import random

shuffle_index = True
limit = 10

if shuffle_index:
    random.seed(42)
    random.shuffle(index)

if limit is not None:
    index_ = index[:limit]
# index_

In [18]:
def load_audio(target_sr, path):
    audio_tensor, sr = torchaudio.load(path)
    audio_tensor = audio_tensor[0:1, :]  # remove all channels but the first
    target_sr = target_sr
    if sr != target_sr:
        audio_tensor = torchaudio.functional.resample(audio_tensor, sr, target_sr)
    return audio_tensor

In [22]:
def preprocess_data(instance_transforms, instance_data):
    if instance_transforms is not None:
        for transform_name in instance_transforms.keys():
            if transform_name == "get_spectrogram":
                continue  # skip special key
            instance_data[transform_name] = instance_transforms[
                transform_name
            ](instance_data[transform_name])
    return instance_data

In [23]:
def getitem_(ind):
    target_sr=16000
    data_dict = index_[ind]
    audio_path = data_dict["path"]
    audio = load_audio(target_sr=target_sr, path=audio_path)
    text = data_dict["text"]
    text_encoded = text_encoder.encode(text)

    spectrogram = instance_transforms["get_spectrogram"](audio)

    instance_data = {
        "audio": audio,
        "spectrogram": spectrogram,
        "text": text,
        "text_encoded": text_encoded,
        "audio_path": audio_path,
    }


    instance_data_ = preprocess_data(instance_transforms, instance_data)
    return instance_data_

In [24]:
audio_shapes, spectrogram_shapes, text_encoded_shapes = [], [], []
for ind in range(len(index_)):
    audio_shapes.append(getitem_(ind)['audio'].shape)
    spectrogram_shapes.append(getitem_(ind)['spectrogram'].shape)
    text_encoded_shapes.append(getitem_(ind)['text_encoded'].shape)

print(audio_shapes, spectrogram_shapes, text_encoded_shapes, sep='\n')


[torch.Size([1, 55920]), torch.Size([1, 120960]), torch.Size([1, 47360]), torch.Size([1, 242400]), torch.Size([1, 131600]), torch.Size([1, 158400]), torch.Size([1, 192640]), torch.Size([1, 108160]), torch.Size([1, 125520]), torch.Size([1, 75840])]
[torch.Size([1, 128, 280]), torch.Size([1, 128, 605]), torch.Size([1, 128, 237]), torch.Size([1, 128, 1213]), torch.Size([1, 128, 659]), torch.Size([1, 128, 793]), torch.Size([1, 128, 964]), torch.Size([1, 128, 541]), torch.Size([1, 128, 628]), torch.Size([1, 128, 380])]
[torch.Size([1, 38]), torch.Size([1, 97]), torch.Size([1, 53]), torch.Size([1, 222]), torch.Size([1, 100]), torch.Size([1, 169]), torch.Size([1, 160]), torch.Size([1, 65]), torch.Size([1, 131]), torch.Size([1, 55])]


In [25]:
dataset_items = []
for ind in range(len(index_)):
    dataset_items.append(getitem_(ind))
# dataset_items[0]

In [26]:
from torch.nn.utils.rnn import pad_sequence

# For BaseModel
audios = [item['audio'].squeeze() for item in dataset_items]
spectrograms = [item['spectrogram'].squeeze().transpose(0,1) for item in dataset_items]
texts = [item['text'] for item in dataset_items]
text_encoded = [item['text_encoded'].squeeze() for item in dataset_items]
audio_paths = [item['audio_path'] for item in dataset_items]

# Pad audios, spectrograms and text_encoded sequences
padded_audios = pad_sequence(audios, batch_first=True, padding_value=0)
padded_spectrograms = pad_sequence(spectrograms, batch_first=True, padding_value=0).transpose(1,2)
padded_text_encoded = pad_sequence(text_encoded, batch_first=True, padding_value=0)

# Create the result batch dictionary
result_batch = {
    'audio': padded_audios,
    'spectrogram': padded_spectrograms,
    'text_encoded': padded_text_encoded,
    'text': texts,
    'audio_path': audio_paths
}