In [48]:
from datasets import load_dataset, Audio

SAMPLING_RATE=24_000
# Load the LJ Speech dataset
dataset = load_dataset("MikhailT/lj-speech")
dataset = dataset.cast_column("audio", Audio(sampling_rate=SAMPLING_RATE))
dataset = dataset.with_format("torch")
len(dataset["full"])

In [2]:
from transformers import MimiModel, AutoFeatureExtractor

feature_extractor = AutoFeatureExtractor.from_pretrained("kyutai/mimi")
model = MimiModel.from_pretrained("kyutai/mimi")
model = model.to("cuda")


  self.register_buffer("padding_total", torch.tensor(kernel_size - stride, dtype=torch.int64), persistent=False)


In [69]:
from torch.nn.utils.rnn import pad_sequence
import math

def get_target_length(arr: torch.Tensor) -> int:
    return math.ceil(arr.size(-1) / (SAMPLING_RATE / 12.5))

def batch_wav_encoder(batch_dict) -> torch.Tensor:
    batch = batch_dict["audio"]
    target_lengths = [get_target_length(sample["array"]) for sample in batch]
    max_length = max(target_lengths)
    # print(f"Padding to {max_length} frames")
    padded_batch = pad_sequence([sample["array"] for sample in batch], batch_first=True).unsqueeze(1)
    # print(f"Encoding tensor of shape {padded_batch.shape}")

    encoder_outputs = model.encode(padded_batch.to("cuda"))
    encoder_outputs = encoder_outputs.audio_codes[:,0:8,:].cpu()

    padded_batch = padded_batch.cpu()
    del padded_batch
    # audio = audio.cpu()  # Moves the tensor to CPU
    # del audio  # Deletes the tensor reference
    torch.cuda.empty_cache()  # Clears any cached memory
    chunked = list(torch.unbind(encoder_outputs, dim=0))
    output = [
        t[:, :length] for t, length in zip(chunked, target_lengths)
    ]

    return { "codes": output }

first_item = dataset["full"][0:16]
foo = batch_wav_encoder(first_item)["codes"]
foo[9].shape
# print(first_item)

torch.Size([8, 111])

In [71]:
dataset = dataset.map(batch_wav_encoder, batched=True, batch_size=24)

Map: 100%|██████████| 13100/13100 [01:53<00:00, 115.83 examples/s]


In [73]:
dataset = dataset.remove_columns("audio")

In [78]:
dataset.save_to_disk("encoded_dataset")

Saving the dataset (1/1 shards): 100%|██████████| 13100/13100 [00:00<00:00, 287139.97 examples/s]
