In [1]:
!bash /home/azureuser/cloudfiles/code/blobfuse/blobfuse_raadsinformatie.sh

In [9]:
import sys
sys.path.append("..")

# Select where to run notebook: "azure" or "local"
my_run = "azure"

import my_secrets as sc
import settings as st

if my_run == "azure":
    import config_azure as cf
elif my_run == "local":
    import config as cf


import os
if my_run == "azure":
    if not os.path.exists(cf.HUGGING_CACHE):
        os.mkdir(cf.HUGGING_CACHE)
    os.environ["TRANSFORMERS_CACHE"] = cf.HUGGING_CACHE


## Notebook overview
- Get insight into tokenizer, tokens and doc lengths.
- Test different text truncation thresholds on the baseline.

#### Text truncation -- overview in tokenizer/doc lengths
- tokenize text using tokenizer of mistral, geitje and Llama.
- Check if mistral and geitje indeed have the same tokenizer.
- After getting the tokens, check distribution.
- Truncate text and test multiple thresholds on baseline

Results are saved in txtfiles_tokenizer.pkl, so that txtfiles.pkl is a back-up file, in case anything gets messed up

In [2]:
import pandas as pd
df = pd.read_pickle(f"{cf.output_path}/txtfiles.pkl")

In [3]:
from transformers import AutoTokenizer

def get_tokens(model_name, df, save_to_path, text_col, new_col_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    all_texts = list(df[text_col].values)

    all_tokens = []
    all_tokens_len = []
    for txt in all_texts:
        tokens = tokenizer.tokenize(txt)
        all_tokens.append(tokens)
        all_tokens_len.append(len(tokens))

    df[new_col_name] = all_tokens
    df[f"count_{new_col_name}"] = all_tokens_len
    df.to_pickle(save_to_path)
    return df

# subdf = df.iloc[0:2]
# # display(subdf)
# get_token_length('Rijgersberg/GEITje-7B-chat-v2', subdf, f"{cf.output_path}/try_out_token_count.pkl", 'text', 'token_count_geitje')

def fraction_token(df, max_token, token_len_col):
    for col in token_len_col:
        print(f"{len(df.loc[df[col]>max_token])} out of {len(df)} ({round(len(df.loc[df[col]>max_token])/len(df)*100, 2)}%) docs exceed a token length of {max_token}")

    for col in token_len_col:
        print(df[col].describe())

    


    



##### Tokenize text

In [4]:
"""GEITje""" ## not necesarry -> since tokenizer is the same as mistral
# df = pd.read_pickle(f"{cf.output_path}/txtfiles_tokenizer.pkl")
# get_tokens('Rijgersberg/GEITje-7B-chat-v2', df, f"{cf.output_path}/txtfiles_tokenizer.pkl", 'text', 'GEITjeTokens')

"""Mistral"""
# df = pd.read_pickle(f"{cf.output_path}/txtfiles_tokenizer.pkl")
# get_tokens('mistralai/Mistral-7B-v0.1', df, f"{cf.output_path}/txtfiles_tokenizer.pkl", 'text', 'MistralTokens')

"""Llama"""
# df = pd.read_pickle(f"{cf.output_path}/txtfiles_tokenizer.pkl")
# get_tokens('meta-llama/Llama-2-7b-hf', df, f"{cf.output_path}/txtfiles_tokenizer.pkl", 'text', 'LlamaTokens')

'Llama'

##### Analyse token length of model tokenizers

In [10]:
import pandas as pd
tok = pd.read_pickle(f"{cf.output_path}/txtfiles_tokenizer.pkl")
# fraction_token(tok, 4096, ['count_MistralTokens', 'count_LlamaTokens'])

#### Test text truncation on baseline

In [11]:
import itertools
from sklearn.naive_bayes import MultinomialNB


# load file with baseline function
import sys
sys.path.append('../scripts/') 
import baseline as bf

# load file with truncation function
from truncation import add_truncation_column

from sklearn.svm import LinearSVC
# variables for text truncation
DATAFRAME = tok
TEXT_COL = 'text'
TOKENS_COL = 'LlamaTokens'

# variables for baseline
BASELINE_FUNCTION = MultinomialNB()
MODEL_NAME = 'MultinomialNB'
TRAIN_SET = 'train' # must be dev or train
TEST_SET = 'test' # must be val or test
SPLIT_COLUMN = '4split' #must be either 2split or 4split. 2split = data split into train and test. 4split = data split into train, test, dev and val. 
LABEL_COLUMN = 'label'
PREDICTION_PATH = f"{cf.output_path}/predictions/baselineTruncationPredictions.pkl"
OVERVIEW_PATH = f"{cf.output_path}/overview/baselineTruncationOverview.pkl"
# PREDICTION_PATH = f"{cf.output_path}/predictions/tryoutBaselineTruncationPredictions.pkl"
# OVERVIEW_PATH = f"{cf.output_path}/overview/tryoutBaselineTruncationOverview.pkl"
TRUNC_COLUMN = 'trunc_txt'
threshold_combinations =[(100,0), (200,0), (500,0), (1000,0), (2000,0), (100,100),(200,200), (500,500), (1000,1000), (0,100), (0,200), (0,500), (0,1000), (0,2000)]
# threshold_combinations = [(100,0)]

In [12]:
# loop through all thresholds and save predictions
for thresholds in threshold_combinations:
    front_threshold = thresholds[0]
    back_threshold = thresholds[1]
    trunc = add_truncation_column(DATAFRAME, TEXT_COL, TOKENS_COL, front_threshold,back_threshold)
    bf.run_baseline(BASELINE_FUNCTION,MODEL_NAME, trunc, SPLIT_COLUMN, TRAIN_SET, TEST_SET, TRUNC_COLUMN, LABEL_COLUMN, PREDICTION_PATH, OVERVIEW_PATH)

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


                     precision    recall  f1-score   support

        Actualiteit       1.00      0.06      0.11       152
             Agenda       0.86      0.98      0.92       528
            Besluit       0.99      0.81      0.89       113
              Brief       0.80      0.79      0.79       206
          Factsheet       1.00      0.02      0.04        45
              Motie       0.84      0.97      0.90      1545
  Onderzoeksrapport       0.75      0.68      0.71       222
         Raadsadres       0.79      0.71      0.75       313
       Raadsnotulen       0.00      0.00      0.00        42
Schriftelijke Vraag       1.00      0.94      0.97       603
         Voordracht       0.94      0.99      0.97       395

           accuracy                           0.87      4164
          macro avg       0.81      0.63      0.64      4164
       weighted avg       0.87      0.87      0.84      4164



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


                     precision    recall  f1-score   support

        Actualiteit       1.00      0.01      0.03       152
             Agenda       0.95      0.99      0.97       528
            Besluit       0.99      0.73      0.84       113
              Brief       1.00      0.08      0.15       206
          Factsheet       0.00      0.00      0.00        45
              Motie       0.70      1.00      0.82      1545
  Onderzoeksrapport       0.77      0.51      0.61       222
         Raadsadres       0.94      0.50      0.65       313
       Raadsnotulen       0.00      0.00      0.00        42
Schriftelijke Vraag       0.94      0.90      0.92       603
         Voordracht       0.97      0.98      0.98       395

           accuracy                           0.81      4164
          macro avg       0.75      0.52      0.54      4164
       weighted avg       0.83      0.81      0.77      4164



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


                     precision    recall  f1-score   support

        Actualiteit       0.00      0.00      0.00       152
             Agenda       0.97      0.95      0.96       528
            Besluit       1.00      0.19      0.31       113
              Brief       0.00      0.00      0.00       206
          Factsheet       0.00      0.00      0.00        45
              Motie       0.51      1.00      0.67      1545
  Onderzoeksrapport       0.77      0.42      0.55       222
         Raadsadres       1.00      0.04      0.08       313
       Raadsnotulen       0.00      0.00      0.00        42
Schriftelijke Vraag       1.00      0.33      0.50       603
         Voordracht       0.98      0.63      0.77       395

           accuracy                           0.63      4164
          macro avg       0.57      0.32      0.35      4164
       weighted avg       0.69      0.63      0.56      4164



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


                     precision    recall  f1-score   support

        Actualiteit       0.00      0.00      0.00       152
             Agenda       0.97      0.93      0.95       528
            Besluit       1.00      0.03      0.05       113
              Brief       0.00      0.00      0.00       206
          Factsheet       0.00      0.00      0.00        45
              Motie       0.48      1.00      0.65      1545
  Onderzoeksrapport       0.75      0.31      0.43       222
         Raadsadres       1.00      0.01      0.02       313
       Raadsnotulen       0.00      0.00      0.00        42
Schriftelijke Vraag       0.99      0.15      0.26       603
         Voordracht       0.98      0.59      0.74       395

           accuracy                           0.58      4164
          macro avg       0.56      0.27      0.28      4164
       weighted avg       0.68      0.58      0.49      4164



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


                     precision    recall  f1-score   support

        Actualiteit       0.00      0.00      0.00       152
             Agenda       0.98      0.92      0.95       528
            Besluit       1.00      0.02      0.03       113
              Brief       0.00      0.00      0.00       206
          Factsheet       0.00      0.00      0.00        45
              Motie       0.48      1.00      0.65      1545
  Onderzoeksrapport       0.64      0.21      0.32       222
         Raadsadres       1.00      0.01      0.02       313
       Raadsnotulen       0.00      0.00      0.00        42
Schriftelijke Vraag       0.99      0.11      0.20       603
         Voordracht       0.99      0.71      0.83       395

           accuracy                           0.58      4164
          macro avg       0.55      0.27      0.27      4164
       weighted avg       0.67      0.58      0.49      4164



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


                     precision    recall  f1-score   support

        Actualiteit       1.00      0.01      0.01       152
             Agenda       0.92      0.96      0.94       528
            Besluit       1.00      0.55      0.71       113
              Brief       0.93      0.12      0.21       206
          Factsheet       0.00      0.00      0.00        45
              Motie       0.65      1.00      0.79      1545
  Onderzoeksrapport       0.66      0.36      0.47       222
         Raadsadres       0.96      0.30      0.45       313
       Raadsnotulen       0.00      0.00      0.00        42
Schriftelijke Vraag       1.00      0.88      0.94       603
         Voordracht       0.98      0.99      0.99       395

           accuracy                           0.78      4164
          macro avg       0.74      0.47      0.50      4164
       weighted avg       0.81      0.78      0.73      4164



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


                     precision    recall  f1-score   support

        Actualiteit       0.00      0.00      0.00       152
             Agenda       0.96      0.97      0.96       528
            Besluit       1.00      0.23      0.37       113
              Brief       1.00      0.01      0.02       206
          Factsheet       0.00      0.00      0.00        45
              Motie       0.55      1.00      0.71      1545
  Onderzoeksrapport       0.68      0.35      0.46       222
         Raadsadres       1.00      0.09      0.16       313
       Raadsnotulen       0.00      0.00      0.00        42
Schriftelijke Vraag       1.00      0.39      0.56       603
         Voordracht       0.98      0.99      0.99       395

           accuracy                           0.68      4164
          macro avg       0.65      0.37      0.39      4164
       weighted avg       0.75      0.68      0.61      4164



: 

: 

: 

In [6]:
import pandas as pd
PREDICTION_PATH = f"{cf.output_path}/predictions/baselineTruncationPredictions.pkl"
OVERVIEW_PATH = f"{cf.output_path}/overview/baselineTruncationOverview.pkl"

yeet = pd.read_pickle(OVERVIEW_PATH)
yeet = yeet.sort_values(by=['macro_avg_f1', 'accuracy'], ascending=False)
display(yeet)

Unnamed: 0,model,date,train_set,test_set,train_set_support,test_set_support,split_col,text_col,runtime,accuracy,macro_avg_precision,macro_avg_recall,macro_avg_f1,classification_report
0,LinearSVC,2024-04-24 17:00:16.579063+02:00,train,test,15613,4164,4split,TruncationLlamaTokensFront2000Back0,80.954455,0.960375,0.951155,0.906994,0.920938,precision recall f1-...
0,LinearSVC,2024-04-24 16:58:14.136467+02:00,train,test,15613,4164,4split,TruncationLlamaTokensFront500Back0,49.083472,0.960855,0.948753,0.903805,0.91665,precision recall f1-...
0,LinearSVC,2024-04-24 16:59:07.293244+02:00,train,test,15613,4164,4split,TruncationLlamaTokensFront1000Back0,64.067199,0.961335,0.95034,0.902041,0.916227,precision recall f1-...
0,LinearSVC,2024-04-24 17:07:29.160956+02:00,train,test,15613,4164,4split,TruncationLlamaTokensFront1000Back1000,146.854178,0.961095,0.948089,0.902357,0.91512,precision recall f1-...
0,LinearSVC,2024-04-24 17:05:07.923798+02:00,train,test,15613,4164,4split,TruncationLlamaTokensFront500Back500,134.567976,0.960375,0.946173,0.900545,0.911388,precision recall f1-...
0,LinearSVC,2024-04-24 17:03:08.622085+02:00,train,test,15613,4164,4split,TruncationLlamaTokensFront200Back200,113.316262,0.959174,0.952566,0.889015,0.902993,precision recall f1-...
0,LinearSVC,2024-04-24 17:22:30.597072+02:00,train,test,15613,4164,4split,TruncationLlamaTokensFront0Back2000,244.802182,0.949087,0.932871,0.887491,0.901655,precision recall f1-...
0,LinearSVC,2024-04-24 17:01:31.788353+02:00,train,test,15613,4164,4split,TruncationLlamaTokensFront100Back100,93.574816,0.952209,0.945007,0.887995,0.901268,precision recall f1-...
0,LinearSVC,2024-04-24 16:57:41.418294+02:00,train,test,15613,4164,4split,TruncationLlamaTokensFront200Back0,28.803534,0.958213,0.94899,0.888563,0.899384,precision recall f1-...
0,LinearSVC,2024-04-24 16:57:27.306647+02:00,train,test,15613,4164,4split,TruncationLlamaTokensFront100Back0,11.678588,0.95245,0.935202,0.883517,0.897058,precision recall f1-...


In [8]:
import pandas as pd
PREDICTION_PATH = f"{cf.output_path}/predictions/baselineTruncationPredictions.pkl"
OVERVIEW_PATH = f"{cf.output_path}/overview/baselineTruncationOverview.pkl"

yeet = pd.read_pickle(OVERVIEW_PATH)
yeet = yeet.sort_values(by=['accuracy', 'macro_avg_f1'], ascending=False)
display(yeet)

Unnamed: 0,model,date,train_set,test_set,train_set_support,test_set_support,split_col,text_col,runtime,accuracy,macro_avg_precision,macro_avg_recall,macro_avg_f1,classification_report
0,LinearSVC,2024-04-24 16:59:07.293244+02:00,train,test,15613,4164,4split,TruncationLlamaTokensFront1000Back0,64.067199,0.961335,0.95034,0.902041,0.916227,precision recall f1-...
0,LinearSVC,2024-04-24 17:07:29.160956+02:00,train,test,15613,4164,4split,TruncationLlamaTokensFront1000Back1000,146.854178,0.961095,0.948089,0.902357,0.91512,precision recall f1-...
0,LinearSVC,2024-04-24 16:58:14.136467+02:00,train,test,15613,4164,4split,TruncationLlamaTokensFront500Back0,49.083472,0.960855,0.948753,0.903805,0.91665,precision recall f1-...
0,LinearSVC,2024-04-24 17:00:16.579063+02:00,train,test,15613,4164,4split,TruncationLlamaTokensFront2000Back0,80.954455,0.960375,0.951155,0.906994,0.920938,precision recall f1-...
0,LinearSVC,2024-04-24 17:05:07.923798+02:00,train,test,15613,4164,4split,TruncationLlamaTokensFront500Back500,134.567976,0.960375,0.946173,0.900545,0.911388,precision recall f1-...
0,LinearSVC,2024-04-24 17:03:08.622085+02:00,train,test,15613,4164,4split,TruncationLlamaTokensFront200Back200,113.316262,0.959174,0.952566,0.889015,0.902993,precision recall f1-...
0,LinearSVC,2024-04-24 16:57:41.418294+02:00,train,test,15613,4164,4split,TruncationLlamaTokensFront200Back0,28.803534,0.958213,0.94899,0.888563,0.899384,precision recall f1-...
0,LinearSVC,2024-04-24 16:57:27.306647+02:00,train,test,15613,4164,4split,TruncationLlamaTokensFront100Back0,11.678588,0.95245,0.935202,0.883517,0.897058,precision recall f1-...
0,LinearSVC,2024-04-24 17:01:31.788353+02:00,train,test,15613,4164,4split,TruncationLlamaTokensFront100Back100,93.574816,0.952209,0.945007,0.887995,0.901268,precision recall f1-...
0,LinearSVC,2024-04-24 17:22:30.597072+02:00,train,test,15613,4164,4split,TruncationLlamaTokensFront0Back2000,244.802182,0.949087,0.932871,0.887491,0.901655,precision recall f1-...
