In [1]:
# This file is intended to be used to predict 2 states. (For states > 2, please use multinomial regression code.)
# Independent Variables: lifetime
# Dependent Variable: States (>2, for binary-classification, refer to 2 state regression code.)

import pandas as pd
import statsmodels.formula.api as smf
import numpy as np
import statsmodels.api as sm

In [2]:
# Read the data Files
lifetime = pd.read_csv("/content/HEKcell_lifetime.csv", header = None)
states = pd.read_csv("/content/HEKcell_ACh_concentration.csv", header = None)

In [3]:
# Initialize combined dataset.
data = pd.DataFrame(columns=['lifetime', 'states'])

In [4]:
# Display datasets
display(lifetime)
display(states)

Unnamed: 0,0
0,3.325809
1,3.320660
2,3.279992
3,3.364916
4,3.368403
...,...
320,3.531065
321,3.529696
322,3.505948
323,3.517041


Unnamed: 0,0
0,0
1,0
2,0
3,0
4,0
...,...
320,10
321,10
322,10
323,10


In [5]:
# Feed the data from individual dataset to the 1 final dataset
data['lifetime'] = lifetime.iloc[:,[0]]
data['states'] = states.iloc[:,[0]]

# Convert the states to equally-spaced increments.
data.loc[data['states'] == 0, 'states'] = 0
data.loc[data['states'] == 1, 'states'] = 1
data.loc[data['states'] == 10, 'states'] = 2

In [6]:
# Display the formatted data
data

Unnamed: 0,lifetime,states
0,3.325809,0
1,3.320660,0
2,3.279992,0
3,3.364916,0
4,3.368403,0
...,...,...
320,3.531065,2
321,3.529696,2
322,3.505948,2
323,3.517041,2


In [8]:
# Format the dataset into a vector compliant with the MNLogit call
x = data['lifetime'].values.reshape(-1,1)
y = data['states']

In [13]:
# Stats Model
x = sm.add_constant(x, prepend = False)

# Refer: https://www.statsmodels.org/dev/generated/statsmodels.discrete.discrete_model.MNLogit.html
# x: lifetime data
# y: multiclass vector

# Defining the model
mnlogit_mod = sm.MNLogit(y, x)
# Fitting the model
mnlogit_fit = mnlogit_mod.fit()
# Making predictions on x
predict = mnlogit_fit.predict(x)

# Displaying model summary
print (mnlogit_fit.summary())

Optimization terminated successfully.
         Current function value: 0.480075
         Iterations 10
                          MNLogit Regression Results                          
Dep. Variable:                 states   No. Observations:                  325
Model:                        MNLogit   Df Residuals:                      321
Method:                           MLE   Df Model:                            2
Date:                Tue, 15 Nov 2022   Pseudo R-squ.:                  0.5623
Time:                        17:07:09   Log-Likelihood:                -156.02
converged:                       True   LL-Null:                       -356.43
Covariance Type:            nonrobust   LLR p-value:                 9.243e-88
  states=1       coef    std err          z      P>|z|      [0.025      0.975]
------------------------------------------------------------------------------
x1            83.5757     14.127      5.916      0.000      55.887     111.265
const       -285.3517     48

In [10]:
# Displaying predictions made.
prediction = np.asarray(predict).argmax(1)
prediction

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
       0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2,
       2, 1, 0, 0, 2, 2, 1, 1, 2, 2, 1, 1, 0, 2, 2, 1, 2, 2, 1, 1, 1, 1,
       2, 1, 2, 2, 2, 2, 1, 1, 1, 2, 1, 2, 1, 2, 2, 1, 2, 2, 1, 2, 2, 1,
       1, 2, 1, 2, 1, 1, 2, 0, 1, 1, 1, 1, 1, 2, 1, 1, 1, 0, 1, 2, 1, 1,
       1, 2, 2, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 1, 0, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 1, 1, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 1, 2, 1, 2, 1, 2, 2, 2, 2, 1,

In [11]:
# Generating confusion matrix, classification report and accuracy score of the model.
# Refer: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html 
from sklearn.metrics import (confusion_matrix, 
                           accuracy_score, classification_report)
  
# confusion matrix
cm = confusion_matrix(data['states'], prediction) 
print ("Confusion Matrix : \n", cm) 

# classification report
cr = classification_report(data['states'], prediction) 
print ("Classification Report : \n", cr) 
  
# accuracy score of the model
print('Test accuracy = ', accuracy_score(data['states'], prediction))

# You can see the output spread (actual/predicted 0,1,2; false positives)

Confusion Matrix : 
 [[108   4   0]
 [  7  53  39]
 [  1  19  94]]
Classification Report : 
               precision    recall  f1-score   support

           0       0.93      0.96      0.95       112
           1       0.70      0.54      0.61        99
           2       0.71      0.82      0.76       114

    accuracy                           0.78       325
   macro avg       0.78      0.77      0.77       325
weighted avg       0.78      0.78      0.78       325

Test accuracy =  0.7846153846153846


In [12]:
# Extracting pseudo r-squared value (McFadden’s pseudo R-squared value) 
# When comparing two models on the same data, McFadden’s would be higher for the model with the greater relation.
mnlogit_fit.summary()

0,1,2,3
Dep. Variable:,states,No. Observations:,325.0
Model:,MNLogit,Df Residuals:,321.0
Method:,MLE,Df Model:,2.0
Date:,"Tue, 15 Nov 2022",Pseudo R-squ.:,0.5623
Time:,17:05:26,Log-Likelihood:,-156.02
converged:,True,LL-Null:,-356.43
Covariance Type:,nonrobust,LLR p-value:,9.242999999999999e-88

states=1,coef,std err,z,P>|z|,[0.025,0.975]
x1,83.5757,14.127,5.916,0.000,55.887,111.265
const,-285.3517,48.128,-5.929,0.000,-379.682,-191.022
states=2,coef,std err,z,P>|z|,[0.025,0.975]
x1,112.0821,15.189,7.379,0.000,82.312,141.852
const,-384.9802,51.919,-7.415,0.000,-486.740,-283.221
