In [1]:
import pandas as pd

import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from src.model import ModelConfig, Encoder, Decoder
from src.audio import AudioTransform
from src.text import TextTransform
from src.dataset import LJSpeechDataset
from tqdm import tqdm

In [2]:
audio_transform = AudioTransform()
text_transform = TextTransform()

  WeightNorm.apply(module, name, dim)


In [3]:
def load_data(path:str = r'data/LJSpeech-1.1/metadata.csv') -> pd.DataFrame:
    manifest = pd.read_csv(path, sep='|', header=None, names=["id", "transcription", "normalized_transcription"])
    manifest = manifest.drop(columns=['transcription'])
    manifest.dropna()
    return manifest

def partition(df:pd.DataFrame, num:int = 5):
    return df.head(num)

In [4]:
df = load_data()
part = partition(df, 30)

In [5]:
y_list = []

for item in part.iterrows():
    path = 'data/LJSpeech-1.1/wavs/' + item[1]['id'] + '.wav'
    wav, sr = torchaudio.load(path)
    enc = audio_transform.encode(wav.unsqueeze(0), sr)
    y_list.append(enc)

audio_transform.fit(y_list)

In [6]:
text = part['normalized_transcription'].to_list()

text_transform.fit(text)

In [7]:
dataset = LJSpeechDataset(part, audio_transform, text_transform)
print(dataset[0][0].shape, dataset[0][1].shape)

torch.Size([1, 141]) torch.Size([1, 8, 746])


In [8]:
config = ModelConfig(
    vocab_size=text_transform.vocab_size,
    pad_idx=text_transform.phoneme2id[text_transform.special_tokens[0]],
    D_model=256,
    K=8,
    T_text=text_transform.max_length,
    T_audio=audio_transform.max_length,
    C=audio_transform.encodec.quantizer.bins,
    epochs=30,
    lr=3e-2,
    device='cuda:0',

    val_split = 0.1,
    test_split = 0.1,
    scheduler_step = 20,
    scheduler_gamma= 0.5,  
)

In [9]:
config

ModelConfig(vocab_size=58, D_model=256, K=8, T_text=140, T_audio=746, C=1024, pad_idx=0, epochs=30, lr=0.03, device='cuda:0', val_split=0.1, test_split=0.1, scheduler_step=20, scheduler_gamma=0.5)

In [10]:
def collate_fn(batch):
    xs, ys = zip(*batch)
    return torch.cat(xs, dim=0), torch.cat(ys, dim=0)

def create_data_loaders(dataset: Dataset, cfg:ModelConfig, collate_fn=collate_fn,
                        batch_size:int=8,) -> tuple:
    # build full dataset
    full = dataset
    n = len(full)
    n_val = int(n * cfg.val_split)
    n_test = int(n * cfg.test_split)
    n_train = n - n_val - n_test
    train_ds, val_ds, test_ds = random_split(full, [n_train, n_val, n_test])
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, collate_fn=collate_fn)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, collate_fn=collate_fn)
    return train_loader, val_loader, test_loader

In [11]:
train_loader, val_loader, test_loader = create_data_loaders(dataset, config, batch_size=8)

In [16]:
def train_model(config:ModelConfig, encoder:nn.Module, decoder:nn.Module,
                train_loader:DataLoader, val_loader:DataLoader=None) -> None:
    device = config.device
    encoder.to(device); decoder.to(device)
    print(f'Device {device}, cuda is available: {torch.cuda.is_available()} ')
    opt = torch.optim.Adam(list(encoder.parameters())+list(decoder.parameters()), lr=config.lr)
    sched = torch.optim.lr_scheduler.StepLR(opt, step_size=config.scheduler_step, gamma=config.scheduler_gamma)

    pad_idx = config.pad_idx; C = config.C

    for epoch in range(1, config.epochs+1):
        encoder.train(); decoder.train()
        train_losses = []
        for x_ids, y in tqdm(train_loader, desc=f"Epoch {epoch}/{config.epochs} [Train]", leave=False):
            x_ids, y = x_ids.to(device), y.to(device)
            opt.zero_grad()
            emb = encoder(x_ids)
            logits = decoder(emb)
            logits_flat = logits.reshape(-1, C)     # [B*K*T_audio, C]
            y_flat = y.reshape(-1)        # [B*K*T_audio]
            loss = F.cross_entropy(logits_flat, y_flat, ignore_index=pad_idx)
            loss.backward(); opt.step(); sched.step()
            train_losses.append(loss.item())

        avg_train = sum(train_losses)/len(train_losses)
        log = f"Epoch {epoch}/{config.epochs} Train Loss: {avg_train:.4f}"

        if val_loader:
            encoder.eval(); decoder.eval()
            with torch.no_grad():
                val_losses=[]
                for x_ids, y in tqdm(val_loader, desc=f"Epoch {epoch}/{config.epochs} [Val]", leave=False):
                    x_ids, y = x_ids.to(device), y.to(device)
                    logits = decoder(encoder(x_ids))
                    logits_flat = logits.reshape(-1, C)     # [B*K*T_audio, C]
                    y_flat = y.reshape(-1)        # [B*K*T_audio]
                    vloss = F.cross_entropy(logits_flat, y_flat, ignore_index=pad_idx)
                    val_losses.append(vloss.item())
                avg_val = sum(val_losses)/len(val_losses)
            log += f" | Val Loss: {avg_val:.4f}"

        if epoch==1 or epoch%config.scheduler_step==0:
            lr=sched.get_last_lr()[0]
            log += f" | LR: {lr:.6f}"
        print(log)

In [17]:
encoder = Encoder(config)
decoder = Decoder(config)

train_model(config, encoder, decoder, train_loader, val_loader)

Device cuda:0, cuda is available: True 


                                                                 

KeyboardInterrupt: 

In [None]:
def evaluate(test_loader:DataLoader, encoder:nn.Module, decoder:nn.Module,
             cfg:ModelConfig) -> dict:
    encoder.eval(); decoder.eval()
    total, correct = 0, 0
    pad_idx = cfg.pad_idx; C=cfg.C
    with torch.no_grad():
        for x_ids, y in test_loader:
            x_ids, y = x_ids.to(cfg.device), y.to(cfg.device)
            print(x_ids.shape)
            pred = decoder(encoder(x_ids)).argmax(dim=-1)
            mask = y!=pad_idx
            correct += (pred[mask]==y[mask]).sum().item()
            total += mask.sum().item()
    return {'code_accuracy': correct/total}

In [None]:
print(evaluate(
    test_loader, encoder, decoder, config
))

torch.Size([5, 141])
{'code_accuracy': 0.01664238484814524}


In [None]:
def synthesize(text:str, encoder:nn.Module, decoder:nn.Module,
               text_t:TextTransform, audio_t:AudioTransform, cfg:ModelConfig) -> torch.Tensor:
    encoder.eval(); decoder.eval()
    with torch.no_grad():
        x_ids = text_t(text).to(cfg.device)
        print(x_ids.shape)
        emb = encoder(x_ids)
        codes = decoder(emb).argmax(dim=-1)  # [K, T_audio]
        wav = audio_t.decode(codes)
    return wav

In [None]:
wave = synthesize('hello world', encoder, decoder, text_transform, audio_transform, config)

torchaudio.save('out.wav', wave, 24_000)

torch.Size([1, 141])
