In [1]:
!bash /home/azureuser/cloudfiles/code/blobfuse/blobfuse_raadsinformatie.sh

In [1]:
import sys
sys.path.append("..")

# Select where to run notebook: "azure" or "local"
my_run = "azure"

# import my_secrets as sc
# import settings as st

if my_run == "azure":
    import config_azure as cf
elif my_run == "local":
    import config as cf


import os
if my_run == "azure":
    if not os.path.exists(cf.HUGGING_CACHE):
        os.mkdir(cf.HUGGING_CACHE)
    os.environ["TRANSFORMERS_CACHE"] = cf.HUGGING_CACHE


import pandas as pd

# set-up environment - GEITje-7b-chat InContextLearning:
# - install blobfuse -> sudo apt-get install blobfuse
# - pip install transformers
# - pip install torch
# - pip install accelerate
# - pip install jupyter
# - pip install ipywidgets

### Notebook Overview
Goal: get insight into the predictions made.

In [2]:
txt = pd.read_pickle(f"{cf.output_path}/txtfiles_tokenizer.pkl")

In [70]:
from transformers import AutoTokenizer
from collections import Counter
from sklearn.metrics import classification_report
import sys
sys.path.append('../scripts/') 
import prompt_template as pt
import warnings


def get_tokens(model_name, df, text_col, new_col_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    all_texts = list(df[text_col].values)

    all_tokens = []
    all_tokens_len = []
    for txt in all_texts:
        tokens = tokenizer.tokenize(txt)
        all_tokens.append(tokens)
        all_tokens_len.append(len(tokens))

    df[new_col_name] = all_tokens
    df[f"count_{new_col_name}"] = all_tokens_len
    return df

def format_label(label):
    format = f"""{{'categorie': {label}}}"""
    return format
    

def get_response_length(df, model_name):
    df = get_tokens(model_name, df, 'response', 'responseTokens')
    df['label_formatted'] = df['label'].apply(format_label)
    df = get_tokens(model_name, df, 'label_formatted', 'label_formattedTokens')
    print("RESPONSE LENGTH")
    print(f"Average tokens of IDEAL response: {round(df['count_label_formattedTokens'].describe()['mean'],1)} tokens (std = {round(df['count_label_formattedTokens'].describe()['std'], 1)}) ")
    print(f"Average tokens of PREDICTION response: {round(df['count_responseTokens'].describe()['mean'],1)} tokens (std = {round(df['count_responseTokens'].describe()['std'], 1)}) \n")
   

def prediction_errors(df):
    error_names = ['NoPredictionInOutput', 'MultiplePredictionErrorInFormatting','NoPredictionFormat', 'MultiplePredictionErrorInOutput']

    # only select row that have prediction error -> response of which a prediction could not be extracted. 
    errors_df = df.loc[df['prediction'].isin(error_names)]

    # count for each error the instances
    count = dict(Counter(errors_df['prediction']))

    # check if all errors are included, else set that error to 0
    for error in error_names:
        if error not in count.keys():
            count[error]=0

    # format into df to display
    class_count = errors_df.groupby('prediction')['label'].value_counts().reset_index(name='count')

    classes_in_responses = []
    correct_class_in_response = []
    for index, row in errors_df.iterrows():
        # for each response, return list with all labels that are named in response
        classes_in_response = [category.lower() for category in pt.get_class_list() if category.lower() in row['response'].lower()]
        classes_in_responses.append(classes_in_response)

        # for each response, check if true label is named in response.
        if row['label'].lower() in classes_in_response:
            correct_class_in_response.append(True)
        else:
            correct_class_in_response.append(False)

    # count how many classes are named in a response
    amount_of_classes = dict(Counter([len(response) for response in classes_in_responses]))

    # fomat print statement
    print_amount_of_classes = ''
    for amount in amount_of_classes.keys():
        print_amount_of_classes += f'There are {amount_of_classes[amount]} response that contain {amount} classes.'

        
    print('PREDICTION ERRORS')
    print(count)
    print(print_amount_of_classes)
    print(f"{correct_class_in_response.count(True)} responses out of {len(errors_df)} ({round(correct_class_in_response.count(True)/len(errors_df)*100,1)}%) prediction errors contain the correct label.")
    display(class_count)

def evaluation_metrics(df):
    warnings.filterwarnings("ignore", category=UserWarning)
    report = classification_report(df['label'], df['prediction'])
    print('EVALUATION METRICS')
    print(report)




get_response_length(predictions, 'Rijgersberg/GEITje-7B-chat-v2')
prediction_errors(predictions)
evaluation_metrics(predictions)



RESPONSE LENGTH
Average tokens of IDEAL response: 9.3 tokens (std = 1.6) 
Average tokens of PREDICTION response: 16.1 tokens (std = 5.8) 

PREDICTION ERRORS
{'NoPredictionFormat': 107, 'MultiplePredictionErrorInOutput': 4, 'NoPredictionInOutput': 7, 'MultiplePredictionErrorInFormatting': 0}
There are 105 response that contain 1 classes.There are 9 response that contain 0 classes.There are 4 response that contain 2 classes.
84 responses out of 118 (71.2%) prediction errors contain the correct label.


Unnamed: 0,prediction,label,count
0,MultiplePredictionErrorInOutput,besluit,4
1,NoPredictionFormat,schriftelijke vraag,41
2,NoPredictionFormat,motie,22
3,NoPredictionFormat,raadsadres,13
4,NoPredictionFormat,actualiteit,12
5,NoPredictionFormat,besluit,6
6,NoPredictionFormat,onderzoeksrapport,5
7,NoPredictionFormat,agenda,3
8,NoPredictionFormat,factsheet,3
9,NoPredictionFormat,brief,1


EVALUATION METRICS
                                 precision    recall  f1-score   support

MultiplePredictionErrorInOutput       0.00      0.00      0.00         0
             NoPredictionFormat       0.00      0.00      0.00         0
           NoPredictionInOutput       0.00      0.00      0.00         0
                    actualiteit       0.92      0.46      0.61       100
                         agenda       0.93      0.74      0.82       100
                        besluit       0.69      0.51      0.59       100
                          brief       0.54      0.88      0.67       100
                      factsheet       0.68      0.54      0.60       100
                          motie       0.96      0.45      0.61       100
              onderzoeksrapport       0.71      0.10      0.18       100
                     raadsadres       0.94      0.29      0.44       100
                   raadsnotulen       0.37      1.00      0.54       100
            schriftelijke vraag

#### GEITje

In [5]:
# in-context learning: zero-
predictions = pd.read_pickle(f"{cf.output_path}/predictionsFinal/in_context/GEITje/zeroshot_prompt_geitje/First100Last0Predictions.pkl")
display(predictions)

Unnamed: 0,id,path,text_column,prompt_function,response,prediction,label,runtime,date,prompt,run_id,train_set,test_set,shots
1,32939,/home/azureuser/cloudfiles/code/blobfuse/raads...,TruncationLlamaTokensFront100Back0,zeroshot_prompt_geitje,"{\n ""categorie"": ""Factsheet""\n}",factsheet,factsheet,28.125553,2024-05-17 09:56:06.381670+02:00,\n Classificeer het document in één van de ...,IC_GEITje-7B-chat-v2zeroshot_prompt_geitjeLlam...,train,test,0
2,33085,/home/azureuser/cloudfiles/code/blobfuse/raads...,TruncationLlamaTokensFront100Back0,zeroshot_prompt_geitje,"{\n ""categorie"": ""Onderzoeksrapport""\n}",onderzoeksrapport,factsheet,33.540810,2024-05-17 09:56:39.924174+02:00,\n Classificeer het document in één van de ...,IC_GEITje-7B-chat-v2zeroshot_prompt_geitjeLlam...,train,test,0
3,22985,/home/azureuser/cloudfiles/code/blobfuse/raads...,TruncationLlamaTokensFront100Back0,zeroshot_prompt_geitje,"{\n ""categorie"": ""Raadsinformatiebrief""\n}",brief,brief,35.132770,2024-05-17 09:57:15.058680+02:00,\n Classificeer het document in één van de ...,IC_GEITje-7B-chat-v2zeroshot_prompt_geitjeLlam...,train,test,0
4,32991,/home/azureuser/cloudfiles/code/blobfuse/raads...,TruncationLlamaTokensFront100Back0,zeroshot_prompt_geitje,"{\n ""categorie"": ""Actualiteit""\n}",actualiteit,factsheet,30.356459,2024-05-17 09:57:45.417157+02:00,\n Classificeer het document in één van de ...,IC_GEITje-7B-chat-v2zeroshot_prompt_geitjeLlam...,train,test,0
5,26317,/home/azureuser/cloudfiles/code/blobfuse/raads...,TruncationLlamaTokensFront100Back0,zeroshot_prompt_geitje,"{\n ""categorie"": ""Raadsnotulen""\n}",raadsnotulen,raadsnotulen,33.297573,2024-05-17 09:58:18.716722+02:00,\n Classificeer het document in één van de ...,IC_GEITje-7B-chat-v2zeroshot_prompt_geitjeLlam...,train,test,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6,28826,/home/azureuser/cloudfiles/code/blobfuse/raads...,TruncationLlamaTokensFront100Back0,zeroshot_prompt_geitje,"{\n ""categorie"": ""Raadsnotulen""\n}",raadsnotulen,voordracht,30.467479,2024-05-17 18:47:42.568726+02:00,\n Classificeer het document in één van de ...,IC_GEITje-7B-chat-v2zeroshot_prompt_geitjeLlam...,train,test,0
7,25722,/home/azureuser/cloudfiles/code/blobfuse/raads...,TruncationLlamaTokensFront100Back0,zeroshot_prompt_geitje,"{\n ""categorie"": ""Agenda""\n}",agenda,agenda,25.089451,2024-05-17 18:48:07.660346+02:00,\n Classificeer het document in één van de ...,IC_GEITje-7B-chat-v2zeroshot_prompt_geitjeLlam...,train,test,0
8,23998,/home/azureuser/cloudfiles/code/blobfuse/raads...,TruncationLlamaTokensFront100Back0,zeroshot_prompt_geitje,"{\n ""categorie"": ""Raadsadres""\n}",raadsadres,raadsadres,28.899608,2024-05-17 18:48:36.561795+02:00,\n Classificeer het document in één van de ...,IC_GEITje-7B-chat-v2zeroshot_prompt_geitjeLlam...,train,test,0
9,26414,/home/azureuser/cloudfiles/code/blobfuse/raads...,TruncationLlamaTokensFront100Back0,zeroshot_prompt_geitje,"{\n ""categorie"": ""Raadsnotulen""\n}",raadsnotulen,raadsnotulen,30.416466,2024-05-17 18:49:06.979859+02:00,\n Classificeer het document in één van de ...,IC_GEITje-7B-chat-v2zeroshot_prompt_geitjeLlam...,train,test,0
