In [1]:
import pandas as pd
import numpy as np
import os 
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.base import BaseEstimator,TransformerMixin
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import FunctionTransformer
from sklearn.preprocessing import OneHotEncoder
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer

from pipeline_testing_util import get_transformers

# Reference:-
# https://stackoverflow.com/questions/70527088/columntransformer-pipeline-with-ohe-is-the-ohe-encoded-field-retained-or-rem
# https://towardsdatascience.com/creating-custom-transformers-for-sklearn-pipelines-d3d51852ecc1
# https://stackoverflow.com/questions/62225230/consistent-columntransformer-for-intersecting-lists-of-columns

In [2]:
titanic = pd.read_csv('data/titanic.csv')
titanic.head()

Unnamed: 0,Survived,Pclass,Name,Sex,Age,Siblings/Spouses Aboard,Parents/Children Aboard,Fare
0,0,2,Mr. William John Berriman,male,23.0,0,0,13.0
1,1,3,Mrs. (Beila) Moor,female,27.0,0,1,12.475
2,0,3,Mr. Nestor Cyriel Vande Walle,male,28.0,0,0,9.5
3,0,3,Mr. Khalil Saad,male,25.0,0,0,7.225
4,0,3,Miss. Gerda Ulrika Dahlberg,female,22.0,0,0,10.5167


In [3]:
titanic = titanic[titanic['Fare']!=0]

In [4]:
titanic['title'] = titanic['Name'].apply(lambda x: x.split()[0][:-1])
titanic = titanic.drop('Name',axis=1)

In [5]:
X = titanic.drop("Survived", axis=1)
y = titanic["Survived"]

In [6]:
# class GroupScale(BaseEstimator, TransformerMixin):
#     def __init__(self, columns):
#         self.columns = columns
        
#     def fit(self, X, y = None):
#         return self
#     def transform(self, X, y = None):
#         val = X.groupby(self.columns[0])[self.columns[1]].transform(lambda x: (x - x.mean()) / x.std())
#         return val.values.reshape(-1,1)

In [7]:
class GroupScale1(BaseEstimator, TransformerMixin):
    def __init__(self, columns):
        self.columns = columns
        
    def fit(self, X, y = None):
        return self
    def transform(self, X, y = None):
#         print(list(X))
        val = X.groupby('Pclass')[self.columns].transform(lambda x: (x - x.mean()) / x.std())
        return val.values.reshape(-1,1)

In [8]:
cat_indx = [indx for indx,tp in enumerate(X.dtypes) if tp=='O']

cat_col = list(X.select_dtypes('object'))
num_col = list(set(X) - set(cat_col))
num_col = list(set(num_col) - set(['Age','Fare']))

num_attribs = num_col
cat_attribs = cat_col

def log_transform(x):
#     return x.groupby('Pclass').transform(lambda x: (x - x.mean()) / x.std())
#     print(type(x))
    return np.log(x)

transformer = FunctionTransformer(log_transform)


# ColumnTransformer applies its transformers in parallel
preprocessor = ColumnTransformer([
        ("num_log", transformer, ['Fare']),
        ('group_age_scale', GroupScale1('Age'),['Pclass','Age']),
        ("num", StandardScaler(), num_attribs),
        ("cat", OneHotEncoder(), cat_attribs),
    ])


# preprocessor = ColumnTransformer([
#         ("num_log", transformer, ['Fare']),
#         ('columns selector', GroupScale(['Pclass','Age']),['Pclass','Age']),
#         ("num", StandardScaler(), num_attribs),
#         ("cat", OneHotEncoder(), cat_attribs),
#     ])

In [9]:
model = Pipeline(steps=[
                        ("preprocessor", preprocessor), 
                        ('model', LinearRegression())
                       ])

model.fit(X, y)

In [10]:
model.predict(X)

array([ 2.66110324e-01,  7.01470533e-01,  1.02440116e-01,  9.88371110e-02,
        6.63668579e-01,  2.88616682e-01,  6.38588043e-01, -2.66278962e-01,
       -5.86367526e-01,  3.22436167e-01,  1.04348340e+00,  7.94974792e-02,
        6.48266801e-01,  5.77819613e-01,  8.63585795e-01,  2.85890152e-01,
        3.30857466e-01,  2.66110324e-01,  7.06269222e-01, -6.19553794e-02,
        6.17060991e-01,  7.17546925e-01,  6.51669517e-01,  2.10465801e-02,
        1.15822384e-01,  3.24923880e-01,  7.97334996e-01,  9.40701515e-01,
        2.88416850e-01,  7.95901169e-01,  1.22656122e-01, -4.21385333e-03,
        1.14743015e-02,  6.45990668e-01,  7.36678591e-02,  1.35213790e-01,
        5.09467613e-04,  6.45698247e-01,  6.96232669e-01,  6.88932051e-01,
        1.43649242e-01,  6.12293218e-01,  1.55604870e-01,  1.26357646e-01,
        3.96407003e-02,  2.82099939e-01,  8.72787226e-02,  7.70344371e-01,
       -1.59257672e-02,  8.19397008e-01,  1.02310208e+00,  1.22685443e-01,
        9.98275457e-01,  