In [None]:
!pip install -U scikit-learn


In [None]:
!pip install imodelsx

# Explaining Patterns in Data with Language Models via Interpretable Autoprompting

introduces iPrompt, an algorithm for generating human-interpretable natural language explanations of data patterns using large language models (LLMs). iPrompt iteratively generates and refines prompts, outperforming traditional autoprompting methods. It effectively explains diverse datasets, including synthetic math, natural language understanding, and scientific data, sometimes surpassing human-written prompts. The approach leverages the LLM's capability without requiring fine-tuning.

The original article is [here](https://arxiv.org/abs/2210.01848), the code adapted from [here](https://github.com/csinva/imodelsX/blob/master/demo_notebooks/iprompt.ipynb)

In [4]:
import pandas as pd
def get_add_two_numbers_dataset(num_examples: int= None):
  file_dir = 'https://raw.githubusercontent.com/SalvatoreRa/explanaibleAI/main/tabular_methods/Data/add_two.csv'
  df = pd.read_csv(file_dir, delimiter=',')
  df['output_strings'] = df['output_strings'].str.replace("'", "")
  if num_examples is not None:
      df = df.sample(n=num_examples)
  inputs, outputs = df['input_strings'].values, [v.replace('\\n', '\n') for v in df['output_strings'].values]
  return inputs, outputs

In [5]:
%load_ext autoreload
%autoreload 2
from imodelsx import explain_dataset_iprompt

# get a simple dataset of adding two numbers
input_strings, output_strings = get_add_two_numbers_dataset(num_examples=100)
for i in range(5):
    print(repr(input_strings[i]), repr(output_strings[i]))

'Given the input numbers 0 and 4, the answer is' ' 4.\n\n'
'Given the input numbers 9 and 6, the answer is' ' 15.\n\n'
'Given the input numbers 7 and 6, the answer is' ' 13.\n\n'
'Given the input numbers 9 and 9, the answer is' ' 18.\n\n'
'Given the input numbers 6 and 2, the answer is' ' 8.\n\n'


In [7]:
# explain the relationship between the inputs and outputs
# with a natural-language prompt string
prompts, metadata = explain_dataset_iprompt(
    input_strings=input_strings,
    output_strings=output_strings,
    checkpoint='EleutherAI/gpt-j-6B', # which language model to use
    num_learned_tokens=6, # how long of a prompt to learn
    n_shots=3, # shots per example

    n_epochs=15, # how many epochs to search
    max_n_datapoints=1000, # limit to using this many datapoints in search
    verbose=0, # how much to print
    llm_float16=True, # whether to load the model in float_16
)

tokenizer_config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.37M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/4.04k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/357 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/836 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/12.1G [00:00<?, ?B/s]

start_word_id = tensor([1169])
preprefix: ''
iPrompt got 1000 datapoints, now loading model...
Training with 19 possible answers / random acc 5.3% / majority acc 10.4%
Beginning epoch 0


Loss = 0.907: 100%|██████████| 16/16 [03:45<00:00, 14.12s/it]


Epoch 0. average loss = 1.287 / 803 / 1000 correct (80.30%)
Beginning epoch 1


Loss = 0.613:   0%|          | 0/16 [00:14<?, ?it/s]


Ending epoch 1 early...
Epoch 1. average loss = 0.613 / 59 / 64 correct (92.19%)
Stopping early after 17 steps and 1064 datapoints
Final prefixes
   index                                 prefix      loss  accuracy  \
0    439     (220, 19430, 257, 2446, 284, 2160)  0.798539  0.884896   
1    315     (19430, 257, 2163, 1444, 2160, 62)  1.254141  0.846875   
2    721    (357, 16594, 257, 2163, 3706, 2160)  1.168026  0.796875   
3    137   (19430, 257, 2163, 1444, 2160, 5189)  1.370576  0.796875   
4     28  (19430, 257, 2163, 1444, 2160, 20560)  1.427683  0.789062   

                          prefix_str  n_queries  
0              Write a method to sum          6  
1       Write a function called sum_          7  
2        (Write a function named sum          1  
3      Write a function called sumOf          4  
4   Write a function called sumInput          4  
