In [None]:
%pip install torch
%pip install fairscale
%pip install fire
%pip install sentencepiece

In [35]:
import subprocess
import json
import os
from dotenv import load_dotenv

load_dotenv()

LLM_DATA = os.environ['LLM_DATA']

LLAMA_RUNNER = os.environ['LLAMA_RUNNER']
GPT_RUNNER = os.environ['GPT_RUNNER']


'''
LLAMA PIPELINE
'''

def llama(input):

    with open(LLM_DATA, 'w') as file:
        json.dump(input, file, indent=4)

    arguments = ['torchrun', 
                '--nproc_per_node=1', 
                f'{LLAMA_RUNNER}', 
                '--max_seq_len=512', 
                '--max_batch_size=6']
    result = subprocess.run(arguments, capture_output=True, text=True)

    print("Errors:", result.stderr)
    print("Return Code:", result.returncode)

    return result.stdout

def parse_llama_result(data:str, begin_key=None, end_key=None, inclusive_begin=True):
    story = data[data.find('[@RESPONSEBEGIN]'):data.find('[@RESPONSEEND]')]

    # trims to between begin key and end key(if they exist)
    begin_index = 0
    if begin_key is not None:
        begin_index = story.find(begin_key)
        if not inclusive_begin:
            begin_index += len(begin_key)

    end_index = len(story)
    if end_key is not None:
        end_index = story.find(end_key)

    story = story[begin_index:end_index]

    return story

def parse_characters(data:str):
    characters = []
    last_begin = 0
    while(True):
        begin = data.find('<', last_begin)
        print(f'begin: {begin}')
        if begin == -1:
            break
        last_begin = begin + 3

        end = data.find('>', begin)
        characters.append(data[begin + 1:end])
    print(f"found characters: {characters}")
    return characters


'''
GPT PIPELINE
'''

def gpt(input):
    with open(LLM_DATA, 'w') as file:
        json.dump(input, file, indent=4)

    arguments = ["python3", f"{GPT_RUNNER}"]

    result = subprocess.run(arguments, capture_output=True, text=True)

    print("Errors:", result.stderr)
    print("Return Code:", result.returncode)

    return result.stdout

    

In [38]:
def llama_pipeline(prompt):

    # things needed:
    #   characters list
    #   series of captions
    #   series of dialogue and the character that says it

    # obtaining story:
    data = [
        [
            {"role": "user", "content": f"write a story about {prompt}, breaking apart each scene into its own paragraph and label it with [SCENE <number>]. End the past paragraph with [END]"}
        ],
    ]
    story = llama(data)
    story_parsed = parse_llama_result(story, "[SCENE", "[END]")

    # obtaining characters
    data = [
        [
            {"role": "user", "content": f"here is a story: {story_parsed}\nGenerate a list of characters that may be in this story by providing their names between triangle brackets as such in the following format: <name>"}
        ],
    ]
    characters = llama(data)
    characters = parse_characters(characters)

    # generating captions
    data = [
        [
            {"role": "user", "content": f"here is a story: {story_parsed}\nGenerate a caption for each scene, labelling each caption with [SCENE <number>]. End the last caption with [END]"}
        ],
    ]
    captions = llama(data)
    captions_parsed = parse_llama_result(captions, "[SCENE", "[END]")

    # generating dialogue
    data = [
        [
            {"role": "user", "content": f"here is a story: {story_parsed}\nGenerate a set of dialogues ranging from 1 to 3 for each scene, labelling each dialogue with [SCENE <number>][CHARACTER <character name>]. Use characters from the following pool: {','.join(characters)}"}
        ],
    ]
    dialogue = llama(data)
    dialogue_parsed = parse_llama_result(dialogue, "[SCENE", "[END]")

    print(f"story_parsed: \n{story_parsed}\nEND")
    print(f"characters parsed: \n{characters}\nEND")
    print(f"captions_parsed: \n{captions_parsed}\nEND")
    print(f"dialogue_parsed: \n{dialogue_parsed}\nEND")


def gpt_pipeline(prompt):
    example = """
{
    "characters": [
        {"name": "Radke"}
    ],
    "script": [ 
        {"caption": "A sheep looking at cheese in a supermarket.", 
        "dialogue": [
            {"character": "Radke", "text": "In the mist-enshrouded hills of an ancient land, there lies a mystery as old as time itself. Behold the enigmatic sheep, creatures shrouded in the lore and legend of yesteryears."}
        ]
        }
    ]
}"""

    data = [
        {"role": "system", "content": "You are a TV show writer."},
        {"role": "user", "content": f"generate a story about {prompt} by filling in a json file, remember that captions should be physical descriptions of an image, and the script should contain around 20 captions. The following json example shows all the fields you should create and fill in: {example}"}
    ]
    response = gpt(data)

    print(f'response:\n{response}')

In [39]:

prompt = 'animals overrunning a village and destroying its crops'
gpt_pipeline(prompt)

Errors: 
Return Code: 0
response:
{
    "characters": [
        {"name": "James"},
        {"name": "Emily"},
        {"name": "Sheriff Hank"}
    ],
    "script": [ 
        {"caption": "A swarm of sparrows descends on a wheat field, pecking away at the golden crops.", 
        "dialogue": [
            {"character": "James", "text": "Emily, look! Our crops are under attack!"},
            {"character": "Emily", "text": "Those birds... they're eating everything! What are we going to do?"}
        ]
        },
        {"caption": "Countless deer causally munching on the green lettuce in the village's vegetable field.", 
        "dialogue": [
            {"character": "Sheriff Hank", "text": "This is the worst outbreak we've had in decades. And it's not just the birds... the deer... they're ravaging the vegetable fields."},
            {"character": "James", "text": "We need to do something soon or the entire village will starve."}
        ]
        },
        {"caption": "Night falls a