In [None]:
!pip install -q piper-tts pathvalidate soundfile librosa datasets transformers accelerate

In [None]:
!python -m piper.download_voices en_US-lessac-medium

In [None]:
import os
import json
import subprocess
import torch
import soundfile as sf
import librosa
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration

# ================= CONFIG =================

DATASET_PATH = "context_situated_pun.csv"   # local file
DATASET_SPLIT = "train"

OUT_PATH = "cache/pun_explanations_qwen_audio.jsonl"

MODEL_ID = "Qwen/Qwen2-Audio-7B-Instruct"
MAX_NEW_TOKENS = 120
MAX_ITEMS = 500

# TTS
PIPER_MODEL = "/content/piper_models/en_US-lessac-medium.onnx"
AUDIO_DIR = "cache/pun_tts"
AUDIO_EXT = ".wav"

os.makedirs(AUDIO_DIR, exist_ok=True)
os.makedirs("cache", exist_ok=True)

# ================= PROMPT =================

def build_messages(text):
    return [
        {
            "role": "system",
            "content": "You are an expert linguist."
        },
        {
            "role": "user",
            "content": f"""Explain whether the following text contains a pun.

You are given the written text and its spoken audio.

<Audio>
<|AUDIO|>
</Audio>

Instructions:
- Do NOT explain your analysis process.
- Do NOT define what a pun is.
- Focus ONLY on the linguistic mechanism.
- If the text is a pun, clearly state:
  • the word or phrase involved
  • the two meanings or sound-based ambiguity
- If it is not a pun, clearly state that no wordplay or ambiguity is present.

Write a concise paragraph (3–6 sentences).

Text:
{text}
"""
        }
    ]

# ================= HELPERS =================

def normalize_id(idx):
    return f"pun_{idx}"

def load_audio(path, target_sr=16000):
    wav, sr = sf.read(path)
    if sr != target_sr:
        wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr)
    return wav

def generate_tts(text, uid):
    out_wav = os.path.join(AUDIO_DIR, uid + AUDIO_EXT)

    if os.path.exists(out_wav) and os.path.getsize(out_wav) > 1000:
        return True

    p = subprocess.run(
        [
            "piper",
            "--model", PIPER_MODEL,
            "--output_file", out_wav,
        ],
        input=text + "\n",
        text=True,
        capture_output=True,
    )

    return p.returncode == 0 and os.path.exists(out_wav)

def valid_user_pun(item):
    raw = item.get("user_pun")
    if raw is None:
        return None
    text = str(raw).strip()
    if not text or text in {"{}", "{ }", "null", "None"}:
        return None
    return text

# ================= MAIN =================

def main():

    device = "cuda"
    torch.set_grad_enabled(False)

    print(f"Loading model: {MODEL_ID}")
    processor = AutoProcessor.from_pretrained(MODEL_ID)
    model = Qwen2AudioForConditionalGeneration.from_pretrained(
        MODEL_ID,
        device_map="auto",
        torch_dtype=torch.float16,
    ).eval()

    print(f"Loading local dataset: {DATASET_PATH}")
    ds = load_dataset(
        "csv" if DATASET_PATH.endswith(".csv") else "json",
        data_files=DATASET_PATH,
        split=DATASET_SPLIT,
    )

    # ---- Phase A: TTS ----
    print("=== Generating TTS ===")
    tts_count = 0
    for idx, item in tqdm(enumerate(ds), total=len(ds), desc="TTS"):
        if tts_count >= MAX_ITEMS:
            break

        text = valid_user_pun(item)
        if not text:
            continue

        uid = normalize_id(idx)
        if generate_tts(text, uid):
            tts_count += 1

    # ---- Phase B: Inference ----
    def generate(text, uid):
        messages = build_messages(text)

        prompt = processor.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )

        audio = load_audio(os.path.join(AUDIO_DIR, uid + AUDIO_EXT))

        inputs = processor(
            text=prompt,
            audio=audio,
            sampling_rate=16000,
            return_tensors="pt",
            padding=True,
        ).to(device)

        with torch.no_grad():
            out = model.generate(
                **inputs,
                max_new_tokens=MAX_NEW_TOKENS,
                min_new_tokens=40,
                do_sample=False,
                pad_token_id=processor.tokenizer.eos_token_id,
            )

        gen_tokens = out[0][inputs["input_ids"].shape[1]:]
        return processor.tokenizer.decode(
            gen_tokens,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True,
        ).strip()

    print("=== Explaining texts (Text + Audio) ===")
    count = 0
    with open(OUT_PATH, "w", encoding="utf-8") as f:
        for idx, item in tqdm(enumerate(ds), total=len(ds), desc="Inference"):
            if count >= MAX_ITEMS:
                break

            text = valid_user_pun(item)
            if not text:
                continue

            uid = normalize_id(idx)
            wav = os.path.join(AUDIO_DIR, uid + AUDIO_EXT)
            if not os.path.exists(wav):
                continue

            explanation = generate(text, uid)

            out_obj = {
                "id": idx,            
                "Explanation": explanation,
            }

            f.write(json.dumps(out_obj, ensure_ascii=False) + "\n")
            f.flush()

            count += 1
            torch.cuda.empty_cache()

    print("Done.")
    print(f"Generated {count} audio-based explanations")
    print(f"Output → {OUT_PATH}")

# ---------------- RUN -----------------

if __name__ == "__main__":
    main()


Loading model: Qwen/Qwen2-Audio-7B-Instruct


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading (incomplete total...): 0.00B [00:00, ?B/s]

Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

Loading weights:   0%|          | 0/876 [00:00<?, ?it/s]

Loading local dataset: context_situated_pun.csv


Generating train split: 0 examples [00:00, ? examples/s]

=== Generating TTS ===


TTS:  19%|█▉        | 886/4551 [18:49<1:17:52,  1.27s/it]


=== Explaining texts (Text + Audio) ===


Inference:   0%|          | 0/4551 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Inference:  19%|█▉        | 886/4551 [34:21<2:22:05,  2.33s/it]

Done.
Generated 500 audio-based explanations
Output → cache/pun_explanations_qwen_audio.jsonl



