In [3]:
%%bigquery
-- GLM coefficients (by feature / category)
SELECT *
FROM ML.WEIGHTS(MODEL `telco_churn_ds.churn_logreg_model`)
ORDER BY ABS(weight) DESC;


Query is running:   0%|          |

Downloading:   0%|          |

Unnamed: 0,processed_input,weight,category_weights
0,Contract_mtm,0.3129,[]
1,Contract_2yr,-0.285286,[]
2,Pay_ElectronicCheck,0.242422,[]
3,Internet_Fiber,0.240957,[]
4,Internet_None,-0.21431,[]
5,Contract_1yr,-0.154033,[]
6,Paperless_1,0.13652,[]
7,Pay_CreditCard,-0.125848,[]
8,Pay_BankTransfer,-0.120168,[]
9,Internet_DSL,-0.106851,[]


In [4]:
%%bigquery
CREATE OR REPLACE MODEL
  `telco_churn_ds.churn_logreg_model`
OPTIONS (
  MODEL_TYPE = 'LOGISTIC_REG',
  INPUT_LABEL_COLS = ['label'],      -- your actual target col
  AUTO_CLASS_WEIGHTS = TRUE,         -- handles imbalance
  MAX_ITERATIONS = 30,
  LEARN_RATE_STRATEGY = 'CONSTANT',  -- required for LEARN_RATE
  LEARN_RATE = 0.1,
  ENABLE_GLOBAL_EXPLAIN = TRUE       -- to use ML.GLOBAL_EXPLAIN
) AS
SELECT
  * EXCEPT(customerID, split)        -- drop id + split
FROM `telco_churn_ds.v_features_with_split`
WHERE split = 'train';

Query is running:   0%|          |

In [5]:
%%bigquery
SELECT *
FROM ML.EVALUATE(
  MODEL `telco_churn_ds.churn_logreg_model`,
  (
    SELECT * EXCEPT(customerID, split)
    FROM `telco_churn_ds.v_features_with_split`
    WHERE split = 'test'
  )
);

Query is running:   0%|          |

Downloading:   0%|          |

Unnamed: 0,precision,recall,accuracy,f1_score,log_loss,roc_auc
0,0.496269,0.841772,0.713004,0.624413,0.586346,0.825208


In [6]:
%%bigquery df_weights
  SELECT * FROM ML.WEIGHTS(MODEL `telco_churn_ds.churn_logreg_model`)
  ORDER BY ABS(weight) DESC;

Query is running:   0%|          |

Downloading:   0%|          |

In [7]:
from google.cloud import bigquery
bq = bigquery.Client(project="infinite-mantra-480821-v7")

sql_weights = """
SELECT *
FROM ML.WEIGHTS(MODEL `telco_churn_ds.churn_logreg_model`)
ORDER BY ABS(weight) DESC;
"""
weights_df = bq.query(sql_weights).to_dataframe()

sql_global = """
SELECT *
FROM ML.GLOBAL_EXPLAIN(MODEL `telco_churn_ds.churn_logreg_model`);
"""
global_df = bq.query(sql_global).to_dataframe()

weights_df.head(), global_df.head()

(       processed_input    weight category_weights
 0         Contract_mtm  0.312900               []
 1         Contract_2yr -0.285286               []
 2  Pay_ElectronicCheck  0.242422               []
 3       Internet_Fiber  0.240957               []
 4        Internet_None -0.214310               [],
                feature  attribution
 0         Contract_mtm     0.155655
 1               tenure     0.123749
 2       Internet_Fiber     0.118320
 3  Pay_ElectronicCheck     0.106589
 4         Contract_2yr     0.105763)

In [11]:
%%bigquery
-- Find best thresholds on VALIDATION by F1 (derive precision & F1)
WITH roc AS (
  SELECT *
  FROM ML.ROC_CURVE(
    MODEL `telco_churn_ds.churn_logreg_model`,
    (
      SELECT * EXCEPT(customerID, split)
      FROM `telco_churn_ds.v_features_with_split`
      WHERE split = 'val'
    )
  )
),
metrics AS (
  SELECT
    threshold,
    recall,
    SAFE_DIVIDE(true_positives, true_positives + false_positives) AS precision,
    -- F1 = 2 * (P * R) / (P + R)
    2 * SAFE_DIVIDE(
          (SAFE_DIVIDE(true_positives, true_positives + false_positives) * recall),
          (SAFE_DIVIDE(true_positives, true_positives + false_positives) + recall)
        ) AS f1_score
  FROM roc
)
SELECT *
FROM metrics
ORDER BY f1_score DESC
LIMIT 10;

Query is running:   0%|          |

Downloading:   0%|          |

Unnamed: 0,threshold,recall,precision,f1_score
0,0.568826,0.768627,0.571429,0.655518
1,0.565041,0.780392,0.563739,0.654605
2,0.561337,0.792157,0.556474,0.653722
3,0.575185,0.752941,0.576577,0.653061
4,0.556383,0.803922,0.549598,0.652866
5,0.576894,0.733333,0.57716,0.645941
6,0.550571,0.803922,0.536649,0.643642
7,0.581574,0.713725,0.579618,0.639719
8,0.586766,0.698039,0.585526,0.636852
9,0.546263,0.807843,0.52551,0.636785


In [12]:
%%bigquery
-- Replace 0.34 with your chosen threshold
SELECT *
FROM ML.CONFUSION_MATRIX(
  MODEL `telco_churn_ds.churn_logreg_model`,
  (
    SELECT * EXCEPT(customerID, split)
    FROM `telco_churn_ds.v_features_with_split`
    WHERE split = 'test'
  ),
  STRUCT(0.34 AS threshold)
);

Query is running:   0%|          |

Downloading:   0%|          |

Unnamed: 0,expected_label,_0,_1
0,0,242,557
1,1,6,310
