In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
hf_personal = user_secrets.get_secret("personalhuggingface")

In [None]:
!huggingface-cli login --token {hf_personal}

## Generate Model Relavant Data

In [None]:
MODEL_NAME="llama-mimi1.3B-greedy" #Only for naming the output huggingface repo, no slashes allowed
 #Actual Model Path on Huggingfce
assert MODEL_NAME!="temp"

In [None]:
# Model Preparation: All Model Utlilty Go Here

from transformers import AutoModelForCausalLM, AutoTokenizer, MimiModel, AutoFeatureExtractor, StoppingCriteria
import torch
import torch.nn.functional as F
import torchaudio
import re
import requests
import io

def audio_array_to_text(
    audio_array: torch.tensor,
    audio_tokenizer,
    feature_extractor,
    num_quantizers: int,
) -> str:
    inputs = feature_extractor(
        raw_audio=audio_array,
        sampling_rate=feature_extractor.sampling_rate,
        return_tensors="pt",
    ).to(audio_tokenizer.device)
    with torch.no_grad():
        encoder_outputs = audio_tokenizer.encode(
            inputs["input_values"],
            inputs["padding_mask"],
            num_quantizers=num_quantizers,
        )
    flatten_audio_codes = encoder_outputs.audio_codes.transpose(1, 2).reshape(-1)
    assert flatten_audio_codes.numel() % num_quantizers == 0
    steps = []
    for i in range(0, flatten_audio_codes.numel(), num_quantizers):
        group = [
            f"<{flatten_audio_codes[i + j].item()}_{j}>" for j in range(num_quantizers)
        ]
        steps.append(group)

    parts = [tok for step in steps for tok in step]

    text = "".join(parts)

    return f"<audio>{text}</audio>"

def text_to_audio_values(
    text: str,
    num_quantizers: int,
    output_file: str,
    audio_tokenizer,
    feature_extractor,
):
    # Extract (val, idx) pairs from the <val_idx> format in the text
    matches = re.findall(r"<(\d+)_(\d+)>", text)
    vals = []
    for i in range(0, len(matches), num_quantizers):
        chunk = matches[i : i + num_quantizers]
        if len(chunk) < num_quantizers:
            break
        indices = [int(idx) for _, idx in chunk]
        if indices == list(range(num_quantizers)):
            vals.extend(int(val) for val, _ in chunk)
        else:
            break
    vals = vals[: len(vals) - len(vals) % num_quantizers]
    
    tensor_bt4 = torch.tensor(vals).reshape(1, -1, num_quantizers)  # (B, T, 4)
    #print("text_to_audio_values", tensor_bt4)
    
    tensor_b4t = tensor_bt4.transpose(1, 2).to(device)  # (B, 4, T)
    audio_values = audio_tokenizer.decode(tensor_b4t)[0]
    torchaudio.save(
        output_file,
        audio_values[0].detach().cpu(),
        feature_extractor.sampling_rate,
    )
    return audio_values[0].detach().cpu()


class StopOnAudioEnd(StoppingCriteria):
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.target_text = "</audio>"
        self.target_ids = tokenizer(
            self.target_text, add_special_tokens=False
        ).input_ids

    def __call__(self, input_ids, scores, **kwargs):
        if len(input_ids[0]) < len(self.target_ids):
            return False
        return input_ids[0][-len(self.target_ids) :].tolist() == self.target_ids


HF_MODEL_ID="llm-jp/Llama-Mimi-1.3B"
device="cuda" if torch.cuda.is_available() else "cpu"
MODEL = AutoModelForCausalLM.from_pretrained(
    HF_MODEL_ID,
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
).eval().to(device)

TOKENIZER = AutoTokenizer.from_pretrained(HF_MODEL_ID)
NUM_QUANTIZERS = getattr(MODEL.config, "num_quantizers", 4)


MIMI = MimiModel.from_pretrained("kyutai/mimi").to(device).eval()
FEATURE_EXTRACTOR = AutoFeatureExtractor.from_pretrained("kyutai/mimi")

STOPPING_CRITERIA = StopOnAudioEnd(TOKENIZER)

## Implement the following two functions!

In [None]:
@torch.no_grad()
def get_per_token_losses(
    audio_sample
): 
    #return torch.Tensor([3])
    """
    Calculate all loss, given model and audio sample
    """
    wav, sr = torch.Tensor(audio_sample["array"]), audio_sample["sampling_rate"]
   
    if sr != FEATURE_EXTRACTOR.sampling_rate:
        wav = torchaudio.transforms.Resample(
            sr, FEATURE_EXTRACTOR.sampling_rate
        )(wav)
        sr = FEATURE_EXTRACTOR.sampling_rate
    
    txt = audio_array_to_text(wav, MIMI, FEATURE_EXTRACTOR, NUM_QUANTIZERS)
    #print(pos_txt, neg_txt)

    # 3) Tokenize
    input_ids = TOKENIZER(txt, return_tensors="pt").input_ids

    input_ids = input_ids.to(device)

    labels = input_ids.clone()
    labels[:, :-1] = input_ids[:, 1:].clone()
    labels[:, -1] = -100  # don't predict the last token in this chunk

    out = MODEL(input_ids=input_ids)
    logits = out.logits  # (B, T, V)

    loss_all_tokens = F.cross_entropy(
        logits.reshape(-1, logits.size(-1)),
        labels.reshape(-1),
        ignore_index=-100,
        reduction='none',
    )
    return loss_all_tokens

@torch.no_grad()
def generate_continuation_audio(
    audio_sample,
    temperature = 0.8,
    top_k = 1,
    do_sample = False,
    min_length = 144,
    max_length = 240, #around 5s is sufficient
): 
    wav, sr = torch.Tensor(audio_sample["array"]), audio_sample["sampling_rate"]
   
    if sr != FEATURE_EXTRACTOR.sampling_rate:
        wav = torchaudio.transforms.Resample(
            sr, FEATURE_EXTRACTOR.sampling_rate
        )(wav)
        sr = FEATURE_EXTRACTOR.sampling_rate
    
    txt = audio_array_to_text(wav, MIMI, FEATURE_EXTRACTOR, NUM_QUANTIZERS)
    txt = txt.replace("</audio>", "")
    #print(pos_txt, neg_txt)

    # 3) Tokenize
    inputs = TOKENIZER(txt, return_tensors="pt")#.input_ids

    inputs = inputs.to(device)
    
    generated = MODEL.generate(
        **inputs,
        max_new_tokens=max_length,
        min_new_tokens=min_length,
        do_sample=do_sample,
        temperature=temperature,
        top_k=top_k,
        bad_words_ids=[TOKENIZER.convert_tokens_to_ids(["</audio>"])] #forced continuation, do not early stop
        #stopping_criteria=[STOPPING_CRITERIA],
    )
    
    #print("gen shape", generated.shape, "iid shape", inputs.input_ids.shape)
    generated_text = TOKENIZER.decode(generated[0]) + "</audio>" #add stop token at generation end

    #print(txt)
    #print(generated_text)
    #s()
    
    audio_values = text_to_audio_values(
        generated_text,
        num_quantizers=NUM_QUANTIZERS,
        output_file="output.wav",
        audio_tokenizer=MIMI,
        feature_extractor=FEATURE_EXTRACTOR,
    )
    return audio_values

In [None]:
def get_model_features(e):
    # audio is 16000hz, maybe resample 
    e["postive_sample_tokenwise_loss"] = get_per_token_losses(e["positive_audio"])
    e["negative_sample_tokenwise_loss"] = get_per_token_losses(e["negative_audio"])
    if "consistency" in e["task"]:
        e["prompt_sample_tokenwise_loss"] = get_per_token_losses(e["prompt_audio"])
        generated_audio = generate_continuation_audio(e["prompt_audio"])
        e["model_generated_continuation"] = {"sampling_rate": FEATURE_EXTRACTOR.sampling_rate, "array": generated_audio.squeeze().numpy()}
    
    
    e["code_frame_rate"] =  12,
    e["code_depth"] =  4
    e["model_sampling_rate"] = FEATURE_EXTRACTOR.sampling_rate,
    e["ppl_sanity"] = int((e["postive_sample_tokenwise_loss"].mean() < e["negative_sample_tokenwise_loss"].mean()).item()) #sanity check if number same as SALMon, would be rerun with other methods
    print("sample correct:",e["ppl_sanity"])
    return e

In [None]:
from datasets import Audio, load_dataset

splts = ['bg_all_consistency', 'bg_domain_consistency', 'gender_consistency', 'rir_consistency', 'sentiment_consistency', 'speaker_consistency', 'bg_alignment', 'sentiment_alignment']

for splt in splts:
    print(splt)
    ds = load_dataset("SpeechPPL/SALMon_with_meta", splt)
    #ds["train"] = ds["train"].select([1])
    ds = ds.map(get_model_features)
    ds = ds.cast_column("model_generated_continuation", Audio(sampling_rate=FEATURE_EXTRACTOR.sampling_rate))
    #s()
    print("Accuracy:",sum(ds["train"]["ppl_sanity"])/len(ds["train"]))
    
    #break`
    
    
    ds.push_to_hub(f"SpeechPPL/SALMon_{MODEL_NAME}", config_name=splt)

