## Multinomial Regression Demonstration

Using `regression-inference` package

In [1]:
from regression_inference import MultinomialLogisticRegression, summary

In [2]:
import numpy as np
import pandas as pd

In [3]:
#Â© Copyright 2007 - 2025, scikit-learn developers (BSD License).
from sklearn.datasets import load_wine
data = load_wine(as_frame = True).frame

In [4]:
# True: Use CUDA to train models on GPU
CUDA = True

### Model Fitting

- Fit the Multinomial Logistic Regression on the training set

In [5]:
data['const'] = np.ones(len(data))

features = data[[
    'const', 'target',
    'alcohol', 'malic_acid', 'ash',
    'alcalinity_of_ash', 'total_phenols',
]].dropna()

X = features.drop(columns=['target'])
y = features['target']

model = MultinomialLogisticRegression().fit(X=X, y=y, cuda = CUDA, cov_type=None, alpha=0.05)

CUDA Acceleration is Experimental
Device: NVIDIA GeForce RTX 3060

  model = MultinomialLogisticRegression().fit(X=X, y=y, cuda = CUDA, cov_type=None, alpha=0.05)


In [6]:
print(model)

Multinomial Regression Results
---------------------------------------------
Dependent:                             target
---------------------------------------------
Class:                                      1

const                             111.3940***
                                    (36.7062)
 
alcohol                            -6.7107***
                                     (2.1485)
 
malic_acid                          -2.3517**
                                     (0.9595)
 
ash                               -19.7606***
                                     (6.5256)
 
alcalinity_of_ash                   2.4221***
                                     (0.8217)
 
total_phenols                        -7.1129*
                                     (3.8594)
 
---------------------------------------------
Class:                                      2

const                                 54.9528
                                    (36.4229)
 
alcohol                          

### Model Predictions

- Predict a set of data

- Predict new values

- Predict with inference table

- Predict at sample mean

- Predict over range of specified values

In [7]:
model.feature_names[1:]

Index(['alcohol', 'malic_acid', 'ash', 'alcalinity_of_ash', 'total_phenols'], dtype='object')

In [8]:
# All predictions are in order of model.feature_names[1:]

model.predict( X = [[12.85, 1.6, 2.52, 17.8, 2.48]] )

array([[0.93456116, 0.05016003, 0.0152788 ]])

In [9]:
# Predict new values with inference

prediction = model.predict(X = [ [12.85, 1.6, 2.52, 17.8, 2.48] ], return_table = True )

pd.DataFrame(prediction)

Unnamed: 0,features,prediction_linear,prediction_class,prediction_prob,std_error,z_statistic,P>|z|,ci_low_0.05,ci_high_0.05
0,"{'alcohol': '12.85', 'malic_acid': '1.60', 'as...","[0.0, -2.9249, -4.1136]",0,"[0.9346, 0.0502, 0.0153]","[0.0774, 1.3106, 1.437]","[None, -2.2316, -2.8625]","[None, 0.026, 0.004]","[0.7829, 0.004, 0.0009]","[1.0, 0.408, 0.206]"


In [10]:
prediction_set = [
     [[12.85, 1.6, 2.52, 17.8, 2.48]],
     [[13.73, 1.5, 2.7, 22.5, 3]],
] 

predictions = pd.concat(
    [pd.DataFrame(model.predict(X = pred, return_table=True)) for pred in prediction_set]
)

predictions

Unnamed: 0,features,prediction_linear,prediction_class,prediction_prob,std_error,z_statistic,P>|z|,ci_low_0.05,ci_high_0.05
0,"{'alcohol': '12.85', 'malic_acid': '1.60', 'as...","[0.0, -2.9249, -4.1136]",0,"[0.9346, 0.0502, 0.0153]","[0.0774, 1.3106, 1.437]","[None, -2.2316, -2.8625]","[None, 0.026, 0.004]","[0.7829, 0.004, 0.0009]","[1.0, 0.408, 0.206]"
0,"{'alcohol': '13.73', 'malic_acid': '1.50', 'as...","[0.0, -4.4669, -4.5099]",0,"[0.978, 0.0112, 0.0108]","[0.0375, 1.8937, 1.7745]","[None, -2.3588, -2.5415]","[None, 0.018, 0.011]","[0.9045, 0.0003, 0.0003]","[1.0, 0.3173, 0.2605]"


In [11]:
# Predict at the sample mean
  
sample_mean = (
    [X[i].mean() for i in list(model.feature_names[1:])] # Preserves ordering
) 

prediction_set = [[sample_mean]] 

predictions = pd.concat(
    [pd.DataFrame(model.predict(X = pred, return_table=True)) for pred in prediction_set]
)

predictions

Unnamed: 0,features,prediction_linear,prediction_class,prediction_prob,std_error,z_statistic,P>|z|,ci_low_0.05,ci_high_0.05
0,"{'alcohol': '13.00', 'malic_acid': '2.34', 'as...","[0.0, 2.786, 2.131]",1,"[0.039, 0.6325, 0.3285]","[0.0599, 1.6018, 1.6716]","[None, 1.7393, 1.2748]","[None, 0.082, 0.202]","[0.0, 0.0694, 0.0181]","[0.1565, 0.9755, 0.9283]"


In [12]:
'''
Predict increments of 'ash' holding all else at the sample mean

Maintain order of lm.feature_names[1:], ie, ['alcohol', 'malic_acid', 'ash', 'alcalinity_of_ash', 'total_phenols']
'''

prev_names, post_names = ['alcohol', 'malic_acid'], ['alcalinity_of_ash', 'total_phenols']

mean_prev, mean_post = [X[i].mean() for i in prev_names], [X[i].mean() for i in post_names]


prediction_range = np.linspace(
    X['ash'].min(),
    X['ash'].max(),
    20                          # Number of predictions 
)

prediction_set = [
    [ mean_prev + [i] + mean_post]
    for i in prediction_range  
] 

predictions = pd.concat(
    [pd.DataFrame(model.predict(X = pred, return_table=True)) for pred in prediction_set]
)

predictions.tail()

Unnamed: 0,features,prediction_linear,prediction_class,prediction_prob,std_error,z_statistic,P>|z|,ci_low_0.05,ci_high_0.05
0,"{'alcohol': '13.00', 'malic_acid': '2.34', 'as...","[0.0, -6.4975, -3.8841]",0,"[0.9784, 0.0015, 0.0201]","[0.0421, 2.1866, 2.0052]","[None, -2.9716, -1.937]","[None, 0.003, 0.053]","[0.8959, 0.0, 0.0004]","[1.0, 0.0969, 0.5111]"
0,"{'alcohol': '13.00', 'malic_acid': '2.34', 'as...","[0.0, -8.4424, -5.1442]",0,"[0.994, 0.0002, 0.0058]","[0.0152, 2.7616, 2.553]","[None, -3.057, -2.0149]","[None, 0.002, 0.044]","[0.9642, 0.0, 0.0]","[1.0, 0.0458, 0.4649]"
0,"{'alcohol': '13.00', 'malic_acid': '2.34', 'as...","[0.0, -10.3872, -6.4043]",0,"[0.9983, 0.0, 0.0017]","[0.0052, 3.3611, 3.1291]","[None, -3.0904, -2.0467]","[None, 0.002, 0.041]","[0.988, 0.0, 0.0]","[1.0, 0.0219, 0.4325]"
0,"{'alcohol': '13.00', 'malic_acid': '2.34', 'as...","[0.0, -12.3321, -7.6645]",0,"[0.9995, 0.0, 0.0005]","[0.0018, 3.974, 3.7202]","[None, -3.1032, -2.0602]","[None, 0.002, 0.039]","[0.9961, 0.0, 0.0]","[1.0, 0.0105, 0.4078]"
0,"{'alcohol': '13.00', 'malic_acid': '2.34', 'as...","[0.0, -14.2769, -8.9246]",0,"[0.9999, 0.0, 0.0001]","[0.0006, 4.5949, 4.3204]","[None, -3.1071, -2.0657]","[None, 0.002, 0.039]","[0.9987, 0.0, 0.0]","[1.0, 0.0051, 0.3877]"


### Coefficient Inference Table

- Comprehensive regression inference

In [13]:
pd.DataFrame(model.inference_table())

Unnamed: 0,feature,class,coefficient,std_error,z_statistic,P>|z|,ci_low_0.05,ci_high_0.05
0,const,Class_1.0,111.394,36.7062,3.0347,0.002,39.4512,183.3368
1,alcohol,Class_1.0,-6.7107,2.1485,-3.1235,0.002,-10.9216,-2.4998
2,malic_acid,Class_1.0,-2.3517,0.9595,-2.4511,0.014,-4.2322,-0.4712
3,ash,Class_1.0,-19.7606,6.5256,-3.0282,0.002,-32.5504,-6.9707
4,alcalinity_of_ash,Class_1.0,2.4221,0.8217,2.9475,0.003,0.8115,4.0327
5,total_phenols,Class_1.0,-7.1129,3.8594,-1.843,0.065,-14.6772,0.4513
6,const,Class_2.0,54.9528,36.4229,1.5087,0.131,-16.4348,126.3403
7,alcohol,Class_2.0,-2.6855,2.1389,-1.2556,0.209,-6.8776,1.5066
8,malic_acid,Class_2.0,-1.3694,0.9514,-1.4394,0.15,-3.2341,0.4953
9,ash,Class_2.0,-12.8035,6.329,-2.023,0.043,-25.2082,-0.3988


### Variance Inflation Factor

- Generate a VIF table on the models features

In [14]:
pd.DataFrame(model.variance_inflation_factor())

Unnamed: 0,feature,VIF
0,alcohol,1.4244
1,malic_acid,1.2727
2,ash,1.6087
3,alcalinity_of_ash,1.8769
4,total_phenols,1.3988


### Robust Covariance (Apply on Fit)

- Apply a robust covariance to a model during fit

- Subsequent predictions will be made with the robust covariance

In [15]:
robust_hc0 = MultinomialLogisticRegression().fit(X=X, y=y, cov_type="HC0", alpha=0.05, target_name="targetHC0")
robust_hc1 = MultinomialLogisticRegression().fit(X=X, y=y, cov_type="HC1", alpha=0.05, target_name="targetHC1")

In [16]:
# Compare to the nonrobust model

print(summary(model, robust_hc0, robust_hc1))

Multinomial Regression Results
---------------------------------------------------------------------------
Dependent:                             target      targetHC0      targetHC1
---------------------------------------------------------------------------
Class:                                      1

const                             111.3940***    111.3940***    111.3940***
                                    (36.7062)      (32.1213)      (33.2620)
 
alcohol                            -6.7107***     -6.7107***     -6.7107***
                                     (2.1485)       (2.0159)       (2.0875)
 
malic_acid                          -2.3517**     -2.3517***     -2.3517***
                                     (0.9595)       (0.6121)       (0.6339)
 
ash                               -19.7606***    -19.7606***    -19.7606***
                                     (6.5256)       (6.4799)       (6.7100)
 
alcalinity_of_ash                   2.4221***      2.4221***      2.4221***
  