In [None]:
from dotenv import load_dotenv

from memo import memlist, memfile, grid, time_taken
from datetime import datetime, timezone
from tqdm import tqdm
from pathlib import Path
import json
import instructor
from litellm import completion
from pydantic import BaseModel, Field
import vertexai

vertexai.init(
    project=...,
    location=...,
)


load_dotenv()
POEM_TOPICS = (
    "dogs",
    "cats",
    "spring",
    "summer",
    "christmas",
    "halloween",
    "love",
    "loss",
    "family",
    "friends",
    "war",
    "peace",
    "nature",
    "city",
    "home",
    "work",
    "school",
    "music",
    "art",
    "food",
    "winter",
    "fall",
    "dreams",
    "life",
    "death",
    "childhood",
    "time",
    "memories",
    "happiness",
    "sadness",
    "adventures",
    "travels",
    "fantasy",
    "realities",
    "freedom",
    "night",
    "day",
    "ocean",
    "mountains",
    "stars",
    "moon",
    "birth",
    "ageing",
    "loneliness",
    "hope",
)

MODELS = (
    "bedrock/anthropic.claude-3-haiku-20240307-v1:0",
    "bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
    "bedrock/anthropic.claude-3-opus-20240229-v1:0",
    "gpt-3.5-turbo",
    "gpt-4-turbo-2024-04-09",
    "gemini-1.5-pro-preview-0409",
)

In [None]:
vertex_safety_settings = [
        {
            "category": "HARM_CATEGORY_HARASSMENT",
            "threshold": "BLOCK_NONE",
        },
        {
            "category": "HARM_CATEGORY_HATE_SPEECH",
            "threshold": "BLOCK_NONE",
        },
        {
            "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
            "threshold": "BLOCK_NONE",
        },
        {
            "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
            "threshold": "BLOCK_NONE",
        },
        ]

In [None]:
from time import sleep

# generate directly
results_fpath = Path("../data/results.jsonl")

if results_fpath.exists():
    data = [json.loads(l) for l in results_fpath.read_text().splitlines()]
else:
    data = []

@memfile(filepath=results_fpath)
@memlist(data=data)
@time_taken()
def create_ABBA_rhyme_directly(model, topic):
    # direct call
    response = completion(
        model=model,
        messages=[
            {
                "content": f"Create a 4 line poem in ABBA rhyme scheme about the following topic: {topic}. Return only the poem.",
                "role": "user",
            }
        ],
        # TODO: only needed for vertex_ai/gemini-pro
        safety_settings= vertex_safety_settings if "gemini" in model else None
    )
    return {"poem": response.choices[0].message.content, "datetime": datetime.now(timezone.utc).isoformat(), "format": "direct_ask"}

# TODO: filter grid for done items...so the tqdm bar is more reliable
for settings in tqdm(grid(model = MODELS, topic=POEM_TOPICS)):
    if any(d["model"] == settings["model"] and d["topic"]==settings["topic"] and d["format"] == "with_instructor" for d in data):
        continue
    else:
        create_ABBA_rhyme_directly(**settings)
        if "gemini" in settings['model']:
            sleep(10)

In [None]:
class Poem(BaseModel):
    rhyme_word_A1: str = Field(description = "The word that sentence A1 ends with.")
    sentence_A1: str = Field(description = "The sentence that ends with the rhyme word A1.")
    rhyme_word_B1: str = Field(description = "The word that sentence B1 ends with. It SHOULD NOT rhyme with rhyme_word_A1.")
    sentence_B1: str = Field(description = "The sentence that ends with the rhyme word B1.")
    rhyme_word_B2: str = Field(description = "The word that sentence B2 ends with. It SHOULD rhyme with rhyme_word_B1. It SHOULD NOT rhyme with rhyme_word_A1.")
    sentence_B2: str = Field(description = "The sentence that ends with the rhyme word B2.")
    rhyme_word_A2: str = Field(description = "The word that sentence A2 ends with. It SHOULD rhyme with rhyme_word_A1. It SHOULD NOT rhyme with rhyme_word_B2.")
    sentence_A2: str = Field(description = "The sentence that ends with the rhyme word A2.")

client = instructor.from_litellm(completion)

In [None]:
# generate w instructor
results_fpath = Path("results.jsonl")

if results_fpath.exists():
    data = [json.loads(l) for l in results_fpath.read_text().splitlines()]
else:
    data = []

@memfile(filepath=results_fpath)
@memlist(data=data)
@time_taken()
def create_ABBA_rhyme_with_instructor(model, topic):
    resp = client.chat.completions.create(
        model=model,
    # max_tokens=1024,
        messages=[
            {
                "content": f"Create a 4 line poem in ABBA rhyme scheme about the following topic: {topic}. Return only the poem.",
                "role": "user",
            }
        ],
        response_model=Poem,
        safety_settings= vertex_safety_settings if "gemini" in model else None
    )
    poem = resp.sentence_A1 + "\n" + resp.sentence_B1 + "\n" + resp.sentence_B2 + "\n" + resp.sentence_A2
    return {"poem": poem, "datetime": datetime.now(timezone.utc).isoformat(), "format": "with_instructor"}

for settings in tqdm(grid(model = MODELS, topic=POEM_TOPICS)):
    if any(d["model"] == settings["model"] and d["topic"]==settings["topic"] and d["format"] == "with_instructor" for d in data):
        continue
    else:
        create_ABBA_rhyme_with_instructor(**settings)
        if "gemini" in settings['model']:
            sleep(10)