Skip to content
This repository has been archived by the owner on Sep 20, 2022. It is now read-only.

[HIVEMALL-218] Fixed train_lda NPE where input row is null #164

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -57,6 +57,8 @@
import org.apache.hadoop.mapred.Counters;
import org.apache.hadoop.mapred.Reporter;

import com.google.common.base.Preconditions;

public abstract class ProbabilisticTopicModelBaseUDTF extends UDTFWithOptions {
private static final Log logger = LogFactory.getLog(ProbabilisticTopicModelBaseUDTF.class);

Expand Down Expand Up @@ -159,11 +161,17 @@ public void process(Object[] args) throws HiveException {
this.model = createModel();
}

final int length = wordCountsOI.getListLength(args[0]);
Preconditions.checkArgument(args.length >= 1);
Object arg0 = args[0];
if (arg0 == null) {
return;
}

final int length = wordCountsOI.getListLength(arg0);
final String[] wordCounts = new String[length];
int j = 0;
for (int i = 0; i < length; i++) {
Object o = wordCountsOI.getListElement(args[0], i);
Object o = wordCountsOI.getListElement(arg0, i);
if (o == null) {
throw new HiveException("Given feature vector contains invalid null elements");
}
Expand Down Expand Up @@ -268,17 +276,17 @@ private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull NioStateful

@Override
public void close() throws HiveException {
if (model.getDocCount() == 0L) {
this.model = null;
throw new HiveException("No training exmples to learn. Please revise input data.");
}
finalizeTraining();
forwardModel();
this.model = null;
}

@VisibleForTesting
void finalizeTraining() throws HiveException {
if (model.getDocCount() == 0L) {
this.model = null;
return;
}
if (miniBatchCount > 0) { // update for remaining samples
model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount));
}
Expand Down Expand Up @@ -462,6 +470,9 @@ protected void forwardModel() throws HiveException {
topicIdx.set(k);

final SortedMap<Float, List<String>> topicWords = model.getTopicWords(k);
if (topicWords == null) {
continue;
}
for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) {
score.set(e.getKey().floatValue());
for (String v : e.getValue()) {
Expand Down