# Example of prior elicitation for a dataset

In [1]:
# import the necessary functions and classes
from llm_elicited_priors.utils import load_prompts
from llm_elicited_priors.gpt import (
    GPTOutputs, get_llm_elicitation_for_dataset
)
from llm_elicited_priors.datasets import load_breast_cancer
import numpy as np

In [2]:
# wrapper for language models
# see llm_elicited_priors.gpt for more details
CLIENT_CLASS = GPTOutputs
CLIENT_KWARGS = dict(
    temperature=0.1,
    model_id="gpt-3.5-turbo-0125",
    result_args=dict(
        response_format={"type": "json_object"},
    ),
)

In [3]:
# load the dataset which contains information
# about the feature names, target names, and 
# the dataset itself
dataset = load_breast_cancer()

In [4]:
# load the prompts for the system and user roles
system_roles = load_prompts("prompts/elicitation/system_roles_breast_cancer.txt")
user_roles = load_prompts("prompts/elicitation/user_roles_breast_cancer.txt")

In [5]:
# reducing the number of descriptions for demonstration
system_roles = system_roles[:2]
user_roles = user_roles[:2]

In [6]:
# create the llm client
client = CLIENT_CLASS(**CLIENT_KWARGS)

In [7]:
#### elicit the priors for the dataset ####
expert_priors = get_llm_elicitation_for_dataset(
    # the language model client
    client=client,
    # the prompts
    system_roles=system_roles,
    user_roles=user_roles,
    # the dataset contains the feature names as an attribute
    feature_names=dataset.feature_names.tolist(),
    # the dataset contains the target names as an attribute
    target_map={k: v for v, k in enumerate(dataset.target_names)},
    # print the prompts before passing them to the language model
    verbose=True,
)

Getting priors for 4 combinations:   0%|          | 0/4 [00:00<?, ?it/s]

System role 
 --------- 
 
You are a simulator of a logistic regression predictive model 
for predicting breast cancer diagnosis from tumour characteristics.
Here the inputs are tumour characteristics and the output is 
the probability of breast cancer diagnosis from tumour characteristics. 
Specifically, the targets are benign or malignant with mapping
'benign' = 0 and 'malignant' = 1.
With your best guess, you can provide the probabilities of a malignant 
breast cancer diagnosis for the given tumour characteristics. 
User query 
 --------- 
 
I am a data scientist with a dataset and the task: predicting breast 
cancer diagnosis from tumour characteristics. 
I would like to use your model to predict the diagnosis of my samples.
I have a dataset that is made up of the following features:
['mean radius', 'mean texture', 'mean perimeter', 'mean area', 'mean smoothness', 'mean compactness', 'mean concavity', 'mean concave points', 'mean symmetry', 'mean fractal dimension', 'radius error',

Getting priors for 4 combinations:  25%|██▌       | 1/4 [00:09<00:27,  9.18s/it]



matched features:
mean radius: mean radius: [0.5, 0.1]
mean texture: mean texture: [-0.3, 0.05]
mean perimeter: mean perimeter: [0.4, 0.08]
mean area: mean area: [0.6, 0.12]
mean smoothness: mean smoothness: [-0.2, 0.04]
mean compactness: mean compactness: [0.3, 0.06]
mean concavity: mean concavity: [0.4, 0.07]
mean concave points: mean concave points: [0.5, 0.09]
mean symmetry: mean symmetry: [-0.1, 0.03]
mean fractal dimension: mean fractal dimension: [-0.05, 0.02]
radius error: radius error: [0.2, 0.05]
texture error: texture error: [-0.1, 0.03]
perimeter error: perimeter error: [0.3, 0.06]
area error: area error: [0.4, 0.07]
smoothness error: smoothness error: [-0.1, 0.03]
compactness error: compactness error: [0.2, 0.04]
concavity error: concavity error: [0.3, 0.06]
concave points error: concave points error: [0.4, 0.07]
symmetry error: symmetry error: [-0.1, 0.03]
fractal dimension error: fractal dimension error: [-0.05, 0.02]
worst radius: worst radius: [0.6, 0.1]
worst textur

Getting priors for 4 combinations:  50%|█████     | 2/4 [00:15<00:15,  7.52s/it]



matched features:
mean radius: mean radius: [0.75, 0.15]
mean texture: mean texture: [0.25, 0.1]
mean perimeter: mean perimeter: [0.8, 0.2]
mean area: mean area: [0.7, 0.18]
mean smoothness: mean smoothness: [-0.4, 0.12]
mean compactness: mean compactness: [0.6, 0.16]
mean concavity: mean concavity: [0.65, 0.14]
mean concave points: mean concave points: [0.7, 0.15]
mean symmetry: mean symmetry: [-0.3, 0.1]
mean fractal dimension: mean fractal dimension: [-0.2, 0.08]
radius error: radius error: [0.4, 0.1]
texture error: texture error: [0.15, 0.06]
perimeter error: perimeter error: [0.45, 0.12]
area error: area error: [0.5, 0.14]
smoothness error: smoothness error: [-0.25, 0.08]
compactness error: compactness error: [0.35, 0.1]
concavity error: concavity error: [0.4, 0.1]
concave points error: concave points error: [0.45, 0.12]
symmetry error: symmetry error: [-0.2, 0.07]
fractal dimension error: fractal dimension error: [-0.15, 0.06]
worst radius: worst radius: [0.8, 0.2]
worst textur

Getting priors for 4 combinations:  75%|███████▌  | 3/4 [00:21<00:06,  6.89s/it]



matched features:
mean radius: mean radius: [0.5, 0.1]
mean texture: mean texture: [0.3, 0.05]
mean perimeter: mean perimeter: [0.6, 0.1]
mean area: mean area: [0.7, 0.15]
mean smoothness: mean smoothness: [-0.2, 0.05]
mean compactness: mean compactness: [0.4, 0.1]
mean concavity: mean concavity: [0.6, 0.1]
mean concave points: mean concave points: [0.7, 0.1]
mean symmetry: mean symmetry: [-0.1, 0.05]
mean fractal dimension: mean fractal dimension: [-0.3, 0.05]
radius error: radius error: [0.4, 0.1]
texture error: texture error: [0.2, 0.05]
perimeter error: perimeter error: [0.5, 0.1]
area error: area error: [0.6, 0.1]
smoothness error: smoothness error: [-0.1, 0.05]
compactness error: compactness error: [0.3, 0.1]
concavity error: concavity error: [0.4, 0.1]
concave points error: concave points error: [0.5, 0.1]
symmetry error: symmetry error: [-0.1, 0.05]
fractal dimension error: fractal dimension error: [-0.2, 0.05]
worst radius: worst radius: [0.6, 0.1]
worst texture: worst textu

Getting priors for 4 combinations: 100%|██████████| 4/4 [00:27<00:00,  6.94s/it]



matched features:
mean radius: mean radius: [0.75, 0.15]
mean texture: mean texture: [0.25, 0.1]
mean perimeter: mean perimeter: [0.8, 0.2]
mean area: mean area: [0.7, 0.18]
mean smoothness: mean smoothness: [-0.3, 0.08]
mean compactness: mean compactness: [0.6, 0.12]
mean concavity: mean concavity: [0.65, 0.14]
mean concave points: mean concave points: [0.7, 0.16]
mean symmetry: mean symmetry: [-0.2, 0.06]
mean fractal dimension: mean fractal dimension: [-0.15, 0.05]
radius error: radius error: [0.4, 0.1]
texture error: texture error: [0.1, 0.04]
perimeter error: perimeter error: [0.45, 0.11]
area error: area error: [0.5, 0.13]
smoothness error: smoothness error: [-0.25, 0.07]
compactness error: compactness error: [0.35, 0.09]
concavity error: concavity error: [0.4, 0.1]
concave points error: concave points error: [0.45, 0.11]
symmetry error: symmetry error: [-0.15, 0.05]
fractal dimension error: fractal dimension error: [-0.1, 0.04]
worst radius: worst radius: [0.85, 0.2]
worst tex




In [9]:
print("Elicited priors:")
print(np.stack(expert_priors))

Elicited priors:
[[[ 0.    1.  ]
  [ 0.5   0.1 ]
  [-0.3   0.05]
  [ 0.4   0.08]
  [ 0.6   0.12]
  [-0.2   0.04]
  [ 0.3   0.06]
  [ 0.4   0.07]
  [ 0.5   0.09]
  [-0.1   0.03]
  [-0.05  0.02]
  [ 0.2   0.05]
  [-0.1   0.03]
  [ 0.3   0.06]
  [ 0.4   0.07]
  [-0.1   0.03]
  [ 0.2   0.04]
  [ 0.3   0.06]
  [ 0.4   0.07]
  [-0.1   0.03]
  [-0.05  0.02]
  [ 0.6   0.1 ]
  [-0.4   0.08]
  [ 0.5   0.09]
  [ 0.7   0.12]
  [-0.3   0.06]
  [ 0.4   0.07]
  [ 0.5   0.09]
  [ 0.6   0.1 ]
  [-0.2   0.05]
  [-0.1   0.03]]

 [[ 0.    1.  ]
  [ 0.75  0.15]
  [ 0.25  0.1 ]
  [ 0.8   0.2 ]
  [ 0.7   0.18]
  [-0.4   0.12]
  [ 0.6   0.16]
  [ 0.65  0.14]
  [ 0.7   0.15]
  [-0.3   0.1 ]
  [-0.2   0.08]
  [ 0.4   0.1 ]
  [ 0.15  0.06]
  [ 0.45  0.12]
  [ 0.5   0.14]
  [-0.25  0.08]
  [ 0.35  0.1 ]
  [ 0.4   0.1 ]
  [ 0.45  0.12]
  [-0.2   0.07]
  [-0.15  0.06]
  [ 0.8   0.2 ]
  [ 0.3   0.1 ]
  [ 0.85  0.22]
  [ 0.75  0.18]
  [-0.35  0.1 ]
  [ 0.55  0.15]
  [ 0.6   0.14]
  [ 0.65  0.15]
  [-0.25  0.1 ]
  [-0