In [69]:
import warnings

import hydra
import torch
from hydra.utils import instantiate
from omegaconf import OmegaConf

from src.datasets.data_utils import get_dataloaders
from src.trainer import Trainer
from src.utils.init_utils import set_random_seed, setup_saving_and_logging


warnings.filterwarnings("ignore", category=UserWarning)

In [70]:
from src.text_encoder import CTCTextEncoder
from src.datasets import LibrispeechDataset
from src.transforms.wav_augs import Gain
from src.datasets.collate import collate_fn
import torch
import torchaudio 

part="dev-clean"
instance_transforms= {
    'get_spectrogram': torchaudio.transforms.MelSpectrogram(sample_rate=16000),
    'audio': Gain()
}

text_encoder = CTCTextEncoder()
dataset = LibrispeechDataset(
    text_encoder=text_encoder,
    part= part,
    max_audio_length= 20.0,
    max_text_length= 200,
    limit= 10,
    instance_transforms= instance_transforms
)
# dataloaders, batch_transforms = get_dataloaders(config, text_encoder, device)

  >>> augment = Gain(..., output_type='dict')
  >>> augmented_samples = augment(samples).samples


In [75]:
dataloader = torch.utils.data.DataLoader(
    batch_size=2,
    num_workers=2,
    pin_memory=True,
    dataset=dataset,
    collate_fn=collate_fn,
    drop_last=True,
    shuffle=True,
)

In [76]:
for batch in dataloader:
    break

In [78]:
batch

{'audio': tensor([[-0.0007, -0.0077, -0.0030,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0004, -0.0002,  0.0012,  ..., -0.0003, -0.0004, -0.0004]]),
 'spectrogram': tensor([[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [1.9926e-04, 3.0717e-03, 1.4854e-03,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [1.0729e-03, 1.6539e-02, 7.9980e-03,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [7.9302e-05, 3.6578e-05, 5.7361e-05,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [1.0455e-04, 6.7070e-05, 3.9266e-05,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [1.8559e-04, 2.9940e-05, 6.8682e-05,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00]],
 
         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [7.7752e-04, 9.0164e-03, 4.0364e-03,  ..., 4.7646e-02,
           5.4324e-02, 9.7332e-03],
      

In [79]:
batch['text_encoded_length']

tensor([ 51, 176], dtype=torch.int32)

In [80]:
from torch.nn import CTCLoss
criterion = CTCLoss()


In [81]:
from src.utils.io_utils import ROOT_PATH
import json

data_dir = ROOT_PATH / "data" / "datasets" / "librispeech"
index_path = data_dir / f"{part}_index.json"
with index_path.open() as f:
    index = json.load(f)

In [82]:
import numpy as np

max_audio_length = 20.0
max_text_length = 300

initial_size = len(index)
if max_audio_length is not None:
    exceeds_audio_length = (
        np.array([el["audio_len"] for el in index]) >= max_audio_length
    )
    _total = exceeds_audio_length.sum()
    print(
        f"{_total} ({_total / initial_size:.1%}) records are longer then "
        f"{max_audio_length} seconds. Excluding them."
    )
else:
    exceeds_audio_length = False

initial_size = len(index)
if max_text_length is not None:
    exceeds_text_length = (
        np.array(
            [len(CTCTextEncoder.normalize_text(el["text"])) for el in index]
        )
        >= max_text_length
    )
    _total = exceeds_text_length.sum()
    print(
        f"{_total} ({_total / initial_size:.1%}) records are longer then "
        f"{max_text_length} characters. Excluding them."
    )
else:
    exceeds_text_length = False

records_to_filter = exceeds_text_length | exceeds_audio_length

if records_to_filter is not False and records_to_filter.any():
    _total = records_to_filter.sum()
    index = [el for el, exclude in zip(index, records_to_filter) if not exclude]
    print(
        f"Filtered {_total} ({_total / initial_size:.1%}) records  from dataset"
    )

61 (2.3%) records are longer then 20.0 seconds. Excluding them.
59 (2.2%) records are longer then 300 characters. Excluding them.
Filtered 75 (2.8%) records  from dataset


In [83]:
import random

shuffle_index = True
limit = 10

if shuffle_index:
    random.seed(42)
    random.shuffle(index)

if limit is not None:
    index_ = index[:limit]
# index_

In [84]:
def load_audio(target_sr, path):
    audio_tensor, sr = torchaudio.load(path)
    audio_tensor = audio_tensor[0:1, :]  # remove all channels but the first
    target_sr = target_sr
    if sr != target_sr:
        audio_tensor = torchaudio.functional.resample(audio_tensor, sr, target_sr)
    return audio_tensor

In [85]:
def preprocess_data(instance_transforms, instance_data):
    if instance_transforms is not None:
        for transform_name in instance_transforms.keys():
            if transform_name == "get_spectrogram":
                continue  # skip special key
            instance_data[transform_name] = instance_transforms[
                transform_name
            ](instance_data[transform_name])
    return instance_data

In [86]:
def getitem_(ind):
    target_sr=16000
    data_dict = index_[ind]
    audio_path = data_dict["path"]
    audio = load_audio(target_sr=target_sr, path=audio_path)
    text = data_dict["text"]
    text_encoded = text_encoder.encode(text)

    spectrogram = instance_transforms["get_spectrogram"](audio)

    instance_data = {
        "audio": audio,
        "spectrogram": spectrogram,
        "text": text,
        "text_encoded": text_encoded,
        "audio_path": audio_path,
    }


    instance_data_ = preprocess_data(instance_transforms, instance_data)
    return instance_data_

In [87]:
audio_shapes, spectrogram_shapes, text_encoded_shapes = [], [], []
for ind in range(len(index_)):
    audio_shapes.append(getitem_(ind)['audio'].shape)
    spectrogram_shapes.append(getitem_(ind)['spectrogram'].shape)
    text_encoded_shapes.append(getitem_(ind)['text_encoded'].shape)

print(audio_shapes, spectrogram_shapes, text_encoded_shapes, sep='\n')


[torch.Size([1, 55920]), torch.Size([1, 120960]), torch.Size([1, 47360]), torch.Size([1, 242400]), torch.Size([1, 131600]), torch.Size([1, 158400]), torch.Size([1, 192640]), torch.Size([1, 108160]), torch.Size([1, 125520]), torch.Size([1, 75840])]
[torch.Size([1, 128, 280]), torch.Size([1, 128, 605]), torch.Size([1, 128, 237]), torch.Size([1, 128, 1213]), torch.Size([1, 128, 659]), torch.Size([1, 128, 793]), torch.Size([1, 128, 964]), torch.Size([1, 128, 541]), torch.Size([1, 128, 628]), torch.Size([1, 128, 380])]
[torch.Size([1, 38]), torch.Size([1, 97]), torch.Size([1, 53]), torch.Size([1, 222]), torch.Size([1, 100]), torch.Size([1, 169]), torch.Size([1, 160]), torch.Size([1, 65]), torch.Size([1, 131]), torch.Size([1, 55])]


In [89]:
dataset_items = []
for ind in range(len(index_)):
    dataset_items.append(getitem_(ind))
# dataset_items[0]

In [90]:
from torch.nn.utils.rnn import pad_sequence

# For BaseModel
audios = [item['audio'].squeeze() for item in dataset_items]
spectrograms = [item['spectrogram'].squeeze().transpose(0,1) for item in dataset_items]
texts = [item['text'] for item in dataset_items]
text_encoded = [item['text_encoded'].squeeze() for item in dataset_items]
audio_paths = [item['audio_path'] for item in dataset_items]

# Pad audios, spectrograms and text_encoded sequences
padded_audios = pad_sequence(audios, batch_first=True, padding_value=0)
padded_spectrograms = pad_sequence(spectrograms, batch_first=True, padding_value=0).transpose(1,2)
padded_text_encoded = pad_sequence(text_encoded, batch_first=True, padding_value=0)

# Create the result batch dictionary
result_batch = {
    'audio': padded_audios,
    'spectrogram': padded_spectrograms,
    'text_encoded': padded_text_encoded,
    'text': texts,
    'audio_path': audio_paths
}

In [173]:
from src.model.conformer import Conformer, ConformerEncoder, ConformerBlock, LSTMDecoder, \
                                Conv2dSubsampling, FeedForwardBlock, ConvBlock, RelativeMultiHeadAttention, PositionalEncoder
import math
import torch
from torch import nn
import torch.nn.functional as F

d_input = 956
d_model = 144
encoder_num_layers = 16
num_heads = 4
kernel_size = 31
dropout = 0.1
feed_forward_residual_factor = 0.5
feed_forward_expansion_factor = 4
d_decoder = 320
decoder_num_layers = 1

In [174]:
conv_subsample = Conv2dSubsampling(d_model=d_model)
linear_proj = nn.Linear(d_model * ((((d_input - 3)//2 + 1) - 3)//2 + 1), d_model)  # project subsamples to d_model
dropout_layer = nn.Dropout(p=dropout)

# define global positional encoder to limit model parameters
positional_encoder = PositionalEncoder(d_model)
layers = nn.ModuleList([ConformerBlock(
        d_model=d_model,
        kernel_size=kernel_size,
        feed_forward_residual_factor=feed_forward_residual_factor,
        feed_forward_expansion_factor=feed_forward_expansion_factor,
        num_heads=num_heads,
        positional_encoder=positional_encoder,
        dropout=dropout
    ) for _ in range(encoder_num_layers)])

def forward(x, mask=None):
    x = conv_subsample(x)
    if mask is not None:
        mask = mask[:, :-2:2, :-2:2]  # account for subsampling
        mask = mask[:, :-2:2, :-2:2]  # account for subsampling
        assert mask.shape[1] == x.shape[1], f'{mask.shape} {x.shape}'
        
    x = linear_proj(x)
    x = dropout_layer(x)

    for layer in layers:
        x = layer(x, mask=mask)
        return x

In [183]:
x = batch['spectrogram']
print(x.shape)
B, D, T = x.shape

torch.Size([2, 128, 956])


In [184]:
x = conv_subsample(x)
print(x.shape)

# (B, (((D - 3)//2 + 1) - 3)//2 + 1, d_model * ((((d_input - 3)//2 + 1) - 3)//2 + 1))

torch.Size([2, 31, 34272])


In [185]:
x = linear_proj(x)
x = dropout_layer(x)
print(x.shape)

torch.Size([2, 31, 144])


In [186]:
for layer in layers:
    x = layer(x, mask=None)

In [187]:
x.shape # (B, (((D - 3)//2 + 1) - 3)//2 + 1, d_model)

torch.Size([2, 31, 144])

In [188]:
lstm = nn.LSTM(input_size=d_model, hidden_size=d_decoder, num_layers=decoder_num_layers, batch_first=True)
linear = nn.Linear(d_decoder, 28)

def forward(self, x):
    x, _ = self.lstm(x)
    logits = self.linear(x)
    return logits

In [191]:
x_lstm, _ = lstm(x)
x_lstm.shape

torch.Size([2, 31, 320])

In [192]:
logits = linear(x_lstm)

In [193]:
logits.shape

torch.Size([2, 31, 28])

In [194]:
batch['spectrogram_length']

tensor([309, 956], dtype=torch.int32)

In [195]:
log_probs = nn.functional.log_softmax(logits, dim=-1)
log_probs_length = batch['spectrogram_length']
{"log_probs": log_probs, "log_probs_length": log_probs_length}

{'log_probs': tensor([[[-3.3323, -3.3108, -3.3361,  ..., -3.3956, -3.2959, -3.2685],
          [-3.3289, -3.3358, -3.2979,  ..., -3.4059, -3.3126, -3.2541],
          [-3.3644, -3.4266, -3.2446,  ..., -3.3554, -3.3281, -3.1901],
          ...,
          [-3.3651, -3.1910, -3.2827,  ..., -3.3874, -3.3250, -3.2776],
          [-3.4006, -3.2088, -3.2813,  ..., -3.4104, -3.2806, -3.2939],
          [-3.4238, -3.2461, -3.2897,  ..., -3.3975, -3.2951, -3.2552]],
 
         [[-3.4573, -3.3119, -3.3816,  ..., -3.3381, -3.3896, -3.2692],
          [-3.4425, -3.3378, -3.4109,  ..., -3.3623, -3.3256, -3.3057],
          [-3.4132, -3.2908, -3.4506,  ..., -3.3858, -3.3482, -3.3471],
          ...,
          [-3.3580, -3.2561, -3.3572,  ..., -3.3753, -3.3795, -3.2994],
          [-3.3785, -3.2847, -3.3927,  ..., -3.4138, -3.3738, -3.3354],
          [-3.3365, -3.2762, -3.3231,  ..., -3.3392, -3.3141, -3.2813]]],
        grad_fn=<LogSoftmaxBackward0>),
 'log_probs_length': tensor([309, 956], dtype=to

In [196]:
log_probs.shape

torch.Size([2, 31, 28])