In [1]:
import sys
sys.path.insert(0,'..')

In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

import wandb
import numpy as np
from tqdm import tqdm
from pprint import pprint

from config import *
from data_processing import ukr_lang_chars_handle
from data_processing import UkrVoiceDataset
from model import EfConfRecognizer as Model
from model import get_cosine_schedule_with_warmup, OneCycleLR

import os


def train(model, train_dataloader, optimizer, device, scheduler=None, epoch=1, wb=None):
    print(f"Training begin")
    model.train()
    ctc_criterion = nn.CTCLoss(reduction="none", blank=ukr_lang_chars_handle.token_to_index["<blank>"], zero_infinity=True)
    running_loss = []
    losses_per_phase = []
    train_len = len(train_dataloader)

    for idx, (X, tgt) in tqdm(enumerate(train_dataloader)):
        tgt_text = tgt["text"]

        tgt_lengths = [len(txt) for txt in tgt_text]
        tgt_max_len = max(tgt_lengths)
        one_hots = ukr_lang_chars_handle.sentences_to_one_hots(tgt_text, tgt_max_len).to(device)
        one_hots = one_hots.squeeze(dim=1).permute(0, 2, 1).float()
        
        X = X.to(device) #
        X = X.squeeze(dim=1).permute(0, 2, 1)
        #print(f"{X.shape=}")
        #print(f"{one_hots.shape=}")

        emb, output = model(X, one_hots)  # (batch, time, n_class), (batch, time, n_class)
        output = F.log_softmax(output, dim=-1)
        
        output = output.permute(1, 0, 2).to(device).detach().requires_grad_()
        indeces = ukr_lang_chars_handle.sentences_to_indeces(tgt_text).to(device)
        

        """
        print("tgt_text:")
        pprint(tgt_text)
        print("tgt_len")
        pprint(tgt_lengths)
        print("indeces:")
        pprint(indeces)
        print(f"Inputs shape: {output.shape}")
        print(f"Tgt shape: {indeces.shape}")
        print(f"one_hots shape: {one_hots.shape}")
        """
        
        input_lengths = torch.full(size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long).to(device)
        target_lengths = torch.Tensor(tgt_lengths).long().to(device)        
        
        loss = ctc_criterion(output, indeces, input_lengths, target_lengths)
        #print(f"{loss=}")
        if wb:
            wb.log({
                "loss": loss.item(),
                "epoch": epoch
            })
        #print(f"{loss=}")
        loss.backward()
        optimizer.step()
        if scheduler:
            scheduler.step()

        running_loss.append(loss.cpu().detach().numpy())
        losses_per_phase.append(loss.cpu().detach().numpy())
        if (idx + 1) % 10 == 0:  # print every 200 mini-batches
            loss_mean = np.mean(np.array(losses_per_phase))
            print(f"Epoch: {epoch}, Last loss: {loss.item():.4f}, Loss phase mean: {loss_mean:.4f}")
            if wb:
                wb.log({"loss phase mean": loss_mean})
            losses_per_phase = []
        optimizer.zero_grad()
        
        

In [20]:


def val(model, train_dataloader, device, epoch, wb=None):
    model.eval()
    positive = 0
    train_len = train_dataloader.sampler.num_samples

    print("\n")
    print("Evaluation on train dataset")
    with torch.no_grad():
        for idx, (X, tgt) in tqdm(enumerate(train_dataloader)):
            tgt_text = " "#tgt["text"]
            tgt_class = torch.Tensor(tgt["label"]).long().to(device)
            tgt_class = F.one_hot(tgt_class, num_classes=5)
            one_hots = ukr_lang_chars_handle.sentences_to_one_hots(tgt_text, 152).to(device)
            one_hots = one_hots.squeeze(dim=1).permute(0, 2, 1).float()
            X = X.to(device)  #
            X = X.squeeze(dim=1).permute(0, 2, 1)
            emb, output = model(X, one_hots)
            A = torch.argmax(output, dim=-1)
            B = torch.argmax(tgt_class, dim=-1)
            is_right = (A == B)
            positive += torch.sum(is_right)

    train_accuracy = positive / train_len
    if wb:
        wb.log({
            "train accuracy": train_accuracy,
            "epoch": epoch
        })
    print(f"Accuracy on TRAIN dataset: {train_accuracy*100:.2f}%\n")


def get_scheduler(epochs, train_len, optimizer, scheduler_name="cosine_with_warmup", wb=None):
    if wb:
        wb.config["scheduler"] = scheduler_name
    if scheduler_name == "cosine_with_warmup":
        return get_cosine_schedule_with_warmup(optimizer,
                                                num_warmup_steps=epochs//5,
                                                num_training_steps=epochs - epochs//5,
                                                num_cycles=0.5)#1.25)
    elif scheduler_name == "constant":
        return torch.optim.lr_scheduler.ConstantLR(optimizer)
    elif scheduler_name == "exponential":
        return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1)
    elif scheduler_name == "one_circle":
        return OneCycleLR(optimizer,
                          max_lr=CONFIG["learning_rate"]*10,
                          total_steps=train_len)


def collate_fn(data):
    Xs, LBLs = zip(*data)
    Xs_out = pad_sequence([X.permute(0, 2, 1).squeeze(dim=0) for X in Xs], batch_first=True)
    lbl1 = LBLs[0]
    d_out = {}
    for key in lbl1.keys():
        d_out[key] = [d[key] for d in LBLs]
    return Xs_out, d_out


def main():
    wandb_stat = None#wandb.init(project="ASR", entity="Alex2135", config=CONFIG)
    device = torch.device("cuda") if not torch.cuda.is_available() else torch.device("cpu")

    # Making dataset and loader
    ds = UkrVoiceDataset(TRAIN_REC_PATH, TRAIN_REC_SPEC_PATH)
    train_dataloader = DataLoader(ds, shuffle=True, collate_fn=collate_fn, batch_size=CONFIG["batch_size"]["train"])
    train_val_dataloader = DataLoader(ds, shuffle=True, collate_fn=collate_fn, batch_size=CONFIG["batch_size"]["test"])

    epochs = CONFIG["epochs"]
    train_len = len(train_dataloader) * epochs

    tgt_n = 152
    d_model = 64
    model = Model(d_model=d_model, 
                  n_encoders=CONFIG["n_encoders"], 
                  n_decoders=CONFIG["n_decoders"], 
                  device=device)

    # Create optimizator
    optimizer = AdamW(model.parameters(), lr=CONFIG["learning_rate"])
    save_model = False
    scheduler = get_scheduler(CONFIG["epochs"], train_len, optimizer, scheduler_name="constant", wb=wandb_stat)

    for epoch in range(1, epochs + 1):
        print(f"Epoch №{epoch}")
        train(model, train_dataloader, optimizer, device, scheduler=scheduler, epoch=epoch, wb=wandb_stat)
        val(model, train_val_dataloader, device, epoch, wb=wandb_stat)
        scheduler.step(epoch)
        print(f"scheduler last_lr: {scheduler.get_last_lr()[0]}")
        if wandb_stat:
            wandb_stat.log({"scheduler lr": scheduler.get_last_lr()[0]})

    if save_model:
        PATH = os.path.join(DATA_DIR, "model_1.pt")
        print(f"Save model to path: '{PATH}'")
        torch.save(model.state_dict(), PATH)


main()

Epoch №1
Training begin


10it [00:01,  5.22it/s]

Epoch: 1, Last loss: 113.0288, Loss phase mean: 166.0477


21it [00:05,  3.89it/s]

Epoch: 1, Last loss: 129.7336, Loss phase mean: 87.8743


31it [00:06,  5.41it/s]

Epoch: 1, Last loss: 309.3916, Loss phase mean: 158.7272


41it [00:08,  5.52it/s]

Epoch: 1, Last loss: 224.2455, Loss phase mean: 146.5200


51it [00:10,  5.32it/s]

Epoch: 1, Last loss: 48.8244, Loss phase mean: 125.6323


61it [00:12,  6.53it/s]

Epoch: 1, Last loss: 0.0000, Loss phase mean: 112.5623


70it [00:13,  5.23it/s]

Epoch: 1, Last loss: 312.2733, Loss phase mean: 85.1010


80it [00:15,  5.36it/s]

Epoch: 1, Last loss: 0.0000, Loss phase mean: 148.5828


90it [00:17,  6.61it/s]

Epoch: 1, Last loss: 0.0000, Loss phase mean: 116.3164


101it [00:19,  6.00it/s]

Epoch: 1, Last loss: 254.6492, Loss phase mean: 101.4529


110it [00:20,  5.46it/s]

Epoch: 1, Last loss: 140.0344, Loss phase mean: 159.9665


120it [00:23,  4.32it/s]

Epoch: 1, Last loss: 114.7276, Loss phase mean: 156.6473


130it [00:26,  3.96it/s]

Epoch: 1, Last loss: 228.6043, Loss phase mean: 153.3013


141it [00:28,  4.91it/s]

Epoch: 1, Last loss: 278.9314, Loss phase mean: 130.8798


151it [00:30,  4.19it/s]

Epoch: 1, Last loss: 221.3641, Loss phase mean: 114.4657


160it [00:33,  3.81it/s]

Epoch: 1, Last loss: 49.5900, Loss phase mean: 107.6594


171it [00:34,  6.87it/s]

Epoch: 1, Last loss: 0.0000, Loss phase mean: 105.9473


180it [00:36,  6.23it/s]

Epoch: 1, Last loss: 0.0000, Loss phase mean: 137.2498


191it [00:38,  4.74it/s]

Epoch: 1, Last loss: 427.1979, Loss phase mean: 183.2187


201it [00:41,  4.18it/s]

Epoch: 1, Last loss: 241.9095, Loss phase mean: 124.7030


209it [00:42,  4.51it/s]

Epoch: 1, Last loss: 44.6705, Loss phase mean: 156.5662


219it [00:44,  5.65it/s]

Epoch: 1, Last loss: 18.0213, Loss phase mean: 80.5354


230it [00:47,  3.38it/s]

Epoch: 1, Last loss: 158.6105, Loss phase mean: 118.8133


241it [00:50,  5.36it/s]

Epoch: 1, Last loss: 101.4234, Loss phase mean: 153.4497


250it [00:52,  3.80it/s]

Epoch: 1, Last loss: 399.4970, Loss phase mean: 121.3272


260it [00:54,  5.06it/s]

Epoch: 1, Last loss: 103.9547, Loss phase mean: 95.8644


270it [00:56,  4.69it/s]

Epoch: 1, Last loss: 66.8784, Loss phase mean: 137.9778


281it [00:59,  5.14it/s]

Epoch: 1, Last loss: 85.3929, Loss phase mean: 134.0541


290it [01:01,  5.81it/s]

Epoch: 1, Last loss: 88.1944, Loss phase mean: 123.9275


300it [01:03,  4.90it/s]

Epoch: 1, Last loss: 0.0000, Loss phase mean: 116.7593


311it [01:05,  5.60it/s]

Epoch: 1, Last loss: 113.0333, Loss phase mean: 181.8001


311it [01:06,  4.71it/s]


KeyboardInterrupt: 

In [10]:
from torch.nn.utils.rnn import pad_sequence

a = torch.randn(25, 300)
b = torch.randn(22, 300)
c = torch.randn(15, 300)
pad_sequence((a, b, c)).shape

torch.Size([25, 3, 300])

In [None]:
d1 = {"a": 1, "b": 2}
d2 = {"a": 3, "b": 4}

ds = (d1, d2)
d_out = {}

for key in d1.keys():
    d_out[key] = [d[key] for d in ds]

print(d_out)

In [14]:
"""
Inputs shape: torch.Size([34, 4, 38])
Tgt shape: torch.Size([4, 36])
Input length shape: torch.Size([4])
Tgt length shape: torch.Size([4])
"""

T = 52      # Input sequence length
C = 38      # Number of classes (including blank)
N = 4      # Batch size
S = 36      # Target sequence length of longest target in batch (padding length)

ctc_loss = nn.CTCLoss()

In [16]:
input = torch.randn(T, N, C).requires_grad_()
# Initialize random batch of targets (0 = blank, 1:C = classes)
target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
target_lengths = torch.full(size=(N,), fill_value=S, dtype=torch.long)

print(f"Inputs shape: {input.shape}")
print(f"Tgt shape: {target.shape}")
print(f"Input length shape: {input_lengths.shape}")
print(f"Tgt length shape: {target_lengths.shape}")

loss = ctc_loss(input, target, input_lengths, target_lengths)
loss.backward()

Inputs shape: torch.Size([52, 4, 38])
Tgt shape: torch.Size([4, 36])
Input length shape: torch.Size([4])
Tgt length shape: torch.Size([4])


In [1]:
import torch
import torch.nn as nn

ce_criterion = nn.CrossEntropyLoss()

X = torch.Tensor([[0, 0, 1, 0, 0], [0, 0, 0, 0, 1]])
tgt = torch.Tensor([[0, 0, 1, 0, 0], [0, 1, 0, 0, 0]])

loss = ce_criterion(X, tgt)
loss

tensor(1.4048)

In [None]:
d_inputs = 768
d_model = 64

X = torch.randn([1, 768, 14])
conv1 = nn.Sequential(
    nn.Conv1d(d_inputs, d_model, kernel_size=7, stride=3),
    nn.ReLU())
conv2 = nn.Sequential(
    nn.Conv1d(256, d_model, kernel_size=7, stride=3),
    nn.ReLU()
)
out = conv1(X)
print(f"shape after conv1: {out.shape=}")
#out = conv2(out)
#print(f"shape after conv2: {out.shape=}")