# Let's dive into a spatial sound world.

In [1]:
import os

import numpy as np
import soundfile as sf
from scipy import signal
from IPython.display import Audio

import torch

from dataset.spatial_audio_dataset import format_prompt, SpatialAudioDatasetJsonl

In [2]:
audio_path1 = "./assets/YCqvbWnTBfTk.wav"
audio_path2 = "./assets/Yq4Z8j3IalYs.wav"

print("Let's load and listen to anechoic audio...")
print("Audio 1: Drum; Percussion")
display(Audio(audio_path1))

print("Audio 2: Emergency vehicle; Fire engine, fire truck (siren); Siren")
display(Audio(audio_path2))

Let's load and listen to anechoic audio...
Audio 1: Drum; Percussion


Audio 2: Emergency vehicle; Fire engine, fire truck (siren); Siren


In [3]:
def reverb_waveform(audio_path, reverb_path):
    waveform, sr = sf.read(audio_path)
    if len(waveform.shape) > 1:
        waveform = waveform[:, 0]
    if sr != 32000: #! Please make sure the audio is 32000 Hz
        waveform = signal.resample_poly(waveform, 32000, sr)
        sr = 32000
    waveform = SpatialAudioDatasetJsonl.normalize_audio(waveform, -14.0).reshape(1, -1)
    reverb = np.load(reverb_path)
    waveform = signal.fftconvolve(waveform, reverb, mode='full')
    return waveform, sr

reverb_path1 = "./assets/74.npy"
reverb_path2 = "./assets/75.npy"
"""
{"fname": "q9vSo1VnCiC/74.npy", "agent_position": "-12.8775,0.0801,8.415", "sensor_position": "-12.8775,1.5801,8.415", "source_position": "-13.4677,1.2183,8.6525",},
{"fname": "q9vSo1VnCiC/75.npy", "agent_position": "-12.8775,0.0801,8.415", "sensor_position": "-12.8775,1.5801,8.415", "source_position": "-11.8976,1.0163,8.9789",}
"""

print("Let's load and listen to reverb audio 1 (w/o mixup)...")
waveform1, _ = reverb_waveform(audio_path1, reverb_path1)
display(Audio(waveform1, rate=32000))

print("Let's load and listen to reverb audio 2 (w/o mixup)...")
waveform2, _ = reverb_waveform(audio_path2, reverb_path2)
display(Audio(waveform2, rate=32000))

Let's load and listen to reverb audio 1 (w/o mixup)...


Let's load and listen to reverb audio 2 (w/o mixup)...


In [4]:
print("Now let's mix them up!")
if waveform1.shape[1] < waveform2.shape[1]:
    waveform2 = waveform2[:, :waveform1.shape[1]]
else:
    waveform1 = waveform1[:, :waveform2.shape[1]]
waveform = (waveform1 + waveform2) / 2
display(Audio(waveform, rate=32000))

Now let's mix them up!


In [5]:
prompts = [
    "What is the distance between the sound of the drum and the sound of the siren?",
    "What is the sound on the right side of the sound of the drum?",
    "Are you able to detect the percussion's sound coming from the left and the emergency vehicle's sounds from the right?",
]

gt_answers = [
    "1.5m",
    "emergency vehicle; fire engine, fire truck (siren); siren",
    "Yes"
]

prompts = [format_prompt(prompt) for prompt in prompts]
print(prompts)

["Based on the audio you've heard, refer to the instruction and provide a response.\n\n### Instruction:\nWhat is the distance between the sound of the drum and the sound of the siren?\n\n### Response:", "Based on the audio you've heard, refer to the instruction and provide a response.\n\n### Instruction:\nWhat is the sound on the right side of the sound of the drum?\n\n### Response:", "Based on the audio you've heard, refer to the instruction and provide a response.\n\n### Instruction:\nAre you able to detect the percussion's sound coming from the left and the emergency vehicle's sounds from the right?\n\n### Response:"]


In [6]:
from omegaconf import OmegaConf
from seld_config import TrainConfig, ModelConfig
from model.slam_model_seld import model_factory

train_config = TrainConfig(
    model_name="BAT",
    batching_strategy="custom",
    num_epochs=1,
    num_workers_dataloader=2,
    use_peft=True,
    freeze_encoder=True,
    freeze_llm=True
)
train_config = OmegaConf.merge(train_config)

model_config = ModelConfig(
    llm_name="llama-2-7b",
    llm_path="https://huggingface.co/meta-llama/Llama-2-7b-hf", # 
    encoder_name="SpatialAST",
    encoder_ckpt="https://huggingface.co/datasets/zhisheng01/SpatialAudio/blob/main/SpatialAST/finetuned.pth", # 
)

kwargs = {
    "decode_log": None,
    "ckpt_path": "https://huggingface.co/datasets/zhisheng01/SpatialAudio/blob/main/BAT/model.pt", # Download it from huggingface
}
model, tokenizer = model_factory(train_config, model_config, **kwargs)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # FIX(MZY): put the whole model to device.
model.to(device)
model.eval()

  from .autonotebook import tqdm as notebook_tqdm
  return self.fget.__get__(instance, owner)()
Loading checkpoint shards:  50%|█████     | 1/2 [00:07<00:07,  7.04s/it]

Loading checkpoint shards: 100%|██████████| 2/2 [00:09<00:00,  4.79s/it]


trainable params: 4,194,304 || all params: 6,742,609,920 || trainable%: 0.0622


slam_model_seld(
  (encoder): BinauralEncoder(
    (patch_embed): PatchEmbed_new(
      (proj): Conv2d(1, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (1-11): 11 x Block(

In [7]:
audio_length = 64 # We use 64 learnable tokens as audio representation
audio_pseudo = torch.full((audio_length,), -1)
waveform = torch.from_numpy(waveform).float()
waveform = SpatialAudioDatasetJsonl.padding(waveform, padding_length=10*32000-waveform.shape[1])

In [8]:
for prompt, answer in zip(prompts, gt_answers):
    input_ids = torch.tensor(tokenizer.encode(prompt), dtype=torch.int64)
    input_ids = torch.cat((audio_pseudo, input_ids))  # [audio, prompt]
    input_ids = input_ids.unsqueeze(0)  # [batch, seq]
    attention_mask = input_ids.ge(-1)
    modality_mask = input_ids.eq(-1)

    model_outputs = model.generate(
        input_ids=input_ids.to(device),
        attention_mask=attention_mask.to(device),
        modality_mask=modality_mask.to(device),
        audio=waveform.unsqueeze(0).to(device),
    )
    output_text = model.tokenizer.batch_decode(model_outputs, add_special_tokens=False, skip_special_tokens=True)
    output_text = output_text[0]
    print(f"Question: {prompt}")
    print(f"Pred: {output_text}")
    print(f"Ground Truth: {answer}")
    print("-------------------------")

Question: Based on the audio you've heard, refer to the instruction and provide a response.

### Instruction:
What is the distance between the sound of the drum and the sound of the siren?

### Response:
Pred: 1.5m
Ground Truth: 1.5m
-------------------------
Question: Based on the audio you've heard, refer to the instruction and provide a response.

### Instruction:
What is the sound on the right side of the sound of the drum?

### Response:
Pred: fire engine, fire truck (siren); emergency vehicle; siren; police car (siren)
Ground Truth: emergency vehicle; fire engine, fire truck (siren); siren
-------------------------
Question: Based on the audio you've heard, refer to the instruction and provide a response.

### Instruction:
Are you able to detect the percussion's sound coming from the left and the emergency vehicle's sounds from the right?

### Response:
Pred: Yes
Ground Truth: Yes
-------------------------
