In [None]:
import polars as pl
from sklearn.model_selection import KFold
import category_encoders as ce

class TargetEncoderCV:
    def __init__(self, n_folds, seed):
        self.n_folds = n_folds
        self.seed = seed
        self.final_target_encoder = None

    def fit_transform(self, df, test, col, target_col):
        features = [c for c in df.columns if c != target_col]
        kf = KFold(n_splits=self.n_folds, shuffle=True, random_state=self.seed)
        encoded_features = []

        for train_idx, val_idx in kf.split(df):
            X_train = df[train_idx].select(features)
            X_valid = df[val_idx].select(features)
            y_train = df[train_idx][target_col]

            target_encoder = ce.TargetEncoder()
            target_encoder.fit(X_train[col].to_pandas(), y_train.to_pandas())

            X_valid = X_valid.with_columns(
                pl.Series(f'{col}_target_Encoded', target_encoder.transform(X_valid[col].to_pandas()))
            )
            encoded_features.append(X_valid)

        encoded_df = pl.concat(encoded_features).sort('index')
        df = df.with_columns(encoded_df.select(f'{col}_target_Encoded'))

        self.final_target_encoder = ce.TargetEncoder()
        self.final_target_encoder.fit(df[col].to_pandas(), df[target_col].to_pandas())
        test = test.with_columns(
            pl.Series(f'{col}_target_Encoded', self.final_target_encoder.transform(test[col].to_pandas()))
        )

        return df, test

    def transform(self, test, col):
        if self.final_target_encoder is None:
            raise ValueError("The model has not been fit yet. Call 'fit_transform' first.")
        test = test.with_columns(
            pl.Series(f'{col}_target_Encoded', self.final_target_encoder.transform(test[col].to_pandas()))
        )
        return test


# target_encoder_cv = TargetEncoderCV(n_folds=CFG.n_folds, seed=CFG.seed)
# df_encoded, test_encoded = target_encoder_cv.fit_transform(df, test, col='col1', target_col='target')
# test_transformed = target_encoder_cv.transform(test, col='col1')