Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement support for specifying training and evalation sets #442

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 192 additions & 0 deletions src/example/org/deidentifier/arx/examples/Example61.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
/*
* ARX: Powerful Data Anonymization
* Copyright 2012 - 2021 Fabian Prasser and contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.deidentifier.arx.examples;

import java.io.File;
import java.io.FilenameFilter;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.deidentifier.arx.ARXAnonymizer;
import org.deidentifier.arx.ARXClassificationConfiguration;
import org.deidentifier.arx.ARXConfiguration;
import org.deidentifier.arx.ARXResult;
import org.deidentifier.arx.AttributeType;
import org.deidentifier.arx.AttributeType.Hierarchy;
import org.deidentifier.arx.Data;
import org.deidentifier.arx.DataHandle;
import org.deidentifier.arx.DataSubset;
import org.deidentifier.arx.DataType;
import org.deidentifier.arx.aggregates.ClassificationConfigurationLogisticRegression;
import org.deidentifier.arx.criteria.Inclusion;
import org.deidentifier.arx.criteria.KAnonymity;
import org.deidentifier.arx.io.CSVHierarchyInput;
import org.deidentifier.arx.metric.Metric;

/**
* This class implements an example on how to compare data mining performance
* using a training and a test set
* @author Fabian Prasser
* @author Ibhraheem Al-Dhamari
*/
public class Example61 extends Example {

/**
* Loads a dataset from disk
* @param dataset
* @return
* @throws IOException
*/
public static Data createData(final String dataset) throws IOException {

// Load data
Data data = Data.create("data/" + dataset + ".csv", StandardCharsets.UTF_8, ';');

// Read generalization hierarchies
FilenameFilter hierarchyFilter = new FilenameFilter() {
@Override
public boolean accept(File dir, String name) {
if (name.matches(dataset + "_hierarchy_(.)+.csv")) {
return true;
} else {
return false;
}
}
};

// Create definition
File testDir = new File("data/");
File[] genHierFiles = testDir.listFiles(hierarchyFilter);
Pattern pattern = Pattern.compile("_hierarchy_(.*?).csv");
for (File file : genHierFiles) {
Matcher matcher = pattern.matcher(file.getName());
if (matcher.find()) {
CSVHierarchyInput hier = new CSVHierarchyInput(file, StandardCharsets.UTF_8, ';');
String attributeName = matcher.group(1);
data.getDefinition().setAttributeType(attributeName, Hierarchy.create(hier.getHierarchy()));
}
}

return data;
}

/**
* Gets a set of random record indices for this dataset
* @param data
* @param sampleFraction
* @return
*/
public static Set<Integer> getRandomSample(Data data, double sampleFraction) {

// Create list
int rows = data.getHandle().getNumRows();
List<Integer> list = new ArrayList<>();
for (int i = 0; i < rows; ++i) {
list.add(i);
}

// Shuffle
Collections.shuffle(list, new Random(0xDEADBEEF));

// Select sample and create set
return new HashSet<Integer>(list.subList(0, (int) Math.round((double) rows * sampleFraction)));
}

/**
* Entry point.
*
* @param args the arguments
* @throws ParseException
* @throws IOException
*/
public static void main(String[] args) throws ParseException, IOException {

Data data = createData("adult");
data.getDefinition().setAttributeType("marital-status", AttributeType.INSENSITIVE_ATTRIBUTE);
data.getDefinition().setDataType("age", DataType.INTEGER);
data.getDefinition().setResponseVariable("marital-status", true);

// Size of training set
double trainingSetSize = 0.8d;

// Create sample
Set<Integer> trainingSetIndices = getRandomSample(data, trainingSetSize);
DataSubset trainingSet = DataSubset.create(data, trainingSetIndices);

// Configure anonymization
ARXAnonymizer anonymizer = new ARXAnonymizer();
ARXConfiguration config = ARXConfiguration.create();
config.addPrivacyModel(new KAnonymity(5));
config.addPrivacyModel(new Inclusion(trainingSet));
config.setSuppressionLimit(1d);
config.setQualityModel(Metric.createClassificationMetric());

// Start anonymization process
ARXResult result = anonymizer.anonymize(data, config);
DataHandle output = result.getOutput();

// Run evaluation using k-fold cross validation
System.out.println("----------------------------------------");
System.out.println("Evaluation using k-fold cross validation");
System.out.println("----------------------------------------");
evaluate(output, false);

// Run evaluation using test/training set
System.out.println("--------------------------------------");
System.out.println("Evaluation using test and training set");
System.out.println("--------------------------------------");
evaluate(output, true);
}

/**
* Run evaluations
* @param data
* @param useTestTrainingSet
* @throws ParseException
*/
private static void evaluate(DataHandle data, boolean useTestTrainingSet) throws ParseException {

// Specify
String[] features = new String[] {
"sex",
"age",
"race",
"education",
"native-country",
"workclass",
"occupation",
"salary-class"
};

String clazz = "marital-status";

// Perform measurement
ClassificationConfigurationLogisticRegression logisticClassifier = ARXClassificationConfiguration.createLogisticRegression();
logisticClassifier.setUseTrainingTestSet(useTestTrainingSet);
System.out.println(data.getStatistics().getClassificationPerformance(features, clazz, logisticClassifier));
}
}
20 changes: 18 additions & 2 deletions src/gui/org/deidentifier/arx/gui/model/ModelClassification.java
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ public boolean isModified() {
* @param configCurrent
*/
public void setCurrentConfiguration(ARXClassificationConfiguration<?> configCurrent){
if (!(configCurrent == config ||
configCurrent == configNaiveBayes ||
configCurrent == configRandomForest )) {
throw new IllegalArgumentException("Unknown configuration object");
}
this.configCurrent = configCurrent;
}

Expand All @@ -143,7 +148,7 @@ public void setMaxRecords(Integer t) {
this.configRandomForest.setMaxRecords(t);
this.setModified();
}

/**
* TODO: Ugly hack to set base-parameters for all methods
* @param t
Expand All @@ -154,7 +159,7 @@ public void setNumFolds(Integer t) {
this.configRandomForest.setNumFolds(t);
this.setModified();
}

/**
* Sets a feature scaling function
* @param attribute
Expand All @@ -175,6 +180,17 @@ public void setUnmodified() {
getRandomForestConfiguration().setUnmodified();
}

/**
* TODO: Ugly hack to set base-parameters for all methods
* @param value
*/
public void setUseTrainingTestSet(boolean value) {
this.config.setUseTrainingTestSet(value);
this.configNaiveBayes.setUseTrainingTestSet(value);
this.configRandomForest.setUseTrainingTestSet(value);
this.setModified();
}

/**
* TODO: Ugly hack to set base-parameters for all methods
* @param t
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -1048,6 +1048,7 @@ DialogProperties.19=Vector length
DialogProperties.2=Performance
DialogProperties.20=Prior function
DialogProperties.21=Configuration
DialogProperties.22=Use test and training set
DialogProperties.3=Visualization
DialogProperties.4=Default
DialogProperties.5=Metadata
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,10 @@ private void createTabUtility(PreferencesDialog window) {
window.addPreference(new PreferenceInteger(Resources.getMessage("DialogProperties.19"), 10, Integer.MAX_VALUE, ARXClassificationConfiguration.DEFAULT_VECTOR_LENGTH) { //$NON-NLS-1$
protected Integer getValue() { return model.getClassificationModel().getCurrentConfiguration().getVectorLength(); }
protected void setValue(Object t) { model.getClassificationModel().setVectorLength((Integer)t); }});

window.addPreference(new PreferenceBoolean(Resources.getMessage("DialogProperties.22"), ARXClassificationConfiguration.DEFAULT_TEST_TRAINING_SET, true) { //$NON-NLS-1$
protected Boolean getValue() { return model.getClassificationModel().getCurrentConfiguration().isUseTrainingTestSet(); }
protected void setValue(Object t) { model.getClassificationModel().setUseTrainingTestSet((Boolean)t); }});
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import org.deidentifier.arx.ARXClassificationConfiguration;
import org.deidentifier.arx.ARXFeatureScaling;
import org.deidentifier.arx.DataHandle;
import org.deidentifier.arx.aggregates.StatisticsBuilderInterruptible;
import org.deidentifier.arx.aggregates.StatisticsClassification;
import org.deidentifier.arx.aggregates.StatisticsClassification.ROCCurve;
Expand Down Expand Up @@ -999,13 +1000,32 @@ protected void doReset() {
@Override
protected void doUpdate(final AnalysisContextClassification context) {

// The statistics builder
final StatisticsBuilderInterruptible builder = context.handle.getStatistics().getInterruptibleInstance();
// Classification configuration
final String[] features = context.model.getSelectedFeaturesAsArray();
final String[] targetVariables = context.model.getSelectedClassesAsArray();
final ARXClassificationConfiguration<?> config = context.model.getClassificationModel().getCurrentConfiguration();
final ARXFeatureScaling scaling = context.model.getClassificationModel().getFeatureScaling();

// Make sure that an analysis is done through the UI, even if when training/test set is selected and
// the configuration is non-optimal
DataHandle handle = context.handle;
ARXClassificationConfiguration<?> config = context.model.getClassificationModel().getCurrentConfiguration();
if (config.isUseTrainingTestSet() && !handle.isSubsetAvailable()) {

// Try to fix by switching to the superset
if (handle.isSupersetAvailable()) {
handle = handle.getSupersetHandle();

// Fix by switching to k-fold cross validation
} else {
config = config.clone();
config.setUseTrainingTestSet(false);
}
}

// Obtain statistics builder
final StatisticsBuilderInterruptible builder = handle.getStatistics().getInterruptibleInstance();
final ARXClassificationConfiguration<?> _config = config;

// Break, if nothing do
if (context.model.getSelectedFeatures().isEmpty() ||
context.model.getSelectedClasses().isEmpty()) {
Expand Down Expand Up @@ -1114,7 +1134,7 @@ public void run() throws InterruptedException {
// Compute
StatisticsClassification result = builder.getClassificationPerformance(features,
targetVariable,
config,
_config,
scaling);
progress++;
if (stopped) {
Expand Down
Loading