-
Notifications
You must be signed in to change notification settings - Fork 376
/
BuildModel.java
140 lines (124 loc) · 6.07 KB
/
BuildModel.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
/*
* Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Amazon Software License (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/asl/
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express
* or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package com.amazonaws.samples.machinelearning;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
import com.amazonaws.services.machinelearning.AmazonMachineLearningClient;
import com.amazonaws.services.machinelearning.model.CreateDataSourceFromS3Request;
import com.amazonaws.services.machinelearning.model.CreateEvaluationRequest;
import com.amazonaws.services.machinelearning.model.CreateMLModelRequest;
import com.amazonaws.services.machinelearning.model.MLModelType;
import com.amazonaws.services.machinelearning.model.S3DataSpec;
/**
* This class demonstrates all the steps needed to build an ML Model for
* the targeted marketing example in the Getting Started Guide for
* Amazon Machine Learning.
*/
public class BuildModel {
public static void main(String[] args) throws IOException {
String trainingDataUrl = "s3://aml-sample-data/banking.csv";
String schemaFilename = "banking.csv.schema";
String recipeFilename = "recipe.json";
String friendlyEntityName = "Java Marketing Sample";
BuildModel builder = new BuildModel(friendlyEntityName, trainingDataUrl, schemaFilename, recipeFilename);
builder.build();
}
private AmazonMachineLearningClient client;
private String friendlyEntityName;
private String trainDataSourceId;
private String testDataSourceId;
private String mlModelId;
private String evaluationId;
private int trainPercent=70;
private String trainingDataUrl;
private String schemaFilename;
private String recipeFilename;
public BuildModel(String friendlyName, String trainingDataUrl, String schemaFilename, String recipeFilename) {
this.client = new AmazonMachineLearningClient();
this.friendlyEntityName = friendlyName;
this.trainingDataUrl = trainingDataUrl;
this.schemaFilename = schemaFilename;
this.recipeFilename = recipeFilename;
}
private void build() throws IOException {
createDataSources();
createModel();
createEvaluation();
}
private void createDataSources() throws IOException {
trainDataSourceId = Identifiers.newDataSourceId();
// trainDataSourceId = "ds-" + UUID.randomUUID().toString(); // simpler, a bit more ugly
createDataSource(trainDataSourceId, friendlyEntityName + " - training data", 0, trainPercent);
testDataSourceId = Identifiers.newDataSourceId();
// testDataSourceId = "ds-" + UUID.randomUUID().toString(); // simpler, a bit more ugly
createDataSource(testDataSourceId, friendlyEntityName + " - testing data", trainPercent, 100);
}
private void createDataSource(String entityId, String entityName, int percentBegin, int percentEnd) throws IOException {
String dataSchema = Util.loadFile(schemaFilename);
String dataRearrangementString = "{\"splitting\":{\"percentBegin\":"+percentBegin+",\"percentEnd\":"+percentEnd+"}}";
CreateDataSourceFromS3Request request = new CreateDataSourceFromS3Request()
.withDataSourceId(entityId)
.withDataSourceName(entityName)
.withComputeStatistics(true);
S3DataSpec dataSpec = new S3DataSpec()
.withDataLocationS3(trainingDataUrl)
.withDataRearrangement(dataRearrangementString)
.withDataSchema(dataSchema);
request.setDataSpec(dataSpec);
client.createDataSourceFromS3(request);
System.out.printf("Created DataSource %s with id %s\n", entityName, entityId);
}
/**
* Creates an ML Model object, which begins the training process.
* The quality of the model that the training algorithm produces depends
* primarily on the data, but also on the hyper-parameters specified in
* the parameters map, and the feature-processing recipe.
* @throws IOException
*/
private void createModel() throws IOException {
mlModelId = Identifiers.newMLModelId();
// mlModelId = "ml-" + UUID.randomUUID().toString(); // simpler, a bit more ugly
Map<String, String> parameters = new HashMap<String,String>();
parameters.put("sgd.maxPasses", "100");
parameters.put("sgd.maxMLModelSizeInBytes", "104857600"); // 100 MiB
parameters.put("sgd.l2RegularizationAmount", "1e-4");
CreateMLModelRequest request = new CreateMLModelRequest()
.withMLModelId(mlModelId)
.withMLModelName(friendlyEntityName + " model")
.withMLModelType(MLModelType.BINARY)
.withParameters(parameters)
.withRecipe(Util.loadFile(recipeFilename))
.withTrainingDataSourceId(trainDataSourceId);
client.createMLModel(request);
System.out.printf("Created ML Model with id %s\n", mlModelId);
}
/**
* Creates an Evaluation, which measures the quality of the ML Model
* by seeing how many predictions it gets correct, when run on a
* held-out sample (30%) of the original data.
*/
private void createEvaluation() {
evaluationId = Identifiers.newEvaluationId();
// evaluationId = "ev-" + UUID.randomUUID().toString(); // simpler, a bit more ugly
CreateEvaluationRequest request = new CreateEvaluationRequest()
.withEvaluationDataSourceId(testDataSourceId)
.withEvaluationName(friendlyEntityName + " evaluation")
.withMLModelId(mlModelId);
client.createEvaluation(request);
System.out.printf("Created Evaluation with id %s\n", evaluationId);
}
}