In [1]:
!bash /home/azureuser/cloudfiles/code/blobfuse/blobfuse_raadsinformatie.sh

In [2]:
import sys
sys.path.append("..")

# Select where to run notebook: "azure" or "local"
my_run = "azure"

# import my_secrets as sc
# import settings as st

if my_run == "azure":
    import config_azure as cf
elif my_run == "local":
    import config as cf


import os
if my_run == "azure":
    if not os.path.exists(cf.HUGGING_CACHE):
        os.mkdir(cf.HUGGING_CACHE)
    os.environ["TRANSFORMERS_CACHE"] = cf.HUGGING_CACHE

# set-up environment - GEITje-7b-chat InContextLearning:
# - install blobfuse -> sudo apt-get install blobfuse
# - pip install transformers
# - pip install torch
# - pip install accelerate
# - pip install jupyter
# - pip install ipywidgets

## Notebook overview
- Goal: Run experiment for InContext Learning GEITje
- Trial run model -> prompt GEITje using, example prompt
- Zeroshot prompts
- Fewshot prompts

Load data and functions:
- data is already split
- text is already converted to tokens using model tokenizer 

In [3]:
import pandas as pd
# df = pd.read_pickle(f"{cf.output_path}/txtfiles_tokenizer.pkl")

import sys
sys.path.append('../scripts/') 
import prompt_template as pt
import prediction_helperfunctions as ph
import truncation as tf


In [4]:
import torch
torch.cuda.empty_cache()

In [5]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

#### Trial run Models 
Code to run the models with a simple prompt.

In [None]:
from transformers import pipeline, Conversation

chatbot_geitje = pipeline(task='conversational', model='Rijgersberg/GEITje-7B-chat-v2',
                   device_map='auto', model_kwargs={'offload_buffers':True})


chatbot_llama = pipeline(task='conversational', model='meta-llama/Llama-2-7b-chat-hf',
                   device_map='auto', model_kwargs={'offload_buffers':True})

chatbot_mistral = pipeline(task='conversational', model='mistralai/Mistral-7B-Instruct-v0.2',
                   device_map='auto', model_kwargs={'offload_buffers':True})

## EXAMPLE PROMPT
# print(chatbot(
    # Conversation('Welk woord hoort er niet in dit rijtje thuis: "auto, vliegtuig, geitje, bus"?')
# ))

#### Experiment functions
Prompt GEITje for each document and save the prediction, return response, response time and the prompt version

Code structure:
- 2 functions/cells:
- predictions_incontextlearning -> given a df with docs that need to be predicted, prompt the model
- run the experiment -> built in failsaves (df run in parts, with saves in between)

In [6]:
import time
import os
import pandas as pd
from bm25 import BM25


""" Given a dataframe with txt, return a df with predictions """
# docs_df = dataframe with the documents that need to be predicted
# text_column = name of the column that includes the input_text. Can be different based on the text representation method. 
# prompt_function = prompt template 
# train_df = dataframe with docs, which can be used as examples/training data/context data
# num_examples = number of examples in the prompt

def predictions_incontextlearning(chatbot, docs_df, text_column, prompt_function, train_df, num_examples):
    results_df = pd.DataFrame(columns = ['id', 'path', 'text_column', 'prompt_function', 'response', 'prediction', 'label', 'runtime', 'date', 'prompt'])


    if prompt_function == pt.fewshot_prompt_with_template or prompt_function == pt.fewshot_prompt_no_template:
        BM25_model = BM25()
        BM25_model.fit(train_df[text_column])
   

    # prompt each document
    for index, row in docs_df.iterrows():
        # if (index + 1) % 200 == 0:
        #     print(f"Iteration {index +1}/{len(docs_df)} completed.")

        start_time = time.time()

        # get the prompt, with the doc filled in
        txt = row[text_column]

        # each prompt function takes different arguments
        # zeroshot prompt for geitje
        if prompt_function == pt.zeroshot_prompt_geitje:
            prompt = prompt_function(txt)

        # zeroshot function for mistral and llama
        elif prompt_function == pt.zeroshot_prompt_mistral_llama:
            prompt = prompt_function(txt)

        # select fewshot examples using bm25, fewshot is the same for all models
        # elif prompt_function == pt.fewshot_prompt_bm25:
        #     prompt = prompt_function(txt, train_df, num_examples, text_column, BM25_model)
        
        elif prompt_function == pt.fewshot_prompt_no_template:
            prompt = prompt_function(txt, train_df, num_examples, text_column, BM25_model)

        elif prompt_function == pt.fewshot_prompt_with_template:
            prompt = prompt_function(txt, train_df, num_examples, text_column, BM25_model)

        else:
            raise ValueError("Prompt function not recognised. Check if prompt function is in prompt_template.py and included in the options above.")

        # prompt and get the response
        # print(prompt)
        converse = chatbot(Conversation(prompt))
        response = converse[1]['content']
        print("label: ", row['label'].lower())
        print("response: ", response)

        # extract prediction from response
        prediction = ph.get_prediction_from_response(response)
        print("prediction:", prediction)

        # save results in dataframe
        results_df.loc[len(results_df)] = {
            'id': row['id'],
            'path' : row['path'],
            'text_column' : docs_df.iloc[0]['trunc_col'],
            'prompt_function': ph.get_promptfunction_name(prompt_function),
            'response':response,
            'prediction':prediction,
            'label':row['label'].lower(),
            'runtime':time.time()-start_time,
            'date': ph.get_datetime(),
            'prompt':prompt
        }
    return results_df



In [7]:
import os
import time
from sklearn.metrics import classification_report, accuracy_score, precision_score, recall_score, f1_score

"""
Function to run GEITje In-Context Learning experiment. 
The function allows to resume experiment, if run_id matches.
"""
# df = dataframe with all docs that need to have a prediction (docs still need to be predict + already predicted)
# run_id = unqiue for each experiment. 
# prompt_function = which prompt from prompt_template.py to use
# text_col = colum in df where the text is. (Needs to be already truncated)
# split_col = column with the dataset split. Either '2split' (train and test)or '4split'(train, test, dev and val)
# subset_train = indicates which subset to use as training. either 'train' or 'dev'
# subset_test = indicates which subset to use for testing. either 'test' or 'val'
# label_col = column with the true label
# prediction_path = path to file where predictions need to be saved.
# overview_path = path to file where results of each run need to be saved.
# model_name = name of the model. string.
# num_exmples = number of exaples given to prompt. zero in case of zeroshot. 

def run_experiment(chatbot, df, run_id, prompt_function, text_col, split_col, subset_train, subset_test, label_col, prediction_path, overview_path, model_name, num_examples=0):
    test_df = df.loc[df[split_col]==subset_test]
    train_df = df.loc[df[split_col]==subset_train]
    
    # get rows of df that still need to be predicted for the specific run_id
    to_predict, previous_predictions = ph.get_rows_to_predict(test_df, prediction_path, run_id)

    # devide to_predict into subsection of 50 predictions at a time. 
    # Allows to rerun without problem. And save subsections of 50 predictions.
    step_range = list(range(0, len(to_predict), 10))

    for i in range(len(step_range)):
        try:
            sub_to_predict = to_predict.iloc[step_range[i]:step_range[i+1]]
            print(f'Starting...{step_range[i]}:{step_range[i+1]} out of {len(to_predict)}')
        except Exception as e:
            sub_to_predict = to_predict[step_range[i]:]
            print(f'Starting...last {len(sub_to_predict)} docs')

        # prompt geitje
        predictions = predictions_incontextlearning(chatbot, sub_to_predict, text_col, prompt_function, train_df, num_examples)

        # save info
        predictions['run_id'] = run_id
        predictions['train_set'] = subset_train
        predictions['test_set'] = subset_test
        predictions['shots'] = num_examples

        # save new combinations in file
        print("Dont interrupt, saving predictions...")
        ph.combine_and_save_df(predictions, prediction_path)

        # if previous predictions, combine previous with new predictions, to get update classification report
        try:
            predictions = pd.concat([predictions, previous_predictions])

            # set previous predictions to all predictions made until now. Necessary for next loop
            previous_predictions = predictions
        except Exception as e:
            # set previous predictions to all predictions made until now. Necessary for next loop
            previous_predictions = predictions

        # save results in overview file
        date = ph.get_datetime()
        y_test = predictions['label']
        y_pred = predictions['prediction']

        # change error predictions to one error
        # error_names = ['NoPredictionInOutput', 'MultiplePredictionErrorInFormatting','NoPredictionFormat', 'MultiplePredictionErrorInOutput']
        # y_pred = ['OutputError' if x in error_names else x for x in y_pred]

        report = classification_report(y_test, y_pred)

        overview = pd.DataFrame(
            [{
                'model':model_name,
                'run_id':run_id,
                'date': date,
                'train_set': subset_train,
                'test_set': subset_test,
                'train_set_support':len(df.loc[df[split_col]==subset_train]),
                'test_set_support':len(predictions),
                'split_col':split_col,
                'text_col':df.iloc[0]['trunc_col'],
                'runtime':sum(predictions['runtime']),
                'accuracy': accuracy_score(y_test, y_pred),
                'macro_avg_precision': precision_score(y_test, y_pred, average='macro'),
                'macro_avg_recall': recall_score(y_test, y_pred, average='macro'),
                'macro_avg_f1': f1_score(y_test, y_pred, average='macro'),
                'weighted_avg_precision': precision_score(y_test, y_pred, average='weighted'),
                'weighted_avg_recall': recall_score(y_test, y_pred, average='weighted'),
                'weighted_avg_f1': f1_score(y_test, y_pred, average='weighted'),
                'classification_report':report
            }   ]
        )
        # remove previous results of run_id, replace with new/updated results
        ph.replace_and_save_df(overview, overview_path, run_id)
        print("Saving done! Interrupting is allowed.")
        print("Accuracy: ", accuracy_score(y_test, y_pred))




Set up variables that are the same for each model

In [8]:
#set  variables, same for each model
TRAIN_SET = 'train' # must be dev or train
TEST_SET = 'test' # must be val or test
SPLIT_COLUMN = 'balanced_split' #must be either 2split or 4split. 2split = data split into train and test. 4split = data split into train, test, dev and val. 
LABEL_COLUMN = 'label'
TEXT_COLUMN = 'trunc_txt'


In [9]:
txt = pd.read_pickle(f"{cf.output_path}/txtfiles_tokenizer.pkl")

### GEITje

In [15]:
SHORT_MODEL_NAME = 'GEITje'
PROMPT = pt.zeroshot_prompt_geitje
PROMPT_NAME = ph.get_promptfunction_name(PROMPT)
TOKENS_COL = 'LlamaTokens' # column with text split using tokenizer of either mistral (MistralTokens) or Llama (LlamaTokens). Using Llama, because Llama split into more tokens. 
FRONT_THRESHOLD = 200
BACK_THRESHOLD = 0

if PROMPT==pt.zeroshot_prompt_geitje:
    NUMBER_EXAMPLES = 0
elif PROMPT == pt.fewshot_prompt_no_template:
    NUMBER_EXAMPLES = 2



#### Load model - In-context learning
Note - ONLY load one model: either in-context or fine-tuning

In [None]:
from transformers import pipeline, Conversation

chatbot_geitje = pipeline(task='conversational', model='Rijgersberg/GEITje-7B-chat-v2',
                    device_map='cpu', model_kwargs={'offload_buffers':True})

MODEL_NAME = 'GEITje-7B-chat-v2'
SUBFOLDER = 'in_context'
SHORT_ID = 'IC'



#### Load model - Finetuning

In [16]:
from transformers import pipeline, Conversation

chatbot_geitje = pipeline(task='conversational', model='FemkeBakker/AmsterdamDocClassificationGEITje200T3Epochs',
                   device_map='cpu', model_kwargs={'offload_buffers':True})

MODEL_NAME = 'AmsterdamDocClassificationGEITje200T3Epochs'
SUBFOLDER = 'finetuning'
SHORT_ID = 'FT'
EPOCHS = 3



config.json:   0%|          | 0.00/667 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.52k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/438 [00:00<?, ?B/s]

#### Set-up paths to save predictions

In [17]:
import os

if SPLIT_COLUMN == '4split' or SPLIT_COLUMN == '2split':
    OVERVIEW_PATH = f"{cf.output_path}/predictionsVal/{SUBFOLDER}/{SHORT_MODEL_NAME}/{PROMPT_NAME}/overview.pkl"
    PREDICTION_PATH = f"{cf.output_path}/predictionsVal/{SUBFOLDER}/{SHORT_MODEL_NAME}/{PROMPT_NAME}/predictions.pkl"
    
elif SPLIT_COLUMN == 'balanced_split':
    if SUBFOLDER == 'finetuning':
        OVERVIEW_PATH = f"{cf.output_path}/predictionsFinal/{SUBFOLDER}/{EPOCHS}epochs/overview.pkl"
        PREDICTION_PATH = f"{cf.output_path}/predictionsFinal/{SUBFOLDER}/{EPOCHS}epochs/{SHORT_MODEL_NAME}First{FRONT_THRESHOLD}Last{BACK_THRESHOLD}Predictions.pkl"

    elif SUBFOLDER == 'in_context':
        OVERVIEW_PATH = f"{cf.output_path}/predictionsFinal/{SUBFOLDER}/{SHORT_MODEL_NAME}/overview.pkl"
        PREDICTION_PATH = f"{cf.output_path}/predictionsFinal/{SUBFOLDER}/{SHORT_MODEL_NAME}/{PROMPT_NAME}/First{FRONT_THRESHOLD}Last{BACK_THRESHOLD}Predictions.pkl"

print(OVERVIEW_PATH)
print(PREDICTION_PATH)

if not os.path.isdir(os.path.dirname(os.path.abspath(OVERVIEW_PATH))):
    raise ValueError("Folder to OVERVIEW_PATH does not exist") 
if not os.path.isdir(os.path.dirname(os.path.abspath(PREDICTION_PATH))):
    raise ValueError("Folder to PREDICTION_PATH does not exist") 

run_id = f'{SHORT_ID}_{MODEL_NAME}{PROMPT_NAME}{TOKENS_COL}{FRONT_THRESHOLD}_{BACK_THRESHOLD}{TRAIN_SET}{TEST_SET}_numEx{NUMBER_EXAMPLES}'
print ('\n', run_id)


/home/azureuser/cloudfiles/code/blobfuse/raadsinformatie/processed_data/woo_document_classification/predictionsFinal/finetuning/3epochs/overview.pkl
/home/azureuser/cloudfiles/code/blobfuse/raadsinformatie/processed_data/woo_document_classification/predictionsFinal/finetuning/3epochs/GEITjeFirst200Last0Predictions.pkl

 FT_AmsterdamDocClassificationGEITje200T3Epochszeroshot_prompt_geitjeLlamaTokens200_0traintest_numEx0


#### Run experiment

In [18]:
# ----- EXPERIMENT --------
# add new column with truncated text -> new dataframe with column + new column name
trunc_df = tf.add_truncation_column(txt,'text', TOKENS_COL, FRONT_THRESHOLD, BACK_THRESHOLD)


# if new run MAKE SURE RUN_ID IS UNIQUE, if want to resume run, pass in that run_id
run_experiment(chatbot_geitje, trunc_df, run_id, PROMPT, TEXT_COLUMN, SPLIT_COLUMN, TRAIN_SET, TEST_SET, LABEL_COLUMN, PREDICTION_PATH, OVERVIEW_PATH, MODEL_NAME, NUMBER_EXAMPLES)


Starting...0:10 out of 1100
label:  raadsnotulen
response:  {'categorie': Raadsnotulen}
prediction: raadsnotulen
label:  factsheet
response:  {'categorie': Onderzoeksrapport}
prediction: onderzoeksrapport
label:  factsheet
response:  {'categorie': Onderzoeksrapport}
prediction: onderzoeksrapport
label:  brief
response:  {'categorie': Brief}
prediction: brief
label:  factsheet
response:  {'categorie': Onderzoeksrapport}
prediction: onderzoeksrapport
label:  raadsnotulen
response:  {'categorie': Raadsnotulen}
prediction: raadsnotulen
label:  voordracht
response:  {'categorie': Voordracht}
prediction: voordracht
label:  agenda
response:  {'categorie': Agenda}
prediction: agenda
label:  schriftelijke vraag
response:  {'categorie': Schriftelijke Vraag}
prediction: schriftelijke vraag
label:  schriftelijke vraag
response:  {'categorie': Schriftelijke Vraag}
prediction: schriftelijke vraag
Dont interrupt, saving predictions...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Saving done! Interrupting is allowed.
Accuracy:  0.7
Starting...10:20 out of 1100
label:  brief
response:  {'categorie': Brief}
prediction: brief
label:  onderzoeksrapport
response:  {'categorie': Onderzoeksrapport}
prediction: onderzoeksrapport
label:  raadsadres
response:  {'categorie': Raadsadres}
prediction: raadsadres
label:  schriftelijke vraag
response:  {'categorie': Schriftelijke Vraag}
prediction: schriftelijke vraag
label:  motie
response:  {'categorie': Motie}
prediction: motie
label:  factsheet
response:  {'categorie': Factsheet}
prediction: factsheet
label:  raadsadres
response:  {'categorie': Raadsadres}
prediction: raadsadres
label:  raadsadres
response:  {'categorie': Raadsadres}
prediction: raadsadres
label:  besluit
response:  {'categorie': Besluit}
prediction: besluit
label:  motie
response:  {'categorie': Motie}
prediction: motie
Dont interrupt, saving predictions...
Saving done! Interrupting is allowed.
Accuracy:  0.85
Starting...20:30 out of 1100
label:  raadsnot

Bad pipe message: %s [b'\xa5K^\x92\x9a#\x9c6\xe0\xad<\xbc\x8e\xf9\x8aV?\xf6 \xf8:\x15"\x94C\x14=\x1d\xbe\xa6\xbc\xf4a\xd0\x8b\xe5\xdc\xe5\xe6\xe2j\xe6\x0c\x89\xda\n(d\x01\x92X\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00\x17\x00\x1e\x00\x19\x00\x18\x00#\x00\x00\x00\x16\x00\x00\x00\x17\x00\x00\x00\r\x00\x1e\x00\x1c\x04\x03\x05\x03\x06\x03\x08\x07\x08\x08\x08', b'\n\x08\x0b\x08\x04\x08\x05\x08']
Bad pipe message: %s [b'\x01\x05\x01\x06\x01']
Bad pipe message: %s [b"\xe2S\x1b+\\yG.#\xdeD\x04\x9ey'%\x12/ \xbf\xa2j\xafq\xc1\x985?\x92C\xe0\xba\x17w\x91\xc3\xd6yR\xdd\xd4\xe4*\xb9\x9f\xb8\x86E\x98P[\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00\x17\x00\x1e\x00\x19\x00\x18\x00#\x00\x00\x00\x16\x00\x00\x00\x17\x00\x00\x00\r\x00\x1e\x00\x1c\x04\x03\x05\x03\x06\x03\x

label:  raadsadres
response:  {'categorie': Raadsadres}
prediction: raadsadres
label:  raadsadres
response:  {'categorie': Raadsadres}
prediction: raadsadres
label:  agenda
response:  {'categorie': Agenda}
prediction: agenda
Dont interrupt, saving predictions...
Saving done! Interrupting is allowed.
Accuracy:  0.9
Starting...280:290 out of 1100
label:  brief
response:  {'categorie': Motie}
prediction: motie
label:  schriftelijke vraag
response:  {'categorie': Schriftelijke Vraag}
prediction: schriftelijke vraag
label:  factsheet
response:  {'categorie': Onderzoeksrapport}
prediction: onderzoeksrapport
label:  factsheet
response:  {'categorie': Factsheet}
prediction: factsheet
label:  schriftelijke vraag
response:  {'categorie': Schriftelijke Vraag}
prediction: schriftelijke vraag
label:  brief
response:  {'categorie': Brief}
prediction: brief
label:  actualiteit
response:  {'categorie': Actualiteit}
prediction: actualiteit
label:  agenda
response:  {'categorie': Agenda}
prediction: age

In [None]:
pred = pd.read_pickle(OVERVIEW_PATH)
# pred_run = pred.loc[pred['run_id']==f'{PROMPT_NAME}{TOKENS_COL}{FRONT_THRESHOLD}_{BACK_THRESHOLD}']
display(pred)

### Llama


In [10]:
SHORT_MODEL_NAME = 'Llama'
PROMPT = pt.zeroshot_prompt_mistral_llama
PROMPT_NAME = ph.get_promptfunction_name(PROMPT)
TOKENS_COL = 'LlamaTokens' # column with text split using tokenizer of either mistral (MistralTokens) or Llama (LlamaTokens). Using Llama, because Llama split into more tokens. 
FRONT_THRESHOLD = 200
BACK_THRESHOLD = 0

if PROMPT==pt.zeroshot_prompt_mistral_llama:
    NUMBER_EXAMPLES = 0
elif PROMPT == pt.fewshot_prompt_with_template or PROMPT == pt.fewshot_prompt_no_template:
    NUMBER_EXAMPLES = 2



#### Load model - In-context learning
Note - ONLY load one model: either in-context or fine-tuning

In [None]:
from transformers import pipeline, Conversation

chatbot_llama = pipeline(task='conversational', model='meta-llama/Llama-2-7b-chat-hf',
                   device_map='cpu', model_kwargs={'offload_buffers':True})
# load llama using cpu, else will give cuda out of memory error when running fewshot bm25 prompt.

MODEL_NAME = 'Llama-2-7b-chat-hf'
SUBFOLDER = 'in_context'
SHORT_ID = 'IC'

#### Load model - finetuning

In [11]:
from transformers import pipeline, Conversation

chatbot_llama = pipeline(task='conversational', model='FemkeBakker/AmsterdamDocClassificationLlama200T3Epochs',
                   device_map='cpu', model_kwargs={'offload_buffers':True})

MODEL_NAME = 'AmsterdamDocClassificationLlama200T3Epochs'
SUBFOLDER = 'finetuning'
SHORT_ID = 'FT'
EPOCHS = 3



Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

#### Set-up paths to save predictions

In [12]:
import os

if SPLIT_COLUMN == '4split' or SPLIT_COLUMN == '2split':
    OVERVIEW_PATH = f"{cf.output_path}/predictionsVal/{SUBFOLDER}/{SHORT_MODEL_NAME}/{PROMPT_NAME}/overview.pkl"
    PREDICTION_PATH = f"{cf.output_path}/predictionsVal/{SUBFOLDER}/{SHORT_MODEL_NAME}/{PROMPT_NAME}/predictions.pkl"
    
elif SPLIT_COLUMN == 'balanced_split':
    if SUBFOLDER == 'finetuning':
        OVERVIEW_PATH = f"{cf.output_path}/predictionsFinal/{SUBFOLDER}/{EPOCHS}epochs/overview.pkl"
        PREDICTION_PATH = f"{cf.output_path}/predictionsFinal/{SUBFOLDER}/{EPOCHS}epochs/{SHORT_MODEL_NAME}First{FRONT_THRESHOLD}Last{BACK_THRESHOLD}Predictions.pkl"

    elif SUBFOLDER == 'in_context':
        OVERVIEW_PATH = f"{cf.output_path}/predictionsFinal/{SUBFOLDER}/{SHORT_MODEL_NAME}/overview.pkl"
        PREDICTION_PATH = f"{cf.output_path}/predictionsFinal/{SUBFOLDER}/{SHORT_MODEL_NAME}/{PROMPT_NAME}/First{FRONT_THRESHOLD}Last{BACK_THRESHOLD}Predictions.pkl"

print(OVERVIEW_PATH)
print(PREDICTION_PATH)

if not os.path.isdir(os.path.dirname(os.path.abspath(OVERVIEW_PATH))):
    raise ValueError("Folder to OVERVIEW_PATH does not exist") 
if not os.path.isdir(os.path.dirname(os.path.abspath(PREDICTION_PATH))):
    raise ValueError("Folder to PREDICTION_PATH does not exist") 

run_id = f'{SHORT_ID}_{MODEL_NAME}{PROMPT_NAME}{TOKENS_COL}{FRONT_THRESHOLD}_{BACK_THRESHOLD}{TRAIN_SET}{TEST_SET}_numEx{NUMBER_EXAMPLES}'
print ('\n', run_id)


/home/azureuser/cloudfiles/code/blobfuse/raadsinformatie/processed_data/woo_document_classification/predictionsFinal/finetuning/3epochs/overview.pkl
/home/azureuser/cloudfiles/code/blobfuse/raadsinformatie/processed_data/woo_document_classification/predictionsFinal/finetuning/3epochs/LlamaFirst200Last0Predictions.pkl

 FT_AmsterdamDocClassificationLlama200T3Epochszeroshot_prompt_mistral_llamaLlamaTokens200_0traintest_numEx0


#### Run experiment

In [13]:
# add new column with truncated text -> new dataframe with column + new column name
trunc_df = tf.add_truncation_column(txt,'text', TOKENS_COL, FRONT_THRESHOLD, BACK_THRESHOLD)

# if new run MAKE SURE RUN_ID IS UNIQUE, if want to resume run, pass in that run_id
run_experiment(chatbot_llama, trunc_df, run_id, PROMPT, TEXT_COLUMN, SPLIT_COLUMN, TRAIN_SET, TEST_SET, LABEL_COLUMN, PREDICTION_PATH, OVERVIEW_PATH, MODEL_NAME, NUMBER_EXAMPLES)


Run-id already known, resuming predictions...
Starting...0:10 out of 520
label:  actualiteit
response:  {'categorie': Actualiteit} 
prediction: actualiteit
label:  brief
response:  {'categorie': Brief} 
prediction: brief
label:  schriftelijke vraag
response:  {'categorie': Schriftelijke Vraag} 
prediction: schriftelijke vraag
label:  factsheet
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
label:  besluit
response:  {'categorie': Besluit} 
prediction: besluit
label:  onderzoeksrapport
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
label:  raadsadres
response:  {'categorie': Raadsadres} 
prediction: raadsadres
label:  actualiteit
response:  {'categorie': Actualiteit} 
prediction: actualiteit
label:  schriftelijke vraag
response:  {'categorie': Schriftelijke Vraag} 
prediction: schriftelijke vraag
label:  raadsadres
response:  {'categorie': Raadsadres} 
prediction: raadsadres
Dont interrupt, saving predictions...
Saving done! Interr

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Saving done! Interrupting is allowed.
Accuracy:  0.8593406593406593
Starting...330:340 out of 520
label:  raadsnotulen
response:  {'categorie': Raadsnotulen} 
prediction: raadsnotulen
label:  besluit
response:  {'categorie': Besluit} 
prediction: besluit
label:  actualiteit
response:  {'categorie': Actualiteit} 
prediction: actualiteit
label:  agenda
response:  {'categorie': Agenda} 
prediction: agenda
label:  agenda
response:  {'categorie': Agenda} 
prediction: agenda
label:  onderzoeksrapport
response:  {'categorie': Raadsadres} 
prediction: raadsadres
label:  voordracht
response:  {'categorie': Voordracht} 
prediction: voordracht
label:  raadsnotulen
response:  {'categorie': Raadsnotulen} 
prediction: raadsnotulen
label:  raadsnotulen
response:  {'categorie': Raadsnotulen} 
prediction: raadsnotulen
label:  actualiteit
response:  {'categorie': Actualiteit} 
prediction: actualiteit
Dont interrupt, saving predictions...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Saving done! Interrupting is allowed.
Accuracy:  0.8597826086956522
Starting...340:350 out of 520
label:  besluit
response:  {'categorie': Besluit} 
prediction: besluit
label:  raadsnotulen
response:  {'categorie': Raadsnotulen} 
prediction: raadsnotulen
label:  agenda
response:  {'categorie': Agenda} 
prediction: agenda
label:  schriftelijke vraag
response:  {'categorie': Raadsadres} 
prediction: raadsadres
label:  brief
response:  {'categorie': Brief} 
prediction: brief
label:  agenda
response:  {'categorie': Agenda} 
prediction: agenda
label:  onderzoeksrapport
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
label:  motie
response:  {'categorie': Motie} 
prediction: motie
label:  agenda
response:  {'categorie': Agenda} 
prediction: agenda
label:  agenda
response:  {'categorie': Agenda} 
prediction: agenda
Dont interrupt, saving predictions...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Saving done! Interrupting is allowed.
Accuracy:  0.8602150537634409
Starting...350:360 out of 520
label:  onderzoeksrapport
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
label:  factsheet
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
label:  besluit
response:  {'categorie': Besluit} 
prediction: besluit
label:  voordracht
response:  {'categorie': Voordracht} 
prediction: voordracht
label:  actualiteit
response:  {'categorie': Actualiteit} 
prediction: actualiteit
label:  onderzoeksrapport
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
label:  raadsadres
response:  {'categorie': Raadsadres} 
prediction: raadsadres
label:  motie
response:  {'categorie': Motie} 
prediction: motie
label:  besluit
response:  {'categorie': Besluit} 
prediction: besluit
label:  agenda
response:  {'categorie': Agenda} 
prediction: agenda
Dont interrupt, saving predictions...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Saving done! Interrupting is allowed.
Accuracy:  0.8606382978723405
Starting...360:370 out of 520
label:  factsheet
response:  {'categorie': Factsheet} 
prediction: factsheet
label:  agenda
response:  {'categorie': Agenda} 
prediction: agenda
label:  actualiteit
response:  {'categorie': Actualiteit} 
prediction: actualiteit
label:  voordracht
response:  {'categorie': Voordracht} 
prediction: voordracht
label:  besluit
response:  {'categorie': Besluit} 
prediction: besluit
label:  brief
response:  {'categorie': Actualiteit} 
prediction: actualiteit
label:  raadsnotulen
response:  {'categorie': Raadsnotulen} 
prediction: raadsnotulen
label:  motie
response:  {'categorie': Motie} 
prediction: motie
label:  brief
response:  {'categorie': Brief} 
prediction: brief
label:  agenda
response:  {'categorie': Agenda} 
prediction: agenda
Dont interrupt, saving predictions...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Saving done! Interrupting is allowed.
Accuracy:  0.8610526315789474
Starting...370:380 out of 520
label:  brief
response:  {'categorie': Brief} 
prediction: brief
label:  raadsadres
response:  {'categorie': Raadsadres} 
prediction: raadsadres
label:  factsheet
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
label:  schriftelijke vraag
response:  {'categorie': Schriftelijke Vraag} 
prediction: schriftelijke vraag
label:  brief
response:  {'categorie': Brief} 
prediction: brief
label:  schriftelijke vraag
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
label:  raadsnotulen
response:  {'categorie': Raadsnotulen} 
prediction: raadsnotulen
label:  agenda
response:  {'categorie': Agenda} 
prediction: agenda
label:  raadsadres
response:  {'categorie': Raadsadres} 
prediction: raadsadres
label:  brief
response:  {'categorie': Brief} 
prediction: brief
Dont interrupt, saving predictions...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Saving done! Interrupting is allowed.
Accuracy:  0.8604166666666667
Starting...380:390 out of 520
label:  raadsadres
response:  {'categorie': Raadsadres} 
prediction: raadsadres
label:  actualiteit
response:  {'categorie': Raadsadres} 
prediction: raadsadres
label:  agenda
response:  {'categorie': Agenda} 
prediction: agenda
label:  factsheet
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
label:  agenda
response:  {'categorie': Agenda} 
prediction: agenda
label:  voordracht
response:  {'categorie': Voordracht} 
prediction: voordracht
label:  voordracht
response:  {'categorie': Voordracht} 
prediction: voordracht
label:  motie
response:  {'categorie': Motie} 
prediction: motie
label:  onderzoeksrapport
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
label:  besluit
response:  {'categorie': Besluit} 
prediction: besluit
Dont interrupt, saving predictions...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Saving done! Interrupting is allowed.
Accuracy:  0.8597938144329897
Starting...390:400 out of 520
label:  schriftelijke vraag
response:  {'categorie': Schriftelijke Vraag} 
prediction: schriftelijke vraag
label:  onderzoeksrapport
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
label:  besluit
response:  {'categorie': Besluit} 
prediction: besluit
label:  besluit
response:  {'categorie': Besluit} 
prediction: besluit
label:  factsheet
response:  {'categorie': Factsheet} 
prediction: factsheet
label:  raadsadres
response:  {'categorie': Raadsadres} 
prediction: raadsadres
label:  actualiteit
response:  {'categorie': Agenda} 
prediction: agenda
label:  motie
response:  {'categorie': Motie} 
prediction: motie
label:  motie
response:  {'categorie': Motie} 
prediction: motie
label:  besluit
response:  {'categorie': Besluit} 
prediction: besluit
Dont interrupt, saving predictions...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Saving done! Interrupting is allowed.
Accuracy:  0.860204081632653
Starting...400:410 out of 520
label:  agenda
response:  {'categorie': Agenda} 
prediction: agenda
label:  voordracht
response:  {'categorie': Voordracht} 
prediction: voordracht
label:  raadsnotulen
response:  {'categorie': Raadsnotulen} 
prediction: raadsnotulen
label:  motie
response:  {'categorie': Motie} 
prediction: motie
label:  voordracht
response:  {'categorie': Voordracht} 
prediction: voordracht
label:  voordracht
response:  {'categorie': Voordracht} 
prediction: voordracht
label:  onderzoeksrapport
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
label:  raadsadres
response:  {'categorie': Raadsadres} 
prediction: raadsadres
label:  brief
response:  {'categorie': Brief} 
prediction: brief
label:  actualiteit
response:  {'categorie': Actualiteit} 
prediction: actualiteit
Dont interrupt, saving predictions...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Saving done! Interrupting is allowed.
Accuracy:  0.8616161616161616
Starting...410:420 out of 520
label:  motie
response:  {'categorie': Motie} 
prediction: motie
label:  agenda
response:  {'categorie': Agenda} 
prediction: agenda
label:  onderzoeksrapport
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
label:  raadsadres
response:  {'categorie': Raadsadres} 
prediction: raadsadres
label:  raadsnotulen
response:  {'categorie': Raadsnotulen} 
prediction: raadsnotulen
label:  motie
response:  {'categorie': Motie} 
prediction: motie
label:  agenda
response:  {'categorie': Agenda} 
prediction: agenda
label:  besluit
response:  {'categorie': Besluit} 
prediction: besluit
label:  onderzoeksrapport
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
label:  schriftelijke vraag
response:  {'categorie': Schriftelijke Vraag} 
prediction: schriftelijke vraag
Dont interrupt, saving predictions...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Saving done! Interrupting is allowed.
Accuracy:  0.863
Starting...420:430 out of 520
label:  factsheet
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
label:  actualiteit
response:  {'categorie': Actualiteit} 
prediction: actualiteit
label:  voordracht
response:  {'categorie': Voordracht} 
prediction: voordracht
label:  factsheet
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
label:  agenda
response:  {'categorie': Agenda} 
prediction: agenda
label:  voordracht
response:  {'categorie': Voordracht} 
prediction: voordracht
label:  besluit
response:  {'categorie': Besluit} 
prediction: besluit
label:  schriftelijke vraag
response:  {'categorie': Schriftelijke Vraag} 
prediction: schriftelijke vraag
label:  onderzoeksrapport
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
label:  onderzoeksrapport
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
Dont interrupt, saving predictions...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Saving done! Interrupting is allowed.
Accuracy:  0.8623762376237624
Starting...430:440 out of 520
label:  actualiteit
response:  {'categorie': Besluit} 
prediction: besluit
label:  brief
response:  {'categorie': Brief} 
prediction: brief
label:  motie
response:  {'categorie': Motie} 
prediction: motie
label:  motie
response:  {'categorie': Motie} 
prediction: motie
label:  schriftelijke vraag
response:  {'categorie': Schriftelijke Vraag} 
prediction: schriftelijke vraag
label:  schriftelijke vraag
response:  {'categorie': Schriftelijke Vraag} 
prediction: schriftelijke vraag
label:  brief
response:  {'categorie': Brief} 
prediction: brief
label:  onderzoeksrapport
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
label:  motie
response:  {'categorie': Motie} 
prediction: motie
label:  raadsnotulen
response:  {'categorie': Raadsnotulen} 
prediction: raadsnotulen
Dont interrupt, saving predictions...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Saving done! Interrupting is allowed.
Accuracy:  0.8627450980392157
Starting...440:450 out of 520
label:  raadsnotulen
response:  {'categorie': Raadsnotulen} 
prediction: raadsnotulen
label:  besluit
response:  {'categorie': Besluit} 
prediction: besluit
label:  factsheet
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
label:  motie
response:  {'categorie': Motie} 
prediction: motie
label:  agenda
response:  {'categorie': Agenda} 
prediction: agenda
label:  actualiteit
response:  {'categorie': Actualiteit} 
prediction: actualiteit
label:  besluit
response:  {'categorie': Besluit} 
prediction: besluit
label:  motie
response:  {'categorie': Motie} 
prediction: motie
label:  schriftelijke vraag
response:  {'categorie': Schriftelijke Vraag} 
prediction: schriftelijke vraag
label:  motie
response:  {'categorie': Motie} 
prediction: motie
Dont interrupt, saving predictions...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Saving done! Interrupting is allowed.
Accuracy:  0.8631067961165049
Starting...450:460 out of 520
label:  voordracht
response:  {'categorie': Voordracht} 
prediction: voordracht
label:  besluit
response:  {'categorie': Besluit} 
prediction: besluit
label:  brief
response:  {'categorie': Brief} 
prediction: brief
label:  schriftelijke vraag
response:  {'categorie': Schriftelijke Vraag} 
prediction: schriftelijke vraag
label:  raadsnotulen
response:  {'categorie': Raadsnotulen} 
prediction: raadsnotulen
label:  raadsadres
response:  {'categorie': Raadsadres} 
prediction: raadsadres
label:  raadsnotulen
response:  {'categorie': Raadsnotulen} 
prediction: raadsnotulen
label:  agenda
response:  {'categorie': Agenda} 
prediction: agenda
label:  besluit
response:  {'categorie': Besluit} 
prediction: besluit
label:  motie
response:  {'categorie': Motie} 
prediction: motie
Dont interrupt, saving predictions...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Saving done! Interrupting is allowed.
Accuracy:  0.864423076923077
Starting...460:470 out of 520
label:  onderzoeksrapport
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
label:  brief
response:  {'categorie': Brief} 
prediction: brief
label:  actualiteit
response:  {'categorie': Actualiteit} 
prediction: actualiteit
label:  agenda
response:  {'categorie': Agenda} 
prediction: agenda
label:  factsheet
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
label:  raadsnotulen
response:  {'categorie': Raadsnotulen} 
prediction: raadsnotulen
label:  raadsnotulen
response:  {'categorie': Raadsnotulen} 
prediction: raadsnotulen
label:  raadsadres
response:  {'categorie': Raadsadres} 
prediction: raadsadres
label:  schriftelijke vraag
response:  {'categorie': Schriftelijke Vraag} 
prediction: schriftelijke vraag
label:  agenda
response:  {'categorie': Agenda} 
prediction: agenda
Dont interrupt, saving predictions...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Saving done! Interrupting is allowed.
Accuracy:  0.8647619047619047
Starting...470:480 out of 520
label:  raadsnotulen
response:  {'categorie': Raadsnotulen} 
prediction: raadsnotulen
label:  raadsnotulen
response:  {'categorie': Raadsnotulen} 
prediction: raadsnotulen
label:  raadsadres
response:  {'categorie': Raadsadres} 
prediction: raadsadres
label:  factsheet
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
label:  schriftelijke vraag
response:  {'categorie': Schriftelijke Vraag} 
prediction: schriftelijke vraag
label:  actualiteit
response:  {'categorie': Actualiteit} 
prediction: actualiteit
label:  besluit
response:  {'categorie': Besluit} 
prediction: besluit
label:  raadsnotulen
response:  {'categorie': Raadsnotulen} 
prediction: raadsnotulen
label:  agenda
response:  {'categorie': Agenda} 
prediction: agenda
label:  actualiteit
response:  {'categorie': Actualiteit} 
prediction: actualiteit
Dont interrupt, saving predictions...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Saving done! Interrupting is allowed.
Accuracy:  0.8650943396226415
Starting...480:490 out of 520
label:  brief
response:  {'categorie': Brief} 
prediction: brief
label:  onderzoeksrapport
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
label:  factsheet
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
label:  brief
response:  {'categorie': Brief} 
prediction: brief
label:  voordracht
response:  {'categorie': Voordracht} 
prediction: voordracht
label:  actualiteit
response:  {'categorie': Actualiteit} 
prediction: actualiteit
label:  motie
response:  {'categorie': Motie} 
prediction: motie
label:  onderzoeksrapport
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
label:  raadsnotulen
response:  {'categorie': Raadsnotulen} 
prediction: raadsnotulen
label:  motie
response:  {'categorie': Motie} 
prediction: motie
Dont interrupt, saving predictions...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Saving done! Interrupting is allowed.
Accuracy:  0.8654205607476636
Starting...490:500 out of 520
label:  voordracht
response:  {'categorie': Voordracht} 
prediction: voordracht
label:  onderzoeksrapport
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
label:  raadsadres
response:  {'categorie': Raadsadres} 
prediction: raadsadres
label:  factsheet
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
label:  factsheet
response:  {'categorie': Factsheet} 
prediction: factsheet
label:  raadsnotulen
response:  {'categorie': Raadsnotulen} 
prediction: raadsnotulen
label:  voordracht
response:  {'categorie': Voordracht} 
prediction: voordracht
label:  raadsnotulen
response:  {'categorie': Raadsnotulen} 
prediction: raadsnotulen
label:  raadsnotulen
response:  {'categorie': Raadsnotulen} 
prediction: raadsnotulen
label:  raadsadres
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
Dont interrupt, saving predictions...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Saving done! Interrupting is allowed.
Accuracy:  0.8648148148148148
Starting...500:510 out of 520
label:  motie
response:  {'categorie': Motie} 
prediction: motie
label:  agenda
response:  {'categorie': Agenda} 
prediction: agenda
label:  onderzoeksrapport
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
label:  brief
response:  {'categorie': Brief} 
prediction: brief
label:  voordracht
response:  {'categorie': Voordracht} 
prediction: voordracht
label:  agenda
response:  {'categorie': Agenda} 
prediction: agenda
label:  besluit
response:  {'categorie': Besluit} 
prediction: besluit
label:  raadsnotulen
response:  {'categorie': Raadsnotulen} 
prediction: raadsnotulen
label:  onderzoeksrapport
response:  {'categorie': Onderzoeksrapport} 
prediction: onderzoeksrapport
label:  voordracht
response:  {'categorie': Voordracht} 
prediction: voordracht
Dont interrupt, saving predictions...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Saving done! Interrupting is allowed.
Accuracy:  0.8660550458715597
Starting...last 10 docs
label:  besluit
response:  {'categorie': Besluit} 
prediction: besluit
label:  motie
response:  {'categorie': Motie} 
prediction: motie
label:  brief
response:  {'categorie': Brief} 
prediction: brief
label:  voordracht
response:  {'categorie': Voordracht} 
prediction: voordracht
label:  schriftelijke vraag
response:  {'categorie': Raadsadres} 
prediction: raadsadres
label:  agenda
response:  {'categorie': Agenda} 
prediction: agenda
label:  voordracht
response:  {'categorie': Voordracht} 
prediction: voordracht
label:  agenda
response:  {'categorie': Agenda} 
prediction: agenda
label:  raadsadres
response:  {'categorie': Raadsadres} 
prediction: raadsadres
label:  raadsnotulen
response:  {'categorie': Raadsnotulen} 
prediction: raadsnotulen
Dont interrupt, saving predictions...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Saving done! Interrupting is allowed.
Accuracy:  0.8663636363636363


In [28]:
pred = pd.read_pickle(OVERVIEW_PATH)
display(pred)

Unnamed: 0,model,run_id,date,train_set,test_set,train_set_support,test_set_support,split_col,text_col,runtime,accuracy,macro_avg_precision,macro_avg_recall,macro_avg_f1,weighted_avg_precision,weighted_avg_recall,weighted_avg_f1,classification_report
1,AmsterdamDocClassificationLlama200T3Epochs,FT_AmsterdamDocClassificationLlama200T3Epochsz...,2024-06-04 21:52:00.725048+02:00,train,test,9900,1100,balanced_split,TruncationLlamaTokensFront200Back0,22934.571034,0.866364,0.830406,0.794167,0.783231,0.905898,0.866364,0.854434,precision recall f1-...


### Mistral

In [10]:
SHORT_MODEL_NAME = 'Mistral'
PROMPT = pt.zeroshot_prompt_mistral_llama
PROMPT_NAME = ph.get_promptfunction_name(PROMPT)
TOKENS_COL = 'LlamaTokens' # column with text split using tokenizer of either mistral (MistralTokens) or Llama (LlamaTokens). Using Llama, because Llama split into more tokens. 
FRONT_THRESHOLD = 200
BACK_THRESHOLD = 0

if PROMPT==pt.zeroshot_prompt_mistral_llama:
    NUMBER_EXAMPLES = 0
elif PROMPT == pt.fewshot_prompt_bm25:
    NUMBER_EXAMPLES = 2



#### Load model - In-context learning
Note - ONLY load one model: either in-context or fine-tuning

In [None]:
from transformers import pipeline, Conversation

chatbot_mistral = pipeline(task='conversational', model='mistralai/Mistral-7B-Instruct-v0.2',
                   device_map='cpu', model_kwargs={'offload_buffers':True})

MODEL_NAME = 'Mistral-7B-Instruct-v0.2'
SUBFOLDER = 'in_context'
SHORT_ID = 'IC'


#### Load model - finetuning

In [11]:
from transformers import pipeline, Conversation

chatbot_mistral = pipeline(task='conversational', model='FemkeBakker/AmsterdamDocClassificationMistral200T3Epochs',
                   device_map='cpu', model_kwargs={'offload_buffers':True})

MODEL_NAME = 'AmsterdamDocClassificationMistral200T3Epochs'
SUBFOLDER = 'finetuning'
SHORT_ID = 'FT'
EPOCHS = 3



Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

#### Set-up paths to save predictions

In [12]:
import os

if SPLIT_COLUMN == '4split' or SPLIT_COLUMN == '2split':
    OVERVIEW_PATH = f"{cf.output_path}/predictionsVal/{SUBFOLDER}/{SHORT_MODEL_NAME}/{PROMPT_NAME}/overview.pkl"
    PREDICTION_PATH = f"{cf.output_path}/predictionsVal/{SUBFOLDER}/{SHORT_MODEL_NAME}/{PROMPT_NAME}/predictions.pkl"
    
elif SPLIT_COLUMN == 'balanced_split':
    if SUBFOLDER == 'finetuning':
        OVERVIEW_PATH = f"{cf.output_path}/predictionsFinal/{SUBFOLDER}/{EPOCHS}epochs/overview.pkl"
        PREDICTION_PATH = f"{cf.output_path}/predictionsFinal/{SUBFOLDER}/{EPOCHS}epochs/{SHORT_MODEL_NAME}First{FRONT_THRESHOLD}Last{BACK_THRESHOLD}Predictions.pkl"

    elif SUBFOLDER == 'in_context':
        OVERVIEW_PATH = f"{cf.output_path}/predictionsFinal/{SUBFOLDER}/{SHORT_MODEL_NAME}/overview.pkl"
        PREDICTION_PATH = f"{cf.output_path}/predictionsFinal/{SUBFOLDER}/{SHORT_MODEL_NAME}/{PROMPT_NAME}/First{FRONT_THRESHOLD}Last{BACK_THRESHOLD}Predictions.pkl"

print(OVERVIEW_PATH)
print(PREDICTION_PATH)

if not os.path.isdir(os.path.dirname(os.path.abspath(OVERVIEW_PATH))):
    raise ValueError("Folder to OVERVIEW_PATH does not exist") 
if not os.path.isdir(os.path.dirname(os.path.abspath(PREDICTION_PATH))):
    raise ValueError("Folder to PREDICTION_PATH does not exist") 

run_id = f'{SHORT_ID}_{MODEL_NAME}{PROMPT_NAME}{TOKENS_COL}{FRONT_THRESHOLD}_{BACK_THRESHOLD}{TRAIN_SET}{TEST_SET}_numEx{NUMBER_EXAMPLES}'
print ('\n', run_id)


/home/azureuser/cloudfiles/code/blobfuse/raadsinformatie/processed_data/woo_document_classification/predictionsFinal/finetuning/3epochs/overview.pkl
/home/azureuser/cloudfiles/code/blobfuse/raadsinformatie/processed_data/woo_document_classification/predictionsFinal/finetuning/3epochs/MistralFirst200Last0Predictions.pkl

 FT_AmsterdamDocClassificationMistral200T3Epochszeroshot_prompt_mistral_llamaLlamaTokens200_0traintest_numEx0


#### Run experiment

In [14]:
# run experiment

# add new column with truncated text -> new dataframe with column + new column name
trunc_df = tf.add_truncation_column(txt,'text', TOKENS_COL, FRONT_THRESHOLD, BACK_THRESHOLD)

# if new run MAKE SURE RUN_ID IS UNIQUE, if want to resume run, pass in that run_id
run_experiment(chatbot_mistral, trunc_df, run_id, PROMPT, TEXT_COLUMN, SPLIT_COLUMN, TRAIN_SET, TEST_SET, LABEL_COLUMN, PREDICTION_PATH, OVERVIEW_PATH, MODEL_NAME, NUMBER_EXAMPLES)


Run-id already known, resuming predictions...


In [13]:
pred = pd.read_pickle(OVERVIEW_PATH)
display(pred)


Unnamed: 0,model,run_id,date,train_set,test_set,train_set_support,test_set_support,split_col,text_col,runtime,accuracy,macro_avg_precision,macro_avg_recall,macro_avg_f1,weighted_avg_precision,weighted_avg_recall,weighted_avg_f1,classification_report
1,AmsterdamDocClassificationLlama200T3Epochs,FT_AmsterdamDocClassificationLlama200T3Epochsz...,2024-06-04 21:52:00.725048+02:00,train,test,9900,1100,balanced_split,TruncationLlamaTokensFront200Back0,22934.571034,0.866364,0.830406,0.794167,0.783231,0.905898,0.866364,0.854434,precision recall f1-...
0,AmsterdamDocClassificationMistral200T3Epochs,FT_AmsterdamDocClassificationMistral200T3Epoch...,2024-06-05 06:42:48.454727+02:00,train,test,9900,1100,balanced_split,TruncationLlamaTokensFront200Back0,31472.814609,0.000909,0.076923,0.000769,0.001523,0.090909,0.000909,0.0018,precision...
