Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Merge branch 'master' of github.com:MrChrisJohnson/CollabStream

  • Loading branch information...
commit e048e29d8ee6bca0416aef9c8368db8f79e9adce 2 parents e4923ba + b7710ae
@MrChrisJohnson authored
View
7 commands.txt
@@ -10,4 +10,9 @@ mvn exec:java -Dexec.mainClass=collabstream.streaming.StreamingDSGD -Dexec.args=
mvn compile exec:java -Dexec.mainClass=collabstream.streaming.StreamingDSGD -Dexec.args='local 17 28 data/input/test.dat data/output/test.user data/output/test.item' -Dexec.classpathScope=compile
mvn compile exec:java -Dexec.mainClass=collabstream.streaming.StreamingDSGD -Dexec.args='local 17 28 data/input/test.dat data/output/test.user data/output/test.item' -Dexec.classpathScope=compile | grep '########'
-mvn exec:java -Dexec.mainClass=collabstream.streaming.StreamingDSGD -Dexec.args='local 9 2648862 data/input/netflix_tr_head10000_shuf data/output/netflix_tr_head10000_shuf.user data/output/netflix_tr_head10000_shuf.item' -Dexec.classpathScope=compile
+mvn exec:java -Dexec.mainClass=collabstream.streaming.TestPredictions -Dexec.args='4 5 3 data/input/predtest data/output/predtest.user data/output/predtest.item' -Dexec.classpathScope=compile
+
+mvn exec:java -Dexec.mainClass=collabstream.streaming.StreamingDSGD -Dexec.args='local 6041 3953 data/input/MovieLens/ml_tr_rand.txt data/output/MovieLens/ml.user data/output/MovieLens/ml.item' -Dexec.classpathScope=compile
+
+mvn exec:java -Dexec.mainClass=collabstream.streaming.StreamingDSGD -Dexec.args='local 6041 3953 data/input/MovieLens/ml_100k data/output/MovieLens/ml_100k.user data/output/MovieLens/ml_100k.item' -Dexec.classpathScope=compile
+mvn compile exec:java -Dexec.mainClass=collabstream.streaming.TestPredictions -Dexec.args='6041 3953 10 data/input/MovieLens/ml1m_te_rb.dat data/output/MovieLens/ml_100k.user data/output/MovieLens/ml_100k.item' -Dexec.classpathScope=compile
View
2  pom.xml
@@ -32,7 +32,7 @@
<groupId>storm</groupId>
<artifactId>storm</artifactId>
<version>0.5.4</version>
- <scope>compile</scope>
+ <scope>provided</scope>
</dependency>
<dependency>
<groupId>log4j</groupId>
View
5 src/main/java/collabstream/streaming/Configuration.java
@@ -8,6 +8,7 @@
public final float userPenalty, itemPenalty, initialStepSize;
public final int maxTrainingIters;
public final String inputFilename, userOutputFilename, itemOutputFilename;
+ public final long inputDelay; // delay between sending two training examples in milliseconds
public final boolean debug;
private final int smallUserBlockSize, smallItemBlockSize;
@@ -17,7 +18,8 @@
public Configuration(int numUsers, int numItems, int numLatent, int numUserBlocks, int numItemBlocks,
float userPenalty, float itemPenalty, float initialStepSize, int maxTrainingIters,
- String inputFilename, String userOutputFilename, String itemOutputFilename, boolean debug) {
+ String inputFilename, String userOutputFilename, String itemOutputFilename,
+ long inputDelay, boolean debug) {
this.numUsers = numUsers;
this.numItems = numItems;
this.numLatent = numLatent;
@@ -30,6 +32,7 @@ public Configuration(int numUsers, int numItems, int numLatent, int numUserBlock
this.inputFilename = inputFilename;
this.userOutputFilename = userOutputFilename;
this.itemOutputFilename = itemOutputFilename;
+ this.inputDelay = inputDelay;
this.debug = debug;
smallUserBlockSize = numUsers / numUserBlocks;
View
29 src/main/java/collabstream/streaming/Master.java
@@ -35,7 +35,7 @@
private Queue<BlockPair> userBlockQueue = new LinkedList<BlockPair>();
private Queue<BlockPair> itemBlockQueue = new LinkedList<BlockPair>();
private boolean endOfData = false;
- private long startTime, outputStartTime;
+ private long startTime, outputStartTime = 0;
private final Random random = new Random();
public Master(Configuration config) {
@@ -58,7 +58,7 @@ public void execute(Tuple tuple) {
if (config.debug && msgType != END_OF_DATA && msgType != PROCESS_BLOCK_FIN) {
System.out.println("######## Master.execute: " + msgType + " " + tuple.getValue(1));
}
- TrainingExample ex;
+ TrainingExample ex, latest;
BlockPair bp, head;
switch (msgType) {
@@ -74,7 +74,11 @@ public void execute(Tuple tuple) {
ex = (TrainingExample)tuple.getValue(1);
int userBlockIdx = config.getUserBlockIdx(ex.userId);
int itemBlockIdx = config.getItemBlockIdx(ex.itemId);
- latestExample[userBlockIdx][itemBlockIdx] = ex;
+
+ latest = latestExample[userBlockIdx][itemBlockIdx];
+ if (latest == null || latest.timestamp < ex.timestamp) {
+ latestExample[userBlockIdx][itemBlockIdx] = ex;
+ }
bp = blockPair[userBlockIdx][itemBlockIdx];
if (bp == null) {
@@ -83,8 +87,8 @@ public void execute(Tuple tuple) {
freeSet.add(bp);
}
+ collector.emit(tuple, new Values(TRAINING_EXAMPLE, null, ex, userBlockIdx));
collector.ack(tuple);
- collector.emit(new Values(TRAINING_EXAMPLE, null, ex, userBlockIdx));
distributeWork();
break;
case PROCESS_BLOCK_FIN:
@@ -94,7 +98,7 @@ public void execute(Tuple tuple) {
System.out.println("######## Master.execute: " + msgType + " " + bp + " " + ex);
}
- TrainingExample latest = latestExample[bp.userBlockIdx][bp.itemBlockIdx];
+ latest = latestExample[bp.userBlockIdx][bp.itemBlockIdx];
if (latest.timestamp == ex.timestamp) {
latest.numTrainingIters = ex.numTrainingIters;
if (endOfData && latest.numTrainingIters >= config.maxTrainingIters) {
@@ -117,7 +121,7 @@ public void execute(Tuple tuple) {
float[][] userBlock = (float[][])tuple.getValue(2);
head = userBlockQueue.remove();
if (!head.equals(bp)) {
- System.err.println("######## Master.execute: Expected " + head + " for user block. Received " + bp);
+ throw new RuntimeException("Expected " + head + ", but received " + bp + " for " + USER_BLOCK);
}
writeUserBlock(bp.userBlockIdx, userBlock);
requestNextUserBlock();
@@ -127,7 +131,7 @@ public void execute(Tuple tuple) {
float[][] itemBlock = (float[][])tuple.getValue(2);
head = itemBlockQueue.remove();
if (!head.equals(bp)) {
- System.err.println("######## Master.execute: Expected " + head + " for item block. Received " + bp);
+ throw new RuntimeException("Expected " + head + ", but received " + bp + " for " + ITEM_BLOCK);
}
writeItemBlock(bp.itemBlockIdx, itemBlock);
requestNextItemBlock();
@@ -188,6 +192,7 @@ private void distributeWork() {
private void startOutput() {
try {
+ if (outputStartTime > 0) return;
outputStartTime = System.currentTimeMillis();
System.out.printf("######## Training finished: %1$tY-%1$tb-%1$td %1$tT %tZ\n", outputStartTime);
System.out.println("######## Elapsed training time: "
@@ -228,18 +233,24 @@ private void startOutput() {
private void writeUserBlock(int userBlockIdx, float[][] userBlock) {
int userBlockStart = config.getUserBlockStart(userBlockIdx);
for (int i = 0; i < userBlock.length; ++i) {
+ userOutput.print(userBlockStart + i);
for (int k = 0; k < config.numLatent; ++k) {
- userOutput.printf("%d %d %f\n", userBlockStart + i, k, userBlock[i][k]);
+ userOutput.print(' ');
+ userOutput.print(userBlock[i][k]);
}
+ userOutput.println();
}
}
private void writeItemBlock(int itemBlockIdx, float[][] itemBlock) {
int itemBlockStart = config.getItemBlockStart(itemBlockIdx);
for (int j = 0; j < itemBlock.length; ++j) {
+ itemOutput.print(itemBlockStart + j);
for (int k = 0; k < config.numLatent; ++k) {
- itemOutput.printf("%d %d %f\n", itemBlockStart + j, k, itemBlock[j][k]);
+ itemOutput.print(' ');
+ itemOutput.print(itemBlock[j][k]);
}
+ itemOutput.println();
}
}
View
16 src/main/java/collabstream/streaming/RatingsSource.java
@@ -14,6 +14,7 @@
import backtype.storm.topology.OutputFieldsDeclarer;
import backtype.storm.tuple.Fields;
import backtype.storm.tuple.Values;
+import backtype.storm.utils.Utils;
import static collabstream.streaming.MsgType.*;
@@ -50,7 +51,15 @@ public void ack(Object msgId) {
}
public void fail(Object msgId) {
- System.err.println("######## RatingsSource.fail: " + msgId);
+ if (config.debug) {
+ System.err.println("######## RatingsSource.fail: Resending " + msgId);
+ }
+ if (msgId == END_OF_DATA) {
+ collector.emit(new Values(END_OF_DATA, null), END_OF_DATA);
+ } else {
+ TrainingExample ex = (TrainingExample)msgId;
+ collector.emit(new Values(TRAINING_EXAMPLE, ex), ex);
+ }
}
public void nextTuple() {
@@ -72,8 +81,13 @@ public void nextTuple() {
int userId = Integer.parseInt(token[0]);
int itemId = Integer.parseInt(token[1]);
float rating = Float.parseFloat(token[2]);
+
TrainingExample ex = new TrainingExample(sequenceNum++, userId, itemId, rating);
collector.emit(new Values(TRAINING_EXAMPLE, ex), ex);
+
+ if (config.inputDelay > 0) {
+ Utils.sleep(config.inputDelay);
+ }
} catch (Exception e) {
System.err.println("######## RatingsSource.nextTuple: Could not parse line: " + line + "\n" + e);
}
View
12 src/main/java/collabstream/streaming/StreamingDSGD.java
@@ -12,9 +12,10 @@
public class StreamingDSGD {
public static void main(String[] args) throws Exception {
- if (args.length < 2) {
+ if (args.length < 6) {
System.err.println("######## Wrong number of arguments");
- System.err.println("######## required args: local|production fileName");
+ System.err.println("######## required args: local|production numUsers numItems"
+ + " inputFilename userOutputFilename itemOutputFilename");
return;
}
@@ -33,22 +34,25 @@ public static void main(String[] args) throws Exception {
int numItemBlocks = Integer.parseInt(props.getProperty("numItemBlocks", "10"));
float userPenalty = Float.parseFloat(props.getProperty("userPenalty", "0.1"));
float itemPenalty = Float.parseFloat(props.getProperty("itemPenalty", "0.1"));
- float initialStepSize = Float.parseFloat(props.getProperty("initialStepSize", "1"));
+ float initialStepSize = Float.parseFloat(props.getProperty("initialStepSize", "0.1"));
int maxTrainingIters = Integer.parseInt(props.getProperty("maxTrainingIters", "30"));
String inputFilename = args[3];
String userOutputFilename = args[4];
String itemOutputFilename = args[5];
+ long inputDelay = Long.parseLong(props.getProperty("inputDelay", "0"));
boolean debug = Boolean.parseBoolean(props.getProperty("debug", "false"));
Configuration config = new Configuration(numUsers, numItems, numLatent, numUserBlocks, numItemBlocks,
userPenalty, itemPenalty, initialStepSize, maxTrainingIters,
- inputFilename, userOutputFilename, itemOutputFilename, debug);
+ inputFilename, userOutputFilename, itemOutputFilename,
+ inputDelay, debug);
Config stormConfig = new Config();
stormConfig.addSerialization(TrainingExample.Serialization.class);
stormConfig.addSerialization(BlockPair.Serialization.class);
stormConfig.addSerialization(MatrixSerialization.class);
stormConfig.setNumWorkers(config.getNumProcesses());
+ stormConfig.setNumAckers(config.getNumWorkers()); // our notion of a worker is different from Storm's
TopologyBuilder builder = new TopologyBuilder();
builder.setSpout(1, new RatingsSource(config));
View
190 src/main/java/collabstream/streaming/TestPredictions.java
@@ -0,0 +1,190 @@
+package collabstream.streaming;
+
+import java.io.FileReader;
+import java.io.LineNumberReader;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.apache.commons.lang.StringUtils;
+import org.apache.commons.lang.time.DurationFormatUtils;
+
+public class TestPredictions {
+ public static void main(String[] args) throws Exception {
+ if (args.length < 7) {
+ System.err.println("######## Wrong number of arguments");
+ System.err.println("######## required args: numUsers numItems numLatent"
+ + " trainingFilename testFilename userFilename itemFilename");
+ return;
+ }
+
+ long testStartTime = System.currentTimeMillis();
+ System.out.printf("######## Testing started: %1$tY-%1$tb-%1$td %1$tT %tZ\n", testStartTime);
+
+ int numUsers = Integer.parseInt(args[0]);
+ int numItems = Integer.parseInt(args[1]);
+ int numLatent = Integer.parseInt(args[2]);
+ String trainingFilename = args[3];
+ String testFilename = args[4];
+ String userFilename = args[5];
+ String itemFilename = args[6];
+
+ float trainingTotal = 0.0f;
+ int trainingCount = 0;
+
+ Map<Integer, Integer> userCount = new HashMap<Integer, Integer>();
+ Map<Integer, Integer> itemCount = new HashMap<Integer, Integer>();
+ Map<Integer, Float> userTotal = new HashMap<Integer, Float>();
+ Map<Integer, Float> itemTotal = new HashMap<Integer, Float>();
+
+ long startTime = System.currentTimeMillis();
+ System.out.printf("######## Started reading training file: %1$tY-%1$tb-%1$td %1$tT %tZ\n", startTime);
+
+ String line;
+ LineNumberReader in = new LineNumberReader(new FileReader(trainingFilename));
+ while ((line = in.readLine()) != null) {
+ try {
+ String[] token = StringUtils.split(line, ' ');
+ int i = Integer.parseInt(token[0]);
+ int j = Integer.parseInt(token[1]);
+ float rating = Float.parseFloat(token[2]);
+
+ trainingTotal += rating;
+ ++trainingCount;
+
+ if (userCount.containsKey(i)) {
+ userCount.put(i, userCount.get(i) + 1);
+ userTotal.put(i, userTotal.get(i) + rating);
+ } else {
+ userCount.put(i, 1);
+ userTotal.put(i, rating);
+ }
+
+ if (itemCount.containsKey(j)) {
+ itemCount.put(j, itemCount.get(j) + 1);
+ itemTotal.put(j, itemTotal.get(j) + rating);
+ } else {
+ itemCount.put(j, 1);
+ itemTotal.put(j, rating);
+ }
+ } catch (Exception e) {
+ System.err.printf("######## Could not parse line %d in %s\n%s\n", in.getLineNumber(), trainingFilename, e);
+ }
+ }
+ in.close();
+
+ float trainingAvg = trainingTotal / trainingCount;
+
+ long endTime = System.currentTimeMillis();
+ System.out.printf("######## Finished reading training file: %1$tY-%1$tb-%1$td %1$tT %tZ\n", endTime);
+ System.out.println("######## Time elapsed reading training file: "
+ + DurationFormatUtils.formatPeriod(startTime, endTime, "H:m:s") + " (h:m:s)");
+
+ float[][] userMatrix = new float[numUsers][numLatent];
+ for (int i = 0; i < numUsers; ++i) {
+ for (int k = 0; k < numLatent; ++k) {
+ userMatrix[i][k] = 0.0f;
+ }
+ }
+
+ float[][] itemMatrix = new float[numItems][numLatent];
+ for (int i = 0; i < numItems; ++i) {
+ for (int k = 0; k < numLatent; ++k) {
+ itemMatrix[i][k] = 0.0f;
+ }
+ }
+
+ startTime = System.currentTimeMillis();
+ System.out.printf("######## Started reading user file: %1$tY-%1$tb-%1$td %1$tT %tZ\n", startTime);
+
+ in = new LineNumberReader(new FileReader(userFilename));
+ while ((line = in.readLine()) != null) {
+ try {
+ String[] token = StringUtils.split(line, ' ');
+ int i = Integer.parseInt(token[0]);
+ for (int k = 0; k < numLatent; ++k) {
+ userMatrix[i][k] = Float.parseFloat(token[k+1]);
+ }
+ } catch (Exception e) {
+ System.err.printf("######## Could not parse line %d in %s\n%s\n", in.getLineNumber(), userFilename, e);
+ }
+ }
+ in.close();
+
+ endTime = System.currentTimeMillis();
+ System.out.printf("######## Finished reading user file: %1$tY-%1$tb-%1$td %1$tT %tZ\n", endTime);
+ System.out.println("######## Time elapsed reading user file: "
+ + DurationFormatUtils.formatPeriod(startTime, endTime, "H:m:s") + " (h:m:s)");
+
+ startTime = System.currentTimeMillis();
+ System.out.printf("######## Started reading item file: %1$tY-%1$tb-%1$td %1$tT %tZ\n", startTime);
+
+ in = new LineNumberReader(new FileReader(itemFilename));
+ while ((line = in.readLine()) != null) {
+ try {
+ String[] token = StringUtils.split(line, ' ');
+ int j = Integer.parseInt(token[0]);
+ for (int k = 0; k < numLatent; ++k) {
+ itemMatrix[j][k] = Float.parseFloat(token[k+1]);
+ }
+ } catch (Exception e) {
+ System.err.printf("######## Could not parse line %d in %s\n%s\n", in.getLineNumber(), itemFilename, e);
+ }
+ }
+ in.close();
+
+ endTime = System.currentTimeMillis();
+ System.out.printf("######## Finished reading item file: %1$tY-%1$tb-%1$td %1$tT %tZ\n", endTime);
+ System.out.println("######## Time elapsed reading item file: "
+ + DurationFormatUtils.formatPeriod(startTime, endTime, "H:m:s") + " (h:m:s)");
+
+ startTime = System.currentTimeMillis();
+ System.out.printf("######## Started reading test file: %1$tY-%1$tb-%1$td %1$tT %tZ\n", startTime);
+
+ float totalSqErr = 0.0f;
+ int numRatings = 0;
+
+ in = new LineNumberReader(new FileReader(testFilename));
+ while ((line = in.readLine()) != null) {
+ try {
+ String[] token = StringUtils.split(line, ' ');
+ int i = Integer.parseInt(token[0]);
+ int j = Integer.parseInt(token[1]);
+ float rating = Float.parseFloat(token[2]);
+ float prediction;
+
+ boolean userKnown = userCount.containsKey(i);
+ boolean itemKnown = itemCount.containsKey(j);
+
+ if (userKnown && itemKnown) {
+ prediction = 0.0f;
+ for (int k = 0; k < numLatent; ++k) {
+ prediction += userMatrix[i][k] * itemMatrix[j][k];
+ }
+ } else if (userKnown) {
+ prediction = userTotal.get(i) / userCount.get(i);
+ } else if (itemKnown) {
+ prediction = itemTotal.get(j) / itemCount.get(j);
+ } else {
+ prediction = trainingAvg;
+ }
+
+ float diff = prediction - rating;
+ totalSqErr += diff*diff;
+ ++numRatings;
+ } catch (Exception e) {
+ System.err.printf("######## Could not parse line %d in %s\n%s\n", in.getLineNumber(), testFilename, e);
+ }
+ }
+
+ double rmse = Math.sqrt(totalSqErr / numRatings);
+
+ endTime = System.currentTimeMillis();
+ System.out.printf("######## Finished reading test file: %1$tY-%1$tb-%1$td %1$tT %tZ\n", endTime);
+ System.out.println("######## Time elapsed reading test file: "
+ + DurationFormatUtils.formatPeriod(startTime, endTime, "H:m:s") + " (h:m:s)");
+ System.out.println("######## Total elapsed testing time: "
+ + DurationFormatUtils.formatPeriod(testStartTime, endTime, "H:m:s") + " (h:m:s)");
+ System.out.println("######## Number of ratings used: " + numRatings);
+ System.out.println("######## RMSE: " + rmse);
+ }
+}
View
6 src/main/java/collabstream/streaming/Worker.java
@@ -50,6 +50,7 @@ public void execute(Tuple tuple) {
bp = new BlockPair(config.getUserBlockIdx(ex.userId), config.getItemBlockIdx(ex.itemId));
workingBlock = getWorkingBlock(bp);
workingBlock.examples.add(ex);
+ collector.ack(tuple);
break;
case PROCESS_BLOCK_REQ:
bp = (BlockPair)tuple.getValue(1);
@@ -143,9 +144,10 @@ private void update(BlockPair bp, WorkingBlock workingBlock) {
float[][] userBlock = workingBlock.userBlock;
float[][] itemBlock = workingBlock.itemBlock;
- PermutationUtils.permute(workingBlock.examples);
+ TrainingExample[] examples = workingBlock.examples.toArray(new TrainingExample[workingBlock.examples.size()]);
+ PermutationUtils.permute(examples);
- for (TrainingExample ex : workingBlock.examples) {
+ for (TrainingExample ex : examples) {
if (ex.numTrainingIters >= config.maxTrainingIters) continue;
int i = ex.userId - userBlockStart;
int j = ex.itemId - itemBlockStart;
View
5 src/main/java/collabstream/streaming/WorkingBlock.java
@@ -1,10 +1,11 @@
package collabstream.streaming;
import java.io.Serializable;
-import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.Set;
public class WorkingBlock implements Serializable {
- public final ArrayList<TrainingExample> examples = new ArrayList<TrainingExample>();
+ public final Set<TrainingExample> examples = new HashSet<TrainingExample>();
public float[][] userBlock = null;
public float[][] itemBlock = null;
public boolean waitingForBlocks = false;
View
9 src/main/resources/log4j.properties
@@ -0,0 +1,9 @@
+log4j.rootLogger=INFO, A1
+log4j.appender.A1=org.apache.log4j.ConsoleAppender
+log4j.appender.A1.layout=org.apache.log4j.PatternLayout
+log4j.appender.A1.layout.ConversionPattern=%-4r [%t] %-5p %c %x - %m%n
+
+log4j.logger.backtype.storm.daemon=WARN
+log4j.logger.backtype.storm.serialization=WARN
+log4j.logger.backtype.storm.zookeeper=WARN
+log4j.logger.org.apache.zookeeper=ERROR
Please sign in to comment.
Something went wrong with that request. Please try again.