# Audio Flamingo Inference Notebook

This notebook demonstrates how to use the Audio Flamingo model for inference on audio files with various prompts.

In [None]:
# Install required packages
!pip install torch transformers yaml

In [None]:
import os
import yaml
import torch
from transformers import AutoTokenizer, set_seed
from google.colab import files

# You'll need to implement these imports based on your actual project structure
# from src.factory import create_model_and_transforms
# from data import AudioTextDataProcessor

set_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Upload your config file
uploaded = files.upload()
config_file = next(iter(uploaded))

# Load configuration
with open(config_file, 'r') as f:
    config = yaml.safe_load(f)

clap_config = config['clap_config']
model_config = config['model_config']

In [None]:
def prepare_tokenizer(model_config):
    tokenizer_path = model_config['tokenizer_path']
    cache_dir = model_config['cache_dir']
    text_tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_path,
        local_files_only=False,
        trust_remote_code=True,
        cache_dir=cache_dir,
    )
    text_tokenizer.add_special_tokens(
        {"additional_special_tokens": ["<audio>", "<|endofchunk|>"]}
    )
    if text_tokenizer.pad_token is None:
        text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
    if text_tokenizer.sep_token is None:
        text_tokenizer.add_special_tokens({"sep_token": "<SEP>"})
    return text_tokenizer

tokenizer = prepare_tokenizer(model_config)

In [None]:
# Upload your model checkpoint
uploaded = files.upload()
checkpoint_path = next(iter(uploaded))

def prepare_model(model_config, clap_config, checkpoint_path):
    model, _ = create_model_and_transforms(
        **model_config,
        clap_config=clap_config,
        use_local_files=False,
        gradient_checkpointing=False,
        freeze_lm_embeddings=False,
    )
    model.eval()
    model = model.to(device)

    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    model_state_dict = checkpoint["model_state_dict"]
    model_state_dict = {k.replace("module.", ""): v for k, v in model_state_dict.items()}
    model.load_state_dict(model_state_dict, False)

    return model

model = prepare_model(model_config, clap_config, checkpoint_path)

In [None]:
# Set up the data processor
DataProcessor = AudioTextDataProcessor(
    data_root='/home/sherry/Code/flamingo-inference/model_ckpts/datasets',
    clap_config=clap_config,
    tokenizer=tokenizer,
    max_tokens=512,
)

In [None]:
def inference(model, tokenizer, item, processed_item):
    filename, audio_clips, audio_embed_mask, input_ids, attention_mask = processed_item
    audio_clips = audio_clips.to(device, dtype=None, non_blocking=True)
    audio_embed_mask = audio_embed_mask.to(device, dtype=None, non_blocking=True)
    input_ids = input_ids.to(device, dtype=None, non_blocking=True).squeeze()

    eos_token_id = tokenizer.eos_token_id
    
    inference_kwargs = {
        "do_sample": True,
        "top_k": 50,
        "top_p": 0.95,
        "num_return_sequences": 1
    }
    
    outputs = model.generate(
        audio_x=audio_clips.unsqueeze(0),
        audio_x_mask=audio_embed_mask.unsqueeze(0),
        lang_x=input_ids.unsqueeze(0),
        eos_token_id=eos_token_id,
        max_new_tokens=128,
        **inference_kwargs,
    )

    outputs_decoded = [
        tokenizer.decode(output).split(tokenizer.sep_token)[-1].replace(tokenizer.eos_token, '').replace(tokenizer.pad_token, '').replace('<|endofchunk|>', '') for output in outputs
    ]

    return outputs_decoded[0]

In [None]:
def process_audio(audio_file, prompt):
    item = {
        'name': audio_file,
        'prefix': "The task is audio analysis.",
        'prompt': prompt
    }
    processed_item = DataProcessor.process(item)
    response = inference(model, tokenizer, item, processed_item)
    return response

In [None]:
# Upload an audio file
uploaded = files.upload()
audio_file = next(iter(uploaded))

# Example usage
prompt = "Describe the sound in this audio file."
response = process_audio(audio_file, prompt)
print(f"Prompt: {prompt}")
print(f"Response: {response}")

In [None]:
# Interactive cell for trying different prompts
while True:
    prompt = input("Enter your prompt (or 'q' to quit): ")
    if prompt.lower() == 'q':
        break
    response = process_audio(audio_file, prompt)
    print(f"Response: {response}\n")