In [1]:
import sys
sys.path.insert(0,'..')

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from model import Conformer as con
from data_processing import ukr_lang_chars_handle
from data_processing import CommonVoiceUkr
from model.conformer import Conformer as con

from config import *

In [2]:
tgt_n = 152
target = torch.randn(BATCH_SIZE, tgt_n) # (N, S) where N =batch size and S = max target length 

outputs = torch.randn(BATCH_SIZE, 1, 256, 38) # Tensor of size (T, N, C), where T = input length, N = batch size, and C = number of classes (including blank)
b, cnls, t, clss = outputs.shape
outputs = outputs.view(t*cnls, b, clss)


input_lengths = torch.full(size=(BATCH_SIZE,), fill_value=outputs.shape[0], dtype=torch.long)
target_lengths = torch.full(size=(BATCH_SIZE,), fill_value=target.shape[-1], dtype=torch.long)
print(input_lengths.shape)
print("input lengths:", input_lengths)

print(target_lengths.shape)
print("target lengths:", target_lengths)

ctc_loss = nn.CTCLoss(zero_infinity=False, reduction="none")
loss = ctc_loss(outputs, target, input_lengths, target_lengths)

torch.Size([8])
input lengths: tensor([256, 256, 256, 256, 256, 256, 256, 256])
torch.Size([8])
target lengths: tensor([152, 152, 152, 152, 152, 152, 152, 152])


In [3]:
print(BATCH_SIZE)
ds = CommonVoiceUkr(TRAIN_PATH, TRAIN_SPEC_PATH)
n_ds = len(ds)
n_ds = n_ds - n_ds % BATCH_SIZE

8


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from model import Conformer as con
from data_processing import ukr_lang_chars_handle
from data_processing import CommonVoiceUkr
from config import *
from torch.optim import RAdam
from tqdm import tqdm
import pprint
import numpy as np
import wandb
#wandb.init(project="ASR", entity="alex2135")

#wandb.config = CONFIG

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Making dataset and loader
ds = CommonVoiceUkr(TRAIN_PATH, TRAIN_SPEC_PATH, batch_size=BATCH_SIZE)
train_dataloader = DataLoader(ds, shuffle=True, batch_size=BATCH_SIZE)


tgt_n = 152
model = con(n_encoders=8, n_decoders=8, device=device)

# Create optimizator
optimizer = RAdam(model.parameters(),lr=CONFIG["learning_rate"])


# Create CTC criterion
criterion = nn.CTCLoss(blank=ukr_lang_chars_handle.token_to_index["<blank>"], zero_infinity=True)


running_loss = []
epochs = CONFIG["epochs"]
for epoch in range(1, epochs + 1):
    print(f"Epoch №{epoch}")
    for idx, (X, tgt) in tqdm(enumerate(train_dataloader)):
        optimizer.zero_grad()

        one_hots = ukr_lang_chars_handle.sentences_to_one_hots(tgt, 152).to(device)
        X = X.to(device)
        
        output = model(X, one_hots)  # (batch, _, time, n_class)
        b, cnls, t, clss = output.shape
        output = output.view(t * cnls, b, clss)  # (time, batch, n_class)
        output = F.log_softmax(output, dim=2)
        indeces = ukr_lang_chars_handle.sentences_to_indeces(tgt).to(device)
        
        input_lengths = torch.full(size=(BATCH_SIZE,), fill_value=t, dtype=torch.long).to(device)
        target_lengths = torch.full(size=(BATCH_SIZE,), fill_value=indeces.shape[-1], dtype=torch.long).to(device)
        loss = criterion(output, indeces.cpu(), input_lengths.cpu(), target_lengths.cpu())
        loss.backward()
        optimizer.step()

        running_loss.append(loss)
        #wandb.log({"loss": loss})

        if torch.isnan(loss) or torch.isinf(loss):
            print("Target label:", tgt)
            print("Running loss:")
            pprint.pprint(running_loss)
            print(output.shape)
            print("Is nan in output:", torch.sum(torch.isnan(output)))
            print("Is inf in output:", torch.sum(torch.isinf(output)))
            pprint.pprint(output)
            break
        if (idx + 1) % 50 == 0:  # print every 200 mini-batches
            running_loss = [t.cpu().detach().numpy() if type(t) is torch.Tensor else t for t in running_loss]
            running_loss = np.array(running_loss)
            print(f"Epoch: {epoch}, Last loss: {loss:.4f}, Loss mean: {np.mean(running_loss):.4f}")
            running_loss = list(running_loss)

Epoch №1


50it [01:47,  2.16s/it]

Epoch: 1, Last loss: 9.1475, Loss mean: 9.3085


100it [03:35,  2.15s/it]

Epoch: 1, Last loss: 7.1415, Loss mean: 9.0365


150it [05:22,  2.15s/it]

Epoch: 1, Last loss: 4.6850, Loss mean: 8.2542


186it [06:41,  2.16s/it]


KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

losses_list = [t.cpu().detach().numpy() if type(t) is torch.Tensor else t for t in running_loss ]
plt.figure(figsize=(12, 10))
plt.plot(losses_list)
plt.show()

In [None]:
"""
import matplotlib.pyplot as plt

wtout_zeros = np.array([t.cpu().detach().numpy() if type(t) is torch.Tensor else t for t in running_loss])
print(len(running_loss))
wtout_zeros = wtout_zeros[wtout_zeros != 0]
plt.figure(figsize=(12, 10))
plt.plot(wtout_zeros)
plt.show()
"""

In [None]:
"""
import os
PATH = os.path.join(DATA_DIR, "model_1.pt")
model = con(device=device)
model.load_state_dict(torch.load(PATH))
"""


In [None]:
running_loss[-1]

In [None]:
sent = ("Привіт",)
oh_sent = ukr_lang_chars_handle.sentences_to_one_hots(tgt, 152)
#print(oh_sent)

result = ukr_lang_chars_handle.one_hots_to_sentence(oh_sent)
#print(result)
indeces = ukr_lang_chars_handle.sentence_to_indeces(sent[0])
#print(indeces)

one_hots = F.one_hot(torch.Tensor(indeces).long(), num_classes=38)
#print(one_hots)


reproduced_sent = ukr_lang_chars_handle.onehot_matrix_to_idxs(one_hots)
#print(f"{reproduced_sent=}")

In [None]:
import torch
import torch.nn.functional as F
from data_processing import ukr_lang_chars_handle
from config import *
from model import Conformer as con
from data_processing import CommonVoiceUkr
from torch.utils.data import DataLoader
import pprint

device = "cpu"

PATH = os.path.join(DATA_DIR, "model_1.pt")
model = con(device=device)
model.load_state_dict(torch.load(PATH))

model.eval()
ds = CommonVoiceUkr(TRAIN_PATH, TRAIN_SPEC_PATH)
train_dataloader = DataLoader(ds, shuffle=True, batch_size=1)

with torch.no_grad():
    X, tgt = next(iter(train_dataloader))
    X = X.to(device)
    print("Target:", tgt)
    print("X shape:", X.shape)
    #tgt = ("",)

    tgt_one_hots = ukr_lang_chars_handle.sentences_to_one_hots(tgt, 152)
    print("tgt to one_hots shape:", tgt_one_hots.shape)
    print("tgt to one_hots:", ukr_lang_chars_handle.one_hots_to_sentences(tgt_one_hots))

    out_data = model(X, tgt_one_hots.to(device))
    out_data = F.softmax(out_data, dim=-1)
    out_data = out_data.cpu()
    print("\n\nOutput data shape:", out_data.shape)
    print("output:", out_data)
    out_data = out_data.transpose(-1, -2).contiguous()
    result = ukr_lang_chars_handle.one_hots_to_sentences(out_data)
    pprint.pprint(len(result))
    pprint.pprint(result)

In [None]:
import os
PATH = os.path.join(DATA_DIR, "model_1.pt")
print(PATH)
torch.save(model.state_dict(), PATH)
