Skip to content

Commit

Permalink
Disambiguation integrated.
Browse files Browse the repository at this point in the history
  • Loading branch information
MikaelMayer committed Sep 4, 2017
1 parent 93d8ab3 commit 2d3e447
Show file tree
Hide file tree
Showing 3 changed files with 358 additions and 4 deletions.
116 changes: 113 additions & 3 deletions src/main/scala/ch/epfl/lara/synthesis/stringsolver/ProgramSet.scala
Expand Up @@ -62,15 +62,22 @@ object ProgramSet {
def flatMap[T](f: A => GenTraversableOnce[T]) = map(f).flatten
private var cacheBest: Option[Any] = None
def takeBest: A = { if(cacheBest.isEmpty) cacheBest = Some(takeBestRaw); cacheBest.get.asInstanceOf[A]}
private var cacheNBest: Map[Int, Seq[(Int, Any)]] = Map()
def takeNBest(n: Int): Seq[(Int, A)] = { if(cacheNBest.isEmpty) cacheNBest += n-> takeNBestRaw(n: Int); cacheNBest(n).asInstanceOf[Seq[(Int, A)]]}
//def takeBestUsing(w: Identifier): A = takeBest
def takeBestRaw: A
def takeNBestRaw(n: Int): Seq[(Int, A)]
override def isEmpty: Boolean = this == SEmpty || ProgramSet.sizePrograms(this) == 0
def sizePrograms = ProgramSet.sizePrograms(this)
override def toIterable: Iterable[A] = map((i: A) =>i)
override def toString = this.getClass().getName().replaceAll(".*\\$","")+"("+self.productIterator.mkString(",")+")"
var weightMalus = 0
def examplePosition = 0 // Set only in SDag
}


def weighted[A <: Program](p: A): (Int, A) = (weight(p), p)

/**
* Set of switch expressions described in Programs.scala
*/
Expand All @@ -82,6 +89,10 @@ object ProgramSet {
for(t <- combinations(s map _2)) f(Switch(s map _1 zip t))
}
def takeBestRaw = Switch((s map _1) zip ((s map _2) map (_.takeBest)))
def takeNBestRaw(n: Int) = StreamUtils.cartesianProduct(s.map(_2).map(_.takeNBestRaw(n).toStream)).map{ x=>
val (scores, progs) = x.unzip
(scores.sum, Switch(s map _1 zip progs))
}.sortBy(_1).take(n)
//override def takeBestUsing(w: Identifier) = Switch((s map _1) zip ((s map _2) map (_.takeBestUsing(w))))
}
/**
Expand Down Expand Up @@ -114,7 +125,6 @@ object ProgramSet {
}
def takeBestRaw = {
var minProg = Map[Node, List[AtomicExpr]]()
var weights = Map[Node, Int]()
var nodesToVisit = new PriorityQueue[(Int, List[AtomicExpr], Node)]()(Ordering.by[(Int, List[AtomicExpr], Node), Int](e => e._1))
nodesToVisit.enqueue((0, Nil, ns))
while(!(minProg contains nt) && !nodesToVisit.isEmpty) {
Expand All @@ -134,6 +144,43 @@ object ProgramSet {
}
Concatenate(minProg(nt))//TODO : alternative.
}

def Nneighbors(quantity: Int, n: Node, prev_weight: Int): Option[(Seq[(Int, AtomicExpr)], Node)] = {
ξ.collectFirst[(Node, Node)]{ case e@(start, end) if start == n => e } map { e =>
val versions = W.getOrElse(e, Set.empty)
val possibilities = for(atomic <- versions.flatMap(_.takeNBest(quantity)).toList.sortBy(_1).take(quantity)) yield {
(-atomic._1 + prev_weight, atomic._2)
}
(possibilities, e._2)
}
}
def takeNBestRaw(quantity: Int): Seq[(Int, TraceExpr)] = {
var minProg = Map[Node, Seq[(Int, List[AtomicExpr])]]()
var nodesToVisit = new PriorityQueue[((Int, List[AtomicExpr]), Node)]()(Ordering.by[((Int, List[AtomicExpr]), Node), Int](e => e._1._1))
nodesToVisit.enqueue(((0, Nil), ns))
while(!(minProg.getOrElse(nt, Nil).length >= quantity) && !nodesToVisit.isEmpty) {
val ((weight, path), node) = nodesToVisit.dequeue() // Takes the first node with the minimal path.
minProg += node -> (((weight, path)) +: minProg.getOrElse(node, Nil))
for(e@(newWeightsAtomics, newNode) <- Nneighbors(quantity, node, weight)) {
// (newweight, newAtomic
for((newweight, newAtomic) <- newWeightsAtomics)
{
val alreadyLookingFor = nodesToVisit.toStream.filter{ case ((w, p), n) => n == newNode }
//val shouldBeAdded = alreadyLookingFor.lengthCompare(quantity) < 0 || newweight > alreadyLookingFor(quantity - 1)._1._1
//if(shouldBeAdded) { // We keep only the best quantity.
nodesToVisit.enqueue(((newweight, path ++ List[AtomicExpr](newAtomic)), newNode))
var i = 0
nodesToVisit = nodesToVisit.filterNot {
i += 1
_._2 == newNode && i >= quantity // We still keep the first quantity best, we remove the rest.
}
//}

}
}
}
minProg(nt).map(x => (x._1, Concatenate(x._2)))
}
/*def neighborsUsing(n: Node, n_weight: Int, w: Identifier): Set[(Int, AtomicExpr, Node)] = {
for(e <- ξ if e._1 == n;
versions = W.getOrElse(e, Set.empty);
Expand Down Expand Up @@ -210,10 +257,11 @@ object ProgramSet {
for(prog <- e) f(Loop(i, prog, separator))
}
def takeBestRaw = Loop(i, e.takeBest, separator)//.withAlternative(this.toIterable)
override def takeNBestRaw(n: Int): Seq[(Int, AtomicExpr)] = {
e.takeNBest(n).map(x => weighted(Loop(i, x._2, separator)))
}
}



/**
* Set of SubStr expressions described in Programs.scala
*/
Expand All @@ -227,6 +275,15 @@ object ProgramSet {
def takeBestRaw = SubStr(vi, p1.toList.map(_.takeBest).sortBy(weight(_)(true)).head, p2.toList.map(_.takeBest).sortBy(weight(_)(false)).head.withWeightMalus(this.weightMalus), methods.takeBest)//.withAlternative(this.toIterable)
private var corresponding_string: (String, String, Int, Int) = ("", "", 0, -1)
def setPos(from: String, s: String, start: Int, end: Int) = corresponding_string = (from, s, start, end)

override def takeNBestRaw(n: Int): Seq[(Int, AtomicExpr)] = {
val left = p1.flatMap(_.takeNBest(n)).toSeq.sortBy(_._1).take(n)
val right = p2.flatMap(_.takeNBest(n)).toSeq.sortBy(_._1).take(n)
val method = methods.takeNBest(n)
StreamUtils.cartesianProduct(Seq(left.toStream, right.toStream, method.toStream)).take(n).map{
case Seq((leftScore, l), (rightScore, r), (mScore, m)) => weighted(SubStr(vi, l.asInstanceOf[Position], r.asInstanceOf[Position], m.asInstanceOf[SubStrFlag]))
} take n
}
}

def isCommonSeparator(s: String) = s match {
Expand All @@ -247,6 +304,14 @@ object ProgramSet {
def takeBestRaw = SpecialConversion(s.takeBest.asInstanceOf[SubStr], converters.toList.sortBy(weight(_)(true)).head)
private var corresponding_string: (String, String, Int, Int) = ("", "", 0, -1)
def setPos(from: String, s: String, start: Int, end: Int) = corresponding_string = (from, s, start, end)

override def takeNBestRaw(n: Int): Seq[(Int, AtomicExpr)] = {
val sbest = s.takeNBest(n).toStream
val converterBest = converters.toList.map(x => (-weight(x)(true), x)).sortBy(_1).take(n).toStream
StreamUtils.cartesianProduct(Seq(sbest, converterBest)).take(n).map {
case Seq((sScore, s), (convScore, converter)) => weighted(SpecialConversion(s.asInstanceOf[SubStr], converter.asInstanceOf[SpecialConverter]))
}
}
}


Expand All @@ -273,6 +338,12 @@ object ProgramSet {
}
def takeBestRaw = if(step == 0) IntLiteral(start) else IntLiteral(start+step*((max-start)/step))
def apply(elem: Int): Boolean = elem >= start && elem <= max && (step == 0 && start == elem || step != 0 && (elem-start)%step == 0)

override def takeNBestRaw(n: Int): Seq[(Int, IntLiteral)] = {
if(step == 0) Seq((0, IntLiteral(start))) else {
(start to max by step).reverse.zipWithIndex.map{ x => weighted(IntLiteral(x._1))} take n
}
}
}
/*case class SAnyInt(default: Int) extends SInt {
def map[T](f: IntLiteral => T): Stream[T] = {
Expand Down Expand Up @@ -305,6 +376,11 @@ object ProgramSet {
for(pp1: AtomicExpr <- a; l <- length) f(NumberMap(pp1.asInstanceOf[SubStr], l.k, offset))
}
def takeBestRaw = NumberMap(a.takeBest.asInstanceOf[SubStr], length.takeBest.k, offset)//.withAlternative(this.toIterable)
override def takeNBestRaw(n: Int): Seq[(Int, AtomicExpr)] = {
StreamUtils.cartesianProduct(Seq(a.takeNBest(n).toStream, length.takeNBest(n).toStream)).take(n) map {
case Seq((aScore, a), (lengthScore, length)) => weighted(NumberMap(a.asInstanceOf[SubStr], length.asInstanceOf[IntLiteral].k, offset))
}
}
}

/**
Expand Down Expand Up @@ -340,6 +416,12 @@ object ProgramSet {
for(l <- length.toStream; s: IntLiteral <- starts; step <- if(count == 0) Stream.from(1) else List((index - s.k)/count)) f(Counter(l.k, s.k, step))
}
def takeBestRaw = Counter(length.takeBest.k, starts.takeBest.k, if(count == 0) 1 else (index - starts.takeBest.k)/count)//.withAlternative(this.toIterable)
override def takeNBestRaw(n: Int): Seq[(Int, AtomicExpr)] = {
StreamUtils.cartesianProduct(Seq(length.takeNBest(n).toStream, starts.takeNBest(n).toStream)).take(n) map {
case Seq((lengthScore, length), (startsScore, start)) =>
weighted(Counter(length.k, start.k, if(count == 0) 1 else (index - start.k)/count))
}
}
}

/**
Expand All @@ -353,6 +435,8 @@ object ProgramSet {
f(ConstStr(s))
}
def takeBestRaw = ConstStr(s)

override def takeNBestRaw(n: Int): Seq[(Int, AtomicExpr)] = Seq(weighted(takeBest))
}

type SPosition = ProgramSet[Position]
Expand All @@ -367,6 +451,8 @@ object ProgramSet {
f(CPos(k))
}
def takeBestRaw = CPos(k)

override def takeNBestRaw(n: Int): Seq[(Int, Position)] = Seq(weighted(takeBest))
}
/**
* Set of regexp positions described in Programs.scala
Expand All @@ -380,6 +466,14 @@ object ProgramSet {
}
def takeBestRaw = Pos(r1.takeBest, r2.takeBest, c.toList.sortBy(weight).head)
//var index = 0 // Index at which this position was computed
override def takeNBestRaw(n: Int): Seq[(Int, Position)] = {
StreamUtils.cartesianProduct(Seq(
r1.takeNBest(n).toStream,
r2.takeNBest(n).toStream,
c.toList.sortBy(weight).toStream)) map {
case Seq((_, rr1), (_, rr2), cc: IntLiteral) => weighted(Pos(rr1.asInstanceOf[RegExp], rr2.asInstanceOf[RegExp], cc.k))
}
}
}

type SRegExp = ProgramSet[RegExp]
Expand All @@ -395,6 +489,12 @@ object ProgramSet {
for(t <- combinations(s)) f(TokenSeq(t))
}
def takeBestRaw = TokenSeq(s map (_.takeBest))

override def takeNBestRaw(n: Int): Seq[(Int, RegExp)] = {
StreamUtils.cartesianProduct(s.map(_.takeNBest(n).toStream)).take(n) map {
x => weighted(TokenSeq(x.map(_2)))
}
}
}

/**
Expand All @@ -407,6 +507,8 @@ object ProgramSet {
def iterator = Nil.toIterator
override def toIterable = Nil
override def isEmpty = true

override def takeNBestRaw(n: Int): Seq[(Int, Nothing)] = Seq()
}

type SIntegerExpr = Set[IntegerExpr]
Expand Down Expand Up @@ -466,6 +568,10 @@ object ProgramSet {
def takeBestRaw = map((i: Token) => i).toList.sortBy(weight).head
def contains(t: Token): Boolean = ((1L << l.indexOf(t)) & mask) != 0
override def toString = "SToken("+this.toList.mkString(",")+")"

override def takeNBestRaw(n: Int): Seq[(Int, Token)] = {
map((i: Token) => i).toList.sortBy(weight).take(n).map(weighted)
}
}

/**
Expand Down Expand Up @@ -501,6 +607,10 @@ object ProgramSet {
override def isEmpty = mask == 0
def takeBestRaw = map((i: SubStrFlag) => i).toList.sortBy(weight).head
override def toString = "SSubStrFlag("+this.toList.mkString(",")+")"

override def takeNBestRaw(n: Int): Seq[(Int, SubStrFlag)] = {
map((i: SubStrFlag) => i).toList.sortBy(weight).take(n).map(weighted)
}
}


Expand Down

0 comments on commit 2d3e447

Please sign in to comment.