In [1]:
%pip install polars

%pip install pyarrow

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


Files included in this EDA below and aggregation at the case_id level as part of this EDA:

Properties: depth=0, internal data source 
    train_static_0_0.csv
    train_static_0_1.csv

Properties: depth=0, external data source 
    train_static_cb_0.csv

Properties: depth=1, external data source, Tax registry provider A 
    test_tax_registry_a_1.csv

Properties: depth=1, external data source, Tax registry provider B 
    test_tax_registry_b_1.csv

Properties: depth=1, external data source, Tax registry provider C 
    train_tax_registry_c_1.csv

In [2]:
#Import necessary libraries
import polars as pl
from typing import List
from polars import DataFrame
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import pyarrow as pow
from sklearn.preprocessing import OneHotEncoder

dataPath = "/Users/artjolameli/Desktop/Credit_Risk_Predictions/Kaggle_Credit_Risk_Predictions/dataset/home-credit-credit-risk-model-stability-2/parquet_files/train/"



In [3]:
# Substitute this by the import 

def group_file_data(
    df: pl.DataFrame, 
    num_cols: List[str] = [], 
    date_cols: List[str] = [], 
    cat_cols: List[str] = []
) -> pl.DataFrame:
    '''
    Function to group numerical, date, and categorical columns

    Parameters:
    -----------
    df : Polars DataFrame
    num_cols : List of numerical column names (remember to drop num_group columns)
    date_cols : List of date column names
    cat_cols : List of categorical column names (becomes dummies)
    '''
    
    # Convert date columns
    df_date = df[['case_id'] + date_cols].with_columns([ pl.col(col).str.to_date() for col in date_cols ])

    # One-hot categories
    df_dummies = df[['case_id'] + cat_cols].to_dummies(cat_cols)

    # Num DataFrame
    df_num = df[['case_id'] + num_cols]

    # Date aggs
    date_aggs = [ pl.min(col).name.suffix('_min') for col in date_cols ] +\
                [ pl.max(col).name.suffix('_max') for col in date_cols ] +\
                [ pl.n_unique(col).name.suffix('_distinct') for col in date_cols]
    df_date_grouped = df_date.group_by('case_id').agg(date_aggs)

    # One-hot aggs
    dummy_cols = [ col for col in df_dummies.columns if col != 'case_id']
    dummies_aggs = [ pl.sum(col).name.suffix('_sum') for col in dummy_cols ]
    df_dummies_grouped = df_dummies.group_by('case_id').agg(dummies_aggs)

    # Numerical aggs
    num_aggs = [ pl.min(col).name.suffix('_min') for col in num_cols ] +\
            [ pl.max(col).name.suffix('_max') for col in num_cols ] +\
            [ pl.mean(col).name.suffix('_mean') for col in num_cols ] +\
            [ pl.median(col).name.suffix('_median') for col in num_cols ] +\
            [ pl.sum(col).name.suffix('_sum') for col in num_cols ]
    df_num_grouped = df_num.group_by('case_id').agg(num_aggs)

    # Join DataFrames
    df_joined = df_num_grouped.join(df_date_grouped, on='case_id')
    df_joined = df_joined.join(df_dummies_grouped, on='case_id')

    return df_joined

In [4]:
def set_table_dtypes(df: pl.DataFrame) -> pl.DataFrame:
    # implement here all desired dtypes for tables
    for col in df.columns:
        if col[-1] in ("P", "A"):
            df = df.with_columns(pl.col(col).cast(pl.Float64).alias(col))

    return df

def convert_strings(df: pd.DataFrame) -> pd.DataFrame:
    for col in df.columns:  
        if df[col].dtype.name in ['object', 'string']:
            df[col] = df[col].astype("string").astype('category')
            current_categories = df[col].cat.categories
            new_categories = current_categories.to_list() + ["Unknown"]
            new_dtype = pd.CategoricalDtype(categories=new_categories, ordered=True)
            df[col] = df[col].astype(new_dtype)
    return df

In [5]:
# Load each parquet file and print its schema
files = ["train_static_0_0.parquet", "train_static_0_1.parquet", "train_static_cb_0.parquet", 
         "train_tax_registry_a_1.parquet", "train_tax_registry_b_1.parquet", "train_tax_registry_c_1.parquet", "train_base.parquet"]
for file in files:
    df = pl.read_parquet(dataPath + file)
    print(f"Schema for {file}:")
    print(df.schema, "\n")

Schema for train_static_0_0.parquet:
OrderedDict([('case_id', Int64), ('actualdpdtolerance_344P', Float64), ('amtinstpaidbefduel24m_4187115A', Float64), ('annuity_780A', Float64), ('annuitynextmonth_57A', Float64), ('applicationcnt_361L', Float64), ('applications30d_658L', Float64), ('applicationscnt_1086L', Float64), ('applicationscnt_464L', Float64), ('applicationscnt_629L', Float64), ('applicationscnt_867L', Float64), ('avgdbddpdlast24m_3658932P', Float64), ('avgdbddpdlast3m_4187120P', Float64), ('avgdbdtollast24m_4525197P', Float64), ('avgdpdtolclosure24_3658938P', Float64), ('avginstallast24m_3658937A', Float64), ('avglnamtstart24m_4525187A', Float64), ('avgmaxdpdlast9m_3716943P', Float64), ('avgoutstandbalancel6m_4187114A', Float64), ('avgpmtlast12m_4525200A', Float64), ('bankacctype_710L', String), ('cardtype_51L', String), ('clientscnt12m_3712952L', Float64), ('clientscnt3m_3712950L', Float64), ('clientscnt6m_3712949L', Float64), ('clientscnt_100L', Float64), ('clientscnt_1022L

"recorddate_4527225D" column in train_tax_registry_a file is the same as "deductiondate_4917603D" and "processingdate_168D" in "train_tax_registry_c"

"name_4527232M" in train_tax_registry_a file, "name_4917606M" in train_tax_registry_b file and "employername_160M" in "train_tax_registry_c" are same.

"amount_4527230A" in train_tax_registry_a file, "amount_4917619A" in train_tax_registry_b file and "pmtamount_36A" in train_tax_registry_c are the same.

In [6]:
# Load and merge the static and tax_registry files + base file as well as we will use it later for the aggregation

train_basetable = pl.read_parquet(dataPath + "train_base.parquet")


# Load the rest of the parquet files separately
train_static_0_0 = pl.read_parquet(dataPath + "train_static_0_0.parquet").pipe(set_table_dtypes)
train_static_0_1 = pl.read_parquet(dataPath + "train_static_0_1.parquet").pipe(set_table_dtypes)
train_static_cb = pl.read_parquet(dataPath + "train_static_cb_0.parquet").pipe(set_table_dtypes)
train_tax_registry_a = pl.read_parquet(dataPath + "train_tax_registry_a_1.parquet").pipe(set_table_dtypes)
train_tax_registry_b = pl.read_parquet(dataPath + "train_tax_registry_b_1.parquet").pipe(set_table_dtypes)
train_tax_registry_c = pl.read_parquet(dataPath + "train_tax_registry_c_1.parquet").pipe(set_table_dtypes)

In [7]:
# Overview and statistics for each DataFrame
dataframes = {
    'train_static_0_0': train_static_0_0,
    'train_static_0_1': train_static_0_1,
    'train_static_cb': train_static_cb,
    'train_tax_registry_a': train_tax_registry_a,
    'train_tax_registry_b': train_tax_registry_b,
    'train_tax_registry_c': train_tax_registry_c
}

for name, df in dataframes.items():
    print(f"Overview of {name}:")
    # Print the schema for an overview of the DataFrame structure
    print(df.schema)
    
    # Get a summary of the data
    print(df.describe())

    print("\nMissing values count in each column:")
    # Count missing values in each column
    missing_values = df.select([pl.col(column).is_null().sum().alias(column) for column in df.columns])
    print(missing_values)

    print("\n-----\n")

Overview of train_static_0_0:
OrderedDict([('case_id', Int64), ('actualdpdtolerance_344P', Float64), ('amtinstpaidbefduel24m_4187115A', Float64), ('annuity_780A', Float64), ('annuitynextmonth_57A', Float64), ('applicationcnt_361L', Float64), ('applications30d_658L', Float64), ('applicationscnt_1086L', Float64), ('applicationscnt_464L', Float64), ('applicationscnt_629L', Float64), ('applicationscnt_867L', Float64), ('avgdbddpdlast24m_3658932P', Float64), ('avgdbddpdlast3m_4187120P', Float64), ('avgdbdtollast24m_4525197P', Float64), ('avgdpdtolclosure24_3658938P', Float64), ('avginstallast24m_3658937A', Float64), ('avglnamtstart24m_4525187A', Float64), ('avgmaxdpdlast9m_3716943P', Float64), ('avgoutstandbalancel6m_4187114A', Float64), ('avgpmtlast12m_4525200A', Float64), ('bankacctype_710L', String), ('cardtype_51L', String), ('clientscnt12m_3712952L', Float64), ('clientscnt3m_3712950L', Float64), ('clientscnt6m_3712949L', Float64), ('clientscnt_100L', Float64), ('clientscnt_1022L', Floa

In [8]:
for name, df in dataframes.items():
    print(f"Null value percentages in {name}:")
    total_rows = len(df)

    # Calculate and print the percentage of null values for each column
    for column in df.columns:
        null_count = df.select(pl.col(column).is_null().sum()).to_numpy()[0, 0]  # Accessing the scalar value directly
        null_percentage = (null_count / total_rows) * 100
        print(f"{column}: {null_percentage:.2f}%")

    print("\n-----\n")

Null value percentages in train_static_0_0:
case_id: 0.00%
actualdpdtolerance_344P: 29.56%
amtinstpaidbefduel24m_4187115A: 42.72%
annuity_780A: 0.00%
annuitynextmonth_57A: 0.00%
applicationcnt_361L: 0.00%
applications30d_658L: 0.00%
applicationscnt_1086L: 0.00%
applicationscnt_464L: 0.00%
applicationscnt_629L: 0.00%
applicationscnt_867L: 0.00%
avgdbddpdlast24m_3658932P: 42.45%
avgdbddpdlast3m_4187120P: 65.55%
avgdbdtollast24m_4525197P: 78.28%
avgdpdtolclosure24_3658938P: 32.80%
avginstallast24m_3658937A: 43.21%
avglnamtstart24m_4525187A: 93.05%
avgmaxdpdlast9m_3716943P: 51.34%
avgoutstandbalancel6m_4187114A: 59.68%
avgpmtlast12m_4525200A: 80.43%
bankacctype_710L: 69.65%
cardtype_51L: 87.47%
clientscnt12m_3712952L: 0.00%
clientscnt3m_3712950L: 0.00%
clientscnt6m_3712949L: 0.00%
clientscnt_100L: 0.00%
clientscnt_1022L: 0.00%
clientscnt_1071L: 0.00%
clientscnt_1130L: 0.00%
clientscnt_136L: 99.96%
clientscnt_157L: 0.00%
clientscnt_257L: 0.00%
clientscnt_304L: 0.00%
clientscnt_360L: 0.00%
c

After we checked the data overall, lets do the data aggregation 

In [9]:
# Aggregation functions for numerical columns
numerical_agg_funcs = {
    'min': 'min',
    'max': 'max',
    'mean': 'mean',
    'median': 'median',
    'sum': 'sum'
}

# Aggregation functions for categorical columns
categorical_agg_funcs = {
    'mode': lambda x: x.mode().iloc[0],  # Mode
    'one_hot_encoding': lambda x: x.sum()  # One-hot encoding with sum of counts
}

# Aggregation functions for date columns
date_agg_funcs = {
    'min': 'min',
    'max': 'max',
    'distinct_count': 'nunique'  # Count of distinct values
}

Train Static 0 Aggregation

In [11]:
# Date columns
date_cols = [ train_static_0_0.columns[i] for i in range(len(train_static_0_0.columns)) if (train_static_0_0.columns[i].__contains__('dat')) and (train_static_0_0.dtypes[i] == pl.String) ]

# Categorical columns
cat_cols = [ train_static_0_0.columns[i] for i in range(len(train_static_0_0.columns)) if (train_static_0_0.columns[i] not in date_cols) and (train_static_0_0.dtypes[i] == pl.String) ]

# Numerical columns
ignore_col = ['case_id']
num_col = [ 
    train_static_0_0.columns[i] for i in range(len(train_static_0_0.columns)) 
    if (train_static_0_0.columns[i] not in date_cols) and (train_static_0_0.columns[i] not in cat_cols) and (train_static_0_0.columns[i] not in ignore_col)
]

# Group data
train_static_0_agg = group_file_data(train_static_0_0, num_col, date_cols, cat_cols)

In [12]:
train_static_0_agg.head()

case_id,actualdpdtolerance_344P_min,amtinstpaidbefduel24m_4187115A_min,annuity_780A_min,annuitynextmonth_57A_min,applicationcnt_361L_min,applications30d_658L_min,applicationscnt_1086L_min,applicationscnt_464L_min,applicationscnt_629L_min,applicationscnt_867L_min,avgdbddpdlast24m_3658932P_min,avgdbddpdlast3m_4187120P_min,avgdbdtollast24m_4525197P_min,avgdpdtolclosure24_3658938P_min,avginstallast24m_3658937A_min,avglnamtstart24m_4525187A_min,avgmaxdpdlast9m_3716943P_min,avgoutstandbalancel6m_4187114A_min,avgpmtlast12m_4525200A_min,clientscnt12m_3712952L_min,clientscnt3m_3712950L_min,clientscnt6m_3712949L_min,clientscnt_100L_min,clientscnt_1022L_min,clientscnt_1071L_min,clientscnt_1130L_min,clientscnt_136L_min,clientscnt_157L_min,clientscnt_257L_min,clientscnt_304L_min,clientscnt_360L_min,clientscnt_493L_min,clientscnt_533L_min,clientscnt_887L_min,clientscnt_946L_min,cntincpaycont9m_3716944L_min,…,validfrom_1069D_2019-11-25_sum,validfrom_1069D_2019-11-26_sum,validfrom_1069D_2019-11-27_sum,validfrom_1069D_2019-11-28_sum,validfrom_1069D_2019-11-29_sum,validfrom_1069D_2019-11-30_sum,validfrom_1069D_2019-12-01_sum,validfrom_1069D_2019-12-02_sum,validfrom_1069D_2019-12-03_sum,validfrom_1069D_2019-12-04_sum,validfrom_1069D_2019-12-05_sum,validfrom_1069D_2019-12-06_sum,validfrom_1069D_2019-12-07_sum,validfrom_1069D_2019-12-08_sum,validfrom_1069D_2019-12-09_sum,validfrom_1069D_2019-12-10_sum,validfrom_1069D_2019-12-11_sum,validfrom_1069D_2019-12-12_sum,validfrom_1069D_2019-12-13_sum,validfrom_1069D_2019-12-14_sum,validfrom_1069D_2019-12-15_sum,validfrom_1069D_2019-12-16_sum,validfrom_1069D_2019-12-17_sum,validfrom_1069D_2019-12-18_sum,validfrom_1069D_2019-12-19_sum,validfrom_1069D_2019-12-20_sum,validfrom_1069D_2019-12-21_sum,validfrom_1069D_2019-12-22_sum,validfrom_1069D_2019-12-23_sum,validfrom_1069D_2019-12-24_sum,validfrom_1069D_2019-12-25_sum,validfrom_1069D_2019-12-26_sum,validfrom_1069D_2019-12-27_sum,validfrom_1069D_2019-12-28_sum,validfrom_1069D_2019-12-29_sum,validfrom_1069D_2019-12-30_sum,validfrom_1069D_null_sum
i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,…,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64
1305573,0.0,,2741.4001,8387.2,0.0,0.0,0.0,0.0,0.0,4.0,-21.0,,,0.0,4908.2,,0.0,,,0.0,0.0,0.0,0.0,0.0,0.0,0.0,,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,8.0,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
789395,,,3665.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,,,,,,,,,,0.0,0.0,0.0,0.0,0.0,0.0,0.0,,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
851903,,,6023.4,0.0,0.0,1.0,0.0,9.0,3.0,1.0,,,,,,,,,,0.0,0.0,0.0,0.0,0.0,0.0,0.0,,0.0,0.0,0.0,0.0,0.0,0.0,3.0,0.0,,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
1555141,0.0,0.0,3038.8,0.0,0.0,0.0,0.0,0.0,0.0,2.0,,,,0.0,,,,,,0.0,0.0,0.0,0.0,0.0,0.0,0.0,,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
114861,0.0,,1585.2001,5478.8003,0.0,1.0,0.0,0.0,0.0,1.0,-6.0,,,0.0,2761.2,,0.0,,,0.0,0.0,0.0,0.0,0.0,0.0,0.0,,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,10.0,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1


In [13]:
train_static_0_agg.write_parquet(dataPath + 'train_static_0_grouped.parquet')


Train Static 1 Aggregation

In [None]:
# Date columns
date_cols = [ train_static_0_1.columns[i] for i in range(len(train_static_0_1.columns)) if (train_static_0_1.columns[i].__contains__('dat')) and (train_static_0_1.dtypes[i] == pl.String) ]

# Categorical columns
cat_cols = [ train_static_0_1.columns[i] for i in range(len(train_static_0_1.columns)) if (train_static_0_1.columns[i] not in date_cols) and (train_static_0_1.dtypes[i] == pl.String) ]

# Numerical columns
ignore_col = ['case_id']
num_col = [ 
    train_static_0_1.columns[i] for i in range(len(train_static_0_1.columns)) 
    if (train_static_0_1.columns[i] not in date_cols) and (train_static_0_1.columns[i] not in cat_cols) and (train_static_0_1.columns[i] not in ignore_col)
]

# Group data
train_static_1_agg = group_file_data(train_static_0_1, num_col, date_cols, cat_cols)

In [None]:
train_static_1_agg.head()

In [None]:
train_static_0_agg.write_parquet(dataPath + 'train_static_1_grouped.parquet')

Train Static cb Aggregation

In [None]:
# Date columns
date_cols = [ train_static_cb.columns[i] for i in range(len(train_static_cb.columns)) if (train_static_cb.columns[i].__contains__('dat')) and (train_static_cb.dtypes[i] == pl.String) ]

# Categorical columns
cat_cols = [ train_static_cb.columns[i] for i in range(len(train_static_cb.columns)) if (train_static_cb.columns[i] not in date_cols) and (train_static_cb.dtypes[i] == pl.String) ]

# Numerical columns
ignore_col = ['case_id']
num_col = [ 
    train_static_cb.columns[i] for i in range(len(train_static_cb.columns)) 
    if (train_static_cb.columns[i] not in date_cols) and (train_static_cb.columns[i] not in cat_cols) and (train_static_cb.columns[i] not in ignore_col)
]

# Group data
train_static_cb_agg = group_file_data(train_static_cb, num_col, date_cols, cat_cols)

In [None]:
train_static_cb_agg.head()

In [None]:
train_static_cb_agg.write_parquet(dataPath + 'train_static_cb_grouped.parquet')

Train Tax Registry A Aggregation

In [None]:
# Date columns
date_cols = [ train_tax_registry_a.columns[i] for i in range(len(train_tax_registry_a.columns)) if (train_tax_registry_a.columns[i].__contains__('dat')) and (train_tax_registry_a.dtypes[i] == pl.String) ]

# Categorical columns
cat_cols = [ train_tax_registry_a.columns[i] for i in range(len(train_tax_registry_a.columns)) if (train_tax_registry_a.columns[i] not in date_cols) and (train_tax_registry_a.dtypes[i] == pl.String) ]

# Numerical columns
ignore_col = ['case_id', "num_group1"]
num_col = [ 
    train_tax_registry_a.columns[i] for i in range(len(train_tax_registry_a.columns)) 
    if (train_tax_registry_a.columns[i] not in date_cols) and (train_tax_registry_a.columns[i] not in cat_cols) and (train_tax_registry_a.columns[i] not in ignore_col)
]

# Group data
train_tax_reg_a_agg = group_file_data(train_tax_registry_a, num_col, date_cols, cat_cols)

In [None]:
train_tax_reg_a_agg.head()

In [None]:
train_tax_reg_a_agg.write_parquet(dataPath + 'train_tax_registry_a_grouped.parquet')

Train Tax Registry B Aggregation

In [None]:
# Date columns
date_cols = [ train_tax_registry_b.columns[i] for i in range(len(train_tax_registry_b.columns)) if (train_tax_registry_b.columns[i].__contains__('dat')) and (train_tax_registry_b.dtypes[i] == pl.String) ]

# Categorical columns
cat_cols = [ train_tax_registry_b.columns[i] for i in range(len(train_tax_registry_b.columns)) if (train_tax_registry_b.columns[i] not in date_cols) and (train_tax_registry_b.dtypes[i] == pl.String) ]

# Numerical columns
ignore_col = ['case_id', "num_group1"]
num_col = [ 
    train_tax_registry_b.columns[i] for i in range(len(train_tax_registry_b.columns)) 
    if (train_tax_registry_b.columns[i] not in date_cols) and (train_tax_registry_b.columns[i] not in cat_cols) and (train_tax_registry_b.columns[i] not in ignore_col)
]

# Group data
train_tax_reg_b_agg = group_file_data(train_tax_registry_b, num_col, date_cols, cat_cols)

In [None]:
train_tax_reg_b_agg.head()

In [None]:
train_tax_reg_b_agg.write_parquet(dataPath + 'train_tax_registry_b_grouped.parquet')

Train Tax Registry C Aggregation

In [None]:
# Date columns
date_cols = [ train_tax_registry_c.columns[i] for i in range(len(train_tax_registry_c.columns)) if (train_tax_registry_c.columns[i].__contains__('dat')) and (train_tax_registry_c.dtypes[i] == pl.String) ]

# Categorical columns
cat_cols = [ train_tax_registry_c.columns[i] for i in range(len(train_tax_registry_c.columns)) if (train_tax_registry_c.columns[i] not in date_cols) and (train_tax_registry_c.dtypes[i] == pl.String) ]

# Numerical columns
ignore_col = ['case_id', "num_group1"]
num_col = [ 
    train_tax_registry_c.columns[i] for i in range(len(train_tax_registry_c.columns)) 
    if (train_tax_registry_c.columns[i] not in date_cols) and (train_tax_registry_c.columns[i] not in cat_cols) and (train_tax_registry_c.columns[i] not in ignore_col)
]

# Group data
train_tax_reg_c_agg = group_file_data(train_tax_registry_c, num_col, date_cols, cat_cols)

In [None]:
train_tax_reg_c_agg.head()

In [None]:
train_tax_reg_c_agg.write_parquet(dataPath + 'train_tax_registry_c_grouped.parquet')

Joining all aggregated tables together into the basetable

In [None]:
# Join all tables together.
data = train_basetable.join(
    train_static_0_agg, how="left", on="case_id"
).join(
    train_static_1_agg, how="left", on="case_id"
).join(
    train_static_cb, how="left", on="case_id"
).join(
    train_tax_reg_a_agg, how="left", on="case_id"
).join(
    train_tax_reg_b_agg, how="left", on="case_id"
).join(
    train_tax_reg_c_agg, how="left", on="case_id"
)
