<a href="https://colab.research.google.com/github/MpRonald/Machine-Learning/blob/main/PySpark_Classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Spark Session

In [2]:
!pip install pyspark -q

[K     |████████████████████████████████| 281.3 MB 59 kB/s 
[K     |████████████████████████████████| 199 kB 73.8 MB/s 
[?25h  Building wheel for pyspark (setup.py) ... [?25l[?25hdone


In [99]:
from pyspark.sql import SparkSession, functions as f
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier, RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

In [4]:
spark = SparkSession.builder.master('local[*]').appName('Spark Classifier').getOrCreate()
spark

# Loading and Treatment Dataset 

In [5]:
data = spark.read.csv('/content/drive/MyDrive/Datasets/dados_clientes.csv', sep=',',header=True, inferSchema=True)
data

DataFrame[id: int, Churn: string, Mais65anos: int, Conjuge: string, Dependentes: string, MesesDeContrato: int, TelefoneFixo: string, MaisDeUmaLinhaTelefonica: string, Internet: string, SegurancaOnline: string, BackupOnline: string, SeguroDispositivo: string, SuporteTecnico: string, TVaCabo: string, StreamingFilmes: string, TipoContrato: string, ContaCorreio: string, MetodoPagamento: string, MesesCobrados: double]

In [6]:
data.show()

+---+-----+----------+-------+-----------+---------------+------------+------------------------+-----------+------------------+------------------+------------------+------------------+------------------+------------------+------------+------------+----------------+-------------+
| id|Churn|Mais65anos|Conjuge|Dependentes|MesesDeContrato|TelefoneFixo|MaisDeUmaLinhaTelefonica|   Internet|   SegurancaOnline|      BackupOnline| SeguroDispositivo|    SuporteTecnico|           TVaCabo|   StreamingFilmes|TipoContrato|ContaCorreio| MetodoPagamento|MesesCobrados|
+---+-----+----------+-------+-----------+---------------+------------+------------------------+-----------+------------------+------------------+------------------+------------------+------------------+------------------+------------+------------+----------------+-------------+
|  0|  Nao|         0|    Sim|        Nao|              1|         Nao|    SemServicoTelefonico|        DSL|               Nao|               Sim|              

In [7]:
data.count()

10348

In [8]:
data.groupBy('Churn').count().show()

+-----+-----+
|Churn|count|
+-----+-----+
|  Sim| 5174|
|  Nao| 5174|
+-----+-----+



In [9]:
data.printSchema()

root
 |-- id: integer (nullable = true)
 |-- Churn: string (nullable = true)
 |-- Mais65anos: integer (nullable = true)
 |-- Conjuge: string (nullable = true)
 |-- Dependentes: string (nullable = true)
 |-- MesesDeContrato: integer (nullable = true)
 |-- TelefoneFixo: string (nullable = true)
 |-- MaisDeUmaLinhaTelefonica: string (nullable = true)
 |-- Internet: string (nullable = true)
 |-- SegurancaOnline: string (nullable = true)
 |-- BackupOnline: string (nullable = true)
 |-- SeguroDispositivo: string (nullable = true)
 |-- SuporteTecnico: string (nullable = true)
 |-- TVaCabo: string (nullable = true)
 |-- StreamingFilmes: string (nullable = true)
 |-- TipoContrato: string (nullable = true)
 |-- ContaCorreio: string (nullable = true)
 |-- MetodoPagamento: string (nullable = true)
 |-- MesesCobrados: double (nullable = true)



###Data Treatement

In [10]:
data.columns

['id',
 'Churn',
 'Mais65anos',
 'Conjuge',
 'Dependentes',
 'MesesDeContrato',
 'TelefoneFixo',
 'MaisDeUmaLinhaTelefonica',
 'Internet',
 'SegurancaOnline',
 'BackupOnline',
 'SeguroDispositivo',
 'SuporteTecnico',
 'TVaCabo',
 'StreamingFilmes',
 'TipoContrato',
 'ContaCorreio',
 'MetodoPagamento',
 'MesesCobrados']

In [11]:
binary_columns = [
                'Churn',
                'Conjuge',
                'Dependentes',                
                'TelefoneFixo',
                'MaisDeUmaLinhaTelefonica',                
                'SegurancaOnline',
                'BackupOnline',
                'SeguroDispositivo',
                'SuporteTecnico',
                'TVaCabo',
                'StreamingFilmes',                
                'ContaCorreio',
]

In [12]:
all_columns = [f.when(f.col(x)=='Sim', 1).otherwise(0).alias(x) for x in binary_columns]

In [13]:
all_columns

[Column<'CASE WHEN (Churn = Sim) THEN 1 ELSE 0 END AS Churn'>,
 Column<'CASE WHEN (Conjuge = Sim) THEN 1 ELSE 0 END AS Conjuge'>,
 Column<'CASE WHEN (Dependentes = Sim) THEN 1 ELSE 0 END AS Dependentes'>,
 Column<'CASE WHEN (TelefoneFixo = Sim) THEN 1 ELSE 0 END AS TelefoneFixo'>,
 Column<'CASE WHEN (MaisDeUmaLinhaTelefonica = Sim) THEN 1 ELSE 0 END AS MaisDeUmaLinhaTelefonica'>,
 Column<'CASE WHEN (SegurancaOnline = Sim) THEN 1 ELSE 0 END AS SegurancaOnline'>,
 Column<'CASE WHEN (BackupOnline = Sim) THEN 1 ELSE 0 END AS BackupOnline'>,
 Column<'CASE WHEN (SeguroDispositivo = Sim) THEN 1 ELSE 0 END AS SeguroDispositivo'>,
 Column<'CASE WHEN (SuporteTecnico = Sim) THEN 1 ELSE 0 END AS SuporteTecnico'>,
 Column<'CASE WHEN (TVaCabo = Sim) THEN 1 ELSE 0 END AS TVaCabo'>,
 Column<'CASE WHEN (StreamingFilmes = Sim) THEN 1 ELSE 0 END AS StreamingFilmes'>,
 Column<'CASE WHEN (ContaCorreio = Sim) THEN 1 ELSE 0 END AS ContaCorreio'>]

In [14]:
for column in reversed(data.columns):
    if column not in binary_columns:
        all_columns.insert(0, column)
all_columns

['id',
 'Mais65anos',
 'MesesDeContrato',
 'Internet',
 'TipoContrato',
 'MetodoPagamento',
 'MesesCobrados',
 Column<'CASE WHEN (Churn = Sim) THEN 1 ELSE 0 END AS Churn'>,
 Column<'CASE WHEN (Conjuge = Sim) THEN 1 ELSE 0 END AS Conjuge'>,
 Column<'CASE WHEN (Dependentes = Sim) THEN 1 ELSE 0 END AS Dependentes'>,
 Column<'CASE WHEN (TelefoneFixo = Sim) THEN 1 ELSE 0 END AS TelefoneFixo'>,
 Column<'CASE WHEN (MaisDeUmaLinhaTelefonica = Sim) THEN 1 ELSE 0 END AS MaisDeUmaLinhaTelefonica'>,
 Column<'CASE WHEN (SegurancaOnline = Sim) THEN 1 ELSE 0 END AS SegurancaOnline'>,
 Column<'CASE WHEN (BackupOnline = Sim) THEN 1 ELSE 0 END AS BackupOnline'>,
 Column<'CASE WHEN (SeguroDispositivo = Sim) THEN 1 ELSE 0 END AS SeguroDispositivo'>,
 Column<'CASE WHEN (SuporteTecnico = Sim) THEN 1 ELSE 0 END AS SuporteTecnico'>,
 Column<'CASE WHEN (TVaCabo = Sim) THEN 1 ELSE 0 END AS TVaCabo'>,
 Column<'CASE WHEN (StreamingFilmes = Sim) THEN 1 ELSE 0 END AS StreamingFilmes'>,
 Column<'CASE WHEN (ContaCorr

In [15]:
dataset = data.select(all_columns)
dataset.show()

+---+----------+---------------+-----------+------------+----------------+-------------+-----+-------+-----------+------------+------------------------+---------------+------------+-----------------+--------------+-------+---------------+------------+
| id|Mais65anos|MesesDeContrato|   Internet|TipoContrato| MetodoPagamento|MesesCobrados|Churn|Conjuge|Dependentes|TelefoneFixo|MaisDeUmaLinhaTelefonica|SegurancaOnline|BackupOnline|SeguroDispositivo|SuporteTecnico|TVaCabo|StreamingFilmes|ContaCorreio|
+---+----------+---------------+-----------+------------+----------------+-------------+-----+-------+-----------+------------+------------------------+---------------+------------+-----------------+--------------+-------+---------------+------------+
|  0|         0|              1|        DSL| Mensalmente|BoletoEletronico|        29.85|    0|      1|          0|           0|                       0|              0|           1|                0|             0|      0|              0|      

In [16]:
dataset.printSchema()

root
 |-- id: integer (nullable = true)
 |-- Mais65anos: integer (nullable = true)
 |-- MesesDeContrato: integer (nullable = true)
 |-- Internet: string (nullable = true)
 |-- TipoContrato: string (nullable = true)
 |-- MetodoPagamento: string (nullable = true)
 |-- MesesCobrados: double (nullable = true)
 |-- Churn: integer (nullable = false)
 |-- Conjuge: integer (nullable = false)
 |-- Dependentes: integer (nullable = false)
 |-- TelefoneFixo: integer (nullable = false)
 |-- MaisDeUmaLinhaTelefonica: integer (nullable = false)
 |-- SegurancaOnline: integer (nullable = false)
 |-- BackupOnline: integer (nullable = false)
 |-- SeguroDispositivo: integer (nullable = false)
 |-- SuporteTecnico: integer (nullable = false)
 |-- TVaCabo: integer (nullable = false)
 |-- StreamingFilmes: integer (nullable = false)
 |-- ContaCorreio: integer (nullable = false)



### Dummies Variables

In [17]:
data.select(['internet', 'TipoContrato', 'MetodoPagamento']).show()

+-----------+------------+----------------+
|   internet|TipoContrato| MetodoPagamento|
+-----------+------------+----------------+
|        DSL| Mensalmente|BoletoEletronico|
|        DSL|       UmAno|          Boleto|
|        DSL| Mensalmente|          Boleto|
|        DSL|       UmAno|   DebitoEmConta|
|FibraOptica| Mensalmente|BoletoEletronico|
|FibraOptica| Mensalmente|BoletoEletronico|
|FibraOptica| Mensalmente|   CartaoCredito|
|        DSL| Mensalmente|          Boleto|
|FibraOptica| Mensalmente|BoletoEletronico|
|        DSL|       UmAno|   DebitoEmConta|
|        DSL| Mensalmente|          Boleto|
|        Nao|    DoisAnos|   CartaoCredito|
|FibraOptica|       UmAno|   CartaoCredito|
|FibraOptica| Mensalmente|   DebitoEmConta|
|FibraOptica| Mensalmente|BoletoEletronico|
|FibraOptica|    DoisAnos|   CartaoCredito|
|        Nao|       UmAno|          Boleto|
|FibraOptica|    DoisAnos|   DebitoEmConta|
|        DSL| Mensalmente|   CartaoCredito|
|FibraOptica| Mensalmente|Boleto

In [18]:
dataset.groupBy('id').pivot('Internet').agg(f.lit(1)).na.fill(0).show()

+----+---+-----------+---+
|  id|DSL|FibraOptica|Nao|
+----+---+-----------+---+
|7982|  1|          0|  0|
|9465|  0|          1|  0|
|2122|  1|          0|  0|
|3997|  1|          0|  0|
|6654|  0|          1|  0|
|7880|  0|          1|  0|
|4519|  0|          1|  0|
|6466|  0|          1|  0|
| 496|  1|          0|  0|
|7833|  0|          1|  0|
|1591|  0|          0|  1|
|2866|  0|          1|  0|
|8592|  0|          1|  0|
|1829|  0|          1|  0|
| 463|  0|          1|  0|
|4900|  0|          1|  0|
|4818|  0|          1|  0|
|7554|  1|          0|  0|
|1342|  0|          0|  1|
|5300|  0|          1|  0|
+----+---+-----------+---+
only showing top 20 rows



In [19]:
internet = dataset.groupBy('id').pivot('Internet').agg(f.lit(1)).na.fill(0)
tipo_contrato = dataset.groupBy('id').pivot('TipoContrato').agg(f.lit(1)).na.fill(0)
metodo_pagamento = dataset.groupBy('id').pivot('MetodoPagamento').agg(f.lit(1)).na.fill(0)

In [20]:
dataset.columns

['id',
 'Mais65anos',
 'MesesDeContrato',
 'Internet',
 'TipoContrato',
 'MetodoPagamento',
 'MesesCobrados',
 'Churn',
 'Conjuge',
 'Dependentes',
 'TelefoneFixo',
 'MaisDeUmaLinhaTelefonica',
 'SegurancaOnline',
 'BackupOnline',
 'SeguroDispositivo',
 'SuporteTecnico',
 'TVaCabo',
 'StreamingFilmes',
 'ContaCorreio']

In [21]:
dataset = dataset\
    .join(internet, 'id', how='inner')\
    .join(tipo_contrato, 'id', how='inner')\
    .join(metodo_pagamento, 'id', how='inner')\
    .select(
        '*',
        f.col('DSL').alias('internet_dsl'),
        f.col('FibraOptica').alias('internet_fibra'),
        f.col('Nao').alias('internet_nao'),
        f.col('Mensalmente').alias('tipo_contrat_mensal'),
        f.col('UmAno').alias('contrato_um_ano'),
        f.col('DoisAnos').alias('contrato_dois_anos'),
        f.col('DebitoEmConta').alias('pagamento_deb_conta'),
        f.col('CartaoCredito').alias('pagamento_cartao_credito'),
        f.col('BoletoEletronico').alias('pagamento_boleto_eletr'),
        f.col('Boleto').alias('pagamento_boleto')
    )\
    .drop('Internet', 'TipoContrato', 'MetodoPagamento', 'DSL',
          'FibraOptica', 'Nao', 'Mensalmente', 'UmAno', 'DoisAnos',
          'DebitoEmConta', 'CartaoCredito', 'BoletoEletronico', 'Boleto'

    )

In [22]:
dataset.show()

+----+----------+---------------+-----------------+-----+-------+-----------+------------+------------------------+---------------+------------+-----------------+--------------+-------+---------------+------------+------------+--------------+------------+-------------------+---------------+------------------+-------------------+------------------------+----------------------+----------------+
|  id|Mais65anos|MesesDeContrato|    MesesCobrados|Churn|Conjuge|Dependentes|TelefoneFixo|MaisDeUmaLinhaTelefonica|SegurancaOnline|BackupOnline|SeguroDispositivo|SuporteTecnico|TVaCabo|StreamingFilmes|ContaCorreio|internet_dsl|internet_fibra|internet_nao|tipo_contrat_mensal|contrato_um_ano|contrato_dois_anos|pagamento_deb_conta|pagamento_cartao_credito|pagamento_boleto_eletr|pagamento_boleto|
+----+----------+---------------+-----------------+-----+-------+-----------+------------+------------------------+---------------+------------+-----------------+--------------+-------+---------------+-------

In [23]:
dataset.printSchema()

root
 |-- id: integer (nullable = true)
 |-- Mais65anos: integer (nullable = true)
 |-- MesesDeContrato: integer (nullable = true)
 |-- MesesCobrados: double (nullable = true)
 |-- Churn: integer (nullable = false)
 |-- Conjuge: integer (nullable = false)
 |-- Dependentes: integer (nullable = false)
 |-- TelefoneFixo: integer (nullable = false)
 |-- MaisDeUmaLinhaTelefonica: integer (nullable = false)
 |-- SegurancaOnline: integer (nullable = false)
 |-- BackupOnline: integer (nullable = false)
 |-- SeguroDispositivo: integer (nullable = false)
 |-- SuporteTecnico: integer (nullable = false)
 |-- TVaCabo: integer (nullable = false)
 |-- StreamingFilmes: integer (nullable = false)
 |-- ContaCorreio: integer (nullable = false)
 |-- internet_dsl: integer (nullable = true)
 |-- internet_fibra: integer (nullable = true)
 |-- internet_nao: integer (nullable = true)
 |-- tipo_contrat_mensal: integer (nullable = true)
 |-- contrato_um_ano: integer (nullable = true)
 |-- contrato_dois_anos: int

In [24]:
dataset = dataset.withColumnRenamed('Churn', 'label')

In [25]:
X = dataset.columns
X.remove('label')
X.remove('id')
X

['Mais65anos',
 'MesesDeContrato',
 'MesesCobrados',
 'Conjuge',
 'Dependentes',
 'TelefoneFixo',
 'MaisDeUmaLinhaTelefonica',
 'SegurancaOnline',
 'BackupOnline',
 'SeguroDispositivo',
 'SuporteTecnico',
 'TVaCabo',
 'StreamingFilmes',
 'ContaCorreio',
 'internet_dsl',
 'internet_fibra',
 'internet_nao',
 'tipo_contrat_mensal',
 'contrato_um_ano',
 'contrato_dois_anos',
 'pagamento_deb_conta',
 'pagamento_cartao_credito',
 'pagamento_boleto_eletr',
 'pagamento_boleto']

In [26]:
dataset_prep = VectorAssembler(inputCols=X, outputCol='features').transform(dataset).select('features', 'label')

In [27]:
dataset_prep.show(10, truncate=False)

+-----------------------------------------------------------------------------------------------------------+-----+
|features                                                                                                   |label|
+-----------------------------------------------------------------------------------------------------------+-----+
|(24,[1,2,11,12,13,14,17,22],[1.0,45.30540797610398,1.0,1.0,1.0,1.0,1.0,1.0])                               |1    |
|(24,[1,2,3,5,6,8,9,11,12,13,15,17,22],[60.0,103.6142230120257,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0])|1    |
|(24,[1,2,5,6,10,11,12,13,14,18,23],[12.0,75.85,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0])                       |0    |
|(24,[1,2,3,5,8,12,13,14,19,21],[69.0,61.45,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0])                               |0    |
|(24,[1,2,3,5,6,11,13,15,17,22],[7.0,86.5,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0])                                 |1    |
|(24,[1,2,5,6,12,13,15,17,22],[14.0,85.03742670311915,1.0,1.0,1.0,1.0,1.

### Adjust Sample

In [28]:
seed=42
train, test = dataset_prep.randomSplit([0.8, 0.2], seed=seed)

In [29]:
train.show()

+--------------------+-----+
|            features|label|
+--------------------+-----+
|(24,[0,1,2,3,4,5,...|    0|
|(24,[0,1,2,3,4,5,...|    0|
|(24,[0,1,2,3,4,5,...|    0|
|(24,[0,1,2,3,4,5,...|    0|
|(24,[0,1,2,3,4,5,...|    0|
|(24,[0,1,2,3,4,5,...|    1|
|(24,[0,1,2,3,4,5,...|    1|
|(24,[0,1,2,3,4,5,...|    0|
|(24,[0,1,2,3,4,5,...|    0|
|(24,[0,1,2,3,4,5,...|    0|
|(24,[0,1,2,3,4,5,...|    1|
|(24,[0,1,2,3,4,5,...|    1|
|(24,[0,1,2,3,4,5,...|    0|
|(24,[0,1,2,3,4,5,...|    0|
|(24,[0,1,2,3,4,5,...|    0|
|(24,[0,1,2,3,4,5,...|    0|
|(24,[0,1,2,3,4,5,...|    1|
|(24,[0,1,2,3,4,5,...|    0|
|(24,[0,1,2,3,4,5,...|    0|
|(24,[0,1,2,3,4,5,...|    0|
+--------------------+-----+
only showing top 20 rows



In [30]:
test.show()

+--------------------+-----+
|            features|label|
+--------------------+-----+
|(24,[0,1,2,3,4,5,...|    0|
|(24,[0,1,2,3,4,5,...|    0|
|(24,[0,1,2,3,4,5,...|    0|
|(24,[0,1,2,3,4,5,...|    0|
|(24,[0,1,2,3,4,5,...|    0|
|(24,[0,1,2,3,4,5,...|    0|
|(24,[0,1,2,3,4,5,...|    0|
|(24,[0,1,2,3,4,5,...|    0|
|(24,[0,1,2,3,4,5,...|    1|
|(24,[0,1,2,3,4,5,...|    0|
|(24,[0,1,2,3,4,5,...|    0|
|(24,[0,1,2,3,4,8,...|    0|
|(24,[0,1,2,3,4,8,...|    0|
|(24,[0,1,2,3,4,12...|    0|
|(24,[0,1,2,3,5,6,...|    0|
|(24,[0,1,2,3,5,6,...|    0|
|(24,[0,1,2,3,5,6,...|    0|
|(24,[0,1,2,3,5,6,...|    1|
|(24,[0,1,2,3,5,6,...|    1|
|(24,[0,1,2,3,5,6,...|    0|
+--------------------+-----+
only showing top 20 rows



In [31]:
train.count(), test.count()

(8356, 1992)

# Logistic Regression

In [32]:
lr = LogisticRegression().fit(train)

In [33]:
pred_lr = lr.transform(test)

In [34]:
pred_lr.show()

+--------------------+-----+--------------------+--------------------+----------+
|            features|label|       rawPrediction|         probability|prediction|
+--------------------+-----+--------------------+--------------------+----------+
|(24,[0,1,2,3,4,5,...|    0|[0.79030537977963...|[0.68789689778178...|       0.0|
|(24,[0,1,2,3,4,5,...|    0|[0.97169805769358...|[0.72545782729374...|       0.0|
|(24,[0,1,2,3,4,5,...|    0|[0.84487634423129...|[0.69949123545955...|       0.0|
|(24,[0,1,2,3,4,5,...|    0|[1.21615504810257...|[0.77138620159290...|       0.0|
|(24,[0,1,2,3,4,5,...|    0|[0.84459470045534...|[0.69943202969724...|       0.0|
|(24,[0,1,2,3,4,5,...|    0|[0.12555956766631...|[0.53134871782033...|       0.0|
|(24,[0,1,2,3,4,5,...|    0|[4.26288405643322...|[0.98611390727200...|       0.0|
|(24,[0,1,2,3,4,5,...|    0|[1.52765897351149...|[0.82166353586314...|       0.0|
|(24,[0,1,2,3,4,5,...|    1|[1.42098663640013...|[0.80549304327390...|       0.0|
|(24,[0,1,2,3,4,

### Metrics

In [35]:
resume_lr_train = lr.summary

In [73]:
def metrics(x):
    print('Accuracy: %.2f' % resume_lr_train.accuracy)
    print('Precision: %.2f' % resume_lr_train.precisionByLabel[1])
    print('Recall: %.2f' % resume_lr_train.recallByLabel[1])
    print('F1 Score: %.2f' % resume_lr_train.fMeasureByLabel()[1])

In [74]:
metrics(resume_lr_train)

Accuracy: 0.78
Precision: 0.77
Recall: 0.81
F1 Score: 0.79


### Confusion Matrix

In [37]:
tp = pred_lr.select('label', 'prediction').where((f.col('label')==1) & (f.col('prediction')==1)).count()
tn = pred_lr.select('label', 'prediction').where((f.col('label')==0) & (f.col('prediction')==0)).count()
fp = pred_lr.select('label', 'prediction').where((f.col('label')==0) & (f.col('prediction')==1)).count()
fn = pred_lr.select('label', 'prediction').where((f.col('label')==1) & (f.col('prediction')==0)).count()

In [38]:
tp, tn, fp, fn

(776, 765, 243, 208)

In [67]:
def conf_matrix(df_transform_modelo, normalize=False, percentage=True):
  tp = df_transform_modelo.select('label', 'prediction').where((f.col('label') == 1) & (f.col('prediction') == 1)).count()
  tn = df_transform_modelo.select('label', 'prediction').where((f.col('label') == 0) & (f.col('prediction') == 0)).count()
  fp = df_transform_modelo.select('label', 'prediction').where((f.col('label') == 0) & (f.col('prediction') == 1)).count()
  fn = df_transform_modelo.select('label', 'prediction').where((f.col('label') == 1) & (f.col('prediction') == 0)).count()
  
  valorP = 1
  valorN = 1

  if normalize:
    valorP = tp + fn
    valorN = fp + tn
  
  if percentage and normalize:
    valorP = valorP / 100
    valorN = valorN / 100

  print(' '*20, 'Previsto')
  print(' '*15, 'Churn', ' '*5 ,'Não-Churn')
  print(' '*4, 'Churn', ' '*6, int(tp/valorP), ' '*7, int(fn/valorP))
  print('Real')
  print(' '*4, 'Não-Churn', ' '*2, int(fp/valorN), ' '*7, int(tn/valorN), '\n')
  print(f'True Positive:{round((tp / (tp+fp))*100,2)}%')
  print(f'False Positive:{round((fp / (tp+fp))*100,2)}%')
  print(f'True Negative:{round((tn / (fn+tn))*100,2)}%')
  print(f'False Negative:{round((fn / (fn+tn))*100,2)}%')

In [68]:
conf_matrix(pred_lr)

                     Previsto
                Churn       Não-Churn
     Churn        776         208
Real
     Não-Churn    243         765 

True Positive:76.15%
False Positive:23.85%
True Negative:78.62%
False Negative:21.38%


# Decision Tree Classifier

In [76]:
tree = DecisionTreeClassifier(seed=seed).fit(train)
pred_tree = tree.transform(test)
pred_tree.show()

+--------------------+-----+--------------+--------------------+----------+
|            features|label| rawPrediction|         probability|prediction|
+--------------------+-----+--------------+--------------------+----------+
|(24,[0,1,2,3,4,5,...|    0|    [24.0,4.0]|[0.85714285714285...|       0.0|
|(24,[0,1,2,3,4,5,...|    0|    [24.0,4.0]|[0.85714285714285...|       0.0|
|(24,[0,1,2,3,4,5,...|    0| [269.0,239.0]|[0.52952755905511...|       0.0|
|(24,[0,1,2,3,4,5,...|    0|[2376.0,388.0]|[0.85962373371924...|       0.0|
|(24,[0,1,2,3,4,5,...|    0| [269.0,239.0]|[0.52952755905511...|       0.0|
|(24,[0,1,2,3,4,5,...|    0| [269.0,239.0]|[0.52952755905511...|       0.0|
|(24,[0,1,2,3,4,5,...|    0|[2376.0,388.0]|[0.85962373371924...|       0.0|
|(24,[0,1,2,3,4,5,...|    0|[2376.0,388.0]|[0.85962373371924...|       0.0|
|(24,[0,1,2,3,4,5,...|    1| [269.0,239.0]|[0.52952755905511...|       0.0|
|(24,[0,1,2,3,4,5,...|    0|[2376.0,388.0]|[0.85962373371924...|       0.0|
|(24,[0,1,2,

In [77]:
conf_matrix(pred_tree)

                     Previsto
                Churn       Não-Churn
     Churn        736         248
Real
     Não-Churn    174         834 

True Positive:80.88%
False Positive:19.12%
True Negative:77.08%
False Negative:22.92%


In [80]:
evaluator = MulticlassClassificationEvaluator()
evaluator.evaluate(pred_tree, {evaluator.metricName: 'accuracy'})

0.7881526104417671

In [86]:
def tree_metrics(x):
    print("Acurácia: %f" % evaluator.evaluate(x, {evaluator.metricName: "accuracy"}))
    print("Precisão: %f" % evaluator.evaluate(x, {evaluator.metricName: "precisionByLabel", evaluator.metricLabel: 1}))
    print("Recall: %f" % evaluator.evaluate(x, {evaluator.metricName: "recallByLabel", evaluator.metricLabel: 1}))
    print("F1: %f" % evaluator.evaluate(x, {evaluator.metricName: "fMeasureByLabel", evaluator.metricLabel: 1}))

In [87]:
tree_metrics(pred_tree)

Acurácia: 0.788153
Precisão: 0.808791
Recall: 0.747967
F1: 0.777191


# Random Forest

In [102]:
random = RandomForestClassifier(maxDepth=10, seed=seed).fit(train)
pred_random = random.transform(test)
pred_random.show()

+--------------------+-----+--------------------+--------------------+----------+
|            features|label|       rawPrediction|         probability|prediction|
+--------------------+-----+--------------------+--------------------+----------+
|(24,[0,1,2,3,4,5,...|    0|[14.1789095320942...|[0.70894547660471...|       0.0|
|(24,[0,1,2,3,4,5,...|    0|[12.3622643498605...|[0.61811321749302...|       0.0|
|(24,[0,1,2,3,4,5,...|    0|[11.6920597849830...|[0.58460298924915...|       0.0|
|(24,[0,1,2,3,4,5,...|    0|[13.5920578280655...|[0.67960289140327...|       0.0|
|(24,[0,1,2,3,4,5,...|    0|[10.7988361406302...|[0.53994180703151...|       0.0|
|(24,[0,1,2,3,4,5,...|    0|[10.5064451544935...|[0.52532225772467...|       0.0|
|(24,[0,1,2,3,4,5,...|    0|[16.7497603025320...|[0.83748801512660...|       0.0|
|(24,[0,1,2,3,4,5,...|    0|[18.1402617355559...|[0.90701308677779...|       0.0|
|(24,[0,1,2,3,4,5,...|    1|[11.4763164019559...|[0.57381582009779...|       0.0|
|(24,[0,1,2,3,4,

In [103]:
conf_matrix(pred_random)

                     Previsto
                Churn       Não-Churn
     Churn        830         154
Real
     Não-Churn    202         806 

True Positive:80.43%
False Positive:19.57%
True Negative:83.96%
False Negative:16.04%


In [104]:
evaluator = MulticlassClassificationEvaluator()
evaluator.evaluate(pred_tree, {evaluator.metricName: 'accuracy'})

0.7881526104417671

In [105]:
def tree_metrics(x):
    print("Acurácia: %f" % evaluator.evaluate(x, {evaluator.metricName: "accuracy"}))
    print("Precisão: %f" % evaluator.evaluate(x, {evaluator.metricName: "precisionByLabel", evaluator.metricLabel: 1}))
    print("Recall: %f" % evaluator.evaluate(x, {evaluator.metricName: "recallByLabel", evaluator.metricLabel: 1}))
    print("F1: %f" % evaluator.evaluate(x, {evaluator.metricName: "fMeasureByLabel", evaluator.metricLabel: 1}))

In [106]:
tree_metrics(pred_tree)

Acurácia: 0.788153
Precisão: 0.808791
Recall: 0.747967
F1: 0.777191


In [107]:
tree_metrics(pred_random)

Acurácia: 0.821285
Precisão: 0.804264
Recall: 0.843496
F1: 0.823413


# Cross Validation

In [114]:
random_forest = RandomForestClassifier(seed=seed)
grid_random = ParamGridBuilder()\
        .addGrid(random.maxDepth, [5, 10, 15, 20])\
        .addGrid(random.maxBins, [10, 20, 30, 40])\
        .addGrid(random.numTrees, [100, 200, 500])\
        .build()

In [115]:
random_cv = CrossValidator(
    estimator=random_forest,
    estimatorParamMaps=grid_random,
    evaluator=evaluator,
    numFolds=3,
    seed=seed)

In [116]:
random_cv = random_cv.fit(train)
previsoes_rfc_cv_teste = random_cv.transform(test)

In [117]:
print('Random Forest Classifier - Tuning')
print("="*40)
print("Dados de Teste")
print("="*40)
print("Matriz de Confusão")
print("-"*40)
conf_matrix(previsoes_rfc_cv_teste, normalize=False)
print("-"*40)
print("Métricas")
print("-"*40)
print("Acurácia: %f" % evaluator.evaluate(previsoes_rfc_cv_teste, {evaluator.metricName: "accuracy"}))
print("Precisão: %f" % evaluator.evaluate(previsoes_rfc_cv_teste, {evaluator.metricName: "precisionByLabel", evaluator.metricLabel: 1}))
print("Recall: %f" % evaluator.evaluate(previsoes_rfc_cv_teste, {evaluator.metricName: "recallByLabel", evaluator.metricLabel: 1}))
print("F1: %f" % evaluator.evaluate(previsoes_rfc_cv_teste, {evaluator.metricName: "fMeasureByLabel", evaluator.metricLabel: 1}))

Random Forest Classifier - Tuning
Dados de Teste
Matriz de Confusão
----------------------------------------
                     Previsto
                Churn       Não-Churn
     Churn        818         166
Real
     Não-Churn    267         741 

True Positive:75.39%
False Positive:24.61%
True Negative:81.7%
False Negative:18.3%
----------------------------------------
Métricas
----------------------------------------
Acurácia: 0.782631
Precisão: 0.753917
Recall: 0.831301
F1: 0.790720


In [118]:
melhor_modelo_rfc_cv = random_cv.bestModel
print(melhor_modelo_rfc_cv.getMaxDepth())
print(melhor_modelo_rfc_cv.getMaxBins())
print(melhor_modelo_rfc_cv.getNumTrees)

5
32
20


In [119]:
random_tunning = RandomForestClassifier(maxDepth=5, maxBins=32, numTrees=20, seed=seed).fit(dataset_prep)

In [142]:
new_customer = [{
 'id':123456,
 'Mais65anos':0,
 'MesesDeContrato':24,
 'MesesCobrados':0,
 'label':0,
 'Conjuge':1,
 'Dependentes':0,
 'TelefoneFixo':0,
 'MaisDeUmaLinhaTelefonica':0,
 'SegurancaOnline':1,
 'BackupOnline':0,
 'SeguroDispositivo':0,
 'SuporteTecnico':1,
 'TVaCabo':0,
 'StreamingFilmes':1,
 'ContaCorreio':0,
 'internet_dsl':1,
 'internet_fibra':0,
 'internet_nao':1,
 'tipo_contrat_mensal':0,
 'contrato_um_ano':1,
 'contrato_dois_anos':0,
 'pagamento_deb_conta':1,
 'pagamento_cartao_credito':0,
 'pagamento_boleto_eletr':0,
 'pagamento_boleto':0
}]

In [143]:
new_customer = spark.createDataFrame(new_customer)
new_customer.show()

+------------+-------+------------+-----------+----------+------------------------+-------------+---------------+---------------+-----------------+---------------+--------------+-------+------------+------------------+---------------+------+------------+--------------+------------+-----+----------------+----------------------+------------------------+-------------------+-------------------+
|BackupOnline|Conjuge|ContaCorreio|Dependentes|Mais65anos|MaisDeUmaLinhaTelefonica|MesesCobrados|MesesDeContrato|SegurancaOnline|SeguroDispositivo|StreamingFilmes|SuporteTecnico|TVaCabo|TelefoneFixo|contrato_dois_anos|contrato_um_ano|    id|internet_dsl|internet_fibra|internet_nao|label|pagamento_boleto|pagamento_boleto_eletr|pagamento_cartao_credito|pagamento_deb_conta|tipo_contrat_mensal|
+------------+-------+------------+-----------+----------+------------------------+-------------+---------------+---------------+-----------------+---------------+--------------+-------+------------+-------------

In [147]:
assembler = VectorAssembler(inputCols = X, outputCol = 'features')

In [148]:
new_customer_prep = assembler.transform(new_customer).select('features')
new_customer_prep.show(truncate=False)

+---------------------------------------------------------------------+
|features                                                             |
+---------------------------------------------------------------------+
|(24,[1,3,7,10,12,14,16,18,20],[24.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0])|
+---------------------------------------------------------------------+



In [157]:
pred = random_tunning.transform(new_customer_prep).show()

+--------------------+--------------------+--------------------+----------+
|            features|       rawPrediction|         probability|prediction|
+--------------------+--------------------+--------------------+----------+
|(24,[1,3,7,10,12,...|[16.0214514024386...|[0.80107257012193...|       0.0|
+--------------------+--------------------+--------------------+----------+

