In [13]:
# activate autoreload
%load_ext autoreload
%autoreload 2

# check if session is in Google Colab
try:
    import google.colab
    IN_COLAB = True
    print('Google Colab session!')
except:
    IN_COLAB = False
    print('Not a Google Colab session.')

# add src path to the notebook
import os
import sys
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    PROJECT_ROOT: str = '/content/drive/MyDrive/papers/2025b_relevance_2.0'
    !pip install contextily esda deep-translator h3pandas h3~=3.0 datasets optuna setfit
else:
    PROJECT_ROOT: str = os.path.dirname(os.path.abspath(os.path.dirname("__file__")))
if PROJECT_ROOT not in sys.path:
    sys.path.append(os.path.join(PROJECT_ROOT))
print(PROJECT_ROOT)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Not a Google Colab session.
/n/netscratch/cga/Lab/bresch/david/relevance_2.0


# Fine-Tuning with Task-specific Data
In this notebook, we conduct fine-tuning with our labelled, multilingual data.

Specifically, we evaluate:
- fine-tuning of our pre-trained Twhin-BERT model and a regular Twhin-BERT model
- in-context learning when adding non-text features into the text

In [14]:
import pickle
import json
import torch
import numpy as np
import pandas as pd
import geopandas as gpd
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, roc_auc_score, root_mean_squared_error, mean_absolute_error
from sklearn.preprocessing import OrdinalEncoder
from sklearn.model_selection import train_test_split
from transformers import pipeline
from tqdm import tqdm
from src.model_training.bert import train_classifier_with_hp, evaluate_text_classification_pipeline
from src.model_training.extended_bert import train_classifier_w_numerical_features_with_hp, load_and_infer_batch
tqdm.pandas()

# set data path
DATA_PATH: str = os.path.join(PROJECT_ROOT, 'data')
RESULTS_PATH: str = os.path.join(PROJECT_ROOT, 'results')
print(DATA_PATH)

# set pytorch device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))

/n/netscratch/cga/Lab/bresch/david/relevance_2.0/data
Device: cuda
NVIDIA A100-SXM4-80GB


## 1. Training Data
First, we need to prepare our training and evaluation data that we will use throughout the study.

In [15]:
train_gdf: gpd.GeoDataFrame = gpd.read_parquet(os.path.join(DATA_PATH, 'processed', 'fine_tuning', 'train_data.parquet'))
test_gdf: gpd.GeoDataFrame = gpd.read_parquet(os.path.join(DATA_PATH, 'processed', 'fine_tuning', 'test_data.parquet'))
with open(os.path.join(DATA_PATH, 'processed', 'fine_tuning', 'train_label_encoder.pkl'), 'rb') as f:
    label_encoder: OrdinalEncoder = pickle.load(f)
label_mapping: dict = {category: index for index, category in enumerate(label_encoder.categories_[0])}

# Now you can use the loaded label encoder
print("Class encodings:", label_encoder.categories[0])
print(label_mapping)
print(train_gdf.shape)
print(test_gdf.shape)
pd.DataFrame(train_gdf)

Class encodings: ['Not related', 'Related but not relevant', 'Related and relevant']
{'Not related': 0, 'Related but not relevant': 1, 'Related and relevant': 2}
(3659, 45)
(915, 45)


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


Unnamed: 0,message_id,date,use_case,text,tweet_lang,geometry,photo_url,text_raw,related,x,...,sphere_y,sphere_z,int_label,valid,event_distance_km_norm,event_distance_h_norm,n_disaster_tweets_1km_norm,n_disaster_tweets_10km_norm,n_disaster_tweets_50km_norm,n_disaster_tweets_10000km_norm
0,1.296800e+18,2020-08-21 13:40:45,California 🔥,Closed due to the czu august lightning complex...,,POINT (-122.3611 37.16663),,Closed due to the czu august lightning complex...,1,-1.362118e+07,...,0.095685,-0.954961,1,True,-0.457531,-0.099120,-0.370997,-0.451317,0.758702,1.366236
1,1.417100e+18,2021-07-19 12:23:18,Germany 🌊,Mich beunruhigt nichts mehr.Wir sorgen persönl...,de,POINT (12.22671 51.84923),,Mich beunruhigt nichts mehr.Wir sorgen persönl...,1,1.361556e+06,...,0.645793,0.288072,1,True,1.300809,0.095609,0.362793,0.052028,-0.467472,1.510834
2,1.341270e+18,2020-12-22 06:29:24,California 🔥,The view out my kitchen window of the massive ...,,POINT (-118.41191 34.02069),,The view out my kitchen window of the massive ...,1,-1.318155e+07,...,0.095685,-0.954961,2,True,-0.714648,0.772862,0.798591,0.527691,-0.172292,-0.719290
3,1.320860e+18,2020-10-26 22:49:26,California 🔥,@user @user @user I love 1/2 miles from Anahei...,,POINT (-117.85109 33.84275),,@ZestForLifeNow @City_of_Anaheim @AnaheimFire ...,1,-1.311912e+07,...,0.095685,-0.954961,2,True,-0.714648,-0.099120,-0.105182,-0.068227,0.592908,-0.276396
4,1.296050e+18,2020-08-19 11:26:50,California 🔥,Someone fucking set off my apartment building’...,,POINT (-118.41191 34.02069),,Someone fucking set off my apartment building’...,1,-1.318155e+07,...,0.095685,-0.954961,0,True,-0.714648,0.852190,2.446648,1.889790,1.001015,0.244985
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3654,1.624650e+18,2023-02-12 06:18:19,Turkey 🪨,Hatay/Hassa'da en kazın altında çıkan not...#d...,tr,POINT (36.51252 36.78301),http://pbs.twimg.com/media/FovuSvIX0AUnSdz.jpg,Hatay/Hassa'da en kazın altında çıkan not...\r...,1,4.064655e+06,...,-0.317761,0.887312,2,True,-1.168961,-0.188440,-0.317823,-0.436595,0.543710,2.060679
3655,1.610290e+18,2023-01-03 15:19:57,Turkey 🪨,FutBol Sohbet programımızın yeni bölümü YouTub...,tr,POINT (33.78141 41.38023),http://pbs.twimg.com/media/FljqrPJWQAcU-DW.jpg,FutBol Sohbet programımızın yeni bölümü YouTub...,0,3.760672e+06,...,-0.317761,0.887312,0,True,0.882802,-2.645158,-0.544401,-0.635649,-0.646982,-1.109802
3656,1.627130e+18,2023-02-19 02:12:45,Chile 🔥,Pasando ahora 😭😭 #Coronel #Biobio #IncendioFor...,es,POINT (-73.2222 -37.00482),http://pbs.twimg.com/ext_tw_video_thumb/162712...,Pasando ahora 😭😭 #Coronel #Biobio #IncendioFor...,1,-8.154494e+06,...,0.869559,0.356159,2,True,-0.312896,0.300568,-0.480332,-0.410065,0.183655,0.021594
3657,1.416140e+18,2021-07-16 21:10:24,Germany 🌊,Rhein unterspült Uferstrasse in Basel und löst...,de,POINT (7.65276 47.57676),,Rhein unterspült Uferstrasse in Basel und löst...,1,8.524011e+05,...,0.645793,0.288072,1,True,0.852110,-0.180118,-0.141470,-0.150951,-0.432042,0.471640


Next, we generate our train/validation data right away instead of later on.

In [16]:
# Split dataset into training and validation sets
bert_train_gdf, bert_val_gdf = train_test_split(
    train_gdf, test_size=0.2, random_state=42, stratify=train_gdf['int_label']
)
print(bert_train_gdf.shape, bert_val_gdf.shape)

(2927, 45) (732, 45)


## 2. Text-Only Model Training
Okaaaay, so we can already go ahead and train two classification models. To get a grip if generic pre-training also works in this setting, let's evaluate Twhin-BERT-base with and without generic fine-tuning.

In [6]:
MODELS: dict = {
    # 'twhin-bert-disaster-pretrained': os.path.join(DATA_PATH, 'models', 'twhin-bert-disaster-pretrained', 'model'),
    'twhin-bert-base': 'Twitter/twhin-bert-base',
}

We first fine-tune the two models and then evaluate.

In [None]:
# iterate over all models
for model_name, model_path in MODELS.items():
    model, tokenizer, best_hyperparameters, eval_results = train_classifier_with_hp(
        texts_train=bert_train_gdf['text'].tolist(),
        texts_val=bert_val_gdf['text'].tolist(),
        y_train=bert_train_gdf['int_label'].tolist(),
        y_val=bert_val_gdf['int_label'].tolist(),
        model_name=model_path,
        model_path=os.path.join(DATA_PATH, 'models', f'{model_name}_ft', 'model'),
        logging_path=os.path.join(DATA_PATH, 'models', f'{model_name}_ft', 'logs'),
        weighted_loss=False,
        id2label={i: label for i, label in enumerate(label_encoder.categories[0])},
        label2id={label: i for i, label in enumerate(label_encoder.categories[0])}
    )
    print(f'Results for model: {model_name}')
    print(best_hyperparameters)
    print(eval_results)

    # store the results
    with open(os.path.join(DATA_PATH, 'models', f'{model_name}_ft', 'hyperparameters.json'), 'w') as f:
        json.dump(best_hyperparameters, f)
    with open(os.path.join(DATA_PATH, 'models', f'{model_name}_ft', 'eval_results.json'), 'w') as f:
        json.dump(eval_results, f)

In [8]:
metric_dictlist: list[dict] = []

# iterate over all models
for model_name, _ in MODELS.items():
    print(f"Evaluating model: {model_name}")
    
    # Evaluate using the reusable function
    metrics, pred_df = evaluate_text_classification_pipeline(
        test_df=test_gdf,
        model_path=os.path.join(DATA_PATH, 'models', f'{model_name}_ft', 'model'),
        label_mapping=label_mapping,
        device=device,             # Use GPU (e.g., device=0) or CPU (device=-1)
        text_column='text',        # Column with texts
        label_column='int_label'   # Column with ground truth labels
    )
    
    # Add the model name to the metrics dictionary
    metrics['model_name'] = model_name
    metric_dictlist.append(metrics)

    # Store the predicted results
    pred_df.to_parquet(os.path.join(RESULTS_PATH, 'fine_tuning', f'pred_text_{model_name}.parquet'))

# Convert the list of metric dictionaries to a DataFrame
eval_results = pd.DataFrame(metric_dictlist)
eval_results.to_csv(os.path.join(RESULTS_PATH, 'fine_tuning', 'text_fine_tuning_metrics.csv'), index=False)
eval_results

Evaluating model: twhin-bert-base


Processing texts:   0%|          | 1/915 [00:00<05:14,  2.91it/s]You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
Processing texts: 100%|██████████| 915/915 [00:07<00:00, 125.75it/s]


Evaluation Results:
- Macro Precision: 0.7945589645254074
- Macro Recall: 0.7699348539571425
- Macro F1: 0.7786859079424642
- Accuracy: 0.8021857923497268
- Macro ROC-AUC: 0.9284324153727378
- RMSE: 0.5163977794943222
- MAE: 0.22076502732240438


Unnamed: 0,macro_precision,macro_recall,macro_f1,accuracy,macro_roc_auc,rmse,mae,model_name
0,0.794559,0.769935,0.778686,0.802186,0.928432,0.516398,0.220765,twhin-bert-base


## 3. In-Context Learning
Next, Michael suggested that in-context learning might also work for the task we are trying to achieve. Therefore, let's replicate our fine-tuning approach with in-context learning.

In [17]:
NON_TEXT_COLUMNS: list[str] = [
    'event_distance_km',
    'event_distance_h',
    'n_disaster_tweets_1km',
    'n_disaster_tweets_10km',
    'n_disaster_tweets_50km',
    'n_disaster_tweets_10000km'
]
NON_TEXT_COLUMNS_NORM: list[str] = [f'{x}_norm' for x in NON_TEXT_COLUMNS]

We need some prep code to achieve that effectively.

In [10]:
# we also need names
feature_names_in_context = {
    'event_type_encoding': 'Event Type',
    'event_distance_km': 'Distance from Disaster (km)',
    'event_distance_h': 'Time Gap from Disaster (hours)',
    'n_disaster_tweets_1km': 'Disaster Posts within 1 km',
    'n_disaster_tweets_10km': 'Disaster Posts within 10 km',
    'n_disaster_tweets_50km': 'Disaster Posts within 50 km',
    'n_disaster_tweets_10000km': 'Disaster Posts in Area of Interest',
    'lat_centre': 'Central Latitude',
    'lon_centre': 'Central Longitude'
}

contexts: dict = {
    'none': [], 
    'event': ['event_type_encoding'], 
    'coord': ['lat_centre', 'lon_centre'], 
    'all': ['lat_centre', 'lon_centre', 'event_type_encoding']
}


def encode_row_to_json(row: pd.Series, feature_names: dict = feature_names_in_context) -> str:
    """
    Encodes a row (pandas Series) containing a 'text' field and several features into a JSON string.
    The event type is expected to be encoded as a one-hot vector and is decoded into a string.
    
    Args:
        row (pd.Series): A pandas Series with keys like 'text', 'event_type_encoding', etc.
        feature_names (dict, optional): Mapping from internal feature keys to human-friendly names.
                                        Defaults to feature_names_in_context.
    
    Returns:
        str: A JSON string encoding the text and features.
    """
    # Start with the text field
    output = {"text": row.get("text", "")}
    features = {}
    
    for feature, friendly_name in feature_names.items():
        # Special handling for event type encoding (assumed to be a one-hot vector)
        if feature == 'event_type_encoding':
            vec = row.get(feature, None)
            if vec is not None:
                # Assuming the vector is a list-like object
                if vec[0] == 1:
                    decoded = 'flood'
                elif vec[1] == 1:
                    decoded = 'wildfire'
                else:
                    decoded = 'earthquake'
                features[friendly_name] = decoded
            else:
                features[friendly_name] = None
        else:
            # For other features, take the value directly
            features[friendly_name] = row.get(feature, None)
    
    output.update(features)
    return json.dumps(output, ensure_ascii=False, indent=2)

# Print an example
print(encode_row_to_json(bert_train_gdf.sample(1).iloc[0]))

# Encode the entire data
bert_train_gdf['in_context_string'] = bert_train_gdf.apply(encode_row_to_json, axis=1)
bert_val_gdf['in_context_string'] = bert_val_gdf.apply(encode_row_to_json, axis=1)
test_gdf['in_context_string'] = test_gdf.apply(encode_row_to_json, axis=1)
bert_train_gdf

{
  "text": "Que bueno que los incendios forestales en Ñuble fue controlado @user gracias por la información. Que lo paguen cuanto antes los incendiosEn la región bio bio situación es compleja, humo cenizas en concepciónNo pueden controlar los incendios en la región del bio bio",
  "Event Type": "wildfire",
  "Distance from Disaster (km)": 6.655614316577463,
  "Time Gap from Disaster (hours)": -129.7252777777778,
  "Disaster Posts within 1 km": 0.0,
  "Disaster Posts within 10 km": 122.0,
  "Disaster Posts within 50 km": 238.0,
  "Disaster Posts in Area of Interest": 1363.0,
  "Central Latitude": -34.92167295303011,
  "Central Longitude": -71.06063249257306
}


Unnamed: 0,message_id,date,use_case,text,tweet_lang,geometry,photo_url,text_raw,related,x,...,sphere_z,int_label,valid,event_distance_km_norm,event_distance_h_norm,n_disaster_tweets_1km_norm,n_disaster_tweets_10km_norm,n_disaster_tweets_50km_norm,n_disaster_tweets_10000km_norm,in_context_string
1266,1.625460e+18,2023-02-14 11:49:49,Turkey 🪨,Maraş’tan Elbistan a gelecek olan varmı #14Sub...,tr,POINT (37.38714 38.32362),,Maraş’tan Elbistan a gelecek olan varmı #14Sub...,1,4.162011e+06,...,0.887312,2,True,-1.168961,0.023538,-0.166771,-0.351286,-0.103657,1.115721,"{\n ""text"": ""Maraş’tan Elbistan a gelecek ola..."
2390,1.622770e+18,2023-02-07 01:39:02,Chile 🔥,@user Y estaba en ayuda a los damnificados de ...,es,POINT (-70.56469 -33.37876),,@LaLady98849619 Y estaba en ayuda a los damnif...,1,-7.855515e+06,...,0.356159,1,True,-0.213382,0.040033,2.551017,2.689386,2.961461,1.518689,"{\n ""text"": ""@user Y estaba en ayuda a los da..."
2317,1.629270e+18,2023-02-25 00:01:27,Chile 🔥,@user Sii??..Rusia recuperó el territorio k lo...,es,POINT (-70.82187 -33.51375),,@Cooperativa Sii??..Rusia recuperó el territor...,0,-7.884145e+06,...,0.356159,0,True,-0.353062,-1.250093,-0.480332,-0.529274,-0.671055,-1.160805,"{\n ""text"": ""@user Sii??..Rusia recuperó el t..."
564,1.628820e+18,2023-02-23 18:11:27,Turkey 🪨,Deprem esnasında bırakıp kaçtığınız peş para e...,tr,POINT (40.19679 37.93349),,Deprem esnasında bırakıp kaçtığınız peş para e...,1,4.474747e+06,...,0.887312,1,True,-0.863521,0.540534,0.135333,-0.066923,-0.334859,-0.224169,"{\n ""text"": ""Deprem esnasında bırakıp kaçtığı..."
979,1.311560e+18,2020-10-01 06:40:02,California 🔥,@user Не существует такого национального празд...,,POINT (-122.43598 37.77064),,@dbg_nsk Не существует такого национального пр...,0,-1.362951e+07,...,-0.954961,0,True,-0.660042,-0.099120,-0.424160,-0.536449,-0.758946,-1.041650,"{\n ""text"": ""@user Не существует такого нацио..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2982,1.321960e+18,2020-10-29 23:50:24,California 🔥,The blog post I just linked to is absolutely r...,,POINT (-118.41191 34.02069),,The blog post I just linked to is absolutely r...,0,-1.318155e+07,...,-0.954961,0,True,-0.714648,-0.099120,-0.424160,-0.536449,-0.758946,-1.041650,"{\n ""text"": ""The blog post I just linked to i..."
1427,1.634760e+18,2023-03-12 03:49:30,Chile 🔥,"@user Señores, se ha producido un corte de ene...",es,POINT (-70.86883 -35.22333),,"@CGE_Clientes Señores, se ha producido un cort...",1,-7.889378e+06,...,0.356159,0,True,-0.386665,-0.446052,-0.268843,-0.484571,-0.628319,-0.343514,"{\n ""text"": ""@user Señores, se ha producido u..."
275,1.417030e+18,2021-07-19 08:07:23,Germany 🌊,Der @user hat einen #THREAD geschrieben. Das i...,de,POINT (13.42031 52.50347),,Der @narkosedoc hat einen #THREAD geschrieben....,1,1.494420e+06,...,0.288072,1,True,1.800832,0.076134,6.702106,5.416489,0.678093,1.417898,"{\n ""text"": ""Der @user hat einen #THREAD gesc..."
3031,1.653790e+18,2023-05-03 15:34:59,Italy 🌊,@user Io vivo in Centro noi niente ma la Zona ...,it,POINT (11.89204 44.29132),,@sonotantaroba Io vivo in Centro noi niente ma...,1,1.324238e+06,...,0.440128,2,True,-0.675560,-0.365933,-0.402057,-0.474313,-0.446063,-0.431152,"{\n ""text"": ""@user Io vivo in Centro noi nien..."


We can then go ahead and train our models, with in-context knowledge.

In [None]:
model, tokenizer, best_hyperparameters, eval_results = train_classifier_with_hp(
    texts_train=bert_train_gdf['in_context_string'].tolist(),
    texts_val=bert_val_gdf['in_context_string'].tolist(),
    y_train=bert_train_gdf['int_label'].tolist(),
    y_val=bert_val_gdf['int_label'].tolist(),
    model_name='Twitter/twhin-bert-base',
    model_path=os.path.join(DATA_PATH, 'models', 'in_context_ft', 'model'),
    logging_path=os.path.join(DATA_PATH, 'models', 'in_context_ft', 'logs'),
    weighted_loss=False,
    id2label={i: label for i, label in enumerate(label_encoder.categories[0])},
    label2id={label: i for i, label in enumerate(label_encoder.categories[0])}
)
print(f'Results for model: {model_name}')
print(best_hyperparameters)
print(eval_results)

# store the results
with open(os.path.join(DATA_PATH, 'models', 'in_context_ft', 'hyperparameters.json'), 'w') as f:
    json.dump(best_hyperparameters, f)
with open(os.path.join(DATA_PATH, 'models', 'in_context_ft', 'eval_results.json'), 'w') as f:
    json.dump(eval_results, f)

Best params: `{'learning_rate': 1.4858988509215784e-05, 'per_device_train_batch_size': 16, 'num_train_epochs': 9, 'weight_decay': 0.07546222928186692}`

In [12]:
metric_dictlist: list[dict] = []

# iterate over all models
for model_name, _ in {'in_context': 'in_context_ft'}.items():
    print(f"Evaluating model: {model_name}")
    
    # Evaluate using the reusable function
    metrics, pred_df = evaluate_text_classification_pipeline(
        test_df=test_gdf,
        model_path=os.path.join(DATA_PATH, 'models', 'in_context_ft', 'model'),
        label_mapping=label_mapping,
        device=device,                          # Use GPU (e.g., device=0) or CPU (device=-1)
        text_column='in_context_string',        # Column with texts
        label_column='int_label'                # Column with ground truth labels
    )
    
    # Add the model name to the metrics dictionary
    metrics['model_name'] = model_name
    metric_dictlist.append(metrics)

    # Store the predicted results
    pred_df.to_parquet(os.path.join(RESULTS_PATH, 'fine_tuning', f'pred_text_{model_name}.parquet'))

# Convert the list of metric dictionaries to a DataFrame
eval_results = pd.DataFrame(metric_dictlist)
eval_results.to_csv(os.path.join(RESULTS_PATH, 'fine_tuning', 'in_context_metrics.csv'), index=False)
eval_results

Evaluating model: in_context


Processing texts: 100%|██████████| 915/915 [00:07<00:00, 130.69it/s]


Evaluation Results:
- Macro Precision: 0.791883685207016
- Macro Recall: 0.7738605980955781
- Macro F1: 0.78095338020053
- Accuracy: 0.7989071038251366
- Macro ROC-AUC: 0.9232152845407473
- RMSE: 0.5132133846092447
- MAE: 0.22185792349726777


Unnamed: 0,macro_precision,macro_recall,macro_f1,accuracy,macro_roc_auc,rmse,mae,model_name
0,0.791884,0.773861,0.780953,0.798907,0.923215,0.513213,0.221858,in_context


## 4. Embedding Concatenation
Alternatively, we might also try to concatenate the BERT embeddings with our normalised non-text features and see if this helps during training. We use the identical settings as before.

In [7]:
# We need to add our event encodings as explicit columns for easier use
bert_train_gdf_concat: gpd.GeoDataFrame = bert_train_gdf.copy()
bert_train_gdf_concat[['flood', 'wildfire', 'earthquake']] = bert_train_gdf['event_type_encoding'].tolist()
bert_val_gdf_concat: gpd.GeoDataFrame = bert_val_gdf.copy()
bert_val_gdf_concat[['flood', 'wildfire', 'earthquake']] = bert_val_gdf['event_type_encoding'].tolist()
test_gdf_enriched: gpd.GeoDataFrame = test_gdf.copy()
test_gdf_enriched[['flood', 'wildfire', 'earthquake']] = test_gdf['event_type_encoding'].tolist()

# Available contexts, let's use all
contexts: dict = {
    'none': [], 
    'event': ['flood', 'wildfire', 'earthquake'], 
    'coord': ['sphere_x', 'sphere_y', 'sphere_z'], 
    'all': ['flood', 'wildfire', 'earthquake', 'sphere_x', 'sphere_y', 'sphere_z']
}

In [None]:
for norm_variant in ['norm']: # ['regular', 'norm']
    if norm_variant == 'regular':
        columns: list[str] = NON_TEXT_COLUMNS
    else:
        columns: list[str] = NON_TEXT_COLUMNS_NORM

    for classification_head in ['simple', 'complex']:
        print(f'Fine-tuning model with classification head {classification_head}')
        model, tokenizer, best_params, eval_results = train_classifier_w_numerical_features_with_hp(
            texts_train=bert_train_gdf_concat['text'],
            texts_val=bert_val_gdf_concat['text'],
            numerical_features_train=bert_train_gdf_concat[columns + contexts['all']].to_dict(orient='list'),
            numerical_features_val=bert_val_gdf_concat[columns + contexts['all']].to_dict(orient='list'),
            y_train=bert_train_gdf_concat['int_label'], 
            y_val=bert_val_gdf_concat['int_label'], 
            model_name='Twitter/twhin-bert-base',
            model_path=os.path.join(DATA_PATH, 'models', f'extended_twhin-bert_{classification_head}_{norm_variant}', 'model'),
            logging_path=os.path.join(DATA_PATH, 'models', f'extended_twhin-bert_{classification_head}_{norm_variant}', 'logs'),
            id2label = {i: label for i, label in enumerate(label_encoder.categories[0])},
            label2id = {label: i for i, label in enumerate(label_encoder.categories[0])},
            n_trials=10
        )

        # store the results
        with open(os.path.join(DATA_PATH, 'models', f'extended_twhin-bert_{classification_head}_{norm_variant}', 'hyperparameters.json'), 'w') as f:
            json.dump(best_params, f)
        with open(os.path.join(DATA_PATH, 'models', f'extended_twhin-bert_{classification_head}_{norm_variant}', 'eval_results.json'), 'w') as f:
            json.dump(eval_results, f)

Fine-tuning model with classification head simple


Map: 100%|██████████| 2927/2927 [00:00<00:00, 8561.77 examples/s]
Map: 100%|██████████| 732/732 [00:00<00:00, 8189.97 examples/s]
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Some weights of ExtendedNumBertForSequenceClassification were not initialized from the model checkpoint at Twitter/twhin-bert-base and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight', 'numeric_feature_weights']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[I 2025-04-03 14:59:02,909] A new study created in memory with name: no-name-2e8b146b-70f9-4522-b9a1-c4eada8ef8a8
  "learning_rate": trial.suggest_loguniform("learning_rate", 1e-5, 5e-5),
  "weight_decay": trial.suggest_loguniform("weight_decay", 1e-4, 1e-1)
Detected kernel version 4.18.

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,0.594399,0.762295,0.752064,0.725494,0.734163
2,No log,0.547092,0.790984,0.776849,0.771057,0.773449
3,0.581000,0.551158,0.782787,0.779324,0.769063,0.770494


[I 2025-04-03 15:03:26,693] Trial 0 finished with value: 3.1016685132724784 and parameters: {'learning_rate': 1.1148770104604024e-05, 'per_device_train_batch_size': 16, 'num_train_epochs': 3, 'weight_decay': 0.002891389011381196}. Best is trial 0 with value: 3.1016685132724784.
  "learning_rate": trial.suggest_loguniform("learning_rate", 1e-5, 5e-5),
  "weight_decay": trial.suggest_loguniform("weight_decay", 1e-4, 1e-1)
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Some weights of ExtendedNumBertForSequenceClassification were not initialized from the model checkpoint at Twitter/twhin-bert-base and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight', 'numeric_feature_weights']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and infer

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,0.588621,0.771858,0.772611,0.726005,0.737324
2,No log,0.538262,0.803279,0.792234,0.783606,0.786939
3,0.558000,0.54285,0.786885,0.78103,0.772865,0.774503


[I 2025-04-03 15:07:50,448] Trial 1 finished with value: 3.1152841258004895 and parameters: {'learning_rate': 1.3880645822701862e-05, 'per_device_train_batch_size': 16, 'num_train_epochs': 3, 'weight_decay': 0.005425949233128208}. Best is trial 1 with value: 3.1152841258004895.
  "learning_rate": trial.suggest_loguniform("learning_rate", 1e-5, 5e-5),
  "weight_decay": trial.suggest_loguniform("weight_decay", 1e-4, 1e-1)
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Some weights of ExtendedNumBertForSequenceClassification were not initialized from the model checkpoint at Twitter/twhin-bert-base and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight', 'numeric_feature_weights']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and infer

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,0.632843,0.751366,0.751202,0.697693,0.703932
2,No log,0.534474,0.777322,0.770903,0.783851,0.77082
3,0.526500,0.629742,0.790984,0.782514,0.75843,0.765929
4,0.526500,0.768232,0.806011,0.799196,0.776192,0.784166
5,0.526500,0.974927,0.810109,0.823452,0.779912,0.793594
6,0.156600,1.130098,0.797814,0.793349,0.789192,0.788614
7,0.156600,1.211326,0.804645,0.797135,0.786915,0.790909
8,0.156600,1.226048,0.806011,0.802554,0.783714,0.790954


[I 2025-04-03 15:19:17,301] Trial 2 finished with value: 3.183233206790907 and parameters: {'learning_rate': 3.0090138296789193e-05, 'per_device_train_batch_size': 16, 'num_train_epochs': 8, 'weight_decay': 0.012235627609196274}. Best is trial 2 with value: 3.183233206790907.
  "learning_rate": trial.suggest_loguniform("learning_rate", 1e-5, 5e-5),
  "weight_decay": trial.suggest_loguniform("weight_decay", 1e-4, 1e-1)
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Some weights of ExtendedNumBertForSequenceClassification were not initialized from the model checkpoint at Twitter/twhin-bert-base and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight', 'numeric_feature_weights']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inferen

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,0.631128,0.770492,0.762741,0.730047,0.736645
2,No log,0.529617,0.782787,0.771288,0.784584,0.77385
3,0.530800,0.542768,0.796448,0.785006,0.77734,0.780589
4,0.530800,0.622943,0.804645,0.79907,0.786531,0.790789
5,0.530800,0.638981,0.814208,0.808996,0.799932,0.803156


[I 2025-04-03 15:26:31,816] Trial 3 finished with value: 3.2262914586378493 and parameters: {'learning_rate': 1.8285974829290768e-05, 'per_device_train_batch_size': 16, 'num_train_epochs': 5, 'weight_decay': 0.00011415461275154408}. Best is trial 3 with value: 3.2262914586378493.
  "learning_rate": trial.suggest_loguniform("learning_rate", 1e-5, 5e-5),
  "weight_decay": trial.suggest_loguniform("weight_decay", 1e-4, 1e-1)
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Some weights of ExtendedNumBertForSequenceClassification were not initialized from the model checkpoint at Twitter/twhin-bert-base and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight', 'numeric_feature_weights']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inf

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,0.659119,0.760929,0.770473,0.701267,0.705297
2,No log,0.542737,0.781421,0.764818,0.765836,0.764942
3,0.522800,0.548831,0.804645,0.798078,0.78277,0.789056
4,0.522800,0.722829,0.806011,0.806974,0.775737,0.786557
5,0.522800,0.809259,0.815574,0.824984,0.778289,0.792489
6,0.171200,0.86618,0.819672,0.81203,0.803597,0.806516


[I 2025-04-03 15:35:10,062] Trial 4 finished with value: 3.241815443719778 and parameters: {'learning_rate': 2.4206340660056165e-05, 'per_device_train_batch_size': 16, 'num_train_epochs': 6, 'weight_decay': 0.043833750263877584}. Best is trial 4 with value: 3.241815443719778.
  "learning_rate": trial.suggest_loguniform("learning_rate", 1e-5, 5e-5),
  "weight_decay": trial.suggest_loguniform("weight_decay", 1e-4, 1e-1)
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Some weights of ExtendedNumBertForSequenceClassification were not initialized from the model checkpoint at Twitter/twhin-bert-base and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight', 'numeric_feature_weights']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inferen

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,0.608691,0.769126,0.761144,0.723926,0.731453
2,No log,0.525872,0.790984,0.774677,0.783827,0.778137
3,0.547200,0.559666,0.803279,0.794633,0.781311,0.78676
4,0.547200,0.608991,0.807377,0.820796,0.768643,0.783951
5,0.547200,0.701828,0.807377,0.801635,0.781028,0.788971


[I 2025-04-03 15:42:09,133] Trial 5 pruned. 
  "learning_rate": trial.suggest_loguniform("learning_rate", 1e-5, 5e-5),
  "weight_decay": trial.suggest_loguniform("weight_decay", 1e-4, 1e-1)
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Some weights of ExtendedNumBertForSequenceClassification were not initialized from the model checkpoint at Twitter/twhin-bert-base and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight', 'numeric_feature_weights']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,0.599965,0.760929,0.750477,0.722316,0.730932


[I 2025-04-03 15:43:28,067] Trial 6 pruned. 
  "learning_rate": trial.suggest_loguniform("learning_rate", 1e-5, 5e-5),
  "weight_decay": trial.suggest_loguniform("weight_decay", 1e-4, 1e-1)
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Some weights of ExtendedNumBertForSequenceClassification were not initialized from the model checkpoint at Twitter/twhin-bert-base and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight', 'numeric_feature_weights']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,0.637668,0.763661,0.757014,0.714014,0.718917


[I 2025-04-03 15:44:47,009] Trial 7 pruned. 
  "learning_rate": trial.suggest_loguniform("learning_rate", 1e-5, 5e-5),
  "weight_decay": trial.suggest_loguniform("weight_decay", 1e-4, 1e-1)
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Some weights of ExtendedNumBertForSequenceClassification were not initialized from the model checkpoint at Twitter/twhin-bert-base and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight', 'numeric_feature_weights']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,0.599237,0.766393,0.766725,0.720683,0.731414
2,No log,0.531548,0.79235,0.776348,0.781262,0.778445
3,0.571100,0.538063,0.795082,0.781683,0.781178,0.78081
4,0.571100,0.590359,0.796448,0.798083,0.771586,0.780527


[I 2025-04-03 15:50:21,607] Trial 8 pruned. 
  "learning_rate": trial.suggest_loguniform("learning_rate", 1e-5, 5e-5),
  "weight_decay": trial.suggest_loguniform("weight_decay", 1e-4, 1e-1)
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Some weights of ExtendedNumBertForSequenceClassification were not initialized from the model checkpoint at Twitter/twhin-bert-base and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight', 'numeric_feature_weights']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,0.560813,0.759563,0.750475,0.732342,0.738739
2,No log,0.528653,0.785519,0.774427,0.7635,0.76784
3,No log,0.536011,0.800546,0.793968,0.785155,0.787912
4,No log,0.605905,0.786885,0.787264,0.793926,0.782377
5,No log,0.757407,0.784153,0.804791,0.731598,0.746968
6,0.370900,0.701509,0.815574,0.807256,0.79482,0.799916
7,0.370900,0.859227,0.814208,0.808368,0.801975,0.803928
8,0.370900,0.925247,0.803279,0.794262,0.797265,0.793657
9,0.370900,0.90981,0.814208,0.810547,0.795211,0.80132


[I 2025-04-03 16:02:26,458] Trial 9 finished with value: 3.221285919850513 and parameters: {'learning_rate': 2.3332644712937557e-05, 'per_device_train_batch_size': 32, 'num_train_epochs': 9, 'weight_decay': 0.002273866148524572}. Best is trial 4 with value: 3.241815443719778.
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Best hyperparameters: {'learning_rate': 2.4206340660056165e-05, 'per_device_train_batch_size': 16, 'num_train_epochs': 6, 'weight_decay': 0.043833750263877584}


Some weights of ExtendedNumBertForSequenceClassification were not initialized from the model checkpoint at Twitter/twhin-bert-base and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight', 'numeric_feature_weights']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of ExtendedNumBertForSequenceClassification were not initialized from the model checkpoint at Twitter/twhin-bert-base and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight', 'numeric_feature_weights']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,0.659119,0.760929,0.770473,0.701267,0.705297
2,No log,0.542737,0.781421,0.764818,0.765836,0.764942
3,0.522800,0.548831,0.804645,0.798078,0.78277,0.789056
4,0.522800,0.722829,0.806011,0.806974,0.775737,0.786557
5,0.522800,0.809259,0.815574,0.824984,0.778289,0.792489
6,0.171200,0.86618,0.819672,0.81203,0.803597,0.806516


Evaluation Results: {'eval_loss': 0.8661797642707825, 'eval_accuracy': 0.819672131147541, 'eval_precision': 0.8120301744252734, 'eval_recall': 0.8035966490965114, 'eval_f1': 0.8065164890504523, 'eval_runtime': 6.201, 'eval_samples_per_second': 118.046, 'eval_steps_per_second': 14.836, 'epoch': 6.0}
Fine-tuning model with classification head complex


Map: 100%|██████████| 2927/2927 [00:00<00:00, 8373.37 examples/s]
Map: 100%|██████████| 732/732 [00:00<00:00, 8159.65 examples/s]
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Some weights of ExtendedNumBertForSequenceClassification were not initialized from the model checkpoint at Twitter/twhin-bert-base and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight', 'numeric_feature_weights']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[I 2025-04-03 16:11:18,891] A new study created in memory with name: no-name-a5ad32bb-b02c-44c5-a136-e68c1821276b
  "learning_rate": trial.suggest_loguniform("learning_rate", 1e-5, 5e-5),
  "weight_decay": trial.suggest_loguniform("weight_decay", 1e-4, 1e-1)
Detected kernel version 4.18.

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,No log,0.585653,0.766393,0.751133,0.737093,0.741413
2,No log,0.536068,0.781421,0.774942,0.795702,0.775117
3,0.523400,0.526023,0.808743,0.800421,0.793973,0.796589


- Best params simple: `Best hyperparameters: {'learning_rate': 4.5270517707035104e-05, 'per_device_train_batch_size': 32, 'num_train_epochs': 8, 'weight_decay': 0.004931080960267828}`
- Best params complex: `{'learning_rate': 2.78838131001195e-05, 'per_device_train_batch_size': 32, 'num_train_epochs': 8, 'weight_decay': 0.002729789794354492}`

In [18]:
metric_dictlist: list[dict] = []

# iterate over all models
for norm_variant in ['regular', 'norm']:
    if norm_variant == 'regular':
        columns: list[str] = NON_TEXT_COLUMNS
    else:
        columns: list[str] = NON_TEXT_COLUMNS_NORM
    
    for classification_head in ['simple', 'complex']:
        predictions, pred_probs = load_and_infer_batch(
            model_path=os.path.join(DATA_PATH, 'models', f'extended_twhin-bert_{classification_head}_{norm_variant}', 'model'),
            texts=test_gdf_enriched['text'].tolist(), 
            numerical_features=test_gdf_enriched[columns + contexts['all']].values.tolist()
        )
        test_gdf_enriched['prediction'] = predictions
        
        test_gdf_enriched.to_parquet(os.path.join(RESULTS_PATH, 'fine_tuning', f'pred_extended_twhin-bert_{classification_head}_{norm_variant}.parquet'), index=False)
        prec, rec, f1, support = precision_recall_fscore_support(test_gdf_enriched['int_label'], test_gdf_enriched['prediction'], average='macro')
        roc_auc = roc_auc_score(y_true=test_gdf_enriched['int_label'], y_score=np.array(pred_probs), multi_class='ovr', average='macro')
        rmse = root_mean_squared_error(y_true=test_gdf_enriched['int_label'], y_pred=test_gdf_enriched['prediction'])
        mae = mean_absolute_error(y_true=test_gdf_enriched['int_label'], y_pred=test_gdf_enriched['prediction'])

        print(f'Evaluated model: extended_twhin-bert_{classification_head}')
        print(f'- Macro Precision: {prec}')
        print(f'- Macro Recall: {rec}')
        print(f'- Macro F1: {f1}')
        print(f'- Macro ROC-AUC: {roc_auc}')

        metric_dictlist.append({
            'model_name': f'extended_twhin-bert_{classification_head}',
            'variant': norm_variant,
            'macro_precision': prec,
            'macro_recall': rec,
            'macro_f1': f1,
            'accuracy': accuracy_score(y_true=test_gdf_enriched['int_label'], y_pred=test_gdf_enriched['prediction']),
            'macro_roc_auc': roc_auc,
            'rmse': rmse,
            'mae': mae
        })

# store the evaluation results
eval_results: pd.DataFrame = pd.DataFrame.from_dict(metric_dictlist)
eval_results.to_csv(os.path.join(RESULTS_PATH, 'fine_tuning', 'extended_fine_tuning_metrics.csv'), index=False)
eval_results

Inference: 100%|██████████| 58/58.0 [00:07<00:00,  7.97it/s]


Evaluated model: extended_twhin-bert_simple
- Macro Precision: 0.8098936278387189
- Macro Recall: 0.8065953311925876
- Macro F1: 0.8074561316399315
- Macro ROC-AUC: 0.9264340128764607


Inference: 100%|██████████| 58/58.0 [00:07<00:00,  7.99it/s]


Evaluated model: extended_twhin-bert_complex
- Macro Precision: 0.8123291348266767
- Macro Recall: 0.8133997324705939
- Macro F1: 0.8124582622904235
- Macro ROC-AUC: 0.9404968582282311


Inference: 100%|██████████| 58/58.0 [00:07<00:00,  7.97it/s]


Evaluated model: extended_twhin-bert_simple
- Macro Precision: 0.7824496157713331
- Macro Recall: 0.7785027044013714
- Macro F1: 0.7792841028787647
- Macro ROC-AUC: 0.926234402052475


Inference: 100%|██████████| 58/58.0 [00:07<00:00,  7.98it/s]


Evaluated model: extended_twhin-bert_complex
- Macro Precision: 0.7892776126247861
- Macro Recall: 0.7859462035587592
- Macro F1: 0.7872277040670906
- Macro ROC-AUC: 0.9255723927200741


Unnamed: 0,model_name,variant,macro_precision,macro_recall,macro_f1,accuracy,macro_roc_auc,rmse,mae
0,extended_twhin-bert_simple,regular,0.809894,0.806595,0.807456,0.821858,0.926434,0.516398,0.20765
1,extended_twhin-bert_complex,regular,0.812329,0.8134,0.812458,0.825137,0.940497,0.522708,0.20765
2,extended_twhin-bert_simple,norm,0.78245,0.778503,0.779284,0.795628,0.926234,0.550211,0.237158
3,extended_twhin-bert_complex,norm,0.789278,0.785946,0.787228,0.804372,0.925572,0.53306,0.225137
