Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
23 changed files
with
10,961 additions
and
172 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
71 changes: 71 additions & 0 deletions
71
...ngel/examples/src/main/scala/com/tencent/angel/spark/examples/local/Word2vecExample.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
/* | ||
* Tencent is pleased to support the open source community by making Angel available. | ||
* | ||
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in | ||
* compliance with the License. You may obtain a copy of the License at | ||
* | ||
* https://opensource.org/licenses/Apache-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software distributed under the License | ||
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express | ||
* or implied. See the License for the specific language governing permissions and limitations under | ||
* the License. | ||
* | ||
*/ | ||
|
||
package com.tencent.angel.spark.examples.local | ||
|
||
import scala.util.Random | ||
import com.tencent.angel.spark.context.PSContext | ||
import com.tencent.angel.spark.ml.embedding.Param | ||
import com.tencent.angel.spark.ml.embedding.word2vec.Word2VecModel | ||
import com.tencent.angel.spark.ml.feature.Features | ||
import org.apache.spark.{SparkConf, SparkContext} | ||
|
||
object Word2vecExample { | ||
|
||
def main(args: Array[String]): Unit = { | ||
val conf = new SparkConf() | ||
conf.setMaster("local[1]") | ||
conf.setAppName("Word2vec example") | ||
conf.set("spark.ps.model", "LOCAL") | ||
conf.set("spark.ps.jars", "") | ||
conf.set("spark.ps.instances", "1") | ||
conf.set("spark.ps.cores", "1") | ||
|
||
val sc = new SparkContext(conf) | ||
sc.setLogLevel("ERROR") | ||
PSContext.getOrCreate(sc) | ||
|
||
val input = "data/text8/text8.split.head" | ||
val (corpus, _) = Features.corpusStringToInt(sc.textFile(input)) | ||
val docs = corpus.repartition(2) | ||
docs.cache() | ||
|
||
|
||
val numDocs = docs.count() | ||
val maxWordId = docs.map(_.max).max().toLong + 1 | ||
val numTokens = docs.map(_.length).sum().toLong | ||
|
||
println(s"numDocs=$numDocs maxWordId=$maxWordId numTokens=$numTokens") | ||
|
||
val param = new Param() | ||
param.setLearningRate(0.001f) | ||
param.setEmbeddingDim(100) | ||
param.setWindowSize(10) | ||
param.setBatchSize(128) | ||
param.setNodesNumPerRow(Some(100)) | ||
param.setSeed(Random.nextInt()) | ||
param.setNumPSPart(Some(2)) | ||
param.setNumEpoch(10) | ||
param.setNegSample(5) | ||
param.setMaxIndex(maxWordId) | ||
param.setNumRowDataSet(numDocs) | ||
|
||
val model = new Word2VecModel(param) | ||
model.train(docs, param) | ||
} | ||
|
||
} |
190 changes: 134 additions & 56 deletions
190
...k-on-angel/mllib/src/main/java/com/tencent/angel/spark/ml/psf/embedding/CBOW/CBOWDot.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,75 +1,153 @@ | ||
package com.tencent.angel.spark.ml.psf.embedding.CBOW; | ||
/* | ||
* Tencent is pleased to support the open source community by making Angel available. | ||
* | ||
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in | ||
* compliance with the License. You may obtain a copy of the License at | ||
* | ||
* https://opensource.org/licenses/Apache-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software distributed under the License | ||
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express | ||
* or implied. See the License for the specific language governing permissions and limitations under | ||
* the License. | ||
* | ||
*/ | ||
|
||
package com.tencent.angel.spark.ml.psf.embedding.cbow; | ||
|
||
import com.tencent.angel.PartitionKey; | ||
import com.tencent.angel.ml.matrix.psf.get.base.*; | ||
import io.netty.buffer.ByteBuf; | ||
import com.tencent.angel.exception.AngelException; | ||
import com.tencent.angel.ml.math2.storage.IntFloatDenseVectorStorage; | ||
import com.tencent.angel.ml.matrix.psf.get.base.GetFunc; | ||
import com.tencent.angel.ml.matrix.psf.get.base.GetResult; | ||
import com.tencent.angel.ml.matrix.psf.get.base.PartitionGetParam; | ||
import com.tencent.angel.ml.matrix.psf.get.base.PartitionGetResult; | ||
import com.tencent.angel.ps.storage.matrix.ServerPartition; | ||
import com.tencent.angel.spark.ml.psf.embedding.NENegativeSample; | ||
import com.tencent.angel.spark.ml.psf.embedding.ServerSentences; | ||
import it.unimi.dsi.fastutil.floats.FloatArrayList; | ||
|
||
import java.util.Arrays; | ||
import java.util.List; | ||
import java.util.Random; | ||
|
||
public class CBOWDot extends GetFunc { | ||
|
||
class CBOWDotPartitionParam extends PartitionGetParam { | ||
|
||
private int seed; | ||
private int negative; | ||
private int window; | ||
private int partDim; | ||
private int[][] sentences; | ||
|
||
public CBOWDotPartitionParam(int matrixId, | ||
int seed, | ||
int negative, | ||
int window, | ||
int partDim, | ||
int[][] sentences, | ||
PartitionKey pkey) { | ||
super(matrixId, pkey); | ||
this.seed = seed; | ||
this.negative = negative; | ||
this.window = window; | ||
this.partDim = partDim; | ||
this.sentences = sentences; | ||
} | ||
|
||
@Override | ||
public void serialize(ByteBuf buf) { | ||
super.serialize(buf); | ||
buf.writeInt(seed); | ||
buf.writeInt(negative); | ||
buf.writeInt(window); | ||
buf.writeInt(partDim); | ||
buf.writeInt(sentences.length); | ||
for (int a = 0; a < sentences.length; a ++) { | ||
buf.writeInt(sentences[a].length); | ||
for (int b = 0; b < sentences[a].length; b ++) | ||
buf.writeInt(sentences[a][b]); | ||
} | ||
} | ||
|
||
@Override | ||
public void deserialize(ByteBuf buf) { | ||
|
||
} | ||
} | ||
|
||
class CBOWDotParam extends GetParam { | ||
@Override | ||
public List<PartitionGetParam> split() { | ||
return super.split(); | ||
} | ||
} | ||
public class CbowDot extends GetFunc { | ||
|
||
public CBOWDot(GetParam param) { | ||
public CbowDot(CbowDotParam param) { | ||
super(param); | ||
} | ||
|
||
public CbowDot() { super(null); } | ||
|
||
@Override | ||
public PartitionGetResult partitionGet(PartitionGetParam partParam) { | ||
if (partParam instanceof CbowDotPartitionParam) { | ||
CbowDotPartitionParam param = (CbowDotPartitionParam) partParam; | ||
|
||
// some params | ||
PartitionKey pkey = param.getPartKey(); | ||
|
||
int negative = param.negative; | ||
int partDim = param.partDim; | ||
int window = param.window; | ||
int seed = param.seed; | ||
int order = 2; | ||
|
||
int[][] sentences = ServerSentences.getSentences(param.partitionId); | ||
|
||
// compute number of nodes for one row | ||
int size = (int) (pkey.getEndCol() - pkey.getStartCol()); | ||
int numNodes = size / (partDim * order); | ||
|
||
// used to accumulate the context input vectors | ||
float[] context = new float[partDim]; | ||
|
||
ServerPartition partition = psContext.getMatrixStorageManager().getPart(pkey); | ||
NENegativeSample sample = new NENegativeSample(-1, seed); | ||
|
||
Random rand = new Random(seed); | ||
FloatArrayList partialDots = new FloatArrayList(); | ||
for (int s = 0; s < sentences.length; s ++) { | ||
int[] sen = sentences[s]; | ||
for (int position = 0; position < sen.length; position ++) { | ||
int word = sen[position]; | ||
// fill 0 for context vector | ||
Arrays.fill(context, 0); | ||
// window size | ||
int b = rand.nextInt(window); | ||
// Continuous bag-of-words Models | ||
int cw = 0; | ||
|
||
// Accumulate the input vectors from context | ||
for (int a = b; a < window * 2 + 1 - b; a++) | ||
if (a != window) { | ||
int c = position - window + a; | ||
if (c < 0) continue; | ||
if (c >= sen.length) continue; | ||
int sentence_word = sen[c]; | ||
if (sentence_word == -1) continue; | ||
int rowId = sentence_word / numNodes; | ||
int colId = (sentence_word % numNodes) * partDim * order; | ||
float[] values = ((IntFloatDenseVectorStorage) partition.getRow(rowId) | ||
.getSplit().getStorage()).getValues(); | ||
for (c = 0; c < partDim; c++) context[c] += values[c + colId]; | ||
cw++; | ||
} | ||
|
||
// Calculate the partial dot values | ||
if (cw > 0) { | ||
for (int c = 0; c < partDim; c ++) context[c] /= cw; | ||
int target; | ||
for (int d = 0; d < negative + 1; d ++) { | ||
if (d == 0) target = word; | ||
// We should guarantee here that the sample would not equal the ``word`` | ||
else target = sample.next(word); | ||
|
||
int rowId = target / numNodes; | ||
int colId = (target % numNodes) * partDim * order + partDim; | ||
float f = 0f; | ||
float[] values = ((IntFloatDenseVectorStorage) partition.getRow(rowId) | ||
.getSplit().getStorage()).getValues(); | ||
for (int c = 0; c < partDim; c ++) f += context[c] * values[c + colId]; | ||
partialDots.add(f); | ||
} | ||
} | ||
} | ||
} | ||
return new CbowDotPartitionResult(partialDots.toFloatArray()); | ||
} | ||
|
||
return null; | ||
} | ||
|
||
@Override | ||
public GetResult merge(List<PartitionGetResult> partResults) { | ||
if (partResults.size() > 0 && partResults.get(0) instanceof CbowDotPartitionResult) { | ||
int size = ((CbowDotPartitionResult) partResults.get(0)).length; | ||
|
||
// check the length of dot values | ||
for (PartitionGetResult result: partResults) { | ||
if (result instanceof CbowDotPartitionResult && | ||
size != ((CbowDotPartitionResult) result).length) | ||
throw new AngelException(String.format("length of dot values not same one is %d other is %d", | ||
size, | ||
((CbowDotPartitionResult) result).length)); | ||
} | ||
|
||
// merge dot values from all partitions | ||
float[] results = new float[size]; | ||
for (PartitionGetResult result: partResults) | ||
if (result instanceof CbowDotPartitionResult) | ||
try { | ||
((CbowDotPartitionResult) result).merge(results); | ||
} finally { | ||
((CbowDotPartitionResult) result).clear(); | ||
} | ||
return new CbowDotResult(results); | ||
} | ||
|
||
return null; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
14 changes: 14 additions & 0 deletions
14
...n-angel/mllib/src/main/java/com/tencent/angel/spark/ml/psf/embedding/ServerSentences.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
package com.tencent.angel.spark.ml.psf.embedding; | ||
|
||
public class ServerSentences { | ||
|
||
public static int[][][] batches; | ||
|
||
public static synchronized void initialize(int numPartitions) { | ||
batches = new int[numPartitions][][]; | ||
} | ||
|
||
public static int[][] getSentences(int partitionId) { | ||
return batches[partitionId]; | ||
} | ||
} |
Oops, something went wrong.