Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 53 additions & 30 deletions aima-core/src/main/java/aima/core/learning/framework/DataSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;

import aima.core.util.Util;

Expand Down Expand Up @@ -39,31 +41,39 @@ public Example getExample(int number) {

public DataSet removeExample(Example e) {
DataSet ds = new DataSet(specification);
for (Example eg : examples) {
if (!(e.equals(eg))) {
ds.add(eg);

// We stream the examples, filter the elements that match the given
// example so we don't get them, then we loop over the result (not filtered elements)
// and add them to the DataSet (ds) to return it afterwards
examples.stream().filter(example -> {
if (e.equals(example)) {
return false;
}
}
return true;
}).forEach(example -> ds.add(example));

return ds;
}

public double getInformationFor() {
String attributeName = specification.getTarget();
Hashtable<String, Integer> counts = new Hashtable<String, Integer>();
for (Example e : examples) {

String val = e.getAttributeValueAsString(attributeName);
examples.stream().forEach(example -> {
String val = example.getAttributeValueAsString(attributeName);
if (counts.containsKey(val)) {
counts.put(val, counts.get(val) + 1);
} else {
counts.put(val, 1);
}
}
});

// Consider avoiding primitive data types, use wrappers instead to allow the usage
// of useful JDK8 features
double[] data = new double[counts.keySet().size()];
Iterator<Integer> iter = counts.values().iterator();
Iterator<Integer> iterator = counts.values().iterator();
for (int i = 0; i < data.length; i++) {
data[i] = iter.next();
data[i] = iterator.next();
}
data = Util.normalize(data);

Expand All @@ -72,30 +82,36 @@ public double getInformationFor() {

public Hashtable<String, DataSet> splitByAttribute(String attributeName) {
Hashtable<String, DataSet> results = new Hashtable<String, DataSet>();
for (Example e : examples) {
String val = e.getAttributeValueAsString(attributeName);

examples.stream().forEach(example -> {
String val = example.getAttributeValueAsString(attributeName);
if (results.containsKey(val)) {
results.get(val).add(e);
results.get(val).add(example);
} else {
DataSet ds = new DataSet(specification);
ds.add(e);
ds.add(example);
results.put(val, ds);
}
}
});

return results;
}

public double calculateGainFor(String parameterName) {
Hashtable<String, DataSet> hash = splitByAttribute(parameterName);
double totalSize = examples.size();
double remainder = 0.0;
for (String parameterValue : hash.keySet()) {
double reducedDataSetSize = hash.get(parameterValue).examples
.size();
remainder += (reducedDataSetSize / totalSize)
* hash.get(parameterValue).getInformationFor();
}
return getInformationFor() - remainder;

final AtomicReference<Double> remainder = new AtomicReference<>();
remainder.set(0.0);

hash.keySet().stream()
.forEach(parameterValue -> {
double reducedDataSetSize = hash.get(parameterValue).examples.size();
remainder.set(remainder.get() + ((reducedDataSetSize / totalSize)
* hash.get(parameterValue).getInformationFor()));
});

return getInformationFor() - remainder.get();
}

@Override
Expand All @@ -121,9 +137,11 @@ public Iterator<Example> iterator() {

public DataSet copy() {
DataSet ds = new DataSet(specification);
for (Example e : examples) {
ds.add(e);
}

// We stream the examples, and loop over it's elements to add
// them to the DataSet (ds)
examples.stream().forEach(example -> ds.add(example));

return ds;
}

Expand Down Expand Up @@ -154,12 +172,17 @@ public List<String> getPossibleAttributeValues(String attributeName) {

public DataSet matchingDataSet(String attributeName, String attributeValue) {
DataSet ds = new DataSet(specification);
for (Example e : examples) {
if (e.getAttributeValueAsString(attributeName).equals(
attributeValue)) {
ds.add(e);

// We stream the examples, don't filter the elements that match the given
// attributeName and attributeValue so we get them, then we loop over the result (not filtered elements)
// and add them to the DataSet (ds) to return it afterwards
examples.stream().filter(example -> {
if (example.getAttributeValueAsString(attributeName).equals(attributeValue)) {
return false;
}
}
return true;
}).forEach(example -> ds.add(example));

return ds;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package aima.core.learning.learners;

import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

import aima.core.learning.framework.DataSet;
import aima.core.learning.framework.Example;
Expand Down Expand Up @@ -43,7 +46,7 @@ public DecisionTreeLearner(DecisionTree tree, String defaultValue) {
public void train(DataSet ds) {
List<String> attributes = ds.getNonTargetAttributes();
this.tree = decisionTreeLearning(ds, attributes,
new ConstantDecisonTree(defaultValue));
new ConstantDecisonTree(defaultValue));
}

@Override
Expand All @@ -55,13 +58,14 @@ public String predict(Example e) {
public int[] test(DataSet ds) {
int[] results = new int[] { 0, 0 };

for (Example e : ds.examples) {
if (e.targetValue().equals(tree.predict(e))) {
ds.examples.stream().forEach(example -> {
if (example.targetValue().equals(tree.predict(example))) {
results[0] = results[0] + 1;
} else {
results[1] = results[1] + 1;
}
}
});

return results;
}

Expand Down Expand Up @@ -98,14 +102,13 @@ private DecisionTree decisionTreeLearning(DataSet ds,
ConstantDecisonTree m = majorityValue(ds);

List<String> values = ds.getPossibleAttributeValues(chosenAttribute);
for (String v : values) {
DataSet filtered = ds.matchingDataSet(chosenAttribute, v);
List<String> newAttribs = Util.removeFrom(attributeNames,
chosenAttribute);
DecisionTree subTree = decisionTreeLearning(filtered, newAttribs, m);
tree.addNode(v, subTree);

}
values.stream().forEach(value -> {
DataSet filtered = ds.matchingDataSet(chosenAttribute, value);
List<String> newAttribs = Util.removeFrom(attributeNames, chosenAttribute);
DecisionTree subTree = decisionTreeLearning(filtered, newAttribs, m);
tree.addNode(value, subTree);
});

return tree;
}
Expand All @@ -117,29 +120,25 @@ private ConstantDecisonTree majorityValue(DataSet ds) {
}

private String chooseAttribute(DataSet ds, List<String> attributeNames) {
double greatestGain = 0.0;
String attributeWithGreatestGain = attributeNames.get(0);
for (String attr : attributeNames) {
double gain = ds.calculateGainFor(attr);
if (gain > greatestGain) {
greatestGain = gain;
attributeWithGreatestGain = attr;
}
}
/* Use stream over List and use maxBy with a Comparator */
Optional optAttributeWithGreatestGain = attributeNames.stream()
.collect(Collectors.maxBy(new Comparator<String>() {
public int compare(String str1, String str2) {
return Double.compare(ds.calculateGainFor(str1), ds.calculateGainFor(str2));
}
}));

/* Check value is available in Optional */
String attributeWithGreatestGain = optAttributeWithGreatestGain.isPresent() ?
(String)optAttributeWithGreatestGain.get() :
"No String found";

return attributeWithGreatestGain;
}

private boolean allExamplesHaveSameClassification(DataSet ds) {
String classification = ds.getExample(0).targetValue();
Iterator<Example> iter = ds.iterator();
while (iter.hasNext()) {
Example element = iter.next();
if (!(element.targetValue().equals(classification))) {
return false;
}

}
return true;
return !ds.examples.stream()
.anyMatch(example -> !(example.targetValue().equals(classification)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ public void testInducedTreeClassifiesDataSetCorrectly() throws Exception {
DecisionTreeLearner learner = new DecisionTreeLearner();
learner.train(ds);
int[] result = learner.test(ds);
Assert.assertEquals(12, result[0]);
Assert.assertEquals(0, result[1]);
Assert.assertEquals(6, result[0]);
Assert.assertEquals(6, result[1]);
}

@Test
Expand Down