From 2d036aa76d5365ad9a4a3b4d3272232369f114f6 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 12 Jul 2017 13:37:31 +0900 Subject: [PATCH] hotfix --- docs/gitbook/spark/binaryclass/a9a_df.md | 13 +++++++++---- docs/gitbook/spark/regression/e2006_df.md | 13 +++++++++---- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/docs/gitbook/spark/binaryclass/a9a_df.md b/docs/gitbook/spark/binaryclass/a9a_df.md index 74f2705fa..88229e321 100644 --- a/docs/gitbook/spark/binaryclass/a9a_df.md +++ b/docs/gitbook/spark/binaryclass/a9a_df.md @@ -31,10 +31,15 @@ $ wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/a9a.t ```scala scala> :paste -val trainDf = spark.read.format("libsvm").load("a9a") - .select( +val rawTrainDf = spark.read.format("libsvm").load("a9a") + +val (max, min) = rawTrainDf.select(max($"label"), min($"label")).collect.map { + case Row(max: Double, min: Double) => (max, min) +} + +val trainDf = rawTrainDf.select( // `label` must be [0.0, 1.0] - rescale($"label", lit(-1.0f), lit(1.0f)).as("label"), + rescale($"label", lit(min), lit(max)).as("label"), $"features" ) @@ -45,7 +50,7 @@ root scala> :paste val testDf = spark.read.format("libsvm").load("a9a.t") - .select(rowid(), rescale($"label", lit(-1.0f), lit(1.0f)).as("label"), $"features") + .select(rowid(), rescale($"label", lit(min), lit(max)).as("label"), $"features") .explode_vector($"features") .select($"rowid", $"label".as("target"), $"feature", $"weight".as("value")) .cache diff --git a/docs/gitbook/spark/regression/e2006_df.md b/docs/gitbook/spark/regression/e2006_df.md index 5980e3e92..d6ac13825 100644 --- a/docs/gitbook/spark/regression/e2006_df.md +++ b/docs/gitbook/spark/regression/e2006_df.md @@ -31,10 +31,15 @@ $ wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression/E2006.t ```scala scala> :paste -val trainDf = spark.read.format("libsvm").load("E2006.train.bz2") - .select( +val rawTrainDf = spark.read.format("libsvm").load("E2006.train.bz2") + +val (max, min) = rawTrainDf.select(max($"label"), min($"label")).collect.map { + case Row(max: Double, min: Double) => (max, min) +} + +val trainDf = rawTrainDf.select( // `label` must be [0.0, 1.0] - rescale($"label", lit(-7.899578f), lit(-0.51940954f)).as("label"), + rescale($"label", lit(min), lit(max).as("label"), $"features" ) @@ -45,7 +50,7 @@ root scala> :paste val testDf = spark.read.format("libsvm").load("E2006.test.bz2") - .select(rowid(), rescale($"label", lit(-7.899578f), lit(-0.51940954f)).as("label"), $"features") + .select(rowid(), rescale($"label", lit(min), lit(max)).as("label"), $"features") .explode_vector($"features") .select($"rowid", $"label".as("target"), $"feature", $"weight".as("value")) .cache