# PhoneLM

## Test `G2P` and `Encodec`

In [None]:
!pip install g2p_en encodec

### `G2P`

In [2]:
from g2p_en import G2p

In [3]:
import torch
import random
import string
from functools import cache
from tqdm import tqdm

@cache
def _get_model():
    return G2p()

@cache
def _get_graphs(path):
    with open(path, "r") as f:
        graphs = f.read()
    return graphs

def encode(graphs: str) -> list[str]:
    g2p = _get_model()
    phones = g2p(graphs)
    ignored = {" ", *string.punctuation}
    return ["_" if p in ignored else p for p in phones]

@torch.no_grad()
def write_phones(folder, suffix=".normalized.txt"):
    print("ello?")
    paths = list(folder.rglob(f"*{suffix}"))
    random.shuffle(paths)

    print("paths:", paths)
    for path in tqdm(paths):
        phone_path = path.with_name(path.stem.split(".")[0] + ".phn.txt")
        if phone_path.exists():
            continue
        print("?")
        graphs = _get_graphs(path)
        phones = encode(graphs)
        with open(phone_path, "w") as f:
            f.write(" ".join(phones))

In [4]:
from pathlib import Path
write_phones(Path("./data/text"))

ello?
paths: [WindowsPath('data/text/test.normalized.txt')]


100%|██████████| 1/1 [00:00<?, ?it/s]


### `Encodec`

In [36]:
from tqdm import tqdm
import random
import torch
from functools import cache
import torchaudio
from encodec import EncodecModel
from torch import Tensor
from einops import rearrange
import soundfile
from encodec.utils import convert_audio

SAMPLE_RATE = 24_000
BANDWIDTHS  = [1.5, 3.0, 6.0, 12.0, 24.0]
BANDWIDTH   = BANDWIDTHS[0]

@cache
def _load_model(bandwidth=6.0, device="cuda"):
    # Instantiate a pretrained EnCodec model
    assert SAMPLE_RATE == 24_000
    model = EncodecModel.encodec_model_24khz()
    model.set_target_bandwidth(bandwidth)
    model.to(device)
    return model

def unload_model():
    return _load_model.cache_clear()

@torch.inference_mode()
def decode(codes: Tensor, bandwidth=6.0, device="cuda"):
    """
    Args:
        codes: (b q t)
    """
    assert codes.dim() == 3
    model = _load_model(bandwidth, device)
    return model.decode([(codes, None)]), model.sample_rate

def decode_to_file(resps: Tensor, path: Path):
    assert resps.dim() == 2, f"Require shape (t q), but got {resps.shape}."
    resps = rearrange(resps, "t q -> 1 q t")
    wavs, sr = decode(codes=resps, bandwidth=BANDWIDTH)
    soundfile.write(str(path), wavs.cpu()[0, 0], sr)

def _replace_file_extension(path, suffix):
    return (path.parent / path.name.split(".")[0]).with_suffix(suffix)

@torch.inference_mode()
def encode(wav: Tensor, sr: int, bandwidth=6.0, device="cuda"):
    """
    Args:
        wav: (t)
        sr: int
    """
    model = _load_model(bandwidth, device)
    wav = wav.unsqueeze(0)
    wav = convert_audio(wav, sr, model.sample_rate, model.channels)
    wav = wav.to(device)
    encoded_frames = model.encode(wav)
    qnt = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1)  # (b q t)
    return qnt

def encode_from_file(path, bandwidth=6.0, device="cuda"):
    wav, sr = torchaudio.load(str(path))
    if wav.shape[0] == 2:
        wav = wav[:1]
    return encode(wav, sr, bandwidth, device)

def quantize_audio(folder, suffix=".wav"):
    paths = [*folder.rglob(f"*{suffix}")]
    random.shuffle(paths)

    for path in tqdm(paths):
        out_path = _replace_file_extension(path, ".qnt.pt")
        if out_path.exists():
            continue
        qnt = encode_from_file(path, BANDWIDTH)
        print(qnt.shape)
        torch.save(qnt.cpu(), out_path)

def decode_files(folder, suffix=".qnt.pt"):
    paths = [*folder.rglob(f"*{suffix}")]
    random.shuffle(paths)

    for path in tqdm(paths):
        out_path = _replace_file_extension(path, ".qt.wav")
        if out_path.exists():
            continue
        fi = rearrange(torch.load(path).squeeze(0).cuda(), "q t -> t q")
        decode_to_file(fi, out_path)

In [37]:
from pathlib import Path
quantize_audio(Path("./data/audio"))
decode_files(Path("./data/audio"))

100%|██████████| 1/1 [00:00<00:00,  2.86it/s]


torch.Size([1, 2, 128])


100%|██████████| 1/1 [00:00<00:00, 58.82it/s]


In [27]:
torch.load("data/audio/test.qnt.pt").shape

torch.Size([1, 2, 128])

## Dataset

### LJSpeech

In [1]:
BANDWIDTH_IDX = 0
CODEBOOKS     = [2, 4, 8, 16, 32]
BANDWIDTHS    = [1.5, 3.0, 6.0, 12.0, 24.0]
BANDWIDTH     = BANDWIDTHS[BANDWIDTH_IDX]
CODEBOOK      = CODEBOOKS[BANDWIDTH_IDX]

import torchaudio
from ljspeech import LJSPEECH
DATASET_PATH = "./data/LJSpeech/"
dataset = LJSPEECH(
    "./data/LJSpeech",
    encodec_bandwidth=BANDWIDTH)

In [2]:
len(dataset)

13100

In [3]:
dataset[0][-1].shape

torch.Size([1, 2, 725])

In [4]:
import torch
import torchaudio
from torch.utils.data import DataLoader, SubsetRandomSampler
from sklearn.model_selection import train_test_split

indices = list(range(len(dataset)))
train_indices, test_indices = train_test_split(indices, test_size=0.1, random_state=42)

train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(test_indices)

train_loader = DataLoader(dataset, batch_size=32, sampler=train_sampler, collate_fn=lambda x: x)
test_loader = DataLoader(dataset, batch_size=32, sampler=test_sampler, collate_fn=lambda x: x)

In [5]:
len(train_loader), len(test_loader)

(369, 41)

In [6]:
item = next(iter(train_loader))

In [None]:
item

## Model

In [8]:
import megabyte
import torch
import torch.nn as nn
from einops import rearrange

def get_reserved_mem_gb():
    device = torch.cuda.current_device()
    reserved = torch.cuda.memory_reserved(device)
    reserved_gb = reserved / 1024 / 1024 / 1024
    return reserved_gb

class PhoneLM(nn.Module):
    def __init__(self, n_phone_tokens, n_audio_tokens):
        super(PhoneLM, self).__init__()
        self.megabyte   = megabyte.MEGABYTE(
            heads       = 8, # 1,
            dim_head    = 32, # 16,
            num_tokens  = n_phone_tokens + n_audio_tokens + 4,
            dim         = (768, 256, 128), # (32, 32, 32), # (768, 256, 128)# Dg, Dl1, Dl2
            depth       = (6, 4, 2), # (6, 4, 2)
            max_seq_len = (32, 4, 4),
            flash_attn  = False)

    def forward(self, x, debug=False, return_loss=True):
        x = self.megabyte(x, return_loss=return_loss)
        return x
    
    def get_params(self):
        o = [param.numel() for param in self.parameters() if param.requires_grad]
        o = sum(o)
        return o
    
    def generate(self, *args):
        return self.megabyte.generate(*args)
    
def multi_encode(
        phone_tokens,
        audio_tokens,
        n_phone_tokens,
        n_audio_tokens,
        max_clip_length=1.0):
    """NOTE: 75 steps per second for 24kHz in `encodec.
    Set `max_clip_length` to 0 for original clip length."""

    # Start text token, end text token, start audio token, end audio token
    STT, ETT, SAT, EAT = [n_phone_tokens + n_audio_tokens + i
                          for i in range(4)]
    STT = torch.tensor([STT]).long().cuda()
    ETT = torch.tensor([ETT]).long().cuda()
    SAT = torch.tensor([SAT]).long().cuda()
    EAT = torch.tensor([EAT]).long().cuda()

    if max_clip_length:
        audio_tokens = audio_tokens[:, :, :int(max_clip_length * 75)]
    audio_tokens = rearrange(audio_tokens.squeeze(0), "q s -> (q s)")
    
    # offset phone tokens past audio tokens
    phone_tokens += n_audio_tokens
    
    print("phone_tokens.shape:", phone_tokens.shape)
    print("audio_tokens.shape:", audio_tokens.shape)
    
    device = torch.cuda.current_device()
    phone_tokens = torch.cat((STT, phone_tokens, ETT), dim=0).to(device)
    audio_tokens = torch.cat((SAT, audio_tokens, EAT,), dim=0).to(device)
    combined_tokens = torch.cat((phone_tokens, audio_tokens), dim=0).to(device)
    return phone_tokens, audio_tokens, combined_tokens

In [24]:
from einops import rearrange

from encodec_util import decode_to_file

def generate_audio(sample,
                   n_phone_tokens,
                   n_audio_tokens,
                   audio_path="./out.wav"):
    STT, ETT, SAT, EAT = [n_phone_tokens + n_audio_tokens + i
                          for i in range(4)]
    ST_S = [STT, ETT, SAT, EAT]
    print("STT, ETT, SAT, EAT ids:", ST_S)
    seq = sample.cpu().tolist()[0]
    print("seq:", seq)
    # all special tokens in list
    if all(st_t in seq for st_t in ST_S) and len(seq) >= len(ST_S) + 2:
        # text_tokens  = seq[seq.index(STT + 1):seq.index(ETT - 1)]
        audio_tokens = seq[seq.index(SAT)+1:seq.index(EAT)]
        print(seq.index(SAT), seq.index(EAT), len(audio_tokens))
        audio_tokens = torch.tensor(audio_tokens).cuda()
        audio_tokens = rearrange(
            audio_tokens,
            '(t q) -> t q',
            q=CODEBOOK,
            t=audio_tokens.size(0) // CODEBOOK)
        print("audio_tokens.shape:", audio_tokens, audio_tokens.shape)
        decode_to_file(audio_tokens, audio_path)
        return True
    else:
        return False

## PhoneLM - LJSpeech

### Train

In [10]:
device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PhoneLM(
    n_phone_tokens=len(dataset.phone_dict),
    n_audio_tokens=1024).to(device)

model.megabyte.get_num_params()

number of parameters: 37.30M


37302863

In [11]:
item = next(iter(train_loader))[0]
item_phone_tokens = item[-2]
item_audio_tokens = item[-1]
item_phone_tokens.shape, item_audio_tokens.shape

(torch.Size([121]), torch.Size([1, 2, 660]))

In [27]:
item

(WindowsPath('data/LJSpeech/LJSpeech-1.1/wavs/LJ016-0073.wav'),
 tensor([[-9.1553e-05,  0.0000e+00,  9.1553e-05,  ..., -6.1035e-05,
          -5.4932e-04, -2.4414e-04]]),
 22050,
 'Mr. Cope, the governor of Newgate, having been communicated with, proceeded to Winchester, where he at once identified Williams.',
 'Mr. Cope, the governor of Newgate, having been communicated with, proceeded to Winchester, where he at once identified Williams.',
 ['M',
  'IH1',
  'S',
  'T',
  'ER0',
  '_',
  '_',
  '_',
  'K',
  'OW1',
  'P',
  '_',
  '_',
  '_',
  'DH',
  'AH0',
  '_',
  'G',
  'AH1',
  'V',
  'ER0',
  'N',
  'ER0',
  '_',
  'AH1',
  'V',
  '_',
  'N',
  'UW1',
  'G',
  'EY0',
  'T',
  '_',
  '_',
  '_',
  'HH',
  'AE1',
  'V',
  'IH0',
  'NG',
  '_',
  'B',
  'IH1',
  'N',
  '_',
  'K',
  'AH0',
  'M',
  'Y',
  'UW1',
  'N',
  'AH0',
  'K',
  'EY2',
  'T',
  'IH0',
  'D',
  '_',
  'W',
  'IH1',
  'DH',
  '_',
  '_',
  '_',
  'P',
  'R',
  'AH0',
  'S',
  'IY1',
  'D',
  'AH0',
  'D',
  '

In [12]:
phone_prompt, audio_target, test_inp = multi_encode(
    item_phone_tokens,
    item_audio_tokens,
    n_phone_tokens=len(dataset.phone_dict),
    n_audio_tokens=1024)
test_inp.shape

phone_tokens.shape: torch.Size([121])
audio_tokens.shape: torch.Size([150])


torch.Size([275])

### Training Process

In [13]:
from tqdm.notebook import tqdm

In [14]:
import torch.optim as optim

epochs = 10

MAX_LR       = 1e-2
# MAX_LR       = 1e-2
WEIGHT_DECAY = 1e-4
GRAD_CLIP    = 0.1

optimizer = optim.Adam(
    model.parameters(),
    lr=MAX_LR)
    #,weight_decay=WEIGHT_DECAY)

# def get_lr(optimizer):
#     for param_group in optimizer.param_groups:
#         return param_group['lr']

# sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, MAX_LR, epochs=epochs, 
#                                                 steps_per_epoch=len(trainloader))

In [15]:
test_inp.dtype

torch.int64

In [16]:
test_inp

tensor([1099, 1071, 1063, 1082, 1084, 1053, 1098, 1098, 1098, 1069, 1075, 1080,
        1098, 1098, 1098, 1049, 1034, 1098, 1060, 1035, 1093, 1053, 1072, 1053,
        1098, 1035, 1093, 1098, 1072, 1091, 1060, 1056, 1084, 1098, 1098, 1098,
        1061, 1032, 1093, 1062, 1073, 1098, 1046, 1063, 1072, 1098, 1069, 1034,
        1071, 1095, 1091, 1072, 1034, 1069, 1058, 1084, 1062, 1048, 1098, 1094,
        1063, 1049, 1098, 1098, 1098, 1080, 1081, 1034, 1082, 1066, 1048, 1034,
        1048, 1098, 1084, 1091, 1098, 1094, 1063, 1072, 1047, 1052, 1082, 1084,
        1053, 1098, 1098, 1098, 1094, 1051, 1081, 1098, 1061, 1066, 1098, 1032,
        1084, 1098, 1094, 1035, 1072, 1082, 1098, 1043, 1048, 1051, 1072, 1084,
        1034, 1059, 1045, 1048, 1098, 1094, 1063, 1070, 1095, 1034, 1071, 1096,
        1098, 1098, 1100, 1101,  408,  976,  860,  388,  540,  901,  373,  574,
         574,  574,   47,  574,  148,  738,  339,  176,  254,  103,  862,  958,
         612, 1011,  472,  475,  475,  7

In [17]:
test_inp.dtype

torch.int64

In [18]:
test_inp.shape

torch.Size([275])

In [19]:
import torch.nn.functional as F

EPOCHS = 1000
PRINT_INTERVAL = 100

seq_len = 512

def train(model, trainloader):
    model.train()
    
    padding_len = max(0, seq_len - test_inp.size(0))
    n_test_inp = F.pad(test_inp, (0, padding_len))
    batch = n_test_inp.unsqueeze(0)
    # print(batch.shape)
    loss = model(batch, return_loss=True)
    # loss = model(next(trainloader), return_loss=True)
    loss.backward()
    return loss

# pbar = tqdm.tqdm(EPOCHS, mininterval=10., desc='training')
for epoch in range(EPOCHS):
    loss = train(model, train_loader)
    optimizer.step()
    optimizer.zero_grad()
    mem_gb = get_reserved_mem_gb()
    if epoch % PRINT_INTERVAL == 0:
        print(f"Reserved Memory (GB): {mem_gb}, loss: {loss.item()}")
    #' pbar.set_description(f"Reserved Memory (GB): {mem_gb}, loss: {loss.item()}")

Reserved Memory (GB): 0.99609375, loss: 7.006810188293457


Reserved Memory (GB): 1.017578125, loss: 4.011221408843994
Reserved Memory (GB): 1.017578125, loss: 0.15099826455116272
Reserved Memory (GB): 1.017578125, loss: 0.04681147634983063
Reserved Memory (GB): 1.017578125, loss: 0.016797270625829697
Reserved Memory (GB): 1.017578125, loss: 0.016020676121115685
Reserved Memory (GB): 1.017578125, loss: 0.016087962314486504
Reserved Memory (GB): 1.017578125, loss: 0.01562683843076229
Reserved Memory (GB): 1.017578125, loss: 0.015517176128923893
Reserved Memory (GB): 1.017578125, loss: 0.015402528457343578


### Evaluate

In [20]:
phone_prompt

tensor([1099, 1071, 1063, 1082, 1084, 1053, 1098, 1098, 1098, 1069, 1075, 1080,
        1098, 1098, 1098, 1049, 1034, 1098, 1060, 1035, 1093, 1053, 1072, 1053,
        1098, 1035, 1093, 1098, 1072, 1091, 1060, 1056, 1084, 1098, 1098, 1098,
        1061, 1032, 1093, 1062, 1073, 1098, 1046, 1063, 1072, 1098, 1069, 1034,
        1071, 1095, 1091, 1072, 1034, 1069, 1058, 1084, 1062, 1048, 1098, 1094,
        1063, 1049, 1098, 1098, 1098, 1080, 1081, 1034, 1082, 1066, 1048, 1034,
        1048, 1098, 1084, 1091, 1098, 1094, 1063, 1072, 1047, 1052, 1082, 1084,
        1053, 1098, 1098, 1098, 1094, 1051, 1081, 1098, 1061, 1066, 1098, 1032,
        1084, 1098, 1094, 1035, 1072, 1082, 1098, 1043, 1048, 1051, 1072, 1084,
        1034, 1059, 1045, 1048, 1098, 1094, 1063, 1070, 1095, 1034, 1071, 1096,
        1098, 1098, 1100], device='cuda:0')

In [21]:
audio_target

tensor([1101,  408,  976,  860,  388,  540,  901,  373,  574,  574,  574,   47,
         574,  148,  738,  339,  176,  254,  103,  862,  958,  612, 1011,  472,
         475,  475,  779,  855,  835,  835,  106,  405,  213, 1014,  798,  537,
         887,  575,  575,  504,  288,  755,  259,  837,  291,  808,  942,  921,
         291,  155,  523,   52,   52,  370,   52,  106,  370,  257,  257,  904,
         537,  395,  408,  404,  106,  855,  475,  738,  738,  738,  408,  738,
         106,  408,  408,  408,  913,  877,  252,  418,  792,  674,  991,  160,
        1010,  214,  209,  765,  652,  652,  646,  870, 1010,  420,  624,  486,
         948,  948,  419,  200,  580,  913,  700,  913,  424,  544,  632,  872,
         292,  863,  729,  384,  146,  563,  889,  499,  599,  751,  684,  668,
         283,  379,   30,  593,  128,  230,  687,  984,  984,  928,  913,  924,
         942,  913,  969,  387,  921,  928, 1007, 1007, 1007, 1007,  544,  913,
         518,  424,  518,  913,  424,  4

In [22]:
def generate(model, prompt):
    model.eval()

    prompt = prompt.unsqueeze(0)
    sample = model.generate(prompt)
    sample = sample.flatten(1)
    print("sample:", sample, sample.shape)

    return prompt, sample

prompt, sample = generate(model, phone_prompt)

100%|██████████| 389/389 [00:05<00:00, 75.96it/s]

sample: tensor([[1099, 1071, 1063, 1082, 1084, 1053, 1098, 1098, 1098, 1069, 1075, 1080,
         1098, 1098, 1098, 1049, 1034, 1098, 1060, 1035, 1093, 1053, 1072, 1053,
         1098, 1035, 1093, 1098, 1072, 1091, 1060, 1056, 1084, 1098, 1098, 1098,
         1061, 1032, 1093, 1062, 1073, 1098, 1046, 1063, 1072, 1098, 1069, 1034,
         1071, 1095, 1091, 1072, 1034, 1069, 1058, 1084, 1062, 1048, 1098, 1094,
         1063, 1049, 1098, 1098, 1098, 1080, 1081, 1034, 1082, 1066, 1048, 1034,
         1048, 1098, 1084, 1091, 1098, 1094, 1063, 1072, 1047, 1052, 1082, 1084,
         1053, 1098, 1098, 1098, 1094, 1051, 1081, 1098, 1061, 1066, 1098, 1032,
         1084, 1098, 1094, 1035, 1072, 1082, 1098, 1043, 1048, 1051, 1072, 1084,
         1034, 1059, 1045, 1048, 1098, 1094, 1063, 1070, 1095, 1034, 1071, 1096,
         1098, 1098, 1100, 1101,  408,  976,  860,  388,  540,  901,  373,  574,
          574,  574,   47,  574,  148,  738,  339,  176,  254,  103,  862,  958,
          612, 1011,




In [25]:
out = generate_audio(
    sample,
    n_phone_tokens=len(dataset.phone_dict),
    n_audio_tokens=1024)

STT, ETT, SAT, EAT ids: [1099, 1100, 1101, 1102]
seq: [1099, 1071, 1063, 1082, 1084, 1053, 1098, 1098, 1098, 1069, 1075, 1080, 1098, 1098, 1098, 1049, 1034, 1098, 1060, 1035, 1093, 1053, 1072, 1053, 1098, 1035, 1093, 1098, 1072, 1091, 1060, 1056, 1084, 1098, 1098, 1098, 1061, 1032, 1093, 1062, 1073, 1098, 1046, 1063, 1072, 1098, 1069, 1034, 1071, 1095, 1091, 1072, 1034, 1069, 1058, 1084, 1062, 1048, 1098, 1094, 1063, 1049, 1098, 1098, 1098, 1080, 1081, 1034, 1082, 1066, 1048, 1034, 1048, 1098, 1084, 1091, 1098, 1094, 1063, 1072, 1047, 1052, 1082, 1084, 1053, 1098, 1098, 1098, 1094, 1051, 1081, 1098, 1061, 1066, 1098, 1032, 1084, 1098, 1094, 1035, 1072, 1082, 1098, 1043, 1048, 1051, 1072, 1084, 1034, 1059, 1045, 1048, 1098, 1094, 1063, 1070, 1095, 1034, 1071, 1096, 1098, 1098, 1100, 1101, 408, 976, 860, 388, 540, 901, 373, 574, 574, 574, 47, 574, 148, 738, 339, 176, 254, 103, 862, 958, 612, 1011, 472, 475, 475, 779, 855, 835, 835, 106, 405, 213, 1014, 798, 537, 887, 575, 575, 504, 288, 

In [26]:
out

True