In [1]:
import sys
sys.path.insert(0,'..')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

import wandb
import numpy as np
from tqdm import tqdm
from pprint import pprint

from config import *
from data_processing import ukr_lang_chars_handle
from data_processing import CommonVoiceUkr
from model import EfConfRecognizer as Model
from model import get_cosine_schedule_with_warmup, OneCycleLR

import os


def train(model, train_dataloader, optimizer, device, scheduler=None, epoch=1, wb=None):
    print(f"Training begin")
    model.train()
    ctc_criterion = nn.CTCLoss(reduction="none", blank=ukr_lang_chars_handle.token_to_index["<blank>"])#, zero_infinity=True)
    running_loss = []
    losses_per_phase = []
    train_len = len(train_dataloader)

    for idx, (X, tgt) in tqdm(enumerate(train_dataloader)):
        tgt_text = tgt["text"]
        tgt_class = torch.Tensor(tgt["label"]).long().to(device)
        tgt_class = F.one_hot(tgt_class, num_classes=5)

        tgt_lengths = [len(txt) for txt in tgt_text]
        tgt_max_len = max(tgt_lengths)
        one_hots = ukr_lang_chars_handle.sentences_to_one_hots(tgt_text, tgt_max_len).to(device)
        one_hots = one_hots.squeeze(dim=1).permute(0, 2, 1).float()
        
        X = X.to(device) #
        X = X.squeeze(dim=1).permute(0, 2, 1)
        #print(f"{X.shape=}")
        #print(f"{one_hots.shape=}")

        emb, output = model(X, one_hots)  # (batch, time, n_class), (batch, time, n_class)
        
        output = output.permute(1, 0, 2).to(device).detach().requires_grad_()
        indeces = ukr_lang_chars_handle.sentences_to_indeces(tgt_text).to(device)
        
        """
        print("tgt_text:")
        pprint(tgt_text)
        print("indeces:")
        pprint(indeces)
        print(f"Inputs shape: {output.shape}")
        print(f"Tgt shape: {indeces.shape}")
        print(f"one_hots shape: {one_hots.shape}")
        """
        
        input_lengths = torch.full(size=(output.shape[1],), fill_value=output.shape[-2], dtype=torch.long).to(device)
        target_lengths = torch.full(size=(output.shape[1],), fill_value=tgt_max_len, dtype=torch.long).to(device)        
        
        loss = ctc_criterion(output.cpu(), indeces, input_lengths, target_lengths)
        if wb:
            wb.log({
                "loss": loss.item(),
                "epoch": epoch
            })
        #print(f"{loss=}")
        loss.mean().backward()
        optimizer.step()
        if scheduler:
            scheduler.step()

        running_loss.append(loss.cpu().detach().numpy())
        losses_per_phase.append(loss.cpu().detach().numpy())
        if (idx + 1) % 10 == 0:  # print every 200 mini-batches
            loss_mean = np.mean(np.array(losses_per_phase))
            print(f"Epoch: {epoch}, Last loss: {loss.item():.4f}, Loss phase mean: {loss_mean:.4f}")
            if wb:
                wb.log({"loss phase mean": loss_mean})
            losses_per_phase = []
        optimizer.zero_grad()
        
        

In [3]:


def val(model, train_dataloader, device, epoch, wb=None):
    model.eval()
    positive = 0
    train_len = train_dataloader.sampler.num_samples

    print("\n")
    print("Evaluation on train dataset")
    with torch.no_grad():
        for idx, (X, tgt) in tqdm(enumerate(train_dataloader)):
            tgt_text = " "#tgt["text"]
            tgt_class = torch.Tensor(tgt["label"]).long().to(device)
            tgt_class = F.one_hot(tgt_class, num_classes=5)
            one_hots = ukr_lang_chars_handle.sentences_to_one_hots(tgt_text, 152).to(device)
            one_hots = one_hots.squeeze(dim=1).permute(0, 2, 1).float()
            X = X.to(device)  #
            X = X.squeeze(dim=1).permute(0, 2, 1)
            emb, output = model(X, one_hots)
            A = torch.argmax(output, dim=-1)
            B = torch.argmax(tgt_class, dim=-1)
            is_right = (A == B)
            positive += torch.sum(is_right)

    train_accuracy = positive / train_len
    if wb:
        wb.log({
            "train accuracy": train_accuracy,
            "epoch": epoch
        })
    print(f"Accuracy on TRAIN dataset: {train_accuracy*100:.2f}%\n")


def get_scheduler(epochs, train_len, optimizer, scheduler_name="cosine_with_warmup", wb=None):
    if wb:
        wb.config["scheduler"] = scheduler_name
    if scheduler_name == "cosine_with_warmup":
        return get_cosine_schedule_with_warmup(optimizer,
                                                num_warmup_steps=epochs//5,
                                                num_training_steps=epochs - epochs//5,
                                                num_cycles=0.5)#1.25)
    elif scheduler_name == "constant":
        return torch.optim.lr_scheduler.ConstantLR(optimizer)
    elif scheduler_name == "exponential":
        return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1)
    elif scheduler_name == "one_circle":
        return OneCycleLR(optimizer,
                          max_lr=CONFIG["learning_rate"]*10,
                          total_steps=train_len)


def collate_fn(data):
    Xs, LBLs = zip(*data)
    Xs_out = pad_sequence([X.permute(0, 2, 1).squeeze(dim=0) for X in Xs], batch_first=True)
    lbl1 = LBLs[0]
    d_out = {}
    for key in lbl1.keys():
        d_out[key] = [d[key] for d in LBLs]
    return Xs_out, d_out


def main():
    wandb_stat = None#wandb.init(project="ASR", entity="Alex2135", config=CONFIG)
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    # Making dataset and loader
    ds = CommonVoiceUkr(TRAIN_REC_PATH, TRAIN_REC_SPEC_PATH, batch_size=BATCH_SIZE)
    train_dataloader = DataLoader(ds, shuffle=True, collate_fn=collate_fn, batch_size=BATCH_SIZE)
    train_val_dataloader = DataLoader(ds, shuffle=True, collate_fn=collate_fn, batch_size=64)

    epochs = CONFIG["epochs"]
    train_len = len(train_dataloader) * epochs

    tgt_n = 152
    d_model = 64
    model = Model(d_model=d_model, 
                  n_encoders=CONFIG["n_encoders"], 
                  n_decoders=CONFIG["n_decoders"], 
                  device=device)
    if CONFIG["pretrain"] == True:
        PATH = os.path.join(DATA_DIR, "model_1.pt")
        model = Model(n_encoders=CONFIG["n_encoders"], n_decoders=CONFIG["n_decoders"], device=device)
        model.load_state_dict(torch.load(PATH))

    # Create optimizator
    optimizer = AdamW(model.parameters(), lr=CONFIG["learning_rate"])
    save_model = False
    scheduler = get_scheduler(CONFIG["epochs"], train_len, optimizer, scheduler_name="constant", wb=wandb_stat)

    for epoch in range(1, epochs + 1):
        print(f"Epoch №{epoch}")
        train(model, train_dataloader, optimizer, device, scheduler=scheduler, epoch=epoch, wb=wandb_stat)
        val(model, train_val_dataloader, device, epoch, wb=wandb_stat)
        scheduler.step(epoch)
        print(f"scheduler last_lr: {scheduler.get_last_lr()[0]}")
        if wandb_stat:
            wandb_stat.log({"scheduler lr": scheduler.get_last_lr()[0]})

    if save_model:
        PATH = os.path.join(DATA_DIR, "model_1.pt")
        print(f"Save model to path: '{PATH}'")
        torch.save(model.state_dict(), PATH)


main()

KeyError: 'class'

In [None]:
from torch.nn.utils.rnn import pad_sequence

a = torch.randn(25, 300)
b = torch.randn(22, 300)
c = torch.randn(15, 300)
pad_sequence((a, b, c)).shape

In [None]:
d1 = {"a": 1, "b": 2}
d2 = {"a": 3, "b": 4}

ds = (d1, d2)
d_out = {}

for key in d1.keys():
    d_out[key] = [d[key] for d in ds]

print(d_out)

In [None]:
"""
Inputs shape: torch.Size([34, 4, 38])
Tgt shape: torch.Size([4, 36])
Input length shape: torch.Size([4])
Tgt length shape: torch.Size([4])
"""

T = 34      # Input sequence length
C = 38      # Number of classes (including blank)
N = 4      # Batch size
S = 36      # Target sequence length of longest target in batch (padding length)

ctc_loss = nn.CTCLoss()

In [None]:
input = torch.randn(T, N, C).requires_grad_()
# Initialize random batch of targets (0 = blank, 1:C = classes)
target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
target_lengths = torch.full(size=(N,), fill_value=S, dtype=torch.long)

print(f"Inputs shape: {input.shape, input}")
print(f"Tgt shape: {target.shape}")
print(f"Input length shape: {input_lengths.shape}")
print(f"Tgt length shape: {target_lengths.shape}")

loss = ctc_loss(input, target, input_lengths, target_lengths)
loss.backward()

In [None]:
d_inputs = 768
d_model = 64

X = torch.randn([1, 768, 14])
conv1 = nn.Sequential(
    nn.Conv1d(d_inputs, d_model, kernel_size=7, stride=3),
    nn.ReLU())
conv2 = nn.Sequential(
    nn.Conv1d(256, d_model, kernel_size=7, stride=3),
    nn.ReLU()
)
out = conv1(X)
print(f"shape after conv1: {out.shape=}")
#out = conv2(out)
#print(f"shape after conv2: {out.shape=}")

In [None]:
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from model import Conformer as con
from data_processing import ukr_lang_chars_handle
from data_processing import CommonVoiceUkr
from config import *
from torch.optim import RAdam, AdamW
from tqdm import tqdm
import pprint
import numpy as np
#from torch.optim.lr_scheduler import MultiplicativeLR
from model import MaskedSoftmaxCELoss
from model import get_cosine_schedule_with_warmup
import wandb

# wandb.init(project="ASR", entity="alex2135")

# wandb.config = CONFIG

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Making dataset and loader
ds = CommonVoiceUkr(TRAIN_PATH, TRAIN_SPEC_PATH, batch_size=BATCH_SIZE)
train_dataloader = DataLoader(ds, shuffle=True, batch_size=BATCH_SIZE)
train_len = len(train_dataloader) * CONFIG["epochs"]
print("train len:", train_len)

def eleminate_channels(X: torch.Tensor) -> torch.Tensor:
    b, c, h, w = X.shape
    X = X.view(b, c*h, w)
    return X

tgt_n = 152
model = con(n_encoders=CONFIG["n_encoders"], n_decoders=CONFIG["n_decoders"], device=device)
if CONFIG["pretrain"] == True:
    PATH = os.path.join(DATA_DIR, "model_1.pt")
    model = con(n_encoders=CONFIG["n_encoders"], n_decoders=CONFIG["n_decoders"], device=device)
    model.load_state_dict(torch.load(PATH))

# Create optimizator
optimizer = AdamW(model.parameters(), lr=CONFIG["learning_rate"])
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=train_len//5, num_training_steps=train_len)

# Create CTC criterion
alpha_loss = torch.Tensor([0.7]).to(device)
ctc_criterion = nn.CTCLoss(blank=ukr_lang_chars_handle.token_to_index["<blank>"], zero_infinity=True)
ce_criterion = nn.CrossEntropyLoss()

running_loss = []
losses_per_phase = []
epochs = CONFIG["epochs"]
try:
    for epoch in range(1, epochs + 1):
        print(f"Epoch №{epoch}")
        for idx, (X, tgt) in tqdm(enumerate(train_dataloader)):

            tgt_text = tgt["text"]
            tgt_class = tgt["label"]

            one_hots = ukr_lang_chars_handle.sentences_to_one_hots(tgt_text, 152).to(device)
            X = X.to(device) #

            emb, output = model(X, one_hots)  # (batch, _, n_class, time), (batch, _, time, n_class)
            b, cnls, t, clss = output.shape
            output = output.view(t * cnls, b, clss)  # (time, batch, n_class)
            output = F.log_softmax(output, dim=-1)
            indeces = ukr_lang_chars_handle.sentences_to_indeces(tgt_text).to(device)

            input_lengths = torch.full(size=(BATCH_SIZE,), fill_value=t, dtype=torch.long).to(device)
            target_lengths = torch.full(size=(BATCH_SIZE,), fill_value=indeces.shape[-1], dtype=torch.long).to(device)
            ctc_loss = ctc_criterion(output.to(device), indeces, input_lengths, target_lengths)

            print(f"{output.shape=}")
            break
            #print(f"\n{emb.shape=}\n{one_hots.shape=}")
            #emb, one_hots = eleminate_channels(emb.to(device)), eleminate_channels(one_hots.float())
            #ce_loss = ce_criterion(emb, one_hots)

            #loss = alpha_loss * torch.log(ce_loss) + (1-alpha_loss) * torch.log(ctc_loss)
            loss = ctc_loss
            loss.backward()
            optimizer.step()
            scheduler.step()

            #print(f"\n{ctc_loss.item()=}, {torch.log(ctc_loss)}\n{ce_loss.item()=}, {torch.log(ce_loss)}")
            running_loss.append(loss.cpu().detach().numpy())
            losses_per_phase.append(loss.cpu().detach().numpy())
            # wandb.log({"loss": loss})

            if torch.isnan(loss) or torch.isinf(loss):
                print("Target label:", tgt)
                print("Running loss:")
                pprint.pprint(running_loss)
                print(output.shape)
                print("Is nan in output:", torch.sum(torch.isnan(output)))
                print("Is inf in output:", torch.sum(torch.isinf(output)))
                pprint.pprint(output)
                break
            if (idx + 1) % 50 == 0:  # print every 200 mini-batches
                print(f"Epoch: {epoch}, Last loss: {loss.item():.4f}, Loss phase mean: {np.mean(np.array(losses_per_phase)):.4f}")
                losses_per_phase = []
            optimizer.zero_grad()
    import os
    PATH = os.path.join(DATA_DIR, "model_1.pt")
    print(PATH)
    torch.save(model.state_dict(), PATH)
except Exception as e:
    print(e)
"""