Skip to content

Commit

Permalink
Altered knn and ee to use predefined neighbours and full test neighbours
Browse files Browse the repository at this point in the history
  • Loading branch information
goastler committed Jun 20, 2019
1 parent 2fad5fd commit 07408d1
Show file tree
Hide file tree
Showing 14 changed files with 381 additions and 358 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import weka.core.Instance;
import weka.core.Instances;

import javax.rmi.CORBA.Util;
import java.io.*;
import java.nio.channels.FileChannel;
import java.nio.channels.FileLock;
Expand Down Expand Up @@ -38,28 +39,30 @@ public static void main(String[] args) throws
Instances[] data = sampleDataset(datasetDirPath, datasetName, seed);
Instances train = data[0];
Instances test = data[1];
ElasticEnsemble ee = new ElasticEnsemble();
ee.setSeed(seed);
// Knn cls = new Knn();
ElasticEnsemble cls = new ElasticEnsemble();
cls.setSeed(seed);
double pp = Double.parseDouble(args[4]);
System.out.println(pp);
ee.setNumParameterSetsPercentage(pp);
for(int n = 10; n <= 100; n+=10) { // TODO change!
cls.setNumParametersLimitPercentage(pp);
for(int n = 10; n <= 100; n+=10) {
double np = (double) n / 100;
System.out.println(np + ", " + pp);
String classifierName = "ee_np=" + np + "_pp=" + pp;
// String classifierName = "knn";
String experimentResultsDirPath = resultsDirPath + "/" + classifierName + "/Predictions/" + datasetName;
String trainResultsFilePath = experimentResultsDirPath + "/trainFold" + seed + ".csv";
String testResultsFilePath = experimentResultsDirPath + "/testFold" + seed + ".csv";
ee.setNeighbourhoodSizePercentage(np);
cls.setNeighbourhoodSizeLimitPercentage(np);
boolean trainMissing = !exists(trainResultsFilePath);
boolean testMissing = !exists(testResultsFilePath);
if(trainMissing || testMissing) {
System.out.println("training");
ee.buildClassifier(train);
cls.buildClassifier(train);
}
if(trainMissing) {
System.out.println("getting train results");
ClassifierResults trainResults = ee.getTrainResults();
ClassifierResults trainResults = cls.getTrainResults();
trainResults.setDatasetName(datasetName);
trainResults.setFoldID(seed);
trainResults.setClassifierName(classifierName);
Expand All @@ -69,12 +72,12 @@ public static void main(String[] args) throws
}
if(testMissing) {
System.out.println("getting test results");
ClassifierResults testResults = ee.getTestResults(test);
ClassifierResults testResults = cls.getTestResults(test);
testResults.setDatasetName(datasetName);
testResults.setFoldID(seed);
testResults.setClassifierName(classifierName);
writeToFile(testResults, testResultsFilePath);
ee.resetTestRandom();
cls.resetTestRandom();
} else {
System.out.println("test exists");
}
Expand Down Expand Up @@ -162,14 +165,14 @@ public static void main(String[] args) throws
// Instances train = data[0];
// Instances test = data[1];
// ElasticEnsemble ee = new ElasticEnsemble();
// ee.setNumParameterSetsPercentage(0.1);
// ee.setNumParametersLimitPercentage(0.1);
// ee.setSeed(seed);
// ee.setNeighbourSearchStrategy(Knn.NeighbourSearchStrategy.RANDOM);
// int n = Integer.parseInt(args[4]);
//// n *= train.numClasses();
//// double pp = Double.parseDouble(args[4]);
//// System.out.println(pp);
//// ee.setNumParameterSetsPercentage(pp);
//// ee.setNumParametersLimitPercentage(pp);
//// for(int n = 5; n <= 30; n+=5) {
//// double np = (double) n / 100;
//// System.out.println(np + ", " + pp);
Expand All @@ -179,8 +182,8 @@ public static void main(String[] args) throws
// String experimentResultsDirPath = resultsDirPath + "/" + classifierName + "/Predictions/" + datasetName;
// String trainResultsFilePath = experimentResultsDirPath + "/trainFold" + seed + ".csv";
// String testResultsFilePath = experimentResultsDirPath + "/testFold" + seed + ".csv";
// ee.setNeighbourhoodSize(n);
//// ee.setNeighbourhoodSizePercentage(np);
// ee.setTrainNeighbourhoodSizeLimit(n);
//// ee.setTrainNeighbourhoodSizeLimitPercentage(np);
// boolean trainMissing = !exists(trainResultsFilePath);
// boolean testMissing = !exists(testResultsFilePath);
// if(trainMissing || testMissing) {
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package classifiers.distance_based.elastic_ensemble.iteration;

import java.util.Collection;
import java.util.Iterator;
import java.util.ListIterator;

Expand Down Expand Up @@ -34,4 +35,10 @@ public void set(final A a) {

@Override
public abstract B iterator();

public void addAll(Collection<A> collection) {
for(A item : collection) {
add(item);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ public abstract class AbstractLinearIterator<A, B extends AbstractLinearIterator
protected final List<A> values;
protected int index = 0;

public AbstractLinearIterator(final Collection<? extends A> values) {
public AbstractLinearIterator(final List<A> values) {
this.values = new ArrayList<>(values);
}

Expand All @@ -23,8 +23,8 @@ public AbstractLinearIterator(AbstractLinearIterator<A, B> other) {

@Override
public void remove() {
index--;
values.remove(index);
index--;
if(index < 0) {
index = 0;
}
Expand All @@ -37,7 +37,7 @@ public void add(final A a) {

@Override
public boolean hasNext() {
return !values.isEmpty();
return index < values.size();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package classifiers.distance_based.elastic_ensemble.iteration.linear;

import java.util.Collection;
import java.util.List;

public class LinearIterator<A> extends AbstractLinearIterator<A, LinearIterator<A>> {

public LinearIterator(final Collection<? extends A> values) {
public LinearIterator(final List<A> values) {
super(values);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,35 @@

import classifiers.distance_based.elastic_ensemble.iteration.linear.AbstractLinearIterator;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Random;

public abstract class AbstractRandomIterator<A, B extends AbstractRandomIterator<A, B>> extends AbstractLinearIterator<A, AbstractRandomIterator<A, B>> {
protected final Random random;
protected long seed;

public AbstractRandomIterator(final Collection<? extends A> values, final Random random) {
public AbstractRandomIterator(final List<A> values, final long seed) {
super(values);
this.random = random;
this.random = new Random(seed);
}

public AbstractRandomIterator(AbstractRandomIterator<A, B> other) {
this(other.values, other.random);
this(other.values, other.seed);
index = other.index;
}

public AbstractRandomIterator(Random random) {
this.random = random;
public AbstractRandomIterator(long seed) {
this.random = new Random(seed);
}

@Override
public void remove() {
values.remove(index--);
}


@Override
public A next() {
return values.get(index = random.nextInt(values.size()));
index = random.nextInt(values.size());
seed = random.nextLong();
random.setSeed(seed);
return values.get(index);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import classifiers.distance_based.elastic_ensemble.iteration.linear.AbstractLinearIterator;

import java.util.Collection;
import java.util.List;

public abstract class AbstractRoundRobinIterator<A, B extends AbstractRoundRobinIterator<A, B>> extends AbstractLinearIterator<A, AbstractRoundRobinIterator<A, B>>
{

public AbstractRoundRobinIterator(final Collection<? extends A> values) {
public AbstractRoundRobinIterator(final List<A> values) {
super(values);
}

Expand All @@ -20,11 +21,6 @@ public AbstractRoundRobinIterator() {
super();
}

@Override
public void remove() {
values.remove(index--);
}

@Override
public A next() {
return values.get(index = (index + 1) % values.size());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,31 @@
package classifiers.distance_based.elastic_ensemble.iteration.random;

import java.util.Collection;
import java.util.List;
import java.util.Random;

public class RandomIterator<A> extends AbstractRandomIterator<A, RandomIterator<A>> {

public RandomIterator(final Collection<? extends A> values, final Random random) {
super(values, random);
public RandomIterator(final List<A> values, final long seed) {
super(values, seed);
}

public RandomIterator(RandomIterator<A> other) {
this(other.values, other.random);
this(other.values, other.seed);
index = other.index;
}

public RandomIterator(Random random) {
super(random);
public RandomIterator(long seed) {
super(seed);
}

@Override
public RandomIterator<A> iterator() {
return new RandomIterator<>(this);
}

@Override
public boolean hasNext() {
return !values.isEmpty();
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package classifiers.distance_based.elastic_ensemble.iteration.random;

import java.util.Collection;
import java.util.List;
import java.util.Random;

public class RoundRobinIterator<A> extends AbstractRoundRobinIterator<A, RoundRobinIterator<A>> {

public RoundRobinIterator(final Collection<? extends A> values) {
public RoundRobinIterator(final List<A> values) {
super(values);
}

Expand Down
Loading

0 comments on commit 07408d1

Please sign in to comment.