Permalink
Browse files

Merge branch 'master' of github.com:MrChrisJohnson/CollabStream

  • Loading branch information...
MrChrisJohnson committed Dec 7, 2011
2 parents e4923ba + b7710ae commit e048e29d8ee6bca0416aef9c8368db8f79e9adce
View
@@ -10,4 +10,9 @@ mvn exec:java -Dexec.mainClass=collabstream.streaming.StreamingDSGD -Dexec.args=
mvn compile exec:java -Dexec.mainClass=collabstream.streaming.StreamingDSGD -Dexec.args='local 17 28 data/input/test.dat data/output/test.user data/output/test.item' -Dexec.classpathScope=compile
mvn compile exec:java -Dexec.mainClass=collabstream.streaming.StreamingDSGD -Dexec.args='local 17 28 data/input/test.dat data/output/test.user data/output/test.item' -Dexec.classpathScope=compile | grep '########'
-mvn exec:java -Dexec.mainClass=collabstream.streaming.StreamingDSGD -Dexec.args='local 9 2648862 data/input/netflix_tr_head10000_shuf data/output/netflix_tr_head10000_shuf.user data/output/netflix_tr_head10000_shuf.item' -Dexec.classpathScope=compile
+mvn exec:java -Dexec.mainClass=collabstream.streaming.TestPredictions -Dexec.args='4 5 3 data/input/predtest data/output/predtest.user data/output/predtest.item' -Dexec.classpathScope=compile
+
+mvn exec:java -Dexec.mainClass=collabstream.streaming.StreamingDSGD -Dexec.args='local 6041 3953 data/input/MovieLens/ml_tr_rand.txt data/output/MovieLens/ml.user data/output/MovieLens/ml.item' -Dexec.classpathScope=compile
+
+mvn exec:java -Dexec.mainClass=collabstream.streaming.StreamingDSGD -Dexec.args='local 6041 3953 data/input/MovieLens/ml_100k data/output/MovieLens/ml_100k.user data/output/MovieLens/ml_100k.item' -Dexec.classpathScope=compile
+mvn compile exec:java -Dexec.mainClass=collabstream.streaming.TestPredictions -Dexec.args='6041 3953 10 data/input/MovieLens/ml1m_te_rb.dat data/output/MovieLens/ml_100k.user data/output/MovieLens/ml_100k.item' -Dexec.classpathScope=compile
View
@@ -32,7 +32,7 @@
<groupId>storm</groupId>
<artifactId>storm</artifactId>
<version>0.5.4</version>
- <scope>compile</scope>
+ <scope>provided</scope>
</dependency>
<dependency>
<groupId>log4j</groupId>
@@ -8,6 +8,7 @@
public final float userPenalty, itemPenalty, initialStepSize;
public final int maxTrainingIters;
public final String inputFilename, userOutputFilename, itemOutputFilename;
+ public final long inputDelay; // delay between sending two training examples in milliseconds
public final boolean debug;
private final int smallUserBlockSize, smallItemBlockSize;
@@ -17,7 +18,8 @@
public Configuration(int numUsers, int numItems, int numLatent, int numUserBlocks, int numItemBlocks,
float userPenalty, float itemPenalty, float initialStepSize, int maxTrainingIters,
- String inputFilename, String userOutputFilename, String itemOutputFilename, boolean debug) {
+ String inputFilename, String userOutputFilename, String itemOutputFilename,
+ long inputDelay, boolean debug) {
this.numUsers = numUsers;
this.numItems = numItems;
this.numLatent = numLatent;
@@ -30,6 +32,7 @@ public Configuration(int numUsers, int numItems, int numLatent, int numUserBlock
this.inputFilename = inputFilename;
this.userOutputFilename = userOutputFilename;
this.itemOutputFilename = itemOutputFilename;
+ this.inputDelay = inputDelay;
this.debug = debug;
smallUserBlockSize = numUsers / numUserBlocks;
@@ -35,7 +35,7 @@
private Queue<BlockPair> userBlockQueue = new LinkedList<BlockPair>();
private Queue<BlockPair> itemBlockQueue = new LinkedList<BlockPair>();
private boolean endOfData = false;
- private long startTime, outputStartTime;
+ private long startTime, outputStartTime = 0;
private final Random random = new Random();
public Master(Configuration config) {
@@ -58,7 +58,7 @@ public void execute(Tuple tuple) {
if (config.debug && msgType != END_OF_DATA && msgType != PROCESS_BLOCK_FIN) {
System.out.println("######## Master.execute: " + msgType + " " + tuple.getValue(1));
}
- TrainingExample ex;
+ TrainingExample ex, latest;
BlockPair bp, head;
switch (msgType) {
@@ -74,7 +74,11 @@ public void execute(Tuple tuple) {
ex = (TrainingExample)tuple.getValue(1);
int userBlockIdx = config.getUserBlockIdx(ex.userId);
int itemBlockIdx = config.getItemBlockIdx(ex.itemId);
- latestExample[userBlockIdx][itemBlockIdx] = ex;
+
+ latest = latestExample[userBlockIdx][itemBlockIdx];
+ if (latest == null || latest.timestamp < ex.timestamp) {
+ latestExample[userBlockIdx][itemBlockIdx] = ex;
+ }
bp = blockPair[userBlockIdx][itemBlockIdx];
if (bp == null) {
@@ -83,8 +87,8 @@ public void execute(Tuple tuple) {
freeSet.add(bp);
}
+ collector.emit(tuple, new Values(TRAINING_EXAMPLE, null, ex, userBlockIdx));
collector.ack(tuple);
- collector.emit(new Values(TRAINING_EXAMPLE, null, ex, userBlockIdx));
distributeWork();
break;
case PROCESS_BLOCK_FIN:
@@ -94,7 +98,7 @@ public void execute(Tuple tuple) {
System.out.println("######## Master.execute: " + msgType + " " + bp + " " + ex);
}
- TrainingExample latest = latestExample[bp.userBlockIdx][bp.itemBlockIdx];
+ latest = latestExample[bp.userBlockIdx][bp.itemBlockIdx];
if (latest.timestamp == ex.timestamp) {
latest.numTrainingIters = ex.numTrainingIters;
if (endOfData && latest.numTrainingIters >= config.maxTrainingIters) {
@@ -117,7 +121,7 @@ public void execute(Tuple tuple) {
float[][] userBlock = (float[][])tuple.getValue(2);
head = userBlockQueue.remove();
if (!head.equals(bp)) {
- System.err.println("######## Master.execute: Expected " + head + " for user block. Received " + bp);
+ throw new RuntimeException("Expected " + head + ", but received " + bp + " for " + USER_BLOCK);
}
writeUserBlock(bp.userBlockIdx, userBlock);
requestNextUserBlock();
@@ -127,7 +131,7 @@ public void execute(Tuple tuple) {
float[][] itemBlock = (float[][])tuple.getValue(2);
head = itemBlockQueue.remove();
if (!head.equals(bp)) {
- System.err.println("######## Master.execute: Expected " + head + " for item block. Received " + bp);
+ throw new RuntimeException("Expected " + head + ", but received " + bp + " for " + ITEM_BLOCK);
}
writeItemBlock(bp.itemBlockIdx, itemBlock);
requestNextItemBlock();
@@ -188,6 +192,7 @@ private void distributeWork() {
private void startOutput() {
try {
+ if (outputStartTime > 0) return;
outputStartTime = System.currentTimeMillis();
System.out.printf("######## Training finished: %1$tY-%1$tb-%1$td %1$tT %tZ\n", outputStartTime);
System.out.println("######## Elapsed training time: "
@@ -228,18 +233,24 @@ private void startOutput() {
private void writeUserBlock(int userBlockIdx, float[][] userBlock) {
int userBlockStart = config.getUserBlockStart(userBlockIdx);
for (int i = 0; i < userBlock.length; ++i) {
+ userOutput.print(userBlockStart + i);
for (int k = 0; k < config.numLatent; ++k) {
- userOutput.printf("%d %d %f\n", userBlockStart + i, k, userBlock[i][k]);
+ userOutput.print(' ');
+ userOutput.print(userBlock[i][k]);
}
+ userOutput.println();
}
}
private void writeItemBlock(int itemBlockIdx, float[][] itemBlock) {
int itemBlockStart = config.getItemBlockStart(itemBlockIdx);
for (int j = 0; j < itemBlock.length; ++j) {
+ itemOutput.print(itemBlockStart + j);
for (int k = 0; k < config.numLatent; ++k) {
- itemOutput.printf("%d %d %f\n", itemBlockStart + j, k, itemBlock[j][k]);
+ itemOutput.print(' ');
+ itemOutput.print(itemBlock[j][k]);
}
+ itemOutput.println();
}
}
@@ -14,6 +14,7 @@
import backtype.storm.topology.OutputFieldsDeclarer;
import backtype.storm.tuple.Fields;
import backtype.storm.tuple.Values;
+import backtype.storm.utils.Utils;
import static collabstream.streaming.MsgType.*;
@@ -50,7 +51,15 @@ public void ack(Object msgId) {
}
public void fail(Object msgId) {
- System.err.println("######## RatingsSource.fail: " + msgId);
+ if (config.debug) {
+ System.err.println("######## RatingsSource.fail: Resending " + msgId);
+ }
+ if (msgId == END_OF_DATA) {
+ collector.emit(new Values(END_OF_DATA, null), END_OF_DATA);
+ } else {
+ TrainingExample ex = (TrainingExample)msgId;
+ collector.emit(new Values(TRAINING_EXAMPLE, ex), ex);
+ }
}
public void nextTuple() {
@@ -72,8 +81,13 @@ public void nextTuple() {
int userId = Integer.parseInt(token[0]);
int itemId = Integer.parseInt(token[1]);
float rating = Float.parseFloat(token[2]);
+
TrainingExample ex = new TrainingExample(sequenceNum++, userId, itemId, rating);
collector.emit(new Values(TRAINING_EXAMPLE, ex), ex);
+
+ if (config.inputDelay > 0) {
+ Utils.sleep(config.inputDelay);
+ }
} catch (Exception e) {
System.err.println("######## RatingsSource.nextTuple: Could not parse line: " + line + "\n" + e);
}
@@ -12,9 +12,10 @@
public class StreamingDSGD {
public static void main(String[] args) throws Exception {
- if (args.length < 2) {
+ if (args.length < 6) {
System.err.println("######## Wrong number of arguments");
- System.err.println("######## required args: local|production fileName");
+ System.err.println("######## required args: local|production numUsers numItems"
+ + " inputFilename userOutputFilename itemOutputFilename");
return;
}
@@ -33,22 +34,25 @@ public static void main(String[] args) throws Exception {
int numItemBlocks = Integer.parseInt(props.getProperty("numItemBlocks", "10"));
float userPenalty = Float.parseFloat(props.getProperty("userPenalty", "0.1"));
float itemPenalty = Float.parseFloat(props.getProperty("itemPenalty", "0.1"));
- float initialStepSize = Float.parseFloat(props.getProperty("initialStepSize", "1"));
+ float initialStepSize = Float.parseFloat(props.getProperty("initialStepSize", "0.1"));
int maxTrainingIters = Integer.parseInt(props.getProperty("maxTrainingIters", "30"));
String inputFilename = args[3];
String userOutputFilename = args[4];
String itemOutputFilename = args[5];
+ long inputDelay = Long.parseLong(props.getProperty("inputDelay", "0"));
boolean debug = Boolean.parseBoolean(props.getProperty("debug", "false"));
Configuration config = new Configuration(numUsers, numItems, numLatent, numUserBlocks, numItemBlocks,
userPenalty, itemPenalty, initialStepSize, maxTrainingIters,
- inputFilename, userOutputFilename, itemOutputFilename, debug);
+ inputFilename, userOutputFilename, itemOutputFilename,
+ inputDelay, debug);
Config stormConfig = new Config();
stormConfig.addSerialization(TrainingExample.Serialization.class);
stormConfig.addSerialization(BlockPair.Serialization.class);
stormConfig.addSerialization(MatrixSerialization.class);
stormConfig.setNumWorkers(config.getNumProcesses());
+ stormConfig.setNumAckers(config.getNumWorkers()); // our notion of a worker is different from Storm's
TopologyBuilder builder = new TopologyBuilder();
builder.setSpout(1, new RatingsSource(config));
Oops, something went wrong.

0 comments on commit e048e29

Please sign in to comment.