Skip to content

Commit

Permalink
Merge pull request #25 from zhanglistar/master
Browse files Browse the repository at this point in the history
1. fix compile error on can't find symbols when using java and scala.
  • Loading branch information
takun2s committed Mar 21, 2018
2 parents ed5e12c + 98aaa56 commit 65cfedb
Show file tree
Hide file tree
Showing 11 changed files with 57 additions and 10 deletions.
2 changes: 1 addition & 1 deletion core/pom.xml
Expand Up @@ -3,7 +3,7 @@
<parent>
<artifactId>fregata</artifactId>
<groupId>com.talkingdata.fregata</groupId>
<version>0.0.4-SNAPSHOT</version>
<version>0.0.4</version>
<relativePath>../pom.xml</relativePath>
</parent>
<modelVersion>4.0.0</modelVersion>
Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala/fregata/model/ModelTrainer.scala
Expand Up @@ -26,4 +26,5 @@ trait ModelTrainer extends Serializable{

def run(data:Iterable[(Vector,Num)]) : M

def loadModel(fn: String): Int = 0
}
Expand Up @@ -20,4 +20,7 @@ trait ClassificationModel extends Model {
val (p,c) = classPredict(x)
(a,(p,c))
}

def saveModel(fn: String): Int = 0
def loadMode(fn: String): Int = 0
}
@@ -1,11 +1,15 @@
package fregata.model.classification

//import breeze.io.TextWriter.FileWriter
import fregata._
import fregata.model.{Model, ModelTrainer}
import fregata.optimize.sgd.{AdaptiveSGD, StochasticGradientDescent}
import fregata.optimize.{Gradient, Target}
import fregata.param.ParameterServer
import fregata.util.VectorUtil
import java.io.{FileNotFoundException, FileWriter}

import scala.io.Source

/**
* The greedy step averaging(GSA) method, a
Expand Down Expand Up @@ -68,6 +72,18 @@ class LogisticRegressionModel(val weights:Vector) extends ClassificationModel{
val c = if( p > threshold ) 1.0 else 0.0
(asNum(p),asNum(c))
}
override def saveModel(filename: String): Int = {
val outFile = new FileWriter(filename, false)
if (outFile == null) {
-1
} else {
outFile.write(weights.size.toString + '\n')
weights.toArray.foreach((x: Double) => outFile.write(x.toString + '\n'))
outFile.flush()
outFile.close()
0
}
}
}

class LogisticRegression extends ModelTrainer {
Expand All @@ -82,4 +98,21 @@ class LogisticRegression extends ModelTrainer {
.run(data)
new LogisticRegressionModel(ps.get(0))
}

override def loadModel(fn: String): Int = {
var i = 0
var ret = 0
var last_weights: Array[Num] = null
try {
for (line <- Source.fromFile(fn).getLines()) {
if (i == 0) last_weights = new Array[Num](line.toInt)
else last_weights(i-1) = line.toFloat
i += 1
}
} catch {
case _: FileNotFoundException => { ret = -1 }
}
if (last_weights != null) ps.set(Array.fill(1){new DenseVector(last_weights)})
ret
}
}
4 changes: 3 additions & 1 deletion pom.xml
Expand Up @@ -6,7 +6,7 @@
<groupId>com.talkingdata.fregata</groupId>
<artifactId>fregata</artifactId>
<packaging>pom</packaging>
<version>0.0.4-SNAPSHOT</version>
<version>0.0.4</version>
<modules>
<module>core</module>
<module>spark</module>
Expand Down Expand Up @@ -222,6 +222,8 @@
<version>3.2.2</version>
<executions>
<execution>
<id>scala-compile-first</id>
<phase>process-resources</phase>
<goals>
<goal>compile</goal>
<goal>testCompile</goal>
Expand Down
4 changes: 2 additions & 2 deletions ps/pom.xml
Expand Up @@ -5,7 +5,7 @@
<parent>
<artifactId>fregata</artifactId>
<groupId>com.talkingdata.fregata</groupId>
<version>0.0.4-SNAPSHOT</version>
<version>0.0.4</version>
</parent>
<modelVersion>4.0.0</modelVersion>

Expand Down Expand Up @@ -60,4 +60,4 @@
</dependency>
</dependencies>

</project>
</project>
4 changes: 2 additions & 2 deletions spark/pom.xml
Expand Up @@ -3,7 +3,7 @@
<parent>
<artifactId>fregata</artifactId>
<groupId>com.talkingdata.fregata</groupId>
<version>0.0.4-SNAPSHOT</version>
<version>0.0.4</version>
<relativePath>../pom.xml</relativePath>
</parent>
<modelVersion>4.0.0</modelVersion>
Expand All @@ -17,7 +17,7 @@
<dependency>
<groupId>com.talkingdata.fregata</groupId>
<artifactId>core</artifactId>
<version>0.0.4-SNAPSHOT</version>
<version>0.0.4</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
Expand Down
1 change: 1 addition & 0 deletions spark/src/main/scala/fregata/spark/model/SparkModel.scala
Expand Up @@ -30,4 +30,5 @@ trait SparkModel extends Model {
def predict(data:RDD[Vector]) = {
predictPartition[Vector,Num](data,(x,model) => model.predict(x) )
}

}
Expand Up @@ -33,4 +33,8 @@ trait ClassificationModel extends SparkModel{
case ((x,label),model:LClassificationModel) => model.classPredict(x)
})
}

def saveModel(fn: String): Int = {
model.saveModel(fn)
}
}
Expand Up @@ -23,8 +23,10 @@ object LogisticRegression {
*/
def run(data:RDD[(Vector,Num)],
localEpochNum:Int = 1 ,
epochNum:Int = 1) = {
epochNum:Int = 1,
lastModel: String = "") = {
val trainer = new LLogisticRegression
trainer.loadModel(lastModel)
new SparkTrainer(trainer)
.run(data,epochNum,localEpochNum)
new LogisticRegressionModel(trainer.buildModel(trainer.ps))
Expand Down
Expand Up @@ -14,9 +14,9 @@ object TestLogisticRegression {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("logistic regression")
val sc = new SparkContext(conf)
val (_,trainData) = LibSvmReader.read(sc,"/Volumes/takun/data/libsvm/a9a",123)
val (_,testData) = LibSvmReader.read(sc,"/Volumes/takun/data/libsvm/a9a.t",123)
val model = LogisticRegression.run(trainData)
val (_,trainData) = LibSvmReader.read(sc, args(0), Integer.parseInt(args(1)))
val (_,testData) = LibSvmReader.read(sc, args(2), Integer.parseInt(args(1)))
val model = LogisticRegression.run(trainData, lastModel = args(3))
val pd = model.classPredict(testData)
val acc = Accuracy.of( pd.map{
case ((x,l),(p,c)) =>
Expand All @@ -37,5 +37,6 @@ object TestLogisticRegression {
println( s"AreaUnderRoc = $auc ")
println( s"Accuracy = $acc ")
println( s"LogLoss = $loss ")
model.saveModel(args(3))
}
}

0 comments on commit 65cfedb

Please sign in to comment.