## Setup

In [1]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

import sklearn


## Demo

In [32]:
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils.validation import check_is_fitted

W = pd.DataFrame({"x1": range(1,6), "x3": range(5, 0, -1)})
X = pd.DataFrame({"x1": range(1,6), "x2": range(5, 0, -1)})
Y = pd.DataFrame({"x1": range(1,6)})
Z = np.array(X)

In [81]:
class interact_features(BaseEstimator, TransformerMixin):
    def __init__(self, interaction_only=False):
        self.interaction_only = interaction_only
        
    def fit(self, X, y=None):
        self._validate_data(X=X, reset=True, ensure_min_features=2)
        
        return self
    
    def transform(self, X, y=None):
        self._validate_data(X=X, reset=False)
        check_is_fitted(self)
        
        X = np.array(X)
        new_cols = []
        for i in range(self.n_features_in_-1):
            for j in range(i+1, self.n_features_in_):
                new_cols.append( X[:,i] * X[:,j] )
        
        X_new = np.column_stack(new_cols)
        
        if not self.interaction_only:
            X_new = np.column_stack([X, X_new])
        
        return X_new

    def get_feature_names_out(self):
        check_is_fitted(self)
        
        if not hasattr(self, "feature_names_in_"):
            feat_names = ["x" + str(i) for i in range(self.n_features_in_)]
        else:
            feat_names = list(self.feature_names_in_)
        
        new_feat_names = []
        for i in range(self.n_features_in_-1):
            for j in range(i+1, self.n_features_in_):
                new_feat_names.append( feat_names[i] + " * " + feat_names[j] )
    
        if not self.interaction_only:
            new_feat_names = feat_names + new_feat_names 
        
        return new_feat_names

In [78]:
itf = interact_features().fit(X)

In [79]:
itf.transform(X)

array([[1, 5, 5],
       [2, 4, 8],
       [3, 3, 9],
       [4, 2, 8],
       [5, 1, 5]])

In [80]:
itf.get_feature_names_out()

['x1', 'x2']
['x1 * x2']
['x1', 'x2', 'x1 * x2']


['x1', 'x2', 'x1 * x2']