In [None]:


from datasets import load_dataset
from transformers import WhisperFeatureExtractor, WhisperTokenizer
MAX_DURATION = 30
prompt_prefix, prompt_suffix = """Given the following dialogue and audio, assign the last utterance one or more of the following tags (delimited by commas):
'Accusation', 'Defense', 'Evidence', 'Identity Declaration', 'Interrogation', 'No Strategy'

```
""", """
```

Reminder - Assign one or more of the following tags to the last utterance (delimited by commas):
'Accusation', 'Defense', 'Evidence', 'Identity Declaration', 'Interrogation', 'No Strategy'

Assignment:
"""

def load_werewolf_data():
    return load_dataset("iohadrubin/werewolf_dialogue_data_10sec")

def filter_data(werewolf_data):
    def filter_fn(x):
        duration_valid = x["end"] - x["start"] <= MAX_DURATION
        if not duration_valid:
            return False
        has_target = x["dialogue"][-1]["target"] is not None 
        if not has_target:
            return False
        target_not_empty = len(x["dialogue"][-1]["target"].strip()) > 0
        if not target_not_empty:
            return False
        if len(x["dialogue"]) > 200:
            return False
        return True
    
    werewolf_data = werewolf_data.filter(filter_fn)
    return werewolf_data

def create_into_prompt_completion_fn(model_name, max_length=447):
    tokenizer = WhisperTokenizer.from_pretrained(model_name)
    def into_prompt_completion(sample):
        i=0
        while True:
            curr_dialogue = sample["dialogue"][i:]
            if len(curr_dialogue)==0:
                return {"prompt": "", "completion": ""}
            assert len(curr_dialogue)>0
            dialogue = "\n".join(f"{x['speaker']}: {x['utterance']}" for x in curr_dialogue)
            target = sample["dialogue"][-1]["target"]
            prompt_completion = prompt_prefix + dialogue + prompt_suffix + target
            input_ids = tokenizer.encode(prompt_completion)
            if len(input_ids) <= max_length:
                break
            i+=1
        prompt = prompt_prefix + dialogue + prompt_suffix
        
        
        return {"prompt": prompt, "completion": target}
    return into_prompt_completion


model_name = "openai/whisper-small"
werewolf_data = load_werewolf_data()
werewolf_data = filter_data(werewolf_data)

into_prompt_completion = create_into_prompt_completion_fn(model_name)
werewolf_data = werewolf_data.map(into_prompt_completion,num_proc=50)
werewolf_data = werewolf_data.filter(lambda x: len(x["prompt"]) > 0)
werewolf_data.push_to_hub("iohadrubin/werewolf_dialogue_data_10sec_v2")

