In [1]:
data = '/home/boris/Projects/Voice_Assistant_for_Voice_Anomaly_Persons/Multi-lingual Phoneme Recognition/data/raw_kaggle/Speeches.xlsx'

In [22]:
import math

import jiwer
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader

from utils.dataset import SpeechDataset, collate_spl
from utils.transforms import XLSRTransformer


In [3]:
xls_r_transformer = XLSRTransformer()

Some weights of the model checkpoint at /home/boris/Projects/Voice_Assistant_for_Voice_Anomaly_Persons/Multi-lingual Phoneme Recognition/models/phonemizer were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_v', 'wav2vec2.encoder.pos_conv_embed.conv.weight_g']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at /home/boris/Projects/Voice_Assistant_for_Voice_Anomaly_Persons/Multi-lingual Phoneme Recognition/models/phonemizer and are newly initialized: ['wav2vec2.e

# Data loading

In [4]:
BATCH_SIZE = 2

In [5]:
df = pd.read_excel(data)
df.head()

Unnamed: 0,Число,Русская речь
0,1,Как пройти до корпуса?
1,2,Где взять направление?
2,3,Бумага есть
3,4,анальгин
4,5,вата


In [36]:
train_dataset = SpeechDataset(df, 0, 1500)
val_dataset = SpeechDataset(df, 1500, 1700)
test_dataset = SpeechDataset(df, 1600, 2000)

In [37]:
train_dataloader = DataLoader(train_dataset, BATCH_SIZE, True, collate_fn=collate_spl)
val_dataloader = DataLoader(val_dataset, BATCH_SIZE, True, collate_fn=collate_spl)

In [38]:
abc = '?абвгдеёжзийклмнопрстуфхшщчцьыъэюя'

def show_batch(yam, epoch, bn, labels):
    if yam.sum()<1: return
    labels = labels.argmax(-1)
    etalon_s=''
    for i in range(labels.shape[0]):
        for j in range(labels.shape[1]):
            if (ci := labels[i,j])>0:
                etalon_s+=abc[ci]
        etalon_s+='|'
    s=''
    for i in yam:
        f=False
        for l in i:
            if li:=l.item():
                f=True
                s+=abc[li]
        if f:
            s+='|'
    c=-1
    if len(s)>0:
        c=jiwer.cer(etalon_s, s)
        print(f'{epoch}.{bn}:', s, etalon_s, 'cer: ', c)
    return c

In [39]:
from IPython.display import clear_output, display

In [40]:
def get_y_lengths(y):
    return torch.count_nonzero(y, -1)

def padded_stack(list_of_tensors, maxlen=16):
    # print([i.shape for i in list_of_tensors])
    maxlen = max(x.shape[0] for x in list_of_tensors)
    output = torch.zeros((len(list_of_tensors), maxlen, list_of_tensors[0].shape[-1]))
    # print(output.shape)
    for i, t in enumerate(list_of_tensors):
        output[i, :min(maxlen, t.shape[0]), :] = t[:maxlen, :]
    return output

In [41]:
from tqdm.notebook import tqdm

In [42]:
class SimplePerceptron(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn = nn.RNN(392,34)
        self.act = nn.LogSoftmax(-1)

    def forward(self, X):
        X, _ =self.rnn(X)
        X = self.act(X)
        return X

perceptron = SimplePerceptron()

In [43]:
loss = nn.CTCLoss(
    zero_infinity=True,
    blank=0, )
optimizer = torch.optim.AdamW(perceptron.parameters(), lr=0.0001,# weight_decay=0.1
                              )
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')

In [44]:
for epoch in (t1:=tqdm(range(100))):
    elv=0.0
    elc=0
    blv = 0.0
    blc = 0
    minlv=float("inf")
    cer = 0
    cer_cnt = 0
    for bn, ((wf_batch, wf_lengths), (labels, labels_lengths)) in (t2:=tqdm(enumerate(train_dataloader), total=len(train_dataloader), leave=False)):
        # clear_output()
        #display("ji")
        try:
            tensor_list=[xls_r_transformer(wf_batch[i], label=labels[i]) for i in range(wf_batch.shape[0])]
            phonemes = padded_stack(tensor_list)
        except Exception as e:
            raise e

        # phonemes: batch  x len x classes
        # print(phonemes.shape)
        y = perceptron(phonemes)
        
        yam = y[:,:].argmax(-1)
        yl = get_y_lengths(yam)
        yam = y[:,:].argmax(-1)

        # print(y.shape, labels.shape)

        lv = loss(
            y.permute(1,0,2),
            labels.argmax(-1),
            yl,
            labels_lengths
        )
        # plotlosses.update({'loss': lv.item()})
        # plotlosses.send()
        # print(lv.item())
        if not (math.isnan(lv.item()) or math.isinf(lv.item())):
            blv+=lv.exp().item()
        blc+=BATCH_SIZE
        # writer.add_scalar('loss', lv.item(), step)
        t2.set_postfix_str(f'loss: {blv/blc}/{elv/elc if elc>0 else "inf"}')
        # clear_output()

        lv.backward()
        if bn%10==0:
            cer += show_batch(yam, epoch, bn, labels)
            cer_cnt += BATCH_SIZE
            elv+=blv
            elc+=blc
            optimizer.step()
            while (lv1 := loss(
            y.permute(1,0,2),
            labels.argmax(-1),
            yl,
            labels_lengths
        ).item())<lv.item() and not (math.isnan(lv.item()) or math.isnan(lv1)):
                print('>', end='')
                optimizer.step()
            optimizer.zero_grad()
            blv=0
            blc=0
        # scheduler.step()
    t1.set_postfix_str(f'loss: {elv/elc} cer: {cer/cer_cnt}')


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/750 [00:00<?, ?it/s]

0.0: ёрётёбёввтёёрёввётёрхо|ёвбётёёбббёрврвбёеееее| вкакихморях|римскийогоньпотух| cer:  1.4
0.10: ввёвровввввёрррхффффффффффффффф|хрввррвёётётовёёхёрврвтвюёроёрр| способсказать|космоснеперестаетнасудивлять| cer:  1.3488372093023255
0.20: вёвтвёвввёёоёоротрррвоёрвхрвррввё|рвхвбхвёрттвввввёввёворрёёеееееее| книгасопровождающаяваснавсюжизнь|шатенкиибрюнеткипохожи| cer:  1.0714285714285714
0.30: вёввтётррёхвоёёр|ввввввттоёввёвех| разработкаэвм|расширения| cer:  1.2
0.40: ворввтвтввёохоорвффффффф|ровввввтвтвёвхврвёввввтв| явзялсебявруки|естьялюблювязаниеисамбо| cer:  1.1538461538461537
0.50: вввртохвввфффффффффффффффф|врвтрвёрврртррвтотввхрввёх| наборданных|видеонаютубе| cer:  1.96
0.60: втотоотёвовёёрвврхо|воорёёвхёрврввеееее| повторитепожалуйста|первоеправило| cer:  0.9117647058823529
0.70: вооррёврввёвёвёёцётёёхрвррвввворвх|вттвтрёвёрхрхееееееееееецееееееее| рыбалкавеселеееслистобойестькот|красивыйцветок| cer:  1.297872340425532
0.80: рёвёюёроввёвтрров|твоовввотооввввхв| безмятежнаягава

KeyboardInterrupt: 

In [None]:
torch.save(perceptron, '240104.torch')

In [None]:
yam.sum()>-0.1

tensor(True)