<a href="https://colab.research.google.com/github/anarlavrenov/n2/blob/main/inference_github.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Загрузка модели и параметров оптимизатора

# !gdown --no-check-certificate "https://drive.google.com/uc?export=download&id=10bx8VZ4LVJz1JnU2qklgdN0FsS9KM-3d" -O LibriSpeech_100_model.pth
# !gdown --no-check-certificate "https://drive.google.com/uc?export=download&id=10cfMP77QvQ8jl1_7OJCMKraXIQLpZW6h" -O LibriSpeech_100_optimizer.pth

import torch

import numpy as np
import pandas as pd

device = "cuda" if torch.cuda.is_available() else "cpu"

Mounted at /content/drive


In [None]:
# Загрузка тестового датасета LJSPEECH
import torchaudio

dataset = torchaudio.datasets.LJSPEECH(".", download=True)

100%|██████████| 2.56G/2.56G [02:34<00:00, 17.8MB/s]


In [None]:
# Создание словаря из символов
import torchtext; torchtext.disable_torchtext_deprecation_warning()
from torchtext.vocab import vocab
from collections import Counter

chars = [x for x in "abcdefghijklmnopqrstuvwxyz'?! "]

counter = Counter(chars) # Считает кол-во уникальных токенов в списке, возвращает словарь
vocab = vocab(counter)
unk_token = ""
vocab.insert_token(unk_token, 0)
vocab.set_default_index(vocab[unk_token])

In [None]:
# Препроцессинг аудио
import torchaudio

win_length = 256
hop_length = 160
n_fft = 384

def preprocess_audio(waveform, orig_sr):

  # Получение waveform и sample_rate
  waveform = torchaudio.functional.resample(waveform, orig_freq=orig_sr, new_freq=16000)
  waveform = torch.squeeze(waveform, dim=0)
  waveform = waveform.to(torch.float32)
  # Получение спектрограммы
  transforms = torchaudio.transforms.Spectrogram(
      win_length=win_length,
      hop_length=hop_length,
      n_fft=n_fft,
      power=None
  )
  spectrogram = transforms(waveform)
  # Перестановка на timeframes, n_mels
  spectrogram = torch.transpose(spectrogram, 1, 0)

  # Получение магнитуды
  spectrogram = torch.abs(spectrogram)
  spectrogram = torch.pow(spectrogram, 0.5)
  # Нормализация
  means = torch.mean(spectrogram, dim=1, keepdims=True)
  stddevs = torch.std(spectrogram, dim=1, keepdims=True)
  spectrogram = (spectrogram - means) / (stddevs + 1e-10)

  return spectrogram

In [None]:
def collate_fn(batch):

  wf, sr, text, *_ = zip(*batch)

  spectrograms = []
  tokens = []

  # Обработка спектрограммы
  for w in wf:
    spectrogram = preprocess_audio(w)
    spectrogram = torch.nn.functional.pad(
        spectrogram, (0, 0, 0, 2048-spectrogram.shape[0]), "constant", 0
    ) # (слева, справа, сверху, снизу)

    spectrograms.append(spectrogram)

  # Обработка текста
  for t in text:
    t = t.lower()
    t = [vocab[x] for x in t]
    t = torch.nn.functional.pad(
        torch.tensor(t), (0, 216-len(t)), "constant", 0
    )
    tokens.append(t)

  spectrograms = torch.stack(spectrograms, dim=0)
  tokens = torch.stack(tokens, dim=0)

  return spectrograms, tokens

In [None]:
# Модель
class Model(torch.nn.Module):
  def __init__(self, rnn_layers, rnn_units, output_dim):
    super(Model, self).__init__()

    self.conv1 = torch.nn.Conv2d(
        in_channels=1,
        out_channels=32,
        kernel_size=(11, 41),
        padding=(5, 20),
        stride=(2, 2),
        bias=False
        )

    self.conv2 = torch.nn.Conv2d(
      in_channels=32,
      out_channels=32,
      kernel_size=(11, 21),
      padding=(5, 10),
      stride=(1, 2),
      bias=False
    )

    self.conv3 = torch.nn.Conv2d(
      in_channels=32,
      out_channels=64,
      kernel_size=(11, 21),
      padding=(5, 10),
      stride=(1, 2),
      bias=False
    )

    self.lstm = torch.nn.LSTM(
        input_size=64 * 25,
        hidden_size=rnn_units,
        num_layers=rnn_layers,
        bidirectional=True,
        dropout=0.5,
        batch_first=True,
        bias=True
    )

    self.fc1 = torch.nn.Linear(
        in_features=rnn_units * 2,
        out_features=rnn_units * 2
    )

    self.fc2 = torch.nn.Linear(
        in_features=rnn_units * 2,
        out_features=output_dim + 1
    )

    self.bn1 = torch.nn.BatchNorm2d(num_features=32)
    self.bn2 = torch.nn.BatchNorm2d(num_features=32)
    self.bn3 = torch.nn.BatchNorm2d(num_features=64)

    self.dp = torch.nn.Dropout(p=0.5)
    self.relu = torch.nn.ReLU()


  def forward(self, src):
    src = src.unsqueeze(1)

    src = self.relu(self.bn1(self.conv1(src))) # [batch_size, filters, height, width]
    src = self.relu(self.bn2(self.conv2(src))) # [batch_size, filters, height, width]
    src = self.relu(self.bn3(self.conv3(src))) # [batch_size, filters, height, width]
    src = src.permute(0, 2, 1, 3)

    src = src.reshape(src.shape[0], src.shape[1], src.shape[2] * src.shape[3]) # [batch_size, height, filters * width]
    rnn_out, (ht, ct) = self.lstm(src) # [batch_size, height, rnn_units * 2]

    fc_out = self.fc1(rnn_out) # [batch_size, height, rnn_units * 2]
    fc_out = self.relu(fc_out)
    fc_out = self.dp(fc_out)
    out = self.fc2(fc_out) # [batch_size, height, output_dim]

    out = out.permute(1, 0, 2) # [height, batch_size, output_dim]

    out = torch.nn.functional.log_softmax(out, dim=2) # [batch_size, height, output_dim]

    return out

# Загрузка модели
model = torch.load("/content/model.pth").to(device)

In [None]:
class CTCGreedyDecoder(torch.nn.Module):
  def __init__(self, labels, blank=0):
    super(CTCGreedyDecoder, self).__init__()

    self.labels = labels
    self.blank = blank

  def forward(self, outputs):

    indices = torch.argmax(outputs, dim=-1)
    indices = torch.unique_consecutive(indices, dim=0)
    indices = [token for token in indices if token != self.blank]
    joined = "".join([self.labels[idx] for idx in indices])
    splt = joined.strip().split()

    return " ".join(splt)

greedy_decoder = CTCGreedyDecoder(labels=vocab.get_itos())

In [None]:
loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=64,
    shuffle=False,
    num_workers=2,
    drop_last=True,
    collate_fn=collate_fn
)

In [None]:
# Рассчет метрик WER & CER
!pip install jiwer -q

from jiwer import wer, cer
from tqdm import tqdm

model.eval()

wer_scores = []
cer_scores = []

with torch.no_grad():

  for batch in tqdm(loader, desc="Calculating WER & CER scores"):
    val_src, val_tgt = batch[0].to(device), batch[1].to(device)

    for idx in range(val_src.shape[0]):

      y_pred = model(val_src[idx].unsqueeze(0))

      y_pred = greedy_decoder(y_pred)
      y_true = "".join([vocab.get_itos()[i] for i in val_tgt[idx]])

      wer_scores.append(wer(y_pred, y_true))
      cer_scores.append(cer(y_pred, y_true))

wer_score = (sum(wer_scores) / len(wer_scores))
cer_score = (sum(cer_scores) / len(cer_scores))

print(f"\n\n wer score: {wer_score:.2f}, cer score: {cer_score:.2f}")

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/3.4 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/3.4 MB[0m [31m24.8 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m3.4/3.4 MB[0m [31m50.7 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/3.4 MB[0m [31m36.7 MB/s[0m eta [36m0:00:00[0m
[?25h

Calculating WER & CER scores: 100%|██████████| 204/204 [55:10<00:00, 16.23s/it]



 wer score: 0.37, cer score: 0.14





In [None]:
# Показ прогнозов
def show_result(idx):
  y_pred = model(val_src[idx].unsqueeze(0))
  y_pred = greedy_decoder(y_pred)
  y_true = "".join([vocab.get_itos()[i] for i in val_tgt[idx]])

  print(f"pred: {y_pred}")
  print(f"true: {y_true}")

samples = torch.randint(1, 64, (5, ))

for sample in samples:
  show_result(sample)
  print("*" * 100)

pred: studie's indicate that there is some utility ind attempting to desg nate certain buildings as in volving a higher risk of an others
true: the studies indicate that there is some utility in attempting to designate certain buildings as involving a higher risk than others
****************************************************************************************************
pred: ad coordination might be achieved to a greater extet than seems now to be contemplated without intefearence but the primary mession of ecagent se involved
true: that coordination might be achieved to a greater extent than seems now to be contemplated without interference with the primary mission of each agency involved
****************************************************************************************************
pred: at rickon instructions might come into the hands of local newspapers to the prejidice of the procautions described
true: that written instructions might come into the hands of local newspap