# PhoneLM

## Test `G2P` and `Encodec`

In [None]:
!pip install g2p_en encodec

### `G2P`

In [2]:
from g2p_en import G2p

In [3]:
import torch
import random
import string
from functools import cache
from tqdm import tqdm

@cache
def _get_model():
    return G2p()

@cache
def _get_graphs(path):
    with open(path, "r") as f:
        graphs = f.read()
    return graphs

def encode(graphs: str) -> list[str]:
    g2p = _get_model()
    phones = g2p(graphs)
    ignored = {" ", *string.punctuation}
    return ["_" if p in ignored else p for p in phones]

@torch.no_grad()
def write_phones(folder, suffix=".normalized.txt"):
    print("ello?")
    paths = list(folder.rglob(f"*{suffix}"))
    random.shuffle(paths)

    print("paths:", paths)
    for path in tqdm(paths):
        phone_path = path.with_name(path.stem.split(".")[0] + ".phn.txt")
        if phone_path.exists():
            continue
        print("?")
        graphs = _get_graphs(path)
        phones = encode(graphs)
        with open(phone_path, "w") as f:
            f.write(" ".join(phones))

In [4]:
from pathlib import Path
write_phones(Path("./data/text"))

ello?
paths: [WindowsPath('data/text/test.normalized.txt')]


100%|██████████| 1/1 [00:00<?, ?it/s]


### `Encodec`

In [2]:
from tqdm import tqdm
import random
import torch
from functools import cache
import torchaudio
from encodec import EncodecModel
from torch import Tensor
from einops import rearrange
import soundfile
from encodec.utils import convert_audio
from pathlib import Path

SAMPLE_RATE = 24_000
BANDWIDTHS  = [1.5, 3.0, 6.0, 12.0, 24.0]
BANDWIDTH   = BANDWIDTHS[0]

@cache
def _load_model(bandwidth=6.0, device="cuda"):
    # Instantiate a pretrained EnCodec model
    assert SAMPLE_RATE == 24_000
    model = EncodecModel.encodec_model_24khz()
    model.set_target_bandwidth(bandwidth)
    model.to(device)
    return model

def unload_model():
    return _load_model.cache_clear()

@torch.inference_mode()
def decode(codes: Tensor, bandwidth=6.0, device="cuda"):
    """
    Args:
        codes: (b q t)
    """
    assert codes.dim() == 3
    model = _load_model(bandwidth, device)
    return model.decode([(codes, None)]), model.sample_rate

def decode_to_file(resps: Tensor, path: Path):
    assert resps.dim() == 2, f"Require shape (t q), but got {resps.shape}."
    resps = rearrange(resps, "t q -> 1 q t")
    wavs, sr = decode(codes=resps, bandwidth=BANDWIDTH)
    soundfile.write(str(path), wavs.cpu()[0, 0], sr)

def _replace_file_extension(path, suffix):
    return (path.parent / path.name.split(".")[0]).with_suffix(suffix)

@torch.inference_mode()
def encode(wav: Tensor, sr: int, bandwidth=6.0, device="cuda"):
    """
    Args:
        wav: (t)
        sr: int
    """
    model = _load_model(bandwidth, device)
    wav = wav.unsqueeze(0)
    wav = convert_audio(wav, sr, model.sample_rate, model.channels)
    wav = wav.to(device)
    encoded_frames = model.encode(wav)
    qnt = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1)  # (b q t)
    return qnt

def encode_from_file(path, bandwidth=6.0, device="cuda"):
    wav, sr = torchaudio.load(str(path))
    if wav.shape[0] == 2:
        wav = wav[:1]
    return encode(wav, sr, bandwidth, device)

def quantize_audio(folder, suffix=".wav"):
    paths = [*folder.rglob(f"*{suffix}")]
    random.shuffle(paths)

    for path in tqdm(paths):
        out_path = _replace_file_extension(path, ".qnt.pt")
        if out_path.exists():
            continue
        qnt = encode_from_file(path, BANDWIDTH)
        print(qnt.shape)
        torch.save(qnt.cpu(), out_path)

def decode_files(folder, suffix=".qnt.pt"):
    paths = [*folder.rglob(f"*{suffix}")]
    random.shuffle(paths)

    for path in tqdm(paths):
        out_path = _replace_file_extension(path, ".qt.wav")
        if out_path.exists():
            continue
        fi = rearrange(torch.load(path).squeeze(0).cuda(), "q t -> t q")
        decode_to_file(fi, out_path)

In [3]:
# from pathlib import Path
# quantize_audio(Path("./data/audio"))
# decode_files(Path("./data/audio"))

In [4]:
# torch.load("data/audio/test.qnt.pt").shape

#### Generate Audio from Tensor

In [9]:
audio_tensor = torch.tensor([[1019,  662],
        [ 598,   25],
        [ 321,  463],
        [1063,  575],
        [ 745,  727],
        [1073,  344],
        [1098,  344],
        [1046,  959],
        [1062,  874],
        [1059,  804],
        [1038, 1010],
        [1081,  577],
        [1098,  323],
        [1049,  858],
        [1034,  278],
        [1098,  469],
        [1069,  626],
        [1034,  482],
        [1071,  398],
        [1063,  858],
        [1083,  443],
        [1034,  418],
        [1072,  632],
        [1075,  914],
        [1098, 1010],
        [1094,  357],
        [1087,  898],
        [1084,  702],
        [1099,  654],
        [ 835,  364],
        [ 208,  416],
        [ 987,  722],
        [ 872,  708],
        [ 994,  399],
        [ 264,  648],
        [ 264, 1007],
        [1001,  961],
        [ 598,  320],
        [ 360,  993],
        [ 879,  747],
        [ 325,  700],
        [  52,  770],
        [ 257,  268],
        [ 257,  824],
        [ 819,  662],
        [ 709,  567],
        [ 656,  662],
        [  43,  602],
        [1038,  742],
        [  24,  964],
        [1098,  289],
        [1099,  722],
        [ 855,  870],
        [  25,  561],
        [ 472,  519],
        [ 472,  754],
        [ 475, 1038],
        [ 404,  857],
        [ 331,  913],
        [ 574,  434],
        [ 537,  154],
        [1022,  612],
        [ 324,  321],
        [ 937,  563],
        [ 230, 1001],
        [ 912,  563],
        [ 912,  807],
        [ 928,   99],
        [ 928,   99],
        [ 942,  228],
        [ 604,  772],
        [ 904,   94],
        [ 472, 1063],
        [  52,  812],
        [  52,  645],
        [  52,  697],
        [ 257,  387],
        [  52,  362],
        [ 935,  247],
        [ 983,   65],
        [ 683,  874],
        [ 155,  518],
        [  30,  822],
        [ 855,  467],
        [ 904,  909],
        [ 904,  529],
        [ 904,  852],
        [ 855,  399],
        [ 855,  470],
        [ 855, 1023],
        [ 106,  870],
        [ 176,  580],
        [ 574,  669],
        [ 502,  888],
        [ 588,  708],
        [ 782,  700],
        [ 588,  743],
        [ 890,  417],
        [ 373,  822],
        [ 160,  514],
        [  47,  455],
        [  47,  328],
        [  47,  259],
        [ 909,  971],
        [1023,  962],
        [ 577,  367]]).cuda()

In [10]:
audio_tensor = torch.clamp(audio_tensor, min=0, max=1023)

In [11]:
audio_tensor.shape

torch.Size([106, 2])

In [12]:
decode_to_file(audio_tensor, "general_out.wav")

## Dataset

### LJSpeech

In [1]:
BANDWIDTH_IDX = 1 # original VALL-E
CODEBOOKS     = [2, 4, 8, 16, 32]
BANDWIDTHS    = [1.5, 3.0, 6.0, 12.0, 24.0]
BANDWIDTH     = BANDWIDTHS[BANDWIDTH_IDX]
CODEBOOK      = CODEBOOKS[BANDWIDTH_IDX]

import torchaudio
from ljspeech import LJSPEECH
DATASET_PATH = "./data/LJSpeech/"
dataset = LJSPEECH(
    "./data/LJSpeech",
    encodec_bandwidth=BANDWIDTH)

In [2]:
len(dataset)

1919

In [3]:
dataset[0][-1].shape

torch.Size([1, 4, 143])

In [4]:
import torch
device = torch.cuda.current_device()

"""
fileid_audio,
waveform,
sample_rate,
transcript,
normalized_transcript,
phones,
phone_ids,
codes
"""

# def collate_fn(batch) -> torch.tensor:
#     audio_tokens = []
#     phone_tokens = []

#     for item in batch:
#         cur_aud_tok  = torch.tensor(item[7], device=device)
#         cur_phonemes = torch.tensor(item[6], device=device)
#         audio_tokens.append(cur_aud_tok)
#         phone_tokens.append(cur_phonemes)

#     # audio_tokens = torch.tensor(phone_tokens, device=device)
#     audio_tokens = nn.utils.rnn.pad_sequence(audio_tokens, batch_first=True)
#     # phone_tokens = torch.tensor(phone_tokens, device=device)
#     phone_tokens = nn.utils.rnn.pad_sequence(phone_tokens, batch_first=True)

#     return batch, phone_tokens, audio_tokens

'\nfileid_audio,\nwaveform,\nsample_rate,\ntranscript,\nnormalized_transcript,\nphones,\nphone_ids,\ncodes\n'

In [2]:
import torch
import torchaudio
from torch.utils.data import DataLoader, SubsetRandomSampler
from sklearn.model_selection import train_test_split

indices = list(range(len(dataset)))
train_indices, test_indices = train_test_split(indices, test_size=0.1, random_state=42)

train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(test_indices)

train_loader = DataLoader(dataset, batch_size=1, sampler=train_sampler, collate_fn=lambda x: x)
test_loader = DataLoader(dataset, batch_size=1, sampler=test_sampler, collate_fn=lambda x: x)

# train_loader = DataLoader(dataset, batch_size=32, sampler=train_sampler, collate_fn=collate_fn)
# test_loader = DataLoader(dataset, batch_size=32, sampler=test_sampler, collate_fn=collate_fn)

In [3]:
# it = next(iter(train_loader))
# it

In [4]:
len(train_loader), len(test_loader)

(1727, 192)

In [5]:
item = next(iter(train_loader))

In [6]:
item

[(WindowsPath('data/LJSpeech/LJSpeech-1.1/wavs/LJ050-0137.wav'),
  tensor([[ 0.0003,  0.0004,  0.0005,  ..., -0.0008, -0.0008, -0.0007]]),
  22050,
  'FBI, and the Secret Service.',
  'FBI, and the Secret Service.',
  ['B',
   'AY1',
   '_',
   '_',
   '_',
   'AH0',
   'N',
   'D',
   '_',
   'DH',
   'AH0',
   '_',
   'S',
   'IY1',
   'K',
   'R',
   'AH0',
   'T',
   '_',
   'S',
   'ER1',
   'V',
   'AH0',
   'S',
   '_',
   '_'],
  tensor([22, 20, 74, 74, 74, 10, 48, 24, 74, 25, 10, 74, 58, 42, 45, 57, 10, 60,
          74, 58, 30, 69, 10, 58, 74, 74], device='cuda:0'),
  tensor([[[ 865,   59,  309,  392,  695,  361,  706,  913,  822,  325,  176,
             438,  438,  360,  360,  176,  176,  106,  257,  106,  106,  408,
              63,  913,  801,  908,  801,  611,  530,  151,  944,  971,  347,
             523,  855,   25,  593,  695,  723,  683,  169,  203,  760,  683,
             240,  925,  925,   20,  162,  216,  216,  216,  793,  793,  901,
             402,  216,  21

## Model

In [7]:
import megabyte
import torch
import torch.nn as nn
from einops import rearrange

def get_reserved_mem_gb():
    device = torch.cuda.current_device()
    reserved = torch.cuda.memory_reserved(device)
    reserved_gb = reserved / 1024 / 1024 / 1024
    return reserved_gb

class PhoneLM(nn.Module):
    def __init__(self, n_phone_tokens, n_audio_tokens):
        super(PhoneLM, self).__init__()
        self.megabyte   = megabyte.MEGABYTE(
            heads       = 8, # 1,
            dim_head    = 32, # 16,
            num_tokens  = n_phone_tokens + n_audio_tokens + 4,
            dim         = (768, 256, 128), # (32, 32, 32), # (768, 256, 128)# Dg, Dl1, Dl2
            depth       = (6, 4, 2), # (6, 4, 2)
            max_seq_len = (32, 4, 4), # (128, 4, 4), # , # 512
            flash_attn  = False)

    def forward(self, x, debug=False, return_loss=True):
        x = self.megabyte(x, return_loss=return_loss)
        return x
    
    def get_params(self):
        o = [param.numel() for param in self.parameters() if param.requires_grad]
        o = sum(o)
        return o
    
    def generate(self, *args):
        return self.megabyte.generate(*args)
    
def multi_encode(
        phone_tokens,
        audio_tokens,
        n_phone_tokens,
        n_audio_tokens,
        max_clip_length=1.0):
    """NOTE: 75 steps per second for 24kHz in `encodec.
    Set `max_clip_length` to 0 for original clip length."""

    # Start text token, end text token, start audio token, end audio token
    ETT, EAT = [n_phone_tokens + n_audio_tokens + i
                          for i in range(2)]
    ETT = torch.tensor([ETT]).long().cuda()
    EAT = torch.tensor([EAT]).long().cuda()

    if max_clip_length:
        #print("pre audio_tokens.shape", audio_tokens.shape)
        audio_tokens = audio_tokens[:, :, :int(max_clip_length * 75)]
    audio_tokens = rearrange(audio_tokens.squeeze(0), "q s -> (q s)")
    #print("post audio_tokens.shape", audio_tokens.shape)
    
    # offset phone tokens past audio tokens
    phone_tokens += n_audio_tokens
    
    #print("phone_tokens.shape:", phone_tokens.shape)
    #print("audio_tokens.shape:", audio_tokens.shape)
    
    device = torch.cuda.current_device()
    phone_tokens = torch.cat((phone_tokens, ETT), dim=0).to(device)
    audio_tokens = torch.cat((audio_tokens, EAT,), dim=0).to(device)
    combined_tokens = torch.cat((phone_tokens, audio_tokens), dim=0).to(device)
    return phone_tokens, audio_tokens, combined_tokens

In [8]:
from einops import rearrange

from encodec_util import decode_to_file

"""
EinopsError:  Error while processing rearrange-reduction pattern "(t q) -> t q".
 Input tensor shape: torch.Size([75]). Additional info: {'q': 4, 't': 75}.
 Shape mismatch, 75 != 300
"""

def generate_audio(sample,
                   n_phone_tokens,
                   n_audio_tokens,
                   audio_path="./out.wav"):
    ETT, EAT = [n_phone_tokens + n_audio_tokens + i
                          for i in range(2)]
    ST_S = [ETT, EAT]
    print("ETT, EAT ids:", ST_S)
    seq = sample.cpu().tolist()[0]
    print("seq:", seq)
    # all special tokens in list
    if all(st_t in seq for st_t in ST_S) and len(seq) >= len(ST_S) + 2:
        # text_tokens  = seq[seq.index(STT + 1):seq.index(ETT - 1)]
        audio_tokens = seq[seq.index(ETT)+1:seq.index(EAT)]
        print(seq.index(ETT), seq.index(EAT), len(audio_tokens))
        audio_tokens = torch.tensor(audio_tokens).cuda()
        audio_tokens = rearrange(
            audio_tokens,
            '(t q) -> t q',
            q=1, # CODEBOOK,
            t=audio_tokens.size(0) // 1) # t=audio_tokens.size(0) // CODEBOOK)
        print("audio_tokens.shape:", audio_tokens, audio_tokens.shape)
        decode_to_file(audio_tokens, audio_path)
        return True
    else:
        return False

## PhoneLM - LJSpeech (Overfit Multi)

### Train

In [9]:
device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = PhoneLM(
    n_phone_tokens=len(dataset.phone_dict),
    n_audio_tokens=1024).to(device)

model.megabyte.get_num_params()

number of parameters: 37.30M


37302863

In [10]:
item = next(iter(train_loader))
# item_phone_tokens = item[-2]
# # item_audio_tokens = item[-1]
# item_audio_tokens = item[-1][:, 0, :] # Only keep primary coarse tokens, for now
# item_audio_tokens = item_audio_tokens.unsqueeze(0)
# item_phone_tokens.shape, item_audio_tokens.shape

In [11]:
len(item)

1

In [12]:
item[0][3]

'to taking the descriptions of newly-arrived prisoners.'

In [13]:
# item

In [14]:
# phone_prompt, audio_target, test_inp = multi_encode(
#     item_phone_tokens,
#     item_audio_tokens,
#     n_phone_tokens=len(dataset.phone_dict),
#     n_audio_tokens=1024,
#     max_clip_length=5)
# test_inp.shape

### Training Process

In [15]:
from tqdm.notebook import tqdm

In [16]:
import torch.optim as optim

epochs = 10

MAX_LR       = 1e-2
# MAX_LR       = 1e-2
WEIGHT_DECAY = 1e-4
GRAD_CLIP    = 0.1

optimizer = optim.Adam(
    model.parameters(),
    lr=MAX_LR)
    #,weight_decay=WEIGHT_DECAY)

# def get_lr(optimizer):
#     for param_group in optimizer.param_groups:
#         return param_group['lr']

# sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, MAX_LR, epochs=epochs, 
#                                                 steps_per_epoch=len(trainloader))

In [17]:
len(item)

1

In [19]:
import torch.nn.functional as F

EPOCHS = 1000
PRINT_INTERVAL = 100

seq_len = 512 # 2048

# phone_prompt, audio_target, test_inp = multi_encode(
#     item[0][6], # phone tokens
#     item[0][7], # audio tokens
#     n_phone_tokens=len(dataset.phone_dict),
#     n_audio_tokens=1024,
#     max_clip_length=1)

# model = PhoneLM(
#     n_phone_tokens=len(dataset.phone_dict),
#     n_audio_tokens=1024).to(device)

prompt = None

def create_seq(item_phone_tokens, item_audio_tokens):
    global prompt
    phone_prompt, audio_target, test_inp = multi_encode(
        item_phone_tokens,
        item_audio_tokens,
        n_phone_tokens=len(dataset.phone_dict),
        n_audio_tokens=1024,
        max_clip_length=1)
    prompt = phone_prompt
    padding_len = max(0, seq_len - test_inp.size(0))
    n_test_inp  = F.pad(test_inp, (0, padding_len))
    cur_item    = n_test_inp
    # cur_item = n_test_inp.unsqueeze(0)
    return cur_item

def create_batch(batch):
    rnn_batch = []
    for item in batch:
        item_phone_tokens = item[6]
        item_audio_tokens = item[7]
        seq = create_seq(item_phone_tokens, item_audio_tokens)
        rnn_batch.append(seq)
    rnn_batch = nn.utils.rnn.pad_sequence(rnn_batch, batch_first=True)
    return rnn_batch

cur_batch = item
batch = create_batch(cur_batch)

def train(model, trainloader):
    model.train()

    # print("batch:", batch.shape, batch)
    loss = model(batch, return_loss=True)
    loss.backward()
    return loss

# def train(model, trainloader):
#     model.train()
    
#     padding_len = max(0, seq_len - test_inp.size(0))
#     n_test_inp = F.pad(test_inp, (0, padding_len))
#     batch = n_test_inp.unsqueeze(0)
#     print("batch.shape:", batch.shape, batch)
#     loss = model(batch, return_loss=True)
#     # loss = model(next(trainloader), return_loss=True)
#     loss.backward()
#     return loss

# pbar = tqdm.tqdm(EPOCHS, mininterval=10., desc='training')
for epoch in range(EPOCHS):
    optimizer.zero_grad()
    loss = train(model, train_loader)
    optimizer.step()
    
    mem_gb = get_reserved_mem_gb()
    if epoch % PRINT_INTERVAL == 0:
        print(f"Reserved Memory (GB): {mem_gb}, loss: {loss.item()}")
    # pbar.set_description(f"Reserved Memory (GB): {mem_gb}, loss: {loss.item()}")

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [22]:
print(batch.shape)

torch.Size([1, 512])


### Evaluate

In [23]:
prompt.shape

torch.Size([31])

In [34]:
prompt

tensor([[1059, 1081, 1057, 1097, 1053, 1098, 1084, 1075, 1070, 1048, 1098, 1049,
         1034, 1098, 1069, 1034, 1071, 1063, 1083, 1034, 1072, 1098, 1098, 1098,
         1069, 1094, 1075, 1084, 1098, 1098, 1099]], device='cuda:0')

In [39]:
def generate(model, inp):
    model.eval()

    # inp = inp.unsqueeze(0)
    sample = model.generate(inp)
    sample = sample.flatten(1)
    print("sample:", sample, sample.shape)

    return prompt, sample

prompt, sample = generate(model, prompt)

FUCKING HELLO?
prompt: torch.Size([1, 31])


100%|██████████| 481/481 [00:06<00:00, 74.27it/s]

sample: tensor([[1059, 1081, 1057, 1097, 1053, 1098, 1084, 1075, 1070, 1048, 1098, 1049,
         1034, 1098, 1069, 1034, 1071, 1063, 1083, 1034, 1072, 1098, 1098, 1098,
         1069, 1094, 1075, 1084, 1098, 1098, 1099,  835,  160,  438,  488,  887,
          203,  503,  441,    6,   81,  727,  141,  908,  908,  502,  303,  148,
          103,  496,  145,  731,  977,  259,  582,  808,  921,  432,  779,  779,
          472,  472,  331,  103,  887,  457,  987,  501,  921,  197,  197,  931,
          928,  881,  834,  432,  604,  491,  373,  994,  782,  834,  408,  855,
          855,  798,  176,  798,  537,  779,  936,  457,  751,  651,  687,  751,
          790,  686,  994,   57,  916,  751,  699,  145,  145,  148, 1010,    4,
          973,  984, 1002,    3,  472,  185,  662,  662,  471,  471,  106,  734,
          893,  802,  285, 1010,  812,  272,  529,  930,  486,  323,  632,  466,
          930,  601,  646,  924,  913,  160, 1010,  857,  984,  668,  399,  399,
          399,  710,




In [40]:
sample.shape

torch.Size([1, 512])

In [41]:
ETT, EAT = [len(dataset.phone_dict) + 1024 + i
                          for i in range(2)]
print(ETT, EAT)
# sample.index(STT)

1099 1100


In [42]:
out = generate_audio(
    sample,
    n_phone_tokens=len(dataset.phone_dict),
    n_audio_tokens=1024)

ETT, EAT ids: [1099, 1100]
seq: [1059, 1081, 1057, 1097, 1053, 1098, 1084, 1075, 1070, 1048, 1098, 1049, 1034, 1098, 1069, 1034, 1071, 1063, 1083, 1034, 1072, 1098, 1098, 1098, 1069, 1094, 1075, 1084, 1098, 1098, 1099, 835, 160, 438, 488, 887, 203, 503, 441, 6, 81, 727, 141, 908, 908, 502, 303, 148, 103, 496, 145, 731, 977, 259, 582, 808, 921, 432, 779, 779, 472, 472, 331, 103, 887, 457, 987, 501, 921, 197, 197, 931, 928, 881, 834, 432, 604, 491, 373, 994, 782, 834, 408, 855, 855, 798, 176, 798, 537, 779, 936, 457, 751, 651, 687, 751, 790, 686, 994, 57, 916, 751, 699, 145, 145, 148, 1010, 4, 973, 984, 1002, 3, 472, 185, 662, 662, 471, 471, 106, 734, 893, 802, 285, 1010, 812, 272, 529, 930, 486, 323, 632, 466, 930, 601, 646, 924, 913, 160, 1010, 857, 984, 668, 399, 399, 399, 710, 27, 710, 710, 564, 765, 888, 770, 404, 405, 420, 708, 752, 928, 71, 216, 43, 471, 801, 404, 112, 214, 48, 425, 1011, 48, 767, 792, 792, 869, 268, 471, 200, 1010, 857, 857, 970, 457, 212, 188, 212, 54, 90, 573, 

In [43]:
out

True

## PhoneLM - LJSpeech (Generalise)