In [None]:
import cohere
import re
import os

In [None]:
# read api-key from file
with open('api_key', 'r') as file:
    key = file.read().strip()
    file.close()

co = cohere.Client(key)

In [None]:
def call(prompt, mtok, temp, k, p, seed):
    response = co.generate(
        model='command-r-plus',
        prompt=prompt,
        max_tokens=mtok,
        temperature=temp,
        k=k,
        p=p,
        seed=seed
    )

    return response

# ansi escape color functions
def red(text):
    return f"\033[31m{text}\033[0m"
def print_red(text):
    print(red(text))

def blue(text):
    return f"\033[34m{text}\033[0m"
def print_blue(text):
    print(blue(text))

def green(text):
    return f"\033[32m{text}\033[0m"
def print_green(text):
    print(green(text))

def yellow(text):
    return f"\033[33m{text}\033[0m"
def print_yellow(text):
    print(yellow(text))

def bold(text):
    return f"\033[1;4m{text}\033[0m"

# is numeric logic
def is_int(value):
    try:
        return float(value) == int(value)
    except:
        return False
def is_float(value):
    try:
        float(value)
        return True
    except:
        return False

# removes whitespace
def remove_spaces(text):
    return text.strip().replace(" ", "")
# regex keeps letters, numbers, and spaces only
def remove_schars(text):
    cleaned_text = re.sub(r'[^A-Za-z0-9 ]+', '', text)
    return cleaned_text

In [None]:
print_blue('Hello! This application lets you practice crafting prompts and save the zaniest stories you come across!')
print_blue('You may exit at any time by inputting exit when prompted.')
print_yellow('Please enter a prompt: ')

saved = True # tracks if output has been saved, initialized to true for use in exit function
status = True # tracks whether user attempted to exit program

os.makedirs('saved', exist_ok=True) # needed for file saving

def exit(saved):
    if not saved:
        print_yellow(bold('Would you like to save your latest work? ' + red('This is the last chance.' +  yellow('(y/n)'))))
        yn = input().strip().lower()
        match yn:
            case 'y':
                print_yellow('Please name your file, ' + bold('existing files will be overritten:'))
                fname = remove_schars(remove_spaces(input())) # disallows any spaces/special characters
                path = 'saved/output_'+fname+'.txt'
                with open(path, 'w') as file:
                    for val in inputs:
                        file.write(val + '\n')
                    file.write(output.generations[0].text)
                    file.close()
                    saved = True
                    print_green(f'Save to {path} successful.\n')
            case 'n':
                pass
            case _:
                print_red('Sorry, the given input is invalid.')
    print_green('Have a good day!')

while True:
    prompt = input()
    match prompt.strip().lower():
        case 'exit':
            exit(saved)
            status = False
            break
        case '':
            print_red('Oops! The prompt you gave was empty.')
            print_yellow('Please try entering a prompt:\n')
            continue

    # this array allows us to avoid more complicated if statements when parsing validity
    acceptable_forms = [['mtok=', 'is_int(value)'],
                        ['temp=', 'is_float(value)'],
                        ['k=', 'is_int(value)'],
                        ['p=', 'is_float(value)'],
                        ['seed=', 'is_int(value)']]
    
    while True:
        used = [False] * 5
        kargs = {
            'prompt':prompt,
            'mtok':800,
            'temp':0.9,
            'k':5,
            'p':0.8,
            'seed':123
        }
        valid = 0
    
        print_blue('\nThat prompt is valid, are there any parameters you would like to set?')
        print_yellow('Max tokens defaults at 800\nTemperature defaults to 0.9\nTop-k defaults to 5\nTop-p defaults to 0.8\nSeed defaults to 123')
        print_yellow(bold('Inputs should use form "keyword=value" for keywords (mtok, temp, k, p, seed) in a comma seperated entry.') + yellow('\nParameters k, mtok, and seed must be integers, all other parameters must be floats.'))
        print_yellow('Leave blank if you wish to use the default parameters:')
        param = input()
        if param == 'exit':
            exit(saved)
            status = False
            break
        elif param != '': # if arguments were defined, parse them with ',' delimiter
            args = param.split(',')
            
            for k,arg in enumerate(args):
                if valid == -1: # check if last iter encountered an error
                    print('error')
                    break
                args[k] = remove_spaces(arg)
                arg = args[k]
                
                for i,n in enumerate(acceptable_forms):
                    key = arg[0:len(n[0])]
                    value = arg[len(n[0]):] # value in str type
                    key_cond = key == n[0] # validates key
                    type_cond = eval(n[1]) # validates type of value
                    typ = n[1][3:].replace('(value)', '') # appropriate type
                    
                    if key_cond and type_cond: # compares each parsed argument start with list of acceptable keywords
                        value = eval(n[1][3:]) # casts value to appropriate type
                        if used[i]: # if parameter has been defined already
                            print_red('Each parameter can be defined only once.')
                            break
                        if value < 0: # if parameter value is negative
                            print_red('All parameters must be non-negative.')
                            break
                        if (i == 1 or i == 3) and value > 1: # if temp or p value is greater than 1.0
                            print_red(f'Parameter {key} must not exceed 1.0.')
                            break
                        used[i] = True # note that parameter has been defined
                        valid = 1 # note that code has executed properly
                        kargs[arg[0:len(n[0])-1]] = value # convert all values to acceptable key:type forms
                        break
                        
                    elif key_cond and not type_cond:
                        print_red(f'Parameter value must be of type {typ}.')
                        break
                            
                if valid == 1: # breaks out of parsing loop
                    valid = 0
                    continue
                else:
                    print_red(f'Argument "{arg}" not valid, please try again:\n')
                    valid = -1 # signal error
                    break
                    
                    
        if valid != 0: # if encounter error, run arg prompting loop again
            continue
        else:
            print()
            break

    if not status:
        break
    output = call(**kargs)
    inputs = [
        'Prompt: ' + kargs['prompt'],
        'Max Tokens: ' + str(kargs['mtok']),
        'Temperature: ' + str(kargs['temp']),
        'Top-k: ' + str(kargs['k']),
        'Top-p: ' + str(kargs['p']),
        'Seed: ' + str(kargs['seed'])
    ]
    print(green(inputs))
    print(green(output.generations[0].text), '\n')
    saved = False

    while True:
        print_yellow(bold('Would you like to save your latest work? (y/n)'))
        yn = input().strip().lower()
        match yn:
            case 'y':
                print_yellow('Please name your file, ' + bold('existing files will be overritten:'))
                fname = remove_schars(remove_spaces(input())) # disallows any spaces/special characters
                path = 'saved/output_'+fname+'.txt'
                with open(path, 'w') as file:
                    for val in inputs:
                        file.write(val + '\n')
                    file.write(output.generations[0].text)
                    file.close()
                    saved = True
                    print_green(f'Save to {path} successful.\n')
                print_blue('You may exit at any time by inputting exit when prompted.')
                print_yellow('Please enter a prompt:')
                break
            case 'n':
                print_blue('You may exit at any time by inputting exit when prompted.')
                print_yellow('Please enter a prompt:')
                break
            case _:
                print_red('Sorry, the given input is invalid.')
                continue