diff --git a/.classpath b/.classpath index 4af86cd..636c508 100644 --- a/.classpath +++ b/.classpath @@ -6,19 +6,19 @@ - + + - + - + - + - diff --git a/src/main/java/slp/core/modeling/runners/Completion.java b/src/main/java/slp/core/modeling/runners/Completion.java new file mode 100644 index 0000000..8a50c31 --- /dev/null +++ b/src/main/java/slp/core/modeling/runners/Completion.java @@ -0,0 +1,40 @@ +package slp.core.modeling.runners; + +import java.util.List; + +import slp.core.util.Pair; + +public class Completion { + private final Integer realIx; + private final List> completions; + + public Completion(Integer realIx, List> predictions) { + this.realIx = realIx; + this.completions = predictions; + } + + public Completion(List> predictions) { + this.realIx = null; + this.completions = predictions; + } + + public List> getPredictions() { + return completions; + } + + public Integer getRealIx() { + return realIx; + } + + public int getRank() { + if (this.realIx == null) return -1; + else { + for (int i = 0; i < this.completions.size(); i++) { + if (this.completions.get(i).left.equals(this.realIx)) { + return i; + } + } + return -1; + } + } +} diff --git a/src/main/java/slp/core/modeling/runners/ModelRunner.java b/src/main/java/slp/core/modeling/runners/ModelRunner.java index 45ad45b..3fc31a0 100644 --- a/src/main/java/slp/core/modeling/runners/ModelRunner.java +++ b/src/main/java/slp/core/modeling/runners/ModelRunner.java @@ -312,6 +312,80 @@ private void logPredictionProgress(List modeled) { } } + public Stream>>> completeDirectory(File file) { + this.modelStats = new long[] { 0, -System.currentTimeMillis() }; + this.mrr = 0.0; + return this.lexerRunner.lexDirectory(file) + .map(p -> { + this.model.notify(p.left); + return Pair.of(p.left, this.completeTokens(p.right)); + }); + } + + public List> completeFile(File f) { + if (!this.lexerRunner.willLexFile(f)) return null; + this.model.notify(f); + return completeTokens(this.lexerRunner.lexFile(f)); + } + + public List> completeContent(String content) { + return completeTokens(this.lexerRunner.lexText(content)); + } + + public List> completeTokens(Stream> lexed) { + List> lineCompletions; + if (this.lexerRunner.isPerLine()) { + lineCompletions = lexed + .map(this.getVocabulary()::toIndices) + .map(l -> l.collect(Collectors.toList())) + .map(l -> completeSequence(l)) + .peek(this::logCompletionProgress) + .collect(Collectors.toList()); + } else { + List lineLengths = new ArrayList<>(); + List commpletions = completeSequence(lexed + .map(this.getVocabulary()::toIndices) + .map(l -> l.collect(Collectors.toList())) + .peek(l -> lineLengths.add(l.size())) + .flatMap(l -> l.stream()).collect(Collectors.toList())); + lineCompletions = toLines(commpletions, lineLengths); + logCompletionProgress(commpletions); + } + return lineCompletions; + } + + protected List completeSequence(List tokens) { + if (this.selfTesting) this.model.forget(tokens); + List>> preds = this.model.predict(tokens); + if (this.selfTesting) this.model.learn(tokens); + List rankings = IntStream.range(0, preds.size()) + .mapToObj(i -> { + List> completions = preds.get(i).entrySet().stream() + .map(e -> Pair.of(e.getKey(), toProb(e.getValue()))) + .sorted((p1, p2) -> -Double.compare(p1.right, p2.right)) + .limit(GLOBAL_PREDICTION_CUTOFF) + .collect(Collectors.toList()); + return new Completion(tokens.get(i), completions); + }).collect(Collectors.toList()); + return rankings; + } + + private void logCompletionProgress(List completions) { + DoubleSummaryStatistics stats = completions.stream().skip(1) + .map(Completion::getRank) + .mapToDouble(ModelRunner::toMRR) + .summaryStatistics(); + long prevCount = this.modelStats[0]; + this.modelStats[0] += stats.getCount(); + this.mrr += stats.getSum(); + if (this.modelStats[0] / this.MODEL_PRINT_INTERVAL > prevCount / this.MODEL_PRINT_INTERVAL + && this.modelStats[1] != 0) { + System.out.printf("Predicting: %dK tokens processed in %ds, avg. MRR: %.4f\n", + Math.round(this.modelStats[0]/1e3), + (System.currentTimeMillis() + this.modelStats[1])/1000, this.mrr/this.modelStats[0]); + } + } + public List toProb(List> probConfs) { return probConfs.stream().map(this::toProb).collect(Collectors.toList()); } @@ -381,4 +455,11 @@ private DoubleSummaryStatistics getFileStats(Stream>> fileProb .mapToDouble(p -> p).summaryStatistics(); } } + + public DoubleSummaryStatistics getCompletionStats(List> completions) { + List> MRRs = completions.stream() + .map(l -> l.stream().map(c -> toMRR(c.getRank()))) + .map(l -> l.collect(Collectors.toList())).collect(Collectors.toList()); + return getFileStats(Stream.of(MRRs)); + } }