In [None]:
import sys
from pathlib import Path

project_path = Path.cwd().parent

sys.path.append(str(project_path.resolve()))

In [None]:
import json
from src.dataset.load_data_soda import SODADataLoader
from src.utils.inferencing import HFModelForInferencing

In [None]:
soda_dataset_obj = SODADataLoader(
    data_types=['train'],
    samples_per_split=10,
    min_story_length=20,
    max_story_length=250,
    join_dialogue_and_speakers=True,
    add_characters_in_narrative=True,
    add_turns_count_in_narrative=True
)
soda_ds = soda_dataset_obj.dataset

d_type = list(soda_ds.keys())[0]

In [None]:
with open('/kaggle/working/story-to-dialgoue/config/model_details.json', 'r') as file:
    model_data = json.load(file)

In [None]:
gen_models_obj = {}

for model_type, train_data in model_data.items():
    gen_models_obj[model_type] = {}
    for train_type, train_details in train_data.items():
        if train_details['hf-ft-model-path'].split('-')[-1] == 'LoRA':
            gen_models_obj[model_type][train_type] = HFModelForInferencing(
                hf_model_repo_name=train_details['hf-org-model-path'],
                is_lora=True,
                peft_model_repo_name=train_details['hf-ft-model-path'],
                hf_commit_hash=train_details['hf-commit-id']
            )
        else:
            gen_models_obj[model_type][train_type] = HFModelForInferencing(
                hf_model_repo_name=train_details['hf-ft-model-path'],
                hf_commit_hash=train_details['hf-commit-id']
            )

In [None]:
i = 3

narrative = soda_ds[d_type][i]['narrative'].split("\n")[0]
actual_dialogue = soda_ds[d_type][i]['dialogue']
characters = soda_ds[d_type][i]['narrative'].split(
    '\n')[1].split(':')[-1].replace('.', '').split(',')
characters = [c.strip() for c in characters]

print("Narrative:", narrative, "-" * 50, sep="\n")
print("Characters:", characters, "-" * 50, sep="\n")
print("Actual Dialogue:", actual_dialogue, "-" * 50, sep="\n")

In [None]:
gen_output = {}

for model_type, train_data in gen_models_obj.items():
    gen_output[model_type] = {}
    prefix_prompt = "generate dialogue: " if model_type.split(
        '-')[0] == "T5" else None
    for train_type, gen_obj in train_data.items():
        if train_type == "turn-by-turn":
            gen_output[model_type][train_type] = gen_obj.generate_dialogue(
                input_text=narrative,
                tokenizer_max_length=900,
                prefix_prompt=prefix_prompt,
                gen_turn_by_turn=True,
                max_turns=5,
                characters=characters,
                generation_kwargs={
                    "max_new_tokens": 128,
                    "no_repeat_ngram_size": 3,
                    "repetition_penalty": 1.2,
                    "do_sample": True,
                    "temperature": 0.7,
                    "top_p": 0.9
                }
            )
        else:
            gen_output[model_type][train_type] = gen_obj.generate_dialogue(
                input_text=soda_ds[d_type][i]['narrative'],
                tokenizer_max_length=128,
                prefix_prompt=prefix_prompt,
                characters=characters,
                generation_kwargs={
                    "max_new_tokens": 900,
                    "no_repeat_ngram_size": 3,
                    "repetition_penalty": 1.2,
                    "do_sample": True,
                    "temperature": 0.7,
                    "top_p": 0.9
                }
            )

In [None]:
print("\nGenerated Dialogues:\n")

for model_type, train_data in gen_output.items():
    for train_type, gen_obj in train_data.items():
        print("-" * 50, f"Model Type: {model_type}",
              f"Train Type: {train_type}", "-" * 50, sep="\n")
        print(gen_output[model_type][train_type])