Skip to content

Commit

Permalink
Switching network structure, cleaning out some logic. print out test …
Browse files Browse the repository at this point in the history
…results to see how much money would be won with our network - it's not a lot
  • Loading branch information
Bob Murrell committed Mar 7, 2019
1 parent c227170 commit 0d7fb79
Showing 1 changed file with 85 additions and 35 deletions.
120 changes: 85 additions & 35 deletions src/main/java/com/secondline/lotto/LottoNN.java
@@ -1,22 +1,17 @@
package com.secondline.lotto; package com.secondline.lotto;


import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;


import com.google.common.primitives.Doubles; import com.google.common.primitives.Doubles;
import java.util.Comparator;
import java.util.List; import java.util.List;


import org.neuroph.core.NeuralNetwork; import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet; import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow; import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.events.LearningEvent; import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener; import org.neuroph.core.events.LearningEventListener;
import org.neuroph.nnet.MultiLayerPerceptron; import org.neuroph.nnet.JordanNetwork;
import org.neuroph.nnet.learning.BackPropagation; import org.neuroph.nnet.learning.BackPropagation;
import org.neuroph.util.data.norm.MaxNormalizer;
import org.neuroph.util.data.norm.Normalizer;

import com.secondline.lotto.util.DataUtil; import com.secondline.lotto.util.DataUtil;


public class LottoNN { public class LottoNN {
Expand All @@ -26,24 +21,21 @@ public static void main(String[] args) {


DataSet dataSet = createDataSet(dataRows); DataSet dataSet = createDataSet(dataRows);
dataSet.shuffle(); dataSet.shuffle();


//Normalizing data set DataSet[] sets = dataSet.createTrainingAndTestSubsets(65, 35);
Normalizer normalizer = new MaxNormalizer();
normalizer.normalize(dataSet);

DataSet[] sets = dataSet.createTrainingAndTestSubsets(60, 40);
DataSet trainSet = sets[0]; DataSet trainSet = sets[0];
trainSet.shuffle();
DataSet testSet = sets[1]; DataSet testSet = sets[1];


int inputCount = 6; // 5 numbers + powerball int inputCount = 6; // 5 numbers + powerball
int outputCount = 69 + 26; // all numbers + powerball options int outputCount = 69 + 26; // all numbers + powerball options


MultiLayerPerceptron neuralNet = new MultiLayerPerceptron(inputCount, 12, 12, 20, outputCount); JordanNetwork neuralNet = new JordanNetwork(inputCount, 36, 64, outputCount);
BackPropagation learningRule = neuralNet.getLearningRule(); BackPropagation learningRule = new BackPropagation();


learningRule.setLearningRate(0.5); learningRule.setLearningRate(0.004);
learningRule.setMaxError(0.001); learningRule.setMaxError(0.05);
learningRule.setMaxIterations(2500); learningRule.setMaxIterations(3000);


// add learning listener in order to print out training info // add learning listener in order to print out training info
learningRule.addListener(new LearningEventListener() { learningRule.addListener(new LearningEventListener() {
Expand All @@ -54,13 +46,14 @@ public void handleLearningEvent(LearningEvent event) {
System.out.println(); System.out.println();
System.out.println("Training completed in " + bp.getCurrentIteration() + " iterations"); System.out.println("Training completed in " + bp.getCurrentIteration() + " iterations");
System.out.println("With total error " + bp.getTotalNetworkError() + '\n'); System.out.println("With total error " + bp.getTotalNetworkError() + '\n');
} else { } else if((bp.getCurrentIteration() % 100) == 0){
System.out.println("Iteration: " + bp.getCurrentIteration() + " | Network error: " System.out.println("Iteration: " + bp.getCurrentIteration() + " | Network error: "
+ bp.getTotalNetworkError()); + bp.getTotalNetworkError());
} }
} }


}); });
neuralNet.setLearningRule(learningRule);


// train neural network // train neural network
neuralNet.learn(trainSet); neuralNet.learn(trainSet);
Expand Down Expand Up @@ -89,20 +82,20 @@ private static DataSetRow createDataSetRow(String trainRow) {
return null; return null;


double[] inputs = new double[6]; double[] inputs = new double[6];
double[] outputs= new double[95]; double[] outputs = new double[95];


//first 6 columns of row are input layer values // first 6 columns of row are input layer values
for(int i = 0; i < 6; ++i){ for (int i = 0; i < 6; ++i) {
inputs[i] = Double.valueOf(elements[i]); inputs[i] = Double.valueOf(elements[i]);
} }


// next 5 columns are desired output numbers // next 5 columns are desired output numbers
for(int i = 6; i < 11; ++i){ for (int i = 6; i < 11; ++i) {
//index of the output layer equals the lotto number // index of the output layer equals the lotto number
outputs[Integer.valueOf(elements[i]) -1] = 1; outputs[Integer.valueOf(elements[i]) - 1] = 1;
} }
//last column is desired output powerball // last column is desired output powerball
outputs[68+ Integer.valueOf(elements[11])] = 1; outputs[68 + Integer.valueOf(elements[11])] = 1;


return new DataSetRow(inputs, outputs); return new DataSetRow(inputs, outputs);
} }
Expand All @@ -120,32 +113,70 @@ public static void testNeuralNetwork(NeuralNetwork neuralNet, DataSet testSet) {


System.out.println("--------------------------------------------------------------------"); System.out.println("--------------------------------------------------------------------");
System.out.println("***********************TESTING NEURAL NETWORK***********************"); System.out.println("***********************TESTING NEURAL NETWORK***********************");
int totalWinnings = 0, count = 0;
for (DataSetRow testSetRow : testSet.getRows()) { for (DataSetRow testSetRow : testSet.getRows()) {
neuralNet.setInput(testSetRow.getInput()); neuralNet.setInput(testSetRow.getInput());
neuralNet.calculate(); neuralNet.calculate();


int[] predictions = maxOutputs(neuralNet.getOutput()); int[] predictions = maxOutputs(neuralNet.getOutput());
System.out.println(printArray(testSetRow.getInput()) + "----->" + printIntArray(predictions)); double[] expected = testSetRow.getDesiredOutput();
int amountWon = getWinningAmount(predictions, expected);
System.out.println(printArray(testSetRow.getInput()) + ", guessed " + printIntArray(predictions)+ ", Expected: " + printIntArray(expected(expected)) + ", wins $" + amountWon);
System.out.println(""); System.out.println("");
totalWinnings += amountWon;
count++;
} }
System.out.println("***********************DONE TESTING NEURAL NETWORK***********************");
System.out.println("Total winnings: $"+totalWinnings+", Avg winnings: $" + (totalWinnings / count));
} }


private static String printArray(double[] object){ private static int getWinningAmount(int[] predictions, double[] expected) {
final int[] winningNumbers = expected(expected);

int matchingNumbers = 0;
for (int i = 0; i < winningNumbers.length - 1; ++i) {
int winningNumber = winningNumbers[i];
for (int j = 0; j < predictions.length - 1; ++j) {
int guessedNumber = predictions[j];
if (guessedNumber == winningNumber)
matchingNumbers++;
}
}
boolean powerballCorrect = winningNumbers[5] == predictions[5];

switch (matchingNumbers) {
case 1:
return powerballCorrect ? 4 : 0;
case 2:
return powerballCorrect ? 7 : 0;
case 3:
return powerballCorrect ? 100 : 7;
case 4:
return powerballCorrect ? 50000 : 100;
case 5:
return powerballCorrect ? 100000000 : 1000000;
default:
return powerballCorrect ? 4 : 0;

}
}

private static String printArray(double[] object) {
String result = ""; String result = "";
for(double o : object){ for (double o : object) {
result += o+"-"; result += o + "-";
} }
return result; return result;
} }

private static String printIntArray(int[] object){ private static String printIntArray(int[] object) {
String result = ""; String result = "";
for(int o : object){ for (int o : object) {
result += o+"-"; result += o + "-";
} }
return result; return result;
} }

public static int[] maxOutputs(double[] array) { public static int[] maxOutputs(double[] array) {


int[] result = new int[6]; int[] result = new int[6];
Expand All @@ -162,4 +193,23 @@ public static int[] maxOutputs(double[] array) {
result[5] = 95 - sorted.indexOf(sortedNumbers[94]); result[5] = 95 - sorted.indexOf(sortedNumbers[94]);
return result; return result;
} }

public static int[] expected(double[] array) {

int[] result = new int[6];
int count = 0;
for(int i = 0; i < 69; ++i){
if(array[i] > 0.1){
result[count] = i+1;
count++;
}
}
for(int i = 69; i < 95; ++ i){
if(array[i] > 0.1){
result[5] = 95 - i;
break;
}
}
return result;
}
} }

0 comments on commit 0d7fb79

Please sign in to comment.