In [1]:
#%pip install fairseq
import torch
import fairseq
import torch.nn.functional as F
import numpy as np
import os
from torch.hub import download_url_to_file

2023-04-26 00:30:42 | INFO | fairseq.tasks.text_to_speech | Please install tensorboardX: pip install tensorboardX


This notebook embeds a synthetic dataset of wav files <-> semantic token mappings and rederives the "means" for each token, allowing for later routine k-means inference of "ground truth" audio into semantic prompts. This allows for voice cloning and so giving voice to the dead, destroying consensus reality and creating misinformation, destroying copyright, spreading systemic bias, and other similarly enjoyable ways to spend a Saturday afternoon.

## Load HuBERT-formatted dataset

NOTE: This guide presumes you auto-generated a synthetic dataset of examples using the `create_dataset` notebook earlier.

In [32]:
from fairseq.data import HubertDataset
from bark.generation import SEMANTIC_RATE_HZ, SEMANTIC_PAD_TOKEN

dataset_dir = os.path.join("../datasets/", "en")
manifest = f"{dataset_dir}/manifest.tsv"
pad_list = [SEMANTIC_PAD_TOKEN]
# I'm going to regret this
eos_list = [SEMANTIC_PAD_TOKEN]
paths = [f"{dataset_dir}/labels.txt"]

import logging
logging.basicConfig(level=logging.DEBUG)


dataset = HubertDataset(
    manifest,
    sample_rate=16_000,
    label_paths=paths,
    label_rates=SEMANTIC_RATE_HZ,
    pad_list=pad_list,
    eos_list=eos_list,
)


2023-04-26 02:34:24 | INFO | fairseq.data.audio.hubert_dataset | max_keep=None, min_keep=None, loaded 16, skipped 0 short and 0 long, longest-loaded=160000, shortest-loaded=26454
2023-04-26 02:34:24 | INFO | fairseq.data.audio.hubert_dataset | pad_audio=False, random_crop=False, normalize=False, max_sample_size=9223372036854775807


## Infer features


In [23]:
if not os.path.exists("../models/hubert_base_ls960.pt"):
    # Yes, hard-coding the URL of the model is jank. Too bad!
    # Update this if this changes! https://github.com/facebookresearch/textlesslib/blob/698e6a039375bac0cd5f1b8683beeec5e8f702c0/textless/checkpoint_manager/__init__.py#L20
    download_url_to_file("https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt", "../models/hubert_base_ls960.pt")

In [24]:
# Copy-pasted from textlesslib
class HubertFeatureReader(torch.nn.Module):
    def __init__(
        self, checkpoint_path, layer=6, max_chunk=100 * 16_000, lazy_load=False
    ):
        super().__init__()
        # NB: fairseq doesn't support pathlib.Path
        self.checkpoint_path = str(checkpoint_path)
        self.should_normalize = False
        self.lazy_load = lazy_load
        self.model = None
        self.layer = layer
        self.max_chunk = max_chunk
        # this is useful for determining the device
        self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float))
        if not self.lazy_load:
            self.load_checkpoint_()

    @torch.no_grad()  # otherwise some non-leaf nodes appear which breaks serialization
    def load_checkpoint_(self):
        model, _, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
            [self.checkpoint_path]
        )
        self.model = model[0].eval()
        self.model = self.model.to(self.device)
        for parameter in self.model.parameters():
            parameter.requires_grad_(False)

        self.should_normalize = task.cfg.normalize

    @property
    def device(self):
        return self._float_tensor.device

    @property
    def code_hop_size(self) -> int:
        return 320

    @property
    def expected_sample_rate(self) -> int:
        return 16_000

    def forward(self, x):
        if self.lazy_load and self.model is None:
            self.load_checkpoint_()

        return self.get_features(x)

    @torch.inference_mode()
    def get_features(self, x):
        x = x.to(self.device)
        if self.should_normalize:
            x = F.layer_norm(x, x.shape)
        x = x.view(1, -1)

        feat = []
        for start in range(0, x.size(1), self.max_chunk):
            x_chunk = x[:, start : start + self.max_chunk]
            feat_chunk, _ = self.model.extract_features(
                source=x_chunk,
                padding_mask=None,
                mask=False,
                output_layer=self.layer,
            )
            feat.append(feat_chunk)
        return torch.cat(feat, 1).squeeze(0).cpu()

In [33]:
print(dataset.audio_root)

/home/ritsuko/projects/ai/audio/bark/datasets/en


In [79]:
from torch.utils.data import DataLoader, BatchSampler, RandomSampler

class CustomBatchSampler(BatchSampler):
    def __init__(self, sampler, batch_size, collater, drop_last):
        super().__init__(sampler, batch_size, drop_last)
        self.collater = collater

    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                sources = [self.sampler.data_source[i] for i in batch]
                # Ungodly hack, fix this later
                sources = list(map(
                    lambda s: {
                        **s, 
                        "label_list": [torch.tensor(list(map(int, s["label_list"][0].split(" "))))]
                    }, 
                    sources
                    ))
                print(sources[0])
                yield self.collater(sources)
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield self.collater([self.sampler.data_source[i] for i in batch])

In [80]:
from torch.utils.data import DataLoader, RandomSampler

batch_size = 8

sampler = RandomSampler(dataset)
custom_batch_sampler = CustomBatchSampler(sampler, batch_size, dataset.collater, drop_last=False)
dataloader = DataLoader(dataset, batch_sampler=custom_batch_sampler)
for batch_idx, items in enumerate(dataloader):
    print(items)
    break




{'id': 15, 'source': tensor([-3.6678e-03, -5.1473e-03, -4.3893e-03,  ...,  9.4807e-04,
         4.8220e-05, -7.6926e-04]), 'label_list': [tensor([9863, 8840, 8840, 4262,  206,  206,  186, 2371, 2371,   10, 2371,  206,
        2371,  186,  147,  239,   10, 9400, 3511, 2626, 2000,  560, 7551, 1369,
         302,   10,  608, 1134, 7509, 6670, 9095,  366,   50,   10,    5, 1041,
        6107, 1026, 1026, 4924,  648,  171,  429,  402,   41,   38,   38, 5505,
        5760, 6245, 6245, 4738, 4738, 4683, 3221,  232,   10,   27, 3621,  181,
         107,  417, 7352, 7439, 5337,  117,   17, 4608, 4077, 3376,   45,   10,
          41,  429, 1463, 3840, 3890, 9623, 6428,  122,   59,   28,   28,   28,
        5264, 8045, 1291, 3745, 2441,  166, 3363,   10,    5,    5, 2837, 2651,
        4158, 6444,   17,   17, 4483, 2947, 9285,  555, 9285,  555, 8870, 2682,
        2682,  259, 2682, 1487,   26,   26,   26,   26,  175,  736, 8665, 8844,
        4213, 5128, 5750,  273, 7692, 5265,  117,   41,   10, 

TypeError: list indices must be integers or slices, not str