Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Issue #144 - Re-implement classifiers using SMILE library #146

Merged
merged 11 commits into from Jan 8, 2018
5 changes: 3 additions & 2 deletions build.gradle
Expand Up @@ -39,7 +39,8 @@ dependencies {
compile 'org.apache.tika:tika-parsers:1.14'
compile 'com.syncthemall:boilerpipe:1.2.2'
compile 'net.sourceforge.nekohtml:nekohtml:1.9.22'
compile 'nz.ac.waikato.cms.weka:weka-stable:3.6.13'
//compile 'nz.ac.waikato.cms.weka:weka-stable:3.6.13'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remember to remove this commented line.

compile 'com.github.haifengl:smile-core:1.5.0'
compile 'org.apache.lucene:lucene-core:4.10.4'
compile 'org.elasticsearch:elasticsearch:1.4.4'
compile 'org.elasticsearch.client:elasticsearch-rest-client:5.6.3'
Expand All @@ -55,7 +56,7 @@ dependencies {
exclude group: 'org.apache.hadoop', module: 'hadoop-core'
exclude group: 'junit', module: 'junit'
}

// REST server dependencies
compile "com.sparkjava:spark-core:2.5.3"

Expand Down
21 changes: 14 additions & 7 deletions src/main/java/focusedCrawler/Main.java
@@ -1,6 +1,7 @@
package focusedCrawler;

import java.io.File;
import java.io.IOException;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;
Expand All @@ -14,7 +15,7 @@
import focusedCrawler.link.frontier.FrontierManagerFactory;
import focusedCrawler.rest.RestServer;
import focusedCrawler.seedfinder.SeedFinder;
import focusedCrawler.target.classifier.WekaTargetClassifierBuilder;
import focusedCrawler.target.classifier.SmileTargetClassifierBuilder;
import focusedCrawler.tools.StartRestServer;
import io.airlift.airline.Arguments;
import io.airlift.airline.Cli;
Expand Down Expand Up @@ -143,7 +144,7 @@ public void run() {

}

@Command(name = "buildModel", description = "Builds a model for a Weka target classifier")
@Command(name = "buildModel", description = "Builds a model for a Smile target classifier")
public static class BuildModel implements Runnable {

@Option(name = {"-t", "--trainingDataDir"}, required = true, description = "Path to folder containing training data")
Expand All @@ -155,25 +156,31 @@ public static class BuildModel implements Runnable {
@Option(name = {"-c", "--stopWordsFile"}, required = false, description = "Path to stopwords file")
String stopWordsFile;

@Option(name = {"-l", "--learner"}, required = false, description = "Machine-learning algorithm to be used to train the model (SMO, RandomForest)")
@Option(name = {"-l", "--learner"}, required = false, description = "Machine-learning algorithm to be used to train the model (SVM, RandomForest)")
String learner;

@Override
public void run() {

new File(outputPath).mkdirs();

// generate the input for weka
// generate the input for smile
System.out.println("Preparing training data...");
WekaTargetClassifierBuilder.createInputFile(stopWordsFile, trainingPath, trainingPath + "/weka.arff" );
SmileTargetClassifierBuilder.createInputFile(stopWordsFile, trainingPath, trainingPath + "/smile_input.arff" );

// generate the model
System.out.println("Training model...");
WekaTargetClassifierBuilder.trainModel(trainingPath, outputPath, learner);
try {
SmileTargetClassifierBuilder.trainModel(trainingPath, outputPath, learner);
} catch (IOException | java.text.ParseException e) {
System.out.printf("Failed to build model.\n\n");
e.printStackTrace(System.out);
System.exit(1);
}

// generate features file
System.out.println("Creating feature file...");
WekaTargetClassifierBuilder.createFeaturesFile(outputPath,trainingPath);
SmileTargetClassifierBuilder.createFeaturesFile(outputPath,trainingPath);

System.out.println("done.");
}
Expand Down
56 changes: 11 additions & 45 deletions src/main/java/focusedCrawler/link/classifier/LNClassifier.java
@@ -1,31 +1,26 @@
package focusedCrawler.link.classifier;

import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.util.Iterator;
import java.util.Map;

import focusedCrawler.link.classifier.builder.Instance;
import focusedCrawler.link.classifier.builder.LinkNeighborhoodWrapper;
import focusedCrawler.util.ParameterFile;
import focusedCrawler.util.SmileUtil;
import focusedCrawler.util.parser.LinkNeighborhood;
import focusedCrawler.util.string.StopList;
import weka.classifiers.Classifier;
import weka.core.Instances;
import smile.classification.SoftClassifier;

public class LNClassifier {

private final Classifier classifier;
private final Instances instances;
private final SoftClassifier<double[]> classifier;
private final LinkNeighborhoodWrapper wrapper;
private final String[] attributes;

public LNClassifier(Classifier classifier, Instances instances,
LinkNeighborhoodWrapper wrapper, String[] attributes) {

public LNClassifier(SoftClassifier<double[]> classifier, LinkNeighborhoodWrapper wrapper,
String[] attributes) {
this.classifier = classifier;
this.instances = instances;
this.wrapper = wrapper;
this.attributes = attributes;
}
Expand All @@ -37,10 +32,9 @@ public double[] classify(LinkNeighborhood ln) throws Exception {
Instance instance = (Instance)urlWords.get(url);
double[] values = instance.getValues();
synchronized (classifier) {
weka.core.Instance instanceWeka = new weka.core.Instance(1, values);
instanceWeka.setDataset(instances);
double[] probs = classifier.distributionForInstance(instanceWeka);
return probs;
double[] prob = new double[2];
int predictedValue = classifier.predict(values, prob);
return prob;
}
}

Expand All @@ -55,37 +49,9 @@ public static LNClassifier create(String featureFilePath,

public static LNClassifier create(String[] attributes, String[] classValues,
String modelFilePath, StopList stoplist) {
weka.core.FastVector vectorAtt = new weka.core.FastVector();
for (int i = 0; i < attributes.length; i++) {
vectorAtt.addElement(new weka.core.Attribute(attributes[i]));
}
weka.core.FastVector classAtt = new weka.core.FastVector();
for (int i = 0; i < classValues.length; i++) {
classAtt.addElement(classValues[i]);
}
vectorAtt.addElement(new weka.core.Attribute("class", classAtt));
Instances insts = new Instances("link_classification", vectorAtt, 1);
insts.setClassIndex(attributes.length);

LinkNeighborhoodWrapper wrapper = new LinkNeighborhoodWrapper(attributes, stoplist);

Classifier classifier = loadWekaClassifier(modelFilePath);

return new LNClassifier(classifier, insts, wrapper, attributes);

SoftClassifier<double[]> classifier = SmileUtil.loadSmileClassifier(modelFilePath);
return new LNClassifier(classifier, wrapper, attributes);
}

private static Classifier loadWekaClassifier(String modelFilePath) {
try {
InputStream is = new FileInputStream(modelFilePath);
ObjectInputStream objectInputStream = new ObjectInputStream(is);
Classifier classifier = (Classifier) objectInputStream.readObject();
objectInputStream.close();
return classifier;
} catch (IOException | ClassNotFoundException e) {
throw new IllegalArgumentException(
"Failed to load weka classifier instance from file: " + modelFilePath, e);
}
}

}
Expand Up @@ -10,21 +10,19 @@
import focusedCrawler.link.frontier.LinkRelevance;
import focusedCrawler.target.model.Page;
import focusedCrawler.util.parser.LinkNeighborhood;
import weka.classifiers.Classifier;
import weka.core.Instances;
import smile.classification.SoftClassifier;
import smile.classification.SVM;

public class LinkClassifierAuthority implements LinkClassifier{

private LinkNeighborhoodWrapper wrapper;
private String[] attributes;
private Classifier classifier;
private Instances instances;
private SoftClassifier<double[]> classifier;

public LinkClassifierAuthority(Classifier classifier, Instances instances, LinkNeighborhoodWrapper wrapper,String[] attributes) {
public LinkClassifierAuthority(SoftClassifier<double[]> classifier, LinkNeighborhoodWrapper wrapper,String[] attributes) {
this.wrapper = wrapper;
this.attributes = attributes;
this.classifier = classifier;
this.instances = instances;
}

public LinkClassifierAuthority() {
Expand Down Expand Up @@ -53,9 +51,8 @@ public LinkRelevance[] classify(Page page) throws LinkClassifierException {
if(!page.getURL().getHost().equals(url.getHost())){
Instance instance = entry.getValue();
double[] values = instance.getValues();
weka.core.Instance instanceWeka = new weka.core.Instance(1, values);
instanceWeka.setDataset(instances);
double[] prob = classifier.distributionForInstance(instanceWeka);
double[] prob = new double[2];
int predictedValue = classifier.predict(values, prob);
relevance = LinkRelevance.DEFAULT_AUTH_RELEVANCE + (prob[0]*100);
}
linkRelevance[count] = new LinkRelevance(url, relevance);
Expand Down Expand Up @@ -90,9 +87,8 @@ public LinkRelevance classify(LinkNeighborhood ln) throws LinkClassifierExceptio
if(classifier != null){
Instance instance = (Instance) entry.getValue();
double[] values = instance.getValues();
weka.core.Instance instanceWeka = new weka.core.Instance(1, values);
instanceWeka.setDataset(instances);
double[] prob = classifier.distributionForInstance(instanceWeka);
double[] prob = new double[2];
int predictedValue = ((SVM<double[]>) classifier).predict(values, prob);
if(prob[0] == 1){
prob[0] = 0.99;
}
Expand Down
Expand Up @@ -9,8 +9,9 @@
import focusedCrawler.link.frontier.LinkRelevance;
import focusedCrawler.target.model.Page;
import focusedCrawler.util.parser.LinkNeighborhood;
import weka.classifiers.Classifier;
import weka.core.Instances;
import smile.classification.SoftClassifier;
import smile.classification.SVM;


/**
* This class implements the link classifier for the hub links.
Expand All @@ -19,18 +20,16 @@
*/
public class LinkClassifierHub implements LinkClassifier{

private Classifier classifier;
private Instances instances;
private SoftClassifier<double[]> classifier;
private LinkNeighborhoodWrapper wrapper;
private String[] attributes;

public LinkClassifierHub(){

}

public LinkClassifierHub(Classifier classifier, Instances instances, LinkNeighborhoodWrapper wrapper,String[] attributes) {
public LinkClassifierHub(SoftClassifier<double[]> classifier, LinkNeighborhoodWrapper wrapper,String[] attributes) {
this.classifier = classifier;
this.instances = instances;
this.wrapper = wrapper;
this.attributes = attributes;
}
Expand All @@ -47,9 +46,8 @@ public LinkRelevance classify(LinkNeighborhood ln) throws LinkClassifierExceptio
String url = (String)iter.next();
Instance instance = (Instance)urlWords.get(url);
double[] values = instance.getValues();
weka.core.Instance instanceWeka = new weka.core.Instance(1, values);
instanceWeka.setDataset(instances);
double[] prob = classifier.distributionForInstance(instanceWeka);
double[] prob = new double[2];
int predictedValue = ((SVM<double[]>)classifier).predict(values, prob);
double relevance = LinkRelevance.DEFAULT_HUB_RELEVANCE + prob[0]*100;
result = new LinkRelevance(ln.getLink(),relevance);
}
Expand Down