In [None]:
import cohere
import json
import io
import warnings

import pandas as pd
from IPython.display import display
from PIL import Image

from stability_sdk import client
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation

# APIs

In [None]:
with open('../cohere_api_key.txt', 'r') as f:
    cohere_api_key = f.read()
co = cohere.Client(cohere_api_key)

with open('../stability_api_key.txt', 'r') as f:
    stability_api_key = f.read()
stability_api = client.StabilityInference(
    key=stability_api_key, 
    verbose=True,
)

del cohere_api_key
del stability_api_key

# EDA

In [None]:
with open('../stories/hansel_and_gretel.json') as f:
    data = json.load(f)
    text = data['text']
    print(type(text), len(text))

In [None]:
with open('../stories/red_riding_hood.json') as f:
    data = json.load(f)
    text = data['text']
    print(type(text), len(text))

In [None]:
with open('../stories/three_little_pigs.json', errors='ignore') as f:
    data = json.load(f)
    text = data['text']
    print(type(text), len(text))

# Generating stories

In [None]:
def generate(prompt, model="xlarge", 
             num_generations=5, temperature=0.7, 
             max_tokens=2000, stop_sequences=['<end>']):
             
  prediction = co.generate(
    model=model,
    prompt=prompt,
    return_likelihoods = 'GENERATION',
    stop_sequences=stop_sequences,
    max_tokens=max_tokens,
    temperature=temperature,
    num_generations=num_generations)
  
  # Get list of generations
  gens = []
  likelihoods = []
  for gen in prediction.generations:
      gens.append(gen.text)
      
      sum_likelihood = 0
      for t in gen.token_likelihoods:
          sum_likelihood += t.likelihood
      # Get sum of likelihoods
      likelihoods.append(sum_likelihood)

  pd.options.display.max_colwidth = 200
  # Create a dataframe for the generated sentences and their likelihood scores
  df = pd.DataFrame({'generation':gens, 'likelihood': likelihoods})
  # Drop duplicates
  df = df.drop_duplicates(subset=['generation'])
  # Sort by highest sum likelihood
  df = df.sort_values('likelihood', ascending=False, ignore_index=True)
  
  return df

In [None]:
def generate_image(image_prompt):
  # the object returned is a python generator
  answers = stability_api.generate(
      prompt=image_prompt
  )

  # iterating over the generator produces the api response
  for resp in answers:
      for artifact in resp.artifacts:
          if artifact.finish_reason == generation.FILTER:
              warnings.warn(
                  "Your request activated the API's safety filters and could not be processed."
                  "Please modify the prompt and try again.")
          if artifact.type == generation.ARTIFACT_IMAGE:
              img = Image.open(io.BytesIO(artifact.binary))
              display(img)


In [None]:
result = generate(
    "short fairy tale about a girl",
    max_tokens=1000
)

In [None]:
print(result.info())

In [None]:
print(result['generation'][0])

In [1]:
test_string = """
Once upon a time, there was a farmer with three little pigs. He did not have enough food to take care of his pigs so he sent them away to take care of themselves. The first little pig was walking on the road. Suddenly he saw a man with some straws. He could  build a house with the straws, he said to himself and asked the man to give him his straws. The man was kind so he gave the first little pig his straws. The pig used the straws to build a straw house and danced around. Suddenly, a big bad wo
"""

In [2]:
len(test_string)

502