In [None]:
import sys
from pathlib import Path

project_path = Path.cwd().parent

sys.path.append(str(project_path.resolve()))

In [None]:
import json
from src.dataset.load_data_soda import SODADataLoader
from src.utils.inferencing import HFModelForInferencing

In [None]:
soda_dataset_obj = SODADataLoader(
    data_types=['train'],
    samples_per_split=10,
    min_story_length=20,
    max_story_length=250,
    join_dialogue_and_speakers=True,
    add_characters_in_narrative=True,
    add_turns_count_in_narrative=True
)
soda_ds = soda_dataset_obj.dataset

d_type = list(soda_ds.keys())[0]

In [None]:
gen_model_obj = HFModelForInferencing(
    hf_model_repo_name='facebook/bart-base',
    is_lora=True,
    peft_model_repo_name='abirmondalind/story2dialogue-SODA-BERT-LoRA',
    hf_commit_hash='5d671360acad574afac587e3b8e893a7e12ed631'
)

In [None]:
i = 3

narrative = soda_ds[d_type][i]['narrative'].split("\n")[0]
actual_dialogue = soda_ds[d_type][i]['dialogue']
characters = soda_ds[d_type][i]['narrative'].split(
    '\n')[1].split(':')[-1].replace('.', '').split(',')
characters = [c.strip() for c in characters]

print("Narrative:", narrative, "-" * 50, sep="\n")
print("Characters:", characters, "-" * 50, sep="\n")
print("Actual Dialogue:", actual_dialogue, "-" * 50, sep="\n")

In [None]:
prefix_prompt = "" # Add prefix prompt for T5 model

gen_output = gen_model_obj.generate_dialogue(
    input_text=narrative,
    characters=characters,
    prefix_prompt=prefix_prompt,
    tokenizer_max_length=900,
    gen_turn_by_turn=True,
    max_turns=5,
    generation_kwargs={
        "max_new_tokens": 128,
        "no_repeat_ngram_size": 3,
        "repetition_penalty": 1.2,
        "do_sample": True,
        "temperature": 0.7,
        "top_p": 0.9
    }
)

print("Generated Dialogue:", gen_output, "-" * 50, sep="\n")