In [1]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('logistic_consulting').getOrCreate()

In [2]:
# Loading the data
df = spark.sql('SELECT * FROM customer_churn_csv')

In [3]:
# Exploring the data
df.show()

In [4]:
df.printSchema()

In [5]:
df.describe().show()

In [6]:
df_count = df.groupby('Company').count()
df_count.orderBy(df_count['count'].desc()).show()

In [7]:
df.columns

In [8]:
# One hot encoding
from pyspark.ml.feature import StringIndexer, OneHotEncoder ['Low', 'Medium', 'High']
company_indexer = StringIndexer(inputCol='Company', outputCol='CompanyIndex') [1, 2, 3]
company_encoder = OneHotEncoder(inputCol='CompanyIndex', outputCol='CompanyVec') [1, 0, 0], [0, 1, 0], [0, 0, 1]

In [9]:
# Transforming dataframe into PySpark format
from pyspark.ml.feature import VectorAssembler, VectorIndexer
assembler = VectorAssembler(inputCols=['Age', 'Total_Purchase', 'Years', 'Num_Sites', 'CompanyVec'],
                           outputCol='features')

In [10]:
# Creating the model
from pyspark.ml.classification import LogisticRegression
classifier = LogisticRegression(featuresCol='features', labelCol='Churn')

In [11]:
# Creating the pipeline
from pyspark.ml.pipeline import Pipeline
pipeline = Pipeline(stages=[company_indexer, company_encoder, assembler, classifier])

In [12]:
# Train test split
train_data, test_data = df.randomSplit([0.7, 0.3])

In [13]:
# Fitting the model
fitted_classifier = pipeline.fit(train_data)

In [14]:
# Evaluation
results = fitted_classifier.transform(test_data)

In [15]:
results.printSchema()

In [16]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator
evaluator = BinaryClassificationEvaluator(rawPredictionCol='prediction', labelCol='Churn')
area_under_curve = evaluator.evaluate(results)

In [17]:
area_under_curve

In [18]:
# Predicting on unlabeled test set
unlabeled_df = spark.sql('SELECT * FROM new_customers_csv')
unlabeled_df.show()

In [19]:
predictions = fitted_classifier.transform(unlabeled_df.select('Age', 'Total_Purchase', 'Years', 'Num_Sites', 'Company'))

In [20]:
predictions.printSchema()

In [21]:
predictions.select('Prediction').show()