In [None]:
import polars as pl # type: ignore
import warnings

from glob import glob
from pathlib import Path
from typing import Any

warnings.filterwarnings('ignore')

base_path: str = '/kaggle/input/home-credit-credit-risk-model-stability/'

In [None]:
class DTypeHandler():
    @staticmethod
    def get_feat_defs(ending_with: str):
        feat_defs: pl.DataFrame = pl.read_csv(base_path + 'feature_definitions.csv')

        filtered_feats: pl.DataFrame = feat_defs.filter(pl.col('Variable').apply(lambda var: var.endswith(ending_with)))

        with pl.Config(fmt_str_lengths=200, tbl_rows=-1):
            print(filtered_feats)

        filtered_feats = None
        feat_defs = None

     
    @staticmethod
    def find_index(lst: list, item: Any) -> int | None:
        try:
            return lst.index(item)
        except ValueError:
            return None

    
    @staticmethod
    def dtype_to_str(dtype: pl.DataType) -> str:
        dtype_map = {
            pl.Decimal: "Decimal",

            pl.Float32: "Float32",
            pl.Float64: "Float64",

            pl.UInt8: "UInt8",
            pl.UInt16: "UInt16",
            pl.UInt32: "UInt32",
            pl.UInt64: "UInt64",

            pl.Int8: "Int8",
            pl.Int16: "Int16",
            pl.Int32: "Int32",
            pl.Int64: "Int64",

            pl.Date: "Date",
            pl.Datetime: "Datetime",
            pl.Duration: "Duration",
            pl.Time: "Time",

            pl.Array: "Array",
            pl.List: "List",
            pl.Struct: "Struct",

            pl.String: "String",
            pl.Categorical: "Categorical",
            pl.Enum: "Enum",
            pl.Utf8: "Utf8",

            pl.Binary: "Binary",
            pl.Boolean: "Boolean",
            pl.Null: "Null",
            pl.Object: "Object",
            pl.Unknown: "Unknown"
        }

        return dtype_map.get(dtype)

    
    @staticmethod
    def find_feat_occur(regex_path: str, ending_with: str) -> pl.DataFrame:
        feat_defs: pl.DataFrame = pl.read_csv(base_path + 'feature_definitions.csv').filter(pl.col('Variable').apply(lambda var: var.endswith(ending_with)))
        feat_defs.sort(by=['Variable'])

        feats: list = feat_defs['Variable'].to_list()
        feats.sort()

        occurrences: list = [[set(), set()] for _ in range(feat_defs.height)]

        for path in glob(str(regex_path)):
            df_schema: dict = pl.read_parquet_schema(path)

            for (feat, dtype) in df_schema.items():
                index: int = DTypeHandler.find_index(feats, feat)
                if index != None:
                    occurrences[index][0].add(DTypeHandler.dtype_to_str(dtype))
                    occurrences[index][1].add(Path(path).stem)

        data_types: list[str] = [None] * feat_defs.height
        file_locs: list[str] = [None] * feat_defs.height

        for i, feat in enumerate(feats):
            data_types[i] = list(occurrences[i][0])
            file_locs[i] = list(occurrences[i][1])

        feat_defs = feat_defs.with_columns(pl.Series(data_types).alias('Data_Type(s)'))
        feat_defs = feat_defs.with_columns(pl.Series(file_locs).alias('File_Loc(s)'))

        return feat_defs

    @staticmethod
    def change_dtypes(df: pl.DataFrame) -> pl.DataFrame:
        for col in df.columns:
            if col in ['case_id', 'WEEK_NUM', 'num_group1', 'num_group2']:
                df = df.with_columns(pl.col(col).cast(pl.UInt32).alias(col))
            elif col == 'date_decision':
                df = df.with_columns(pl.col(col).cast(pl.Date).alias(col))
            # Predictors belonging to 'P - Transform DPD (Days past due)' must be integers.
            elif col[-1] == 'P':
                df = df.with_columns(pl.col(col).cast(pl.UInt32).alias(col))
            # Predictors belonging to 'A - Transform amount' must be floats.
            elif col[-1] == 'A':
                df = df.with_columns(pl.col(col).cast(pl.Float64).alias(col))
            # Predictors belonging to 'D - Transform date' are dates.
            elif col[-1] == 'D':
                df = df.with_columns(pl.col(col).cast(pl.Date).alias(col))
            elif col[-1] in ("M",):
                    df = df.with_columns(pl.col(col).cast(pl.String));
        return df

In [None]:
# feat_defs: pl.DataFrame = DTypeHandler.find_feat_occur(base_path + 'parquet_files/train/train_*.parquet', 'P')
# feat_defs: pl.DataFrame = DTypeHandler.find_feat_occur(base_path + 'parquet_files/train/train_*.parquet', 'M')
# feat_defs: pl.DataFrame = DTypeHandler.find_feat_occur(base_path + 'parquet_files/train/train_*.parquet', 'A')
feat_defs: pl.DataFrame = DTypeHandler.find_feat_occur(base_path + 'parquet_files/train/train_*.parquet', 'D')
# feat_defs: pl.DataFrame = DTypeHandler.find_feat_occur(base_path + 'parquet_files/train/train_*.parquet', 'T')
# feat_defs: pl.DataFrame = DTypeHandler.find_feat_occur(base_path + 'parquet_files/train/train_*.parquet', 'L')
# feat_defs: pl.DataFrame = pl.read_csv(base_path + 'feature_definitions.csv')
with pl.Config(fmt_str_lengths=1000, tbl_rows=-1, tbl_width_chars=180):
    print(feat_defs)