In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import re
import json

In [None]:
def remove_stress(phoneme):
    return re.sub(r'[0-2]$', '', phoneme)

In [None]:
# Function to preprocess phoneme input from dataset format
def preprocess_phonemes(phoneme_list):
    # Flatten list of phoneme strings and remove stress markers
    phoneme_sequence = []
    for phoneme_str in phoneme_list:
        phonemes = phoneme_str.split()
        phoneme_sequence.extend(remove_stress(p) for p in phonemes)
    
    # Convert to space-separated string (assumed format for phoneme-llama)
    return ' '.join(phoneme_sequence)

In [None]:
# Function to test phoneme-llama with a single phoneme sequence
def test_phoneme_llama(model, tokenizer, phoneme_input, max_length=50):
    try:
        # Tokenize input
        inputs = tokenizer(phoneme_input, return_tensors="pt").to(model.device)
        
        # Generate sentence
        with torch.no_grad():
            outputs = model.generate(
                inputs["input_ids"],
                max_length=max_length,
                num_beams=5,
                no_repeat_ngram_size=2,
                early_stopping=True
            )
        
        # Decode output
        sentence = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return sentence
    except Exception as e:
        return f"Error processing input '{phoneme_input}': {str(e)}"

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load phoneme-llama model and tokenizer
model_name = "bbunzeck/phoneme-llama"
try:
    print(f"Loading model {model_name}...")
    model = AutoModelForCausalLM.from_pretrained(
        model_name, torch_dtype=torch.float16 if device.type == "cuda" else torch.float32
    ).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model.eval()
except Exception as e:
    print(f"Failed to load model: {str(e)}")

In [None]:
# Sample input from your dataset
dataset_sample = {
    "video_id": "YPvP_C4qy0E",
    "chunk_name": "12-1",
    "text": "expensive at all.",
    "phonemes": [
        "IH0 K S P EH1 N S IH0 V",
        "AE1 T"
    ]
}

In [None]:
# Preprocess dataset sample
phoneme_input = preprocess_phonemes(dataset_sample["phonemes"])
print(f"\nTesting dataset sample:")
print(f"Input phonemes: {phoneme_input}")
print(f"Expected text: {dataset_sample['text']}")
result = test_phoneme_llama(model, tokenizer, phoneme_input)
print(f"Generated sentence: {result}")