Skip to content

Commit

Permalink
Add completion mode and stats for more prediction insight
Browse files Browse the repository at this point in the history
  • Loading branch information
VHellendoorn committed Jun 1, 2018
1 parent 37597f4 commit f6c1154
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 5 deletions.
10 changes: 5 additions & 5 deletions .classpath
Expand Up @@ -6,19 +6,19 @@
<attribute name="maven.pomderived" value="true"/>
</attributes>
</classpathentry>
<classpathentry kind="con" path="org.eclipse.jdt.launching.JRE_CONTAINER/org.eclipse.jdt.internal.debug.ui.launcher.StandardVMType/J2SE-1.5">
<classpathentry kind="src" output="target/test-classes" path="src/test/java">
<attributes>
<attribute name="optional" value="true"/>
<attribute name="maven.pomderived" value="true"/>
</attributes>
</classpathentry>
<classpathentry kind="con" path="org.eclipse.m2e.MAVEN2_CLASSPATH_CONTAINER">
<classpathentry kind="con" path="org.eclipse.jdt.launching.JRE_CONTAINER">
<attributes>
<attribute name="maven.pomderived" value="true"/>
<attribute name="module" value="true"/>
</attributes>
</classpathentry>
<classpathentry kind="src" output="target/test-classes" path="src/test/java">
<classpathentry kind="con" path="org.eclipse.m2e.MAVEN2_CLASSPATH_CONTAINER">
<attributes>
<attribute name="optional" value="true"/>
<attribute name="maven.pomderived" value="true"/>
</attributes>
</classpathentry>
Expand Down
40 changes: 40 additions & 0 deletions src/main/java/slp/core/modeling/runners/Completion.java
@@ -0,0 +1,40 @@
package slp.core.modeling.runners;

import java.util.List;

import slp.core.util.Pair;

public class Completion {
private final Integer realIx;
private final List<Pair<Integer, Double>> completions;

public Completion(Integer realIx, List<Pair<Integer, Double>> predictions) {
this.realIx = realIx;
this.completions = predictions;
}

public Completion(List<Pair<Integer, Double>> predictions) {
this.realIx = null;
this.completions = predictions;
}

public List<Pair<Integer, Double>> getPredictions() {
return completions;
}

public Integer getRealIx() {
return realIx;
}

public int getRank() {
if (this.realIx == null) return -1;
else {
for (int i = 0; i < this.completions.size(); i++) {
if (this.completions.get(i).left.equals(this.realIx)) {
return i;
}
}
return -1;
}
}
}
81 changes: 81 additions & 0 deletions src/main/java/slp/core/modeling/runners/ModelRunner.java
Expand Up @@ -312,6 +312,80 @@ private void logPredictionProgress(List<Double> modeled) {
}
}

public Stream<Pair<File, List<List<Completion>>>> completeDirectory(File file) {
this.modelStats = new long[] { 0, -System.currentTimeMillis() };
this.mrr = 0.0;
return this.lexerRunner.lexDirectory(file)
.map(p -> {
this.model.notify(p.left);
return Pair.of(p.left, this.completeTokens(p.right));
});
}

public List<List<Completion>> completeFile(File f) {
if (!this.lexerRunner.willLexFile(f)) return null;
this.model.notify(f);
return completeTokens(this.lexerRunner.lexFile(f));
}

public List<List<Completion>> completeContent(String content) {
return completeTokens(this.lexerRunner.lexText(content));
}

public List<List<Completion>> completeTokens(Stream<Stream<String>> lexed) {
List<List<Completion>> lineCompletions;
if (this.lexerRunner.isPerLine()) {
lineCompletions = lexed
.map(this.getVocabulary()::toIndices)
.map(l -> l.collect(Collectors.toList()))
.map(l -> completeSequence(l))
.peek(this::logCompletionProgress)
.collect(Collectors.toList());
} else {
List<Integer> lineLengths = new ArrayList<>();
List<Completion> commpletions = completeSequence(lexed
.map(this.getVocabulary()::toIndices)
.map(l -> l.collect(Collectors.toList()))
.peek(l -> lineLengths.add(l.size()))
.flatMap(l -> l.stream()).collect(Collectors.toList()));
lineCompletions = toLines(commpletions, lineLengths);
logCompletionProgress(commpletions);
}
return lineCompletions;
}

protected List<Completion> completeSequence(List<Integer> tokens) {
if (this.selfTesting) this.model.forget(tokens);
List<Map<Integer, Pair<Double, Double>>> preds = this.model.predict(tokens);
if (this.selfTesting) this.model.learn(tokens);
List<Completion> rankings = IntStream.range(0, preds.size())
.mapToObj(i -> {
List<Pair<Integer, Double>> completions = preds.get(i).entrySet().stream()
.map(e -> Pair.of(e.getKey(), toProb(e.getValue())))
.sorted((p1, p2) -> -Double.compare(p1.right, p2.right))
.limit(GLOBAL_PREDICTION_CUTOFF)
.collect(Collectors.toList());
return new Completion(tokens.get(i), completions);
}).collect(Collectors.toList());
return rankings;
}

private void logCompletionProgress(List<Completion> completions) {
DoubleSummaryStatistics stats = completions.stream().skip(1)
.map(Completion::getRank)
.mapToDouble(ModelRunner::toMRR)
.summaryStatistics();
long prevCount = this.modelStats[0];
this.modelStats[0] += stats.getCount();
this.mrr += stats.getSum();
if (this.modelStats[0] / this.MODEL_PRINT_INTERVAL > prevCount / this.MODEL_PRINT_INTERVAL
&& this.modelStats[1] != 0) {
System.out.printf("Predicting: %dK tokens processed in %ds, avg. MRR: %.4f\n",
Math.round(this.modelStats[0]/1e3),
(System.currentTimeMillis() + this.modelStats[1])/1000, this.mrr/this.modelStats[0]);
}
}

public List<Double> toProb(List<Pair<Double, Double>> probConfs) {
return probConfs.stream().map(this::toProb).collect(Collectors.toList());
}
Expand Down Expand Up @@ -381,4 +455,11 @@ private DoubleSummaryStatistics getFileStats(Stream<List<List<Double>>> fileProb
.mapToDouble(p -> p).summaryStatistics();
}
}

public DoubleSummaryStatistics getCompletionStats(List<List<Completion>> completions) {
List<List<Double>> MRRs = completions.stream()
.map(l -> l.stream().map(c -> toMRR(c.getRank())))
.map(l -> l.collect(Collectors.toList())).collect(Collectors.toList());
return getFileStats(Stream.of(MRRs));
}
}

0 comments on commit f6c1154

Please sign in to comment.