Skip to content
Browse files

first usable version

  • Loading branch information...
1 parent 3447372 commit f06ce444543cab19d4cc30384b6848c53155339b @avibryant committed
View
35 README
@@ -0,0 +1,35 @@
+Simple similarity service based on chapter 3 of http://infolab.stanford.edu/~ullman/mmds.html.
+Backed by Redis but easily adaptable to other stores.
+
+Very, very, very early days.
+
+Building:
+
+mvn package
+
+The test scripts described below assume you have a redis going on localhost.
+
+Initializing:
+
+scripts/tool Initialize <bands> <rows> <minCount>
+
+- only run this *once*, before you load in any data
+- bands and rows define the similarity threshold, as outlined in the book
+- minCount is how many items a set has to have before it is considered for similarity with others
+- a good default set of values:
+ scripts/tool Initialize 5 25 5
+
+Loading:
+
+scripts/tool Load <file.tsv>
+
+- run as many times as you like with different data
+- expects two tab-separated columns: set key, then item key
+
+Dumping:
+
+scripts/tool Dump
+
+- this will dump out all pairs of similar sets
+- format is TSV, columns are similarity, size of intersection, item1, item2
+
View
80 pom.xml
@@ -0,0 +1,80 @@
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <groupId>com.avibryant.redsim</groupId>
+ <artifactId>redsim</artifactId>
+ <packaging>jar</packaging>
+ <version>1.0-SNAPSHOT</version>
+ <name>redsim</name>
+ <url>http://maven.apache.org</url>
+ <repositories>
+ <repository>
+ <id>scala-tools.org</id>
+ <name>Scala-tools Maven2 Repository</name>
+ <url>http://scala-tools.org/repo-releases</url>
+ </repository>
+ </repositories>
+ <pluginRepositories>
+ <pluginRepository>
+ <id>scala-tools.org</id>
+ <name>Scala-tools Maven2 Repository</name>
+ <url>http://scala-tools.org/repo-releases</url>
+ </pluginRepository>
+ </pluginRepositories>
+ <dependencies>
+ <dependency>
+ <groupId>redis.clients</groupId>
+ <artifactId>jedis</artifactId>
+ <version>2.0.0</version>
+ <type>jar</type>
+ <scope>compile</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.scala-lang</groupId>
+ <artifactId>scala-library</artifactId>
+ <version>2.9.1</version>
+ </dependency>
+ </dependencies>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.scala-tools</groupId>
+ <artifactId>maven-scala-plugin</artifactId>
+ <executions>
+ <execution>
+ <goals>
+ <goal>compile</goal>
+ </goals>
+ </execution>
+ </executions>
+ <configuration>
+ <sourceDir>src/main/scala</sourceDir>
+ <jvmArgs>
+ <jvmArg>-Xms64m</jvmArg>
+ <jvmArg>-Xmx1024m</jvmArg>
+ </jvmArgs>
+ </configuration>
+ </plugin>
+ <plugin>
+ <artifactId>maven-assembly-plugin</artifactId>
+ <configuration>
+ <appendAssemblyId>false</appendAssemblyId>
+ <descriptorRefs>
+ <descriptorRef>jar-with-dependencies</descriptorRef>
+ </descriptorRefs>
+ </configuration>
+ <executions>
+ <execution>
+ <id>make-assembly</id>
+ <phase>package</phase>
+ <goals>
+ <goal>assembly</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+
+ </plugins>
+ </build>
+
+</project>
View
20 scripts/bx.rb
@@ -0,0 +1,20 @@
+titles = {}
+books = File.readlines("/Users/avi/Downloads/bx/BX-Books.csv")
+books.shift
+books.each do |line|
+ begin
+ isbn, title, *rest = line.split(";")
+ titles[isbn] = eval(title).split("(")[0].downcase
+ rescue SyntaxError
+ end
+end
+
+ratings = File.readlines("/Users/avi/Downloads/bx/BX-Book-Ratings.csv")
+ratings.shift
+ratings.each do |line|
+ user, isbn, rating = line.split(";")
+ if t = titles[isbn]
+ puts [t, eval(user)].join("\t")
+ end
+end
+
View
2 scripts/tool
@@ -0,0 +1,2 @@
+#!/bin/sh
+java -cp target/redsim-1.0-SNAPSHOT.jar com.avibryant.redsim.tools.$@
View
20 src/main/scala/com/avibryant/redsim/Connection.scala
@@ -0,0 +1,20 @@
+package com.avibryant.redsim
+
+abstract class Hashing {
+ def hashFor(str : String, seed : Int) : Int
+ def bucketFor(hashes : Array[Int]) : Int
+}
+
+abstract class Connection {
+ def readConfiguration : Configuration
+ def writeConfiguration(config : Configuration) : Unit
+ def readSignature(key : String)(implicit config : Configuration) : Option[Signature]
+ def lockSignature(key : String)(fn : Option[Signature] => Unit)(implicit config : Configuration)
+ def writeSignature(key : String, sig : Signature) : Unit
+ def writeBuckets(key : String, sig : Signature) : Unit
+ def updateSignature(key : String, oldSig : Signature, newSig : Signature) : Unit
+ def updateBuckets(key : String, oldSig : Signature, newSig : Signature) : Unit
+ def readBucket(band : Int, bucket : Int) : List[String]
+ def bucketsWithCandidates(band : Int) : List[String]
+ implicit def hashing : Hashing
+}
View
118 src/main/scala/com/avibryant/redsim/Redis.scala
@@ -0,0 +1,118 @@
+package com.avibryant.redsim
+import redis.clients.jedis._
+import redis.clients.util.MurmurHash
+
+class RedisHashing extends Hashing {
+ def hashFor(str : String, seed : Int) = MurmurHash.hash(str.getBytes, seed)
+ def bucketFor(hashes : Array[Int]) = hashFor(hashes.mkString, hashes(0))
+}
+
+class RedisConnection(jedis : Jedis) extends Connection {
+ def this() = this(new Jedis("localhost"))
+
+ var transaction : Transaction = null
+
+ def readConfiguration = {
+ readInts(configKey) match {
+ case(Array(numRows, numBands, minCount, seeds @ _*)) =>
+ new Configuration(numRows, numBands, minCount, seeds.toArray)
+ case _ => sys.error("Could not find configuration")
+ }
+ }
+
+ def writeConfiguration(config : Configuration) {
+ writeInts(configKey, Array(
+ config.numRows,
+ config.numBands,
+ config.minCount) ++
+ config.seeds)
+ }
+
+ def readSignature(key : String)(implicit config : Configuration) = {
+ readInts(sigKey(key)) match {
+ case(Array(count, values @ _*)) =>
+ Some(new Signature(count, values.toArray))
+ case _ => None
+ }
+ }
+
+ def lockSignature(key : String)(fn : Option[Signature] => Unit)(implicit config : Configuration) {
+ jedis.watch(key)
+ val sig = readSignature(key)
+ //not thread safe
+ transaction = jedis.multi()
+ fn(sig)
+ transaction.exec()
+ transaction = null
+ }
+
+ def writeSignature(key : String, sig : Signature) {
+ checkInTransaction
+ writeInts(transaction, sigKey(key), Array(sig.count) ++ sig.values)
+ }
+
+ def writeBuckets(key : String, sig : Signature) {
+ checkInTransaction
+ sig.buckets.zipWithIndex.foreach{
+ case (bucket, band) => {
+ transaction.zadd(bandKey(band), bucket, key)
+ updateBucketCount(band, bucket, 1)
+ }
+ }
+ }
+
+ def updateSignature(key : String, oldSig : Signature, newSig : Signature) {
+ writeSignature(key, newSig)
+ }
+
+ def updateBuckets(key : String, oldSig : Signature, newSig : Signature) {
+ checkInTransaction
+ newSig.updatedBuckets(oldSig).foreach{
+ case ((b, a), i) => {
+ transaction.zadd(bandKey(i), a, key)
+ updateBucketCount(i, b, -1)
+ updateBucketCount(i, a, 1)
+ }
+ }
+ }
+
+ def readBucket(band : Int, bucket : Int) = {
+ jedis.zrangeByScore(bandKey(band), bucket, bucket).toArray(Array[String]()).toList
+ }
+
+ def bucketsWithCandidates(band : Int) = {
+ jedis.zrangeByScore(bandCountKey(band), 2, 2000).toArray(Array[String]()).toList
+ }
+
+ implicit val hashing = new RedisHashing
+
+ private def updateBucketCount(band : Int, bucket : Int, sign : Int) {
+ transaction.zincrby(bandCountKey(band), sign.toDouble, bucket.toString)
+ }
+
+ private def configKey = prefix + "config"
+ private def sigKey(key : String) = prefix + "sig:" + key
+ private def bandKey(band : Int) = prefix + "band:" + band
+ private def bandCountKey(band : Int) = prefix + "bandCount:" + band
+ private def prefix = "redsim:"
+ private def checkInTransaction {
+ if(transaction == null)
+ sys.error("Not in a transaction")
+ }
+ private def readInts(key : String) = {
+ val str = jedis.get(key)
+ if(str == null)
+ Array[Int]()
+ else
+ str.split(":").map{_.toInt}
+ }
+
+ private def writeInts(txn : Transaction, key : String, ints : Array[Int]) {
+ txn.set(key, ints.mkString(":"))
+ }
+
+ private def writeInts(key : String, ints : Array[Int]) {
+ jedis.set(key, ints.mkString(":"))
+ }
+
+}
View
82 src/main/scala/com/avibryant/redsim/Redsim.scala
@@ -0,0 +1,82 @@
+package com.avibryant.redsim
+
+case class Configuration(
+ val numRows : Int,
+ val numBands : Int,
+ val minCount : Int,
+ val seeds : Array[Int]) {
+
+ def this(nR : Int, nB : Int, mC : Int) = {
+ this(nR, nB, mC, (1 to (nR * nB)).map{i => scala.util.Random.nextInt}.toArray)
+ }
+
+ def estimatedThreshold = math.pow(1.0/numBands, 1.0/numRows)
+}
+
+class Redsim(val conn : Connection) {
+ implicit lazy val config = conn.readConfiguration
+ implicit val hashing = conn.hashing
+
+ def initialize(config : Configuration) {
+ conn.writeConfiguration(config)
+ }
+
+ def addItems(key : String, items : List[String]) {
+ addSignature(key, items.map{new Signature(_)}.reduce{(a,b) => a+b})
+ }
+
+ def addSignature(key : String, sig : Signature) {
+ conn.lockSignature(key) {
+ case None => {
+ conn.writeSignature(key, sig)
+ if(sig.count >= config.minCount)
+ conn.writeBuckets(key, sig)
+ }
+ case Some(oldSig) => {
+ val newSig = sig + oldSig
+ conn.updateSignature(key, oldSig, newSig)
+ if(newSig.count >= config.minCount) {
+ if(oldSig.count >= config.minCount)
+ conn.updateBuckets(key, oldSig, newSig)
+ else
+ conn.writeBuckets(key, newSig)
+ }
+ }
+ }
+ }
+
+ def allSimilarCandidates = {
+ val bucketsOfKeys = (0 until config.numBands).flatMap{
+ band => conn.bucketsWithCandidates(band).map{
+ bucket => conn.readBucket(band, bucket.toInt)
+ }
+ }
+
+ bucketsOfKeys.flatMap{
+ bucket => bucket.zipWithIndex.flatMap {
+ case (key1, index) =>
+ bucket.slice(index + 1, bucket.size).map {
+ key2 =>
+ (key1, key2)
+ }
+ }
+ }.toSet
+ }
+
+ def candidatesSimilarTo(key : String) = {
+ conn.readSignature(key) match {
+ case Some(sig) => sig.buckets.zipWithIndex.flatMap{
+ case (bucket, band) => conn.readBucket(band, bucket)
+ }.toSet - key
+ case None => sys.error("Could not find " + key)
+ }
+ }
+
+ def similarity(key1 : String, key2 : String) = {
+ (conn.readSignature(key1), conn.readSignature(key2)) match {
+ case (Some(sig1), Some(sig2)) => sig1.similarityWith(sig2)
+ case _ => sys.error("Could not find keys")
+ }
+ }
+}
+
View
40 src/main/scala/com/avibryant/redsim/Signature.scala
@@ -0,0 +1,40 @@
+package com.avibryant.redsim
+
+case class Similarity(val jaccard : Float, val leftCount : Int, val rightCount : Int) {
+ def intersectionSize = ((leftCount + rightCount) / (1 + jaccard)).toInt
+ def unionSize = (leftCount + rightCount) - intersectionSize
+ def cosine = unionSize.toFloat / (math.sqrt(leftCount) * math.sqrt(rightCount))
+}
+
+case class Signature(
+ val count : Int,
+ val values : Array[Int])(
+ implicit val config : Configuration,
+ implicit val hashing : Hashing) {
+
+ def this(str : String)(implicit config : Configuration, hashing : Hashing) = {
+ this(1, config.seeds.map{i => hashing.hashFor(str, i)})
+ }
+
+ def +(sig : Signature) = {
+ val minValues = values.zip(sig.values).map{case (l,r) => l.min(r)}
+ new Signature(count + sig.count, minValues)
+ }
+
+ def similarityWith(sig : Signature) = {
+ val matching = values.size - updatedValues(sig).size
+ val jaccard = matching.toFloat / values.size
+ new Similarity(jaccard, count, sig.count)
+ }
+
+ def buckets = {
+ values.grouped(config.numRows).map{a => hashing.bucketFor(a)}.toArray
+ }
+
+ def updatedBuckets(sig : Signature) = diffWithIndex(sig.buckets, buckets)
+ def updatedValues(sig : Signature) = diffWithIndex(sig.values, values)
+
+ private def diffWithIndex(before : Array[Int], after : Array[Int]) = {
+ before.zip(after).zipWithIndex.filter{case ((b,a), i) => b != a}
+ }
+}
View
39 src/main/scala/com/avibryant/redsim/Tools.scala
@@ -0,0 +1,39 @@
+package com.avibryant.redsim.tools
+import com.avibryant.redsim._
+
+object Initialize extends App {
+ val bands = args(0).toInt
+ val rows = args(1).toInt
+ val minCount = args(2).toInt
+ val rs = new Redsim(new RedisConnection())
+
+ rs.initialize(new Configuration(bands, rows, minCount))
+}
+
+object Load extends App {
+ val filename = args(0)
+ val lines = scala.io.Source.fromFile(filename).getLines
+ val rs = new Redsim(new RedisConnection())
+ var counter = 0
+
+ lines.
+ foreach{line =>
+ val parts = line.split("\t")
+ val set = parts(0)
+ val item = parts(1)
+ rs.addItems(set, List(item))
+ counter += 1
+ if(counter % 100 == 0)
+ System.err.println(counter)
+ }
+}
+
+object Dump extends App {
+ val rs = new Redsim(new RedisConnection())
+ rs.allSimilarCandidates.foreach {
+ case (left, right) =>
+ val sim = rs.similarity(left, right)
+ val parts = List(sim.jaccard, sim.intersectionSize, left, right)
+ println(parts.mkString("\t"))
+ }
+}

0 comments on commit f06ce44

Please sign in to comment.
Something went wrong with that request. Please try again.