In [4]:
import torch
import torchaudio
from transformers import MimiModel, AutoFeatureExtractor
from torchview import draw_graph

In [5]:
audio_path = "./music.mp3"

# Load the audio file
waveform_orig, original_sampling_rate = torchaudio.load(audio_path)
num_samples = waveform_orig.shape[1]
print("sample count: ", num_samples)
print(f"Audio sampled at {original_sampling_rate}S/s. lenght: {num_samples / original_sampling_rate}s")

sample count:  21046753
Audio sampled at 48000S/s. lenght: 438.4740208333333s


In [None]:
# load the model + feature extractor (for pre-processing the audio)
model = MimiModel.from_pretrained("kyutai/mimi")
feature_extractor = AutoFeatureExtractor.from_pretrained("kyutai/mimi")

# Define the path to your local wav file
# I'll use a file from your project as an example


sample count:  21046753
Audio sampled at 48000S/s. lenght: 438.4740208333333s


In [None]:
waveform_orig = waveform_orig[0, :].unsqueeze(0)
waveform_orig.shape

In [None]:
sample_seconds = 5
skip_seconds = 15
skip_samples = skip_seconds * original_sampling_rate
waveform = waveform_orig[:, skip_samples:skip_samples+sample_seconds * original_sampling_rate]
waveform.shape

In [None]:
# Resample the audio to the model's required sampling rate
resampler = torchaudio.transforms.Resample(orig_freq=original_sampling_rate, new_freq=feature_extractor.sampling_rate)
resampled_waveform = resampler(waveform)

# The model expects a 1D array, so we squeeze the tensor
audio_sample = resampled_waveform.squeeze().numpy()
audio_sample.shape

# pre-process the inputs
inputs = feature_extractor(raw_audio=audio_sample, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")


In [None]:
model_graph = draw_graph(model, input_data=inputs).visual_graph
# Line below triggers an error: "RuntimeError: Failed to run torchgraph"
# model_graph.render(directory='/').replace('\\', '/')




In [None]:
# model_graph

In [None]:

# explicitly encode then decode the audio inputs
encoder_outputs = model.encode(inputs["input_values"], num_quantizers=32)
print("encoder out:", encoder_outputs.audio_codes.shape)
print("encoder out:", encoder_outputs.audio_codes)
audio_values_decoded = model.decode(encoder_outputs.audio_codes)[0]

# or the equivalent with a forward pass
audio_values_forward = model(inputs["input_values"]).audio_values

# You can now save, play, or further process the audio_values
# For example, to save the output as a new wav file:
# torchaudio.save("output_audio.wav", audio_values_forward.cpu(), feature_extractor.sampling_rate)

print("Successfully processed the local audio file.")
print("Shape of the output audio from decoded:", audio_values_decoded.shape)
print("Shape of the output audio from forward pass:", audio_values_forward.shape)

In [None]:
from IPython.display import display, Audio

In [None]:
display(Audio(data=audio_sample, rate = feature_extractor.sampling_rate, autoplay=False))

In [None]:
display(Audio(data=audio_values_decoded.detach().numpy().squeeze(0), rate = feature_extractor.sampling_rate, autoplay=False))

## MIMI torch compatible

In [None]:
from huggingface_hub import hf_hub_download
import torch

from moshi.models import loaders, LMGen

In [None]:

mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
mimi = loaders.get_mimi(mimi_weight, device='cpu')
mimi.set_num_codebooks(8)  # up to 32 for mimi, but limited to 8 for moshi.

In [None]:
wav = audio_sample

## WARNING: When streaming, make sure to always feed a total amount of audio that is a multiple
#           of the frame size (1920). You should pad or buffer accordingly. Since version 0.2.5a, 
#           Mimi no longer supports partial frames in streaming mode. Besides, when executing on GPU,
#           you should always pass the same amount of audio, as the calls are CUDAGraphed for efficiency.

with torch.no_grad():
    codes = mimi.encode(wav)  # [B, K = 8, T]
    decoded = mimi.decode(codes)

    # Supports streaming too.
    frame_size = mimi.frame_size
    all_codes = []
    with mimi.streaming(batch_size=1):
        for offset in range(0, wav.shape[-1], frame_size):
            frame = wav[:, :, offset: offset + frame_size]
            codes = mimi.encode(frame)
            assert codes.shape[-1] == 1, codes.shape
            all_codes.append(codes)

In [None]:

# Now if you have a GPU around.
mimi.cuda()
moshi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MOSHI_NAME)
moshi = loaders.get_moshi_lm(moshi_weight, device='cuda')
lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7)  # this handles sampling params etc.
out_wav_chunks = []
# Now we will stream over both Moshi I/O, and decode on the fly with Mimi.
with torch.no_grad(), lm_gen.streaming(1), mimi.streaming(1):
    for idx, code in enumerate(all_codes):
        tokens_out = lm_gen.step(code.cuda())
        # tokens_out is [B, 1 + 8, 1], with tokens_out[:, 1] representing the text token.
        if tokens_out is not None:
            wav_chunk = mimi.decode(tokens_out[:, 1:])
            out_wav_chunks.append(wav_chunk)
        print(idx, end='\r')
out_wav = torch.cat(out_wav_chunks, dim=-1)


# DAC

In [6]:
import dac


In [64]:
(320 * 7.5) / 2400

1.0

In [43]:
waveform_orig = waveform_orig[:, :24000*5]
waveform_orig.shape

torch.Size([2, 120000])

In [44]:
model_path = dac.utils.download(model_type="24khz")
model = dac.DAC.load(model_path)
model.eval()


  WeightNorm.apply(module, name, dim)


DAC(
  (encoder): Encoder(
    (block): Sequential(
      (0): Conv1d(1, 64, kernel_size=(7,), stride=(1,), padding=(3,))
      (1): EncoderBlock(
        (block): Sequential(
          (0): ResidualUnit(
            (block): Sequential(
              (0): Snake1d()
              (1): Conv1d(64, 64, kernel_size=(7,), stride=(1,), padding=(3,))
              (2): Snake1d()
              (3): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
            )
          )
          (1): ResidualUnit(
            (block): Sequential(
              (0): Snake1d()
              (1): Conv1d(64, 64, kernel_size=(7,), stride=(1,), padding=(9,), dilation=(3,))
              (2): Snake1d()
              (3): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
            )
          )
          (2): ResidualUnit(
            (block): Sequential(
              (0): Snake1d()
              (1): Conv1d(64, 64, kernel_size=(7,), stride=(1,), padding=(27,), dilation=(9,))
              (2): Snake1d()
              

In [45]:
resampler = torchaudio.transforms.Resample(
                orig_freq=original_sampling_rate, new_freq=24000
            )
waveform = resampler(waveform_orig)

waveform = torch.mean(waveform, dim=0, keepdim=True)
waveform.shape

torch.Size([1, 60000])

In [46]:
waveform = waveform[:, :320]
waveform.shape

torch.Size([1, 320])

In [48]:
# waveform
# waveform = AudioSignal(waveform.audio_data, waveform.sample_rate)

x = model.preprocess(waveform, 24000).unsqueeze(0)
print(x.shape)
samplesnum = x.shape[2]

z, codes, latents, _, _ = model.encode(x, n_quantizers=16)

torch.Size([1, 1, 320])


In [49]:
codes.shape

torch.Size([1, 16, 1])

In [50]:
frames = codes.shape[2]


In [51]:
samplesnum / frames

320.0

# dataset?

In [43]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch
import time

In [47]:
class Ds(Dataset):
  def __init__(self):
    print("inited")
    ...
  
  def __len__(self):
    return 100

  def __getitem__(self, idx):
    print(torch.utils.data.get_worker_info().id, flush=True)
    return idx

In [50]:
ds = Ds()
dataloader = DataLoader(
    ds,
    batch_size=4,
    shuffle=True,
    num_workers=8,  # Reduced from 8 to 1 to prevent memory issues on GPU server
    persistent_workers=False,  # Don't keep workers alive between epochs
)

inited


In [49]:
for b in range(3):
  for i in dataloader:
    print(i)

23105

4


1
0423
5

162
4




1265
03
7



4
6107
32



65

4
103



7
5324


16
07
53




20


6541
73


0
45
1

27
2

63
4506





3


170236
45



2
1
0


765413
2


0
754





6
0
7
6
0
7

70

0
43162
0
57

4

3
6
12
5
0



430726


1

5
0
347

16



2
305
4
6

41720



3



1
30457

6

2

45
02

137




6
043257





143
7
6
20
34



3167

0
257



0
1

724



20
571



105



05

05

6
6
6
071342


056


4315




042
37
5
6



40

3

456
20
1


76
3

45
7
0

3
416
2



71036
2




475

1032




64

15307




60174

32



7
4

1026
57

1

3
04

2


51
6703


71
2




052
6

5
02
5


0
6
tensor([51, 57,  2, 41])
tensor([56, 99, 95, 68])
tensor([45, 75, 35, 72])
tensor([11,  1,  9, 66])
tensor([10, 26, 90, 36])
tensor([61, 63, 84,  5])
tensor([98, 71, 44, 92])
tensor([27, 85, 19, 58])
tensor([65,  8, 80, 32])
tensor([82,  3, 28, 79])
tensor([38, 62, 25, 12])
tensor([ 7, 59, 93, 54])
tensor([86, 78, 13, 21])
tensor([97, 60, 42, 40])
tensor([ 6, 81, 91, 67])
tensor([18, 88, 52, 24])
