# TTS - Experiments

## Dataset

In [63]:
from ljspeech import LJSPEECH

BLOCK_SIZE        = 512
DATASET_PATH      = "./data/LJSpeech/"
BANDWIDTH_IDX     = 0
BANDWIDTHS        = [1.5, 3.0, 6.0, 12.0, 24.0]
BANDWIDTH         = BANDWIDTHS[BANDWIDTH_IDX]
MAX_PROMPT_LENGTH = 128

dataset = LJSPEECH("./data/LJSpeech",
                    encodec_bandwidth=BANDWIDTH,
                    max_prompt_length=MAX_PROMPT_LENGTH)

## Model

In [64]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Head(nn.Module):
    def __init__(self, head_size, n_embed, BLOCK_SIZE, dropout):
        super().__init__()
        self.k = nn.Linear(n_embed, head_size, bias=False)
        self.q = nn.Linear(n_embed, head_size, bias=False)
        self.v = nn.Linear(n_embed, head_size, bias=False)

        self.register_buffer("tril",
                             torch.tril(torch.ones(BLOCK_SIZE, BLOCK_SIZE)))
        
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        B, T, C = x.shape
        k = self.k(x)
        q = self.q(x)
        w = q @ k.transpose(-2, -1) * C ** -0.5
        w = w.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
        w = F.softmax(w, dim=-1)
        w = self.dropout(w)
        v = self.v(x)
        o = w @ v
        return o

class MultiHeadAttention(nn.Module):
    def __init__(self, n_embed, num_heads, head_size, dropout, BLOCK_SIZE):
        super().__init__()
        self.heads = nn.ModuleList(
            [Head(head_size, n_embed, BLOCK_SIZE, dropout)
             for _ in range(num_heads)])
        self.proj  = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)
        # print("MHSA proj.shape:", n_embed)
    def forward(self, x):
        # print("MHSA x.shape:", x.shape)
        o = torch.cat([h(x) for h in self.heads], dim=-1)
        # print("MHSA concat o.shape:", o.shape)
        o = self.dropout(self.proj(o))
        # print("MHSA project o.shape:", o.shape)
        return o
    
class FeedForward(nn.Module):
    def __init__(self, n_embed, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)
    
class Block(nn.Module):
    def __init__(self, n_embed, n_head, dropout, BLOCK_SIZE):
        super().__init__()
        head_size = n_embed // n_head
        self.sa   = MultiHeadAttention(n_embed, n_head, head_size, dropout, BLOCK_SIZE)
        self.ffwd = FeedForward(n_embed, dropout)
        self.ln1  = nn.LayerNorm(n_embed)
        self.ln2  = nn.LayerNorm(n_embed)
    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x
    
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_len, n_embed, n_heads, n_layer, BLOCK_SIZE, dropout=0.2):
        super().__init__()
        self.BLOCK_SIZE = BLOCK_SIZE
        self.token_emb_table    = nn.Embedding(vocab_len, n_embed)
        self.position_emb_table = nn.Embedding(BLOCK_SIZE, n_embed)
        self.blocks = nn.Sequential(
            *[Block(n_embed, n_head=n_heads, dropout=dropout, BLOCK_SIZE=BLOCK_SIZE)
              for _ in range(n_layer)]
        )
        self.ln_f = nn.LayerNorm(n_embed)
        self.lm_head = nn.Linear(n_embed, vocab_len)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        tok_embed = self.token_emb_table(idx)
        pos_embed = self.position_emb_table(
            torch.arange(T, device="cpu"))

        x = tok_embed + pos_embed
        x = self.blocks(x)

        x = self.ln_f(x)
        logits = self.lm_head(x)

        if not targets is None:
            B, T, C = logits.shape
            logits  = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        else:
            loss = None
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.BLOCK_SIZE:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

vocab_len = 1024 + 1 + len(dataset.phone_dict)
# vocab_len = 1024
model = TransformerDecoder(
    vocab_len=vocab_len,
    n_embed=256,
    n_heads=4,
    n_layer=1,
    BLOCK_SIZE=BLOCK_SIZE
)

In [65]:
dataset[0]

(PosixPath('data/LJSpeech/LJSpeech-1.1/wavs/LJ001-0002.wav'),
 tensor([[-0.0003,  0.0000,  0.0000,  ..., -0.0009, -0.0010, -0.0011]]),
 22050,
 'in being comparatively modern.',
 'in being comparatively modern.',
 ['IH0',
  'N',
  '_',
  'B',
  'IY1',
  'IH0',
  'NG',
  '_',
  'K',
  'AH0',
  'M',
  'P',
  'EH1',
  'R',
  'AH0',
  'T',
  'IH0',
  'V',
  'L',
  'IY0',
  '_',
  'M',
  'AA1',
  'D',
  'ER0',
  'N',
  '_',
  '_'],
 tensor([38, 48, 74, 22, 42, 38, 49, 74, 45, 10, 47, 56, 27, 57, 10, 60, 38, 69,
         46, 41, 74, 47,  5, 24, 29, 48, 74, 74]),
 tensor([[[ 738,  523,  141,  504,  970,  363,  746,  913,  949, 1010,  530,
            347,  860,  319,  477,  840,  801,  319,  765,  465,  727,  727,
            906,  840,  990,  801,  765,  563,  807,  565,   25,  276,  904,
            194,  935,  779,  283,  913,  945,  563,  807,  976,  404,   52,
            325,  904, 1020,  666,  372,  677,  537,  695,  352,  348,  240,
            222,  612,  734,  950,  734,  451,  694,

In [66]:
from einops import rearrange

In [67]:
org_item = dataset[4][-1]
org_item.shape

torch.Size([1, 2, 134])

In [68]:
item = rearrange(org_item, "b q t -> b (t q)")

In [69]:
item = torch.tensor(item, device="cpu")

  item = torch.tensor(item, device="cpu")


In [70]:
item = torch.clamp(item, 0, 1023)

In [71]:
item.shape

torch.Size([1, 268])

In [72]:
logits, loss = model(item)

In [73]:
logits, loss

(tensor([[[ 0.2993,  0.2959, -0.3080,  ...,  0.5133, -0.7259, -0.0948],
          [-0.1246,  0.2828,  0.0867,  ..., -0.7155,  0.3794,  1.1395],
          [-0.2055, -1.0728,  0.2162,  ...,  0.1552,  1.1568,  0.1658],
          ...,
          [ 0.1314,  0.2591,  0.6565,  ...,  0.1458,  0.2549,  0.5848],
          [-1.0401, -0.4139, -0.6558,  ...,  0.4244,  0.2001,  0.9856],
          [-0.1315, -0.0691, -0.4677,  ...,  0.8147, -0.8719, -0.3951]]],
        grad_fn=<ViewBackward0>),
 None)

In [74]:
optim = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [75]:
EPOCHS = 100

In [76]:
item.shape[1]

268

In [77]:
item[0, :]

tensor([ 276,  609,  807,  478,  112,  727,  658,  836,  575,  460,  942,  422,
         942,  460,  160, 1010,   47, 1010,   47,  420,   47,  973,   47,  742,
         160,  742,  339,  752,  583,  672,  868,  315,  784,  227,  987,  185,
         984,  952,  998,  985,  333,  792,  841,  869,  548,  881,  548,  869,
         514,  469,  375,  469,  695,  276,  432,  722,   73,  829,  251,  177,
         759,    6,   43,  345,  656,  855,  743,  959,   43,  858,  808,   89,
          43,  320,  808,  227,  699, 1002,  967, 1002, 1011,  646,  457,  646,
         604,  230,  602,  721,  980,  536,  920,  711,  356,  792,  939,  711,
         584,  711,  796,  869,  833,  869,  699,  133,  321,  556,  457,  418,
         136,   31,  676,  102,   47,  420,  744, 1010,   47,  973,  574,  160,
         160,  549,  160,  857,  148,  993,  148,  443,  103,  336,  276, 1010,
         264,  615,  197,  615,  833,  413,  182,  920,  862,  541, 1019,  857,
         855,   71,  855,  888,  855,  6

In [78]:
xb = item[0, 0:-2].unsqueeze(0)
new_list = torch.tensor([item[0, 1:-1].tolist()], dtype=torch.long)
yb = new_list

In [79]:
xb.shape, yb.shape

(torch.Size([1, 266]), torch.Size([1, 266]))

In [80]:
for e in range(EPOCHS):
    optim.zero_grad()
    _, loss = model(xb, targets=yb)
    loss.backward()
    optim.step()
    if e % 10 == 0:
        print(e, loss)

0 tensor(7.2001, grad_fn=<NllLossBackward0>)
10 tensor(3.5264, grad_fn=<NllLossBackward0>)
20 tensor(0.9719, grad_fn=<NllLossBackward0>)
30 tensor(0.1463, grad_fn=<NllLossBackward0>)
40 tensor(0.0374, grad_fn=<NllLossBackward0>)
50 tensor(0.0188, grad_fn=<NllLossBackward0>)
60 tensor(0.0128, grad_fn=<NllLossBackward0>)
70 tensor(0.0105, grad_fn=<NllLossBackward0>)
80 tensor(0.0090, grad_fn=<NllLossBackward0>)
90 tensor(0.0081, grad_fn=<NllLossBackward0>)


In [81]:
from encodec_util import decode_to_file
pred = model.generate(torch.zeros((1, 1), dtype=torch.int), max_new_tokens=100)

In [82]:
print(pred.shape)

torch.Size([1, 101])


In [83]:
clipped_pred = pred[0, :pred.shape[1] - (pred.shape[1] % 2)]

In [84]:
clipped_pred.shape

torch.Size([100])

In [85]:
out_pred = rearrange(clipped_pred.squeeze(0), "(t q) -> t q", q=2)

In [86]:
out_pred

tensor([[   0,  609],
        [ 807,  478],
        [ 112,  727],
        [ 658,  836],
        [ 575,  460],
        [ 942,  422],
        [ 942,  460],
        [ 160, 1010],
        [  47, 1010],
        [  47,  420],
        [  47,  973],
        [  47,  742],
        [ 160,  742],
        [ 339,  752],
        [ 583,  672],
        [ 868,  315],
        [ 784,  227],
        [ 987,  185],
        [ 984,  952],
        [ 998,  985],
        [ 333,  792],
        [ 841,  869],
        [ 548,  881],
        [ 548,  869],
        [ 514,  469],
        [ 375,  469],
        [ 695,  276],
        [ 432,  722],
        [  73,  829],
        [ 251,  177],
        [ 759,    6],
        [  43,  345],
        [ 656,  855],
        [ 743,  959],
        [  43,  858],
        [ 808,   89],
        [  43,  320],
        [ 808,  227],
        [ 699, 1002],
        [ 967, 1002],
        [1011,  646],
        [ 457,  646],
        [ 604,  230],
        [ 602,  721],
        [ 980,  536],
        [ 

In [87]:
torch.min(out_pred), torch.max(out_pred)

(tensor(0), tensor(1011))

In [88]:
out_pred.shape

torch.Size([50, 2])

In [89]:
org_item_test = rearrange(org_item.squeeze(0), "q t -> t q")

In [90]:
out_pred = torch.clamp(out_pred, 1, 1023)

In [91]:
decode_to_file(out_pred, "out.wav")