# Surgical Classification: Complication prediction


We are given a dataset which has 25 columns, and 14636 rows. The first row is a header, and the last column is the target. will be using Tablesaw to easily create a table out of the CSV dataset. The goal is to learn from this data, and be able to make a prediction based on pre-existing conditions, the chance of a patient having a complication.

# Implementing Tablesaw for graphing



We now have the header in it's own Array, and all of the content in a 2d array. We are able to examine any of the data we want, being able to match each column with the proper header. Eventually we will need to randomize the order of the rows, but for now it's not important.

The next two cells implement Tablesaw graph into the notebook.
I did not do this code myself, I found somebody who gave this code on stack overflow. The post can be found here: https://stackoverflow.com/questions/54654434/how-to-embed-tablesaw-graph-in-jupyter-notebook-with-ijava-kernel

In [1]:
%%loadFromPOM
<dependency>
    <groupId>tech.tablesaw</groupId>
    <artifactId>tablesaw-jsplot</artifactId>
    <version>0.38.0</version>
</dependency>

In [2]:
import io.github.spencerpark.ijava.IJava;

IJava.getKernelInstance().getRenderer()
    .createRegistration(tech.tablesaw.plotly.components.Figure.class)
    .preferring(io.github.spencerpark.jupyter.kernel.display.mime.MIMEType.TEXT_HTML)
    .register((figure, ctx) -> {
        ctx.renderIfRequested(io.github.spencerpark.jupyter.kernel.display.mime.MIMEType.TEXT_HTML, () -> {
            String id = UUID.randomUUID().toString().replace("-", "");

            figure.asJavascript(id);
            Map<String, Object> context = figure.getContext();

            StringBuilder html = new StringBuilder();
            html.append("<div id=\"").append(id).append("\"></div>\n");
            html.append("<script>require(['https://cdn.plot.ly/plotly-1.44.4.min.js'], Plotly => {\n");
            html.append("var target_").append(id).append(" = document.getElementById('").append(id).append("');\n");
            html.append(context.get("figure")).append('\n');
            html.append(context.get("plotFunction")).append('\n');
            html.append("})</script>\n");
            return html.toString();
        });
    });

In [3]:
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.io.BufferedReader;
import java.lang.*;
import tech.tablesaw.api.*;
import tech.tablesaw.plotly.api.*;
import tech.tablesaw.plotly.components.*;


String csv = "/Users/surav/notebook/Java ML Project/Surgical-deepnet.csv"; //CSV file directory
Table Data = Table.read().csv(csv); //Create a table of the CSV

//An error may pop up, but as far as I can tell everything still works even when it does.
//Running the cells 1 at a time makes the error go away for me.

SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.


# Understanding the Data


Taking a closer look at the data we're working with will help us when graphing it. Right now we want to know exactly how much data we have, and what types of data our attributes are being stored as. Remember, the target is in the last column.

In [4]:
Data.shape(); //Show number of rows and columns in the dataset.

14635 rows X 25 cols

In [5]:
Data.structure(); /*Show info of the dataset, column names and types.
10           baseline_psych                     INTEGER
11           baseline_pulmonary                 INTEGER
12           ahrq_ccs                           INTEGER
13           ccsComplicationRate                DOUBLE
14           ccsMort30Rate                      DOUBLE
The above columns do not show up in the table below.
*/

       Structure of Surgical-deepnet.csv        
 Index  |     Column Name      |  Column Type  |
------------------------------------------------
     0  |                 bmi  |       DOUBLE  |
     1  |                 Age  |       DOUBLE  |
     2  |          asa_status  |      INTEGER  |
     3  |     baseline_cancer  |      INTEGER  |
     4  |   baseline_charlson  |      INTEGER  |
     5  |        baseline_cvd  |      INTEGER  |
     6  |   baseline_dementia  |      INTEGER  |
     7  |   baseline_diabetes  |      INTEGER  |
     8  |  baseline_digestive  |      INTEGER  |
     9  |   baseline_osteoart  |      INTEGER  |
   ...  |                 ...  |          ...  |
    15  |    complication_rsi  |       DOUBLE  |
    16  |                 dow  |      INTEGER  |
    17  |              gender  |      INTEGER  |
    18  |                hour  |       DOUBLE  |
    19  |               month  |      INTEGER  |
    20  |           moonphase  |      INTEGER  |


# Complication graph


When graphing, we'll look at percentages and averages instead of actual numbers. It's more relevant to us if 70% of patients with a certain condition experience a complication, than knowing that 150 patients with the condition experience a complication.

# Easily changable graphs



Below is an easy way to check the complication rate for any other integer in the dataset. Changing the condition to any name of an integer column in the table allows us to examine the percentage of each category which experiences a complication in surgery. This can help us determine which columns are good to prune away, specifically if the complication rate is extremely close for every number that appears. We're able to get a very good understanding of the integers in this dataset with these graphs.

Looking at doubles is a bit trickier. I couldn't find a way to give a more accurate representation of the data, so for doubles we are looking at the average value for the column when the complication is 1 and when it is 0. This does mean that 0 will not have an impact on the graph, so if it is a common digit in the column, it risks being overlooked entirely.

Before looking at graphs, lets change bmi, Age and Hour to integers. We will lose some significance from bmi, but most of the impact comes from the integer number even when it's a double.

In [6]:
import static tech.tablesaw.aggregate.AggregateFunctions.sum;
import static tech.tablesaw.aggregate.AggregateFunctions.mean;
import tech.tablesaw.api.IntColumn;
import tech.tablesaw.api.DoubleColumn;

//Changing some double columns to int columns

IntColumn temp1 = Data.doubleColumn("bmi").asIntColumn(); //bmi column as an int column
IntColumn temp2 = Data.doubleColumn("Age").asIntColumn(); //Age column as an int column
IntColumn temp3 = Data.doubleColumn("Hour").asIntColumn(); //Hour column as an int column

Data.removeColumns("bmi", "Age", "Hour"); //Remove bmi Age and Hour from the table
Data.insertColumn(0, temp1); //insert temp1 column to the table in index 0
Data.insertColumn(1, temp2); //insert temp2 column to the table in index 1
Data.insertColumn(18, temp3); //insert temp3 column to the table in index 18
//The structure should maintain it's original shape, but the Age and Hour columns are now Integer types.

Data.structure();

       Structure of Surgical-deepnet.csv        
 Index  |     Column Name      |  Column Type  |
------------------------------------------------
     0  |                 bmi  |      INTEGER  |
     1  |                 Age  |      INTEGER  |
     2  |          asa_status  |      INTEGER  |
     3  |     baseline_cancer  |      INTEGER  |
     4  |   baseline_charlson  |      INTEGER  |
     5  |        baseline_cvd  |      INTEGER  |
     6  |   baseline_dementia  |      INTEGER  |
     7  |   baseline_diabetes  |      INTEGER  |
     8  |  baseline_digestive  |      INTEGER  |
     9  |   baseline_osteoart  |      INTEGER  |
   ...  |                 ...  |          ...  |
    15  |    complication_rsi  |       DOUBLE  |
    16  |                 dow  |      INTEGER  |
    17  |              gender  |      INTEGER  |
    18  |                hour  |      INTEGER  |
    19  |               month  |      INTEGER  |
    20  |           moonphase  |      INTEGER  |


# Duplication in the Dataset

I noticed an error in the data. There are a lot of rows with identical information. So before looking at the data in graphs, we need to clean this up. It's extremely unlikely to have the same exact conditions multiple times with this many attributes, so we'll be removing all duplicate rows. This risks removing actual data, but there were well over 100 duplicate rows that I found by hand, so it should make our result better. The result of this error can be seen in the following graph, where people aged 90 are very unlikely to have a complication.

In [7]:
String condition1 = "Age"; //using string as what we're testing against the complication data
//The input string must point to an integer type column.
String target = "complication"; //this is what we're comparing against. by default is the target.

Table Error = Data.summarize(target, mean).by("Age"); //table of complication and condition

Figure figure = VerticalBarPlot.create("% " + target + " by Age", //plot title
    Error,                  //table
    "Age",                  //grouping column name
    "mean [" + target + "]"); //numeric column name

//removing duplicate rows
Data = Data.dropDuplicateRows(); //sets the table to the table without duplicates


//showing the graph from before the rows were duplicated
display(figure);

Data.shape();

11733 rows X 25 cols

In [8]:
//This graph shows the data without the duplicates. nearly 3000 rows were removed.
//The string "condition" can be changed to any integer column, and will display a graph and table.

String condition1 = "Age"; //using string as what we're testing against the complication data
//The input string must point to an integer type column.
String target = "complication";//this is what we're comparing against. by default is the target.

Table complications1 = Data.summarize(target, mean).by(condition1); //table of complication and condition

Figure figure = VerticalBarPlot.create("% " + target + " by " + condition1, //plot title
    complications1,         //table
    condition1,              //grouping column name
    "mean [" + target + "]"); //numeric column name

display(figure); //displays a graph of the table
complications1.summary(); //gives information for the table

               Surgical-deepnet.csv summary                
 Summary   |         Age          |  Mean [complication]  |
-----------------------------------------------------------
    Count  |                  80  |                   80  |
      sum  |                4030  |   51.385688043736565  |
     Mean  |  50.374999999999986  |   0.6423211005467073  |
      Min  |                   6  |  0.09689922480620156  |
      Max  |                  90  |                    1  |
    Range  |                  84  |   0.9031007751937985  |
 Variance  |    550.363924050633  |  0.12568702395083914  |
 Std. Dev  |   23.45983640289576  |  0.35452365781543993  |

In [9]:
String condition2 = "mortality_rsi"; //using string as what we're testing against the complication data
//The input string can point to a double type column.
String target = "complication"; //this is what we're comparing against. by default is the target.

Table complications2 = Data.summarize(condition2, mean).by(target); //table of condition2 and complication

Figure figure = VerticalBarPlot.create("Average " + condition2 + " of "+ target + " status", //plot title
    complications2,                  //table
    target,                          //grouping column name
    "mean [" + condition2 + "]");    //numeric column name

display(figure); //Display the graph of the table
complications2.summary(); //Displays information for the table
//complication column will always be the same

                Surgical-deepnet.csv summary                
 Summary   |     complication     |  Mean [mortality_rsi]  |
------------------------------------------------------------
    Count  |                   2  |                     2  |
      sum  |                   1  |    -0.767170941959326  |
     Mean  |                 0.5  |    -0.383585470979663  |
      Min  |                   0  |   -0.6627427576774825  |
      Max  |                   1  |  -0.10442818428184356  |
    Range  |                   1  |     0.558314573395639  |
 Variance  |                 0.5  |   0.15585758143297715  |
 Std. Dev  |  0.7071067811865476  |    0.3947880208833307  |

# Data Management



To start off we want to look at integer attributes, where there are multiple values of the same attribute which have similar impact on the complication for a patient. We need to set a threshold, in this case we'll be using +/-2.5%, which is close enough to justify merging the different values. We will combine the values into the same integer if they are within the threshold, and then move on to the next attribute. 

After we've consolodated the data for the attributes, we need to again make minor modifications to the stored data. If there are only 2 integers remaining, we should change those to +1 and -1, regardless of what their values were before. If there are more than two numbers remaining we need to reexamine the attribute. If the numbers were sequential, we need to reorder them in accordance to how they were before (for example if the numbers were 1-10 but now are only 1, 2 and 6, they become 1, 2, 3). If they were not sequential and instead represent something which cannot be boiled down further, we may leave the values in their new state.

If there are any integer attributes with only 1 digit after this process, they need to be removed from the dataset, and be declared useless for determining the complication for a patient. Similarly we can prune double attributes if the average for the complication status is within the same threshold of +/-2.5%.

One important thing to remember is that if we wanted to find predictions for new data, any modifications we make to the dataset before training will need to be done to that data.

In [10]:
int size = Data.columnCount(); //sets size = # of columns

String[] Name = new String[size]; // Names of the columns
String[] Type = new String[size]; // Types of the columns
for(int i = 0; i < Name.length; i++){
    Name[i] = Data.column(i).toString(); //Getting the info of the column in position i of the array
    int startIndex = Name[i].indexOf(" "); //first space is after type
    Type[i] = Name[i].substring(0, startIndex);
    startIndex = Name[i].indexOf(":"); //: is after "type column"
    Name[i] = Name[i].substring(startIndex + 2, Name[i].length()); //start index +2 because there is a space after the :
    //removing everything before the column name
    System.out.println(Name[i] + "; " + Type[i]);
}


//Combine, consolidate, remove.
//Change any binary integer columns (0, 1) to instead be (-1, +1). Including target. 

bmi; Integer
Age; Integer
asa_status; Integer
baseline_cancer; Integer
baseline_charlson; Integer
baseline_cvd; Integer
baseline_dementia; Integer
baseline_diabetes; Integer
baseline_digestive; Integer
baseline_osteoart; Integer
baseline_psych; Integer
baseline_pulmonary; Integer
ahrq_ccs; Integer
ccsComplicationRate; Double
ccsMort30Rate; Double
complication_rsi; Double
dow; Integer
gender; Integer
hour; Integer
month; Integer
moonphase; Integer
mort30; Integer
mortality_rsi; Double
race; Integer
complication; Integer


In [11]:
double threshold = 0.025; //% considered to be statistically similar to another source
//Too high will result in every column being removed, too low will result in no changes to the dataset.
Table NewData = Data.copy(); //creating a new table which we will make the modifications on.



//This for loop will go through the dataset and automatically prune the data for us. Based on the double value 
//given above, it will combine numbers which are close enough together, and remove columns which do not add 
//significant information for the target.
for(int i = 0; i < Name.length - 1; i++){ //requires target to be last attribute (it is in our case)
    if(Type[i].charAt(0) == 'I'){ //first character is only I if it's an integer type column.
        Table tempTable = Data.summarize("complication", mean).by(Name[i]); //shows % target for each value
        int rows = tempTable.rowCount();
        double[] percentTarget = new double[rows];
        int[] condition = new int[rows];
        for(int j = 0; j < rows; j++){ //filling arrays with values from the created table
            percentTarget[j] = tempTable.doubleColumn(1).get(j); //set array equal to percent values
            condition[j] = tempTable.intColumn(0).get(j); //set array equal to the int values
        }
        double tempD = 0; //temp double for rearranging array
        int tempI = 0; //temp int for rearranging array
        for(int j = 0; j < rows-1; j++){ //sorting arrays based on target %. Slow, but sorts both arrays based on one.
            for(int k = j+1; k < rows; k++){ //This is just bubble sort, could be improved in the future.
                if(percentTarget[j] > percentTarget[k]){ //sorts smallest to largest
                    tempD = percentTarget[j];
                    tempI = condition[j];
                    percentTarget[j] = percentTarget[k];
                    condition[j] = condition[k];
                    percentTarget[k] = tempD;
                    condition[k] = tempI;
                }
            }
        }
        for(int j = rows-1; j > 0; j--){ //replacing original column values if it's within closeness threshold
            tempD = percentTarget[j] - threshold;
            for(int k = j-1; k >= 0; k--){
                if(tempD < percentTarget[k] && condition[k] != condition[j]){ //also break if conditions are equal
                    // they should only be equal if the row was considered similar to a previous row. 
                    
                    NewData.intColumn(i).set(NewData.intColumn(i).isEqualTo(condition[k]), condition[j]);
                    //replace column values
                    condition[k] = condition[j];
                    percentTarget[k] = percentTarget[j]; //changes probability to that in j. Prevents a case where
                    //the entire column gradually changes slowly enough to where every number would be considered
                    //similar to eachother. In that same case there would instead be much more values in the column.
                }
                else //if tempD is bigger break, since the array is sorted nothing else can be similar
                    break;
            }
        }
    }
    else{//else always catches double columns.
    Table tempTable = NewData.summarize(Name[i], mean).by("complication"); //shows average value for 0 and 1 on target
    }
}










//Lets take a look at what this did to our age graph.





String condition = "Age"; //using string as what we're testing against the complication data
//The input string must point to an integer type column.

Table complications3 = NewData.summarize(target, mean).by(condition); //table of complication and condition

VerticalBarPlot.create("% " + target + " by " + condition, //plot title
    complications3,         //table
    condition,              //grouping column name
    "mean [" + target + "]"); //numeric column name


//Now only 20 groups remain

# Further changes



Now that we've removed some numbers from the original data, we need to do one final change to the remaining numbers. Integer columns which contain only two digits should be changed to -1 and +1, and integer columns with more should be placed in order from highest to lowest %, then put half as negative values and half as positive values. Integer columns with only 1 value should be outright removed.

We will make no changes for double columns.

In [12]:
Table Modified = NewData.copy();

for(int i = Name.length - 1; i >= 0; i--){ //Starting at the end because columns are going to be removed.
    if(Type[i].charAt(0) == 'I'){ //first character is only I if it's an integer type column.
        Table tempTable = NewData.summarize("complication", mean).by(Name[i]); //shows % target for each value
        int rows = tempTable.rowCount(); //if there is only 1 row, the attribute is not significant. Drop the column.
        if(rows == 1){
            Modified.removeColumns(i);
        }
        else if(rows == 2){//no need to sort again if there are only two rows.
            int store1 = tempTable.intColumn(0).get(0); //one of two values in the column
            int store2 = tempTable.intColumn(0).get(1); //one of two values in the column
            
            if(store1 == -1 || store2 == -1){//if either value is already -1
                if(store1 == -1){ //store1 is already -1
                    Modified.intColumn(i).set(Modified.intColumn(i).isEqualTo(store2), 1);
                }
                else{ //store2 is already -1
                    Modified.intColumn(i).set(Modified.intColumn(i).isEqualTo(store1), 1);
                }
            }
            else{ //if neither value is -1 already
                Modified.intColumn(i).set(Modified.intColumn(i).isEqualTo(store1), -1);
                Modified.intColumn(i).set(Modified.intColumn(i).isEqualTo(store2), 1);
            }
        }
        else{ //3 or more values in the column
            double[] percentTarget = new double[rows];
            int[] condition = new int[rows];
            for(int j = 0; j < rows; j++){ //filling arrays with values from the created table
                percentTarget[j] = tempTable.doubleColumn(1).get(j); //set array equal to percent values
                condition[j] = tempTable.intColumn(0).get(j); //set array equal to the int values
            }//arrays now have 
            double tempD = 0; //temp double for rearranging array
            int tempI = 0; //temp int for rearranging array
            for(int j = 0; j < rows-1; j++){ //sorting arrays based on target %. Slow, but sorts both arrays based on one.
                for(int k = j+1; k < rows; k++){ //This is just bubble sort, could be improved in the future.
                    if(percentTarget[j] > percentTarget[k]){ //sorts smallest to largest
                        tempD = percentTarget[j];
                        tempI = condition[j];
                        percentTarget[j] = percentTarget[k];
                        condition[j] = condition[k];
                        percentTarget[k] = tempD;
                        condition[k] = tempI;
                    }//The data is now sorted from lowest to highest %, attribute values are jumbled.
                }
            }//only columns with 3 or more remaining attributes arrive here.
            int half = rows / 2; //integer division, rounds down
            boolean crossover = false; //if there is crossover in the allocation set to 1.
            int largest = 0; //will be the index of the largest number if needed.
            int count = 0; //count how many unique negatives there are in the column
            int[] newAllocation = new int[rows]; //new allocation value for the attributes
            int[] crossoverList = new int[rows]; //will hold the index of the crossover cases
            
            for(int j = 0; j < rows; j++){ //new allocations will be approximately evently distrubuted across
                //the negative. 0 is not a taken value.
                if(j < half){
                    newAllocation[j] = j - half;
                }
                else{ //+1 so we skip 0 being taken
                    newAllocation[j] = j - half + 1;
                }
            }
            for(int j = 0; j < rows; j++){
                for(int k = 0; k < rows; k++){
                    if(condition[j] == newAllocation[k]){
                        crossover = true;
                        crossoverList[j] = 1;
                        break;
                    }
                }
                if(condition[j] > condition[largest]){
                    largest = j; //largest holds the index for the largest number
                }
            }
            largest = condition[largest]; //make largest hold the largest value
            
            if(crossover){ //if some new allocations are already filled change the value in the column
                //and in the array
                for(int j = 0; j < rows; j++){
                    if(crossoverList[j] == 1){
                        largest++; //increment largest so we can use it
                        Modified.intColumn(i).set(Modified.intColumn(i).isEqualTo(condition[j]), largest);
                        condition[j] = largest;
                        crossoverList[j] = 0;
                        //replacing crossover with new large value. This is temporary.
                    }
                }
            }
            //if no new allocations are already filled. Should be no crossover left, if there was any.
            for(int j = 0; j < rows; j++){
                Modified.intColumn(i).set(Modified.intColumn(i).isEqualTo(condition[j]), newAllocation[j]);
                //set the values to their new allocations.
            }
        }//end of 3 or more values if statement
    }
    else{//else always catches double columns. We won't modify these at all, the below instruction runs, but does not
        //modify the table at all. 
        Table tempTable = Modified.summarize(Name[i], mean).by("complication"); //shows average value for 0 and 1 on target
    }
}
Modified.structure(); //A look at the new 

       Structure of Surgical-deepnet.csv       
 Index  |     Column Name     |  Column Type  |
-----------------------------------------------
     0  |                bmi  |      INTEGER  |
     1  |                Age  |      INTEGER  |
     2  |         asa_status  |      INTEGER  |
     3  |    baseline_cancer  |      INTEGER  |
     4  |  baseline_charlson  |      INTEGER  |
     5  |  baseline_dementia  |      INTEGER  |
     6  |  baseline_diabetes  |      INTEGER  |
     7  |  baseline_osteoart  |      INTEGER  |
     8  |     baseline_psych  |      INTEGER  |
     9  |           ahrq_ccs  |      INTEGER  |
   ...  |                ...  |          ...  |
    12  |   complication_rsi  |       DOUBLE  |
    13  |                dow  |      INTEGER  |
    14  |             gender  |      INTEGER  |
    15  |               hour  |      INTEGER  |
    16  |              month  |      INTEGER  |
    17  |          moonphase  |      INTEGER  |
    18  |           

In [13]:
//Split the data into train and test sets. 


Table[] splits = Modified.sampleSplit(.8); //splits the NewData into an 80/20 split
Table train = splits[0]; //takes the 80% split
Table test = splits[1];  //takes the 20% split
train.shape();

9386 rows X 22 cols

In [14]:
//Pocket algorithm

double[] weight = new double[train.columnCount()]; //weights for attributes + 1 for bias
double[] bestWeight = new double[train.columnCount()]; //weights for attributes + 1 for bias
int inError = 0; //in error 
int bestInError = 99999999; //best in error. initialized high so it's instantly replaced.
int itterations = 0;
int randError = 0; //will hold the row which will help us learn for weights
double check = 0;
int trainSize = train.rowCount(); //how many rows are in the training set
int[] errorArray = new int[trainSize]; //tracks what rows have had errors.
int attrLoop = train.columnCount() - 1;
double target = 0;

//The following variables can be changed at will
int loopThreshold = 1500; //number of loops we'll do maximum.
int threshold = 1000; //exit if less than this many errors
//should be at least smaller than the number of rows you have in your dataset or it does nothing
double bias = 1; //set bais to 1


while(true){ //goes forever, only exits from break;
    for(int i = 0; i < trainSize; i++){
        check = 0;
        for(int j = 0; j < attrLoop; j++){ //-1 so we don't do the target as well.
            if(j == attrLoop - 1){
                check += weight[j] * bias;
            }
            else{
                target = Double.parseDouble(train.get(i, j).toString());
                check += weight[j] * target;
            }
        }
        target = Double.parseDouble(train.get(i, train.columnCount() - 1).toString());
        //target is the value of the target for that row
        if(check >= 0 && (int) target == -1){ //calculated target != target
            inError += 1;
            errorArray[i] = 1;
        }
        else if(check < 0 && (int) target == 1){//calculated target != target
            inError += 1;
            errorArray[i] = 1;
        }
    }
    if(inError < bestInError){ //if we have a new best in sample error update best weights and best inError
        for(int i = 0; i < weight.length; i++){
            bestWeight[i] = weight[i]; //update best weights
        }
        bestInError = inError;
    }
    if(bestInError < threshold){ //if we're happy with the error, break from the loop.
        break;
    }
    else if(itterations >= loopThreshold){ //break if we do too many itterations
        break;
    }
    else{ //if we're continuing the loop, update the weights.
        while(true){
            Random rand = new Random();
            randError = rand.nextInt(trainSize);
            if(errorArray[randError] == 1) //choosing a random error which modifies weights.
                break;
        }
        //System.out.println(randError + " " + itterations + " " + inError);
        for(int i = 0; i < weight.length; i++){
            if(i < weight.length - 1){
                weight[i] = bestWeight[i] + (Double.parseDouble(train.get(randError, i).toString()) 
                                         * Double.parseDouble(train.get(randError, train.columnCount() - 1).toString()));
            }
            else{
                weight[i] = bestWeight[i] + (bias 
                                        * Double.parseDouble(train.get(randError, train.columnCount() - 1).toString()));
            }
                
        }
        inError = 0;
    }
    itterations += 1;
}


for(int i = 0; i < weight.length; i++){
    System.out.print("W" + i + ": " + bestWeight[i] + "; "); //bias weight is the final weight
    if(i % 8 == 0 && i != 0)
        System.out.print("\n"); //makes the output more readable
}
System.out.println("\nWith " + bestInError + " errors.");
System.out.println((((double)bestInError / (double)trainSize) * 100) + "% Error"); //gives % in error

W0: 40.0; W1: 25.0; W2: -1.0; W3: -7.0; W4: -6.0; W5: -3.0; W6: -3.0; W7: -7.0; W8: -7.0; 
W9: 10.0; W10: 0.33889034000000007; W11: -0.00841289; W12: 2.4299999999999997; W13: 8.0; W14: -1.0; W15: -5.0; W16: 9.0; 
W17: -7.0; W18: -3.0; W19: 5.0200000000000005; W20: -3.0; W21: 3.0; 
With 1256 errors.
13.381632218197314% Error


# Out of sample error


We used the pokket algorithm to get some weights, and the in sample error seems pretty good. Now that we have our final weights, we need to calculate our out of sample error. Our in sample error will be different each time we run the program, because we randomly select the row which changes the weights, and the train/test sets are also randomly calculated.

In [15]:
test.shape();

2347 rows X 22 cols

In [16]:
//finding the out of sample error

int testSize = test.rowCount(); //rows in test set
int testAttr = test.columnCount() - 1; //columns in test set (will be the same as those in train set)
double testCheck = 0; //sum weights * attribtues
double testTarget = 0; //holds target value.
int Eout = 0; //out of sample error
int falsePositive = 0; //count false positives
int falseNegative = 0; //count false negatives


for(int i = 0; i < testSize; i++){
    testCheck = 0;
    for(int j = 0; j < testAttr - 1; j++){ //-1 so we don't do the target as well.
        if(j == testAttr - 1){
            testCheck += bestWeight[j] * bias;
        }
        else{
            testTarget = Double.parseDouble(test.get(i, j).toString()); //holds attribute value for readability.
            testCheck += bestWeight[j] * testTarget;
            testTarget = 0;
        }
    }
    testTarget = Double.parseDouble(test.get(i, test.columnCount() - 1).toString());
    //target is the value of the target for that row
    if(testCheck >= 0 && testTarget == -1){ //calculated target sign != target sign
        Eout++;
        falsePositive++;
    }
    else if(testCheck < 0 && testTarget == 1){//calculated target sign != target sign
        Eout++;
        falseNegative++;
    }
}



System.out.println("There were " + Eout + " Errors.");
System.out.println("There were " + falsePositive + " false positives, and " + falseNegative + " false negatives.");
System.out.println("The out of sample error is: " + ((double) Eout / (double)testSize) * 100 + "%");

There were 291 Errors.
There were 32 false positives, and 259 false negatives.
The out of sample error is: 12.398806987643802%


# Good out of sample error


It looks like we got some good weights. Our out of sample error is very close to our in sample error. We could try some other algorithms as well if we wanted to, but as long as future data is formatted the same way we formatted this data, we'll have developed a good precdictor with the current weights.

We can see which attributes are most important by comparing the magnitude of the weights, heavier weights imply more important attributes.

In [17]:
int finalSize = Modified.columnCount(); //sets finalSize = # of columns

String[] columnNames = new String[finalSize]; // Names of the columns
int startIndex = 0;


for(int i = 0; i < columnNames.length; i++){
    columnNames[i] = Modified.column(i).toString(); //Getting the info of the column in position i of the array
    startIndex = columnNames[i].indexOf(":"); //: is after "type column"
    columnNames[i] = columnNames[i].substring(startIndex + 2, columnNames[i].length()); //start index +2 because there is a space after the :
    //removing everything before the column name
    System.out.println(columnNames[i] + "; " + bestWeight[i]);
}

bmi; 40.0
Age; 25.0
asa_status; -1.0
baseline_cancer; -7.0
baseline_charlson; -6.0
baseline_dementia; -3.0
baseline_diabetes; -3.0
baseline_osteoart; -7.0
baseline_psych; -7.0
ahrq_ccs; 10.0
ccsComplicationRate; 0.33889034000000007
ccsMort30Rate; -0.00841289
complication_rsi; 2.4299999999999997
dow; 8.0
gender; -1.0
hour; -5.0
month; 9.0
moonphase; -7.0
mort30; -3.0
mortality_rsi; 5.0200000000000005
race; -3.0
complication; 3.0


# Closing remarks


We were able to create a pretty good set of weights for the given dataset. We started off looking at how the data was formatted, and converted a few double columns into int columns. We proceeded to see graphs of the data, allowing us to easily look at any given attribute in relation to the target. Finally, we combined data which hit the target function a similar amount of times, and completely removed columns which only had 1 value after the combination. Finally, we gave the columns completely new data, ordering it in terms of probability to have the target function be a 1, and then put half in the negative class and half positive. After completing data manipulation, we ran the data against pocket algorithm to find weights, and calculated our final Ein and Eout.

We can also see which attributes impact the prediction the most by observing the magnitude of the weights.