Skip to content

Commit

Permalink
Cbow model
Browse files Browse the repository at this point in the history
  • Loading branch information
leleyu committed Oct 25, 2018
1 parent 2a05ca7 commit 6abe556
Show file tree
Hide file tree
Showing 23 changed files with 10,961 additions and 172 deletions.
10,000 changes: 10,000 additions & 0 deletions data/text8/text8.split.head

Large diffs are not rendered by default.

@@ -1,3 +1,20 @@
/*
* Tencent is pleased to support the open source community by making Angel available.
*
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* https://opensource.org/licenses/Apache-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*
*/

package com.tencent.angel.spark.examples.cluster

import com.tencent.angel.spark.context.PSContext
Expand Down Expand Up @@ -54,7 +71,7 @@ object Word2vecExample {
.setMaxIndex(maxWordId)
.setNumRowDataSet(numDocs)

new Word2VecModel(param).train(docs, param, None)
new Word2VecModel(param).train(docs, param)
}

}
Expand Up @@ -28,7 +28,6 @@ import com.tencent.angel.spark.ml.core.{ArgsUtil, GraphModel, OfflineLearner}
import com.tencent.angel.spark.ml.util.{Features, ModelLoader, ModelSaver}
import org.apache.log4j.PropertyConfigurator
import org.apache.spark.{SparkConf, SparkContext}
import org.codehaus.jackson.JsonParser.Feature

object JsonExample {

Expand Down
@@ -0,0 +1,71 @@
/*
* Tencent is pleased to support the open source community by making Angel available.
*
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* https://opensource.org/licenses/Apache-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*
*/

package com.tencent.angel.spark.examples.local

import scala.util.Random
import com.tencent.angel.spark.context.PSContext
import com.tencent.angel.spark.ml.embedding.Param
import com.tencent.angel.spark.ml.embedding.word2vec.Word2VecModel
import com.tencent.angel.spark.ml.feature.Features
import org.apache.spark.{SparkConf, SparkContext}

object Word2vecExample {

def main(args: Array[String]): Unit = {
val conf = new SparkConf()
conf.setMaster("local[1]")
conf.setAppName("Word2vec example")
conf.set("spark.ps.model", "LOCAL")
conf.set("spark.ps.jars", "")
conf.set("spark.ps.instances", "1")
conf.set("spark.ps.cores", "1")

val sc = new SparkContext(conf)
sc.setLogLevel("ERROR")
PSContext.getOrCreate(sc)

val input = "data/text8/text8.split.head"
val (corpus, _) = Features.corpusStringToInt(sc.textFile(input))
val docs = corpus.repartition(2)
docs.cache()


val numDocs = docs.count()
val maxWordId = docs.map(_.max).max().toLong + 1
val numTokens = docs.map(_.length).sum().toLong

println(s"numDocs=$numDocs maxWordId=$maxWordId numTokens=$numTokens")

val param = new Param()
param.setLearningRate(0.001f)
param.setEmbeddingDim(100)
param.setWindowSize(10)
param.setBatchSize(128)
param.setNodesNumPerRow(Some(100))
param.setSeed(Random.nextInt())
param.setNumPSPart(Some(2))
param.setNumEpoch(10)
param.setNegSample(5)
param.setMaxIndex(maxWordId)
param.setNumRowDataSet(numDocs)

val model = new Word2VecModel(param)
model.train(docs, param)
}

}
@@ -1,75 +1,153 @@
package com.tencent.angel.spark.ml.psf.embedding.CBOW;
/*
* Tencent is pleased to support the open source community by making Angel available.
*
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* https://opensource.org/licenses/Apache-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*
*/

package com.tencent.angel.spark.ml.psf.embedding.cbow;

import com.tencent.angel.PartitionKey;
import com.tencent.angel.ml.matrix.psf.get.base.*;
import io.netty.buffer.ByteBuf;
import com.tencent.angel.exception.AngelException;
import com.tencent.angel.ml.math2.storage.IntFloatDenseVectorStorage;
import com.tencent.angel.ml.matrix.psf.get.base.GetFunc;
import com.tencent.angel.ml.matrix.psf.get.base.GetResult;
import com.tencent.angel.ml.matrix.psf.get.base.PartitionGetParam;
import com.tencent.angel.ml.matrix.psf.get.base.PartitionGetResult;
import com.tencent.angel.ps.storage.matrix.ServerPartition;
import com.tencent.angel.spark.ml.psf.embedding.NENegativeSample;
import com.tencent.angel.spark.ml.psf.embedding.ServerSentences;
import it.unimi.dsi.fastutil.floats.FloatArrayList;

import java.util.Arrays;
import java.util.List;
import java.util.Random;

public class CBOWDot extends GetFunc {

class CBOWDotPartitionParam extends PartitionGetParam {

private int seed;
private int negative;
private int window;
private int partDim;
private int[][] sentences;

public CBOWDotPartitionParam(int matrixId,
int seed,
int negative,
int window,
int partDim,
int[][] sentences,
PartitionKey pkey) {
super(matrixId, pkey);
this.seed = seed;
this.negative = negative;
this.window = window;
this.partDim = partDim;
this.sentences = sentences;
}

@Override
public void serialize(ByteBuf buf) {
super.serialize(buf);
buf.writeInt(seed);
buf.writeInt(negative);
buf.writeInt(window);
buf.writeInt(partDim);
buf.writeInt(sentences.length);
for (int a = 0; a < sentences.length; a ++) {
buf.writeInt(sentences[a].length);
for (int b = 0; b < sentences[a].length; b ++)
buf.writeInt(sentences[a][b]);
}
}

@Override
public void deserialize(ByteBuf buf) {

}
}

class CBOWDotParam extends GetParam {
@Override
public List<PartitionGetParam> split() {
return super.split();
}
}
public class CbowDot extends GetFunc {

public CBOWDot(GetParam param) {
public CbowDot(CbowDotParam param) {
super(param);
}

public CbowDot() { super(null); }

@Override
public PartitionGetResult partitionGet(PartitionGetParam partParam) {
if (partParam instanceof CbowDotPartitionParam) {
CbowDotPartitionParam param = (CbowDotPartitionParam) partParam;

// some params
PartitionKey pkey = param.getPartKey();

int negative = param.negative;
int partDim = param.partDim;
int window = param.window;
int seed = param.seed;
int order = 2;

int[][] sentences = ServerSentences.getSentences(param.partitionId);

// compute number of nodes for one row
int size = (int) (pkey.getEndCol() - pkey.getStartCol());
int numNodes = size / (partDim * order);

// used to accumulate the context input vectors
float[] context = new float[partDim];

ServerPartition partition = psContext.getMatrixStorageManager().getPart(pkey);
NENegativeSample sample = new NENegativeSample(-1, seed);

Random rand = new Random(seed);
FloatArrayList partialDots = new FloatArrayList();
for (int s = 0; s < sentences.length; s ++) {
int[] sen = sentences[s];
for (int position = 0; position < sen.length; position ++) {
int word = sen[position];
// fill 0 for context vector
Arrays.fill(context, 0);
// window size
int b = rand.nextInt(window);
// Continuous bag-of-words Models
int cw = 0;

// Accumulate the input vectors from context
for (int a = b; a < window * 2 + 1 - b; a++)
if (a != window) {
int c = position - window + a;
if (c < 0) continue;
if (c >= sen.length) continue;
int sentence_word = sen[c];
if (sentence_word == -1) continue;
int rowId = sentence_word / numNodes;
int colId = (sentence_word % numNodes) * partDim * order;
float[] values = ((IntFloatDenseVectorStorage) partition.getRow(rowId)
.getSplit().getStorage()).getValues();
for (c = 0; c < partDim; c++) context[c] += values[c + colId];
cw++;
}

// Calculate the partial dot values
if (cw > 0) {
for (int c = 0; c < partDim; c ++) context[c] /= cw;
int target;
for (int d = 0; d < negative + 1; d ++) {
if (d == 0) target = word;
// We should guarantee here that the sample would not equal the ``word``
else target = sample.next(word);

int rowId = target / numNodes;
int colId = (target % numNodes) * partDim * order + partDim;
float f = 0f;
float[] values = ((IntFloatDenseVectorStorage) partition.getRow(rowId)
.getSplit().getStorage()).getValues();
for (int c = 0; c < partDim; c ++) f += context[c] * values[c + colId];
partialDots.add(f);
}
}
}
}
return new CbowDotPartitionResult(partialDots.toFloatArray());
}

return null;
}

@Override
public GetResult merge(List<PartitionGetResult> partResults) {
if (partResults.size() > 0 && partResults.get(0) instanceof CbowDotPartitionResult) {
int size = ((CbowDotPartitionResult) partResults.get(0)).length;

// check the length of dot values
for (PartitionGetResult result: partResults) {
if (result instanceof CbowDotPartitionResult &&
size != ((CbowDotPartitionResult) result).length)
throw new AngelException(String.format("length of dot values not same one is %d other is %d",
size,
((CbowDotPartitionResult) result).length));
}

// merge dot values from all partitions
float[] results = new float[size];
for (PartitionGetResult result: partResults)
if (result instanceof CbowDotPartitionResult)
try {
((CbowDotPartitionResult) result).merge(results);
} finally {
((CbowDotPartitionResult) result).clear();
}
return new CbowDotResult(results);
}

return null;
}
}
Expand Up @@ -36,7 +36,7 @@ object NENegativeSample {
this.synchronized {
if (version > this.version) {
val tableCapacity = Math.min(Math.max(maxValue / 10, 1000000), maxValue)
LOG.info("initial sample table firstly, table capacity: " + tableCapacity)
LOG.error("initial sample table firstly, table capacity: " + tableCapacity)
val rand = new Random(seed)
samples = Array.fill(tableCapacity)(rand.nextInt(maxValue))
this.version = version
Expand Down
@@ -0,0 +1,14 @@
package com.tencent.angel.spark.ml.psf.embedding;

public class ServerSentences {

public static int[][][] batches;

public static synchronized void initialize(int numPartitions) {
batches = new int[numPartitions][][];
}

public static int[][] getSentences(int partitionId) {
return batches[partitionId];
}
}

0 comments on commit 6abe556

Please sign in to comment.