In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import project_root

## Pre-requisites

Need to download the single batch 

```
# The following will download the data 50GB - 100GB
# to the directory:
# `~/datasets/seamless_interaction/improvised/dev/0000/`
download_single_batch()
```

to the submodule `seamless-interaction` directory to make it work.

In [None]:
import torch
import numpy as np
from transformers import MimiModel, AutoFeatureExtractor

from IPython.display import Audio, display


In [None]:
#
from seamless_interaction.fs import DatasetConfig, SeamlessInteractionFS
from dataset_utils.duplex import get_data_sample

In [None]:
config = DatasetConfig(label="improvised", split="dev", num_workers=8)
fs = SeamlessInteractionFS(config=config)

pair = ['V00_S0696_I00000375_P0844A', 'V00_S0696_I00000375_P0847']

In [None]:
sample1 = get_data_sample(fs, pair[0])
sample2 = get_data_sample(fs, pair[1])

## 1. Model Initialization

In [None]:
import warnings
warnings.filterwarnings('ignore')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MimiModel.from_pretrained("kyutai/mimi", num_quantizers=8)
feature_extractor = AutoFeatureExtractor.from_pretrained("kyutai/mimi")

if device.type == "cuda":
    model.to(device)


In [None]:
feature_extractor.sampling_rate

## 2. Encoding and Decoding

In [None]:
audio_arr = sample1.get_single_channel_audio(
    rate=feature_extractor.sampling_rate, mask_transcript=True
)

In [None]:
inputs = feature_extractor(raw_audio=audio_arr, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")
inputs = inputs.to(model.device)

In [None]:
encoder_outputs = model.encode(inputs["input_values"])
audio_values = model.decode(encoder_outputs.audio_codes)[0]

In [None]:
reconstructed_audio = audio_values.data.squeeze().cpu().numpy()

In [None]:
print('Reconstructed audio:')
display(Audio(reconstructed_audio, rate=feature_extractor.sampling_rate))
print('Original audio:')
display(Audio(audio_arr, rate=feature_extractor.sampling_rate))

In [None]:
np.linalg.norm(reconstructed_audio - audio_arr) / audio_arr.shape[0]

## 3. Extracting Latent Features from Codes

In [None]:
codes = encoder_outputs.audio_codes
latent = model.quantizer.decode(codes.transpose(0, 1))

In [None]:
# bs, n_quantizers, n_frames
batched_codes = torch.concat([codes, codes])
batched_latents = model.quantizer.decode(batched_codes.transpose(0, 1))

In [None]:
batched_latents.shape

In [None]:
batched_codes.shape