From c55ca9b5a5dbae83a02141db848f22771542b132 Mon Sep 17 00:00:00 2001 From: Joseph Ramsey Date: Tue, 31 Oct 2023 13:55:30 -0400 Subject: [PATCH 01/24] Update INSTALL_APPLICATION.md --- INSTALL_APPLICATION.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/INSTALL_APPLICATION.md b/INSTALL_APPLICATION.md index 24fe1297a8..3158b22fd5 100644 --- a/INSTALL_APPLICATION.md +++ b/INSTALL_APPLICATION.md @@ -6,7 +6,7 @@ Please use a recent Java JDK. See [Setting up Java for Tetrad](https://github.co To download the Tetrad jar, please click the following link (which will always be updated to the latest version): -https://s01.oss.sonatype.org/content/repositories/releases/io/github/cmu-phil/tetrad-gui/7.5.0/tetrad-gui-7.5.0-launch.jar +https://s01.oss.sonatype.org/content/repositories/releases/io/github/cmu-phil/tetrad-gui/7.6.0/tetrad-gui-7.6.0-launch.jar You may be able to launch this jar by double-clicking the jar file name. However, on a Mac, this presents some security challenges. On all platforms, the jar may be launched at the command line (with a specification of the amount of RAM you will allow it to use) using this command: From 546a17c2d6a0d1ce6696f892a2f4837ca1966643 Mon Sep 17 00:00:00 2001 From: Joseph Ramsey Date: Tue, 31 Oct 2023 14:33:29 -0400 Subject: [PATCH 02/24] Update README.md --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index 1d07814678..96f161e71a 100644 --- a/README.md +++ b/README.md @@ -10,8 +10,6 @@ See out insructions for [Installing the Tetrad Application](https://github.com/c We have a project, [py-tetrad](https://github.com/cmu-phil/py-tetrad), that allows you to incorporate arbitrary Tetrad code into a Python workflow. It's new, and the installation is still nonstandard, but it had a good response. This requires Python 3.5+. and Java JDK 9+. -Please see our [description](https://sites.google.com/view/tetradcausal/tetrad-in-python - ## Tetrad in R We also have a project, [rpy-tetrad](https://github.com/cmu-phil/py-tetrad/tree/main/pytetrad/R), that allows you to incorporate _some_ Tetrad functionality in R. It's also new, and the installation for it is also still nonstandard, but has gotten good feedback. This requires Python 3.5+ and Java JDK 9+. From 707851b3cc7757cc1495981fcbb30508d1f25169 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Tue, 7 Nov 2023 07:45:07 -0500 Subject: [PATCH 03/24] Added shading to the Tetrad lib jar. --- tetrad-lib/pom.xml | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tetrad-lib/pom.xml b/tetrad-lib/pom.xml index 79468aefd3..e0fed958cc 100644 --- a/tetrad-lib/pom.xml +++ b/tetrad-lib/pom.xml @@ -22,6 +22,35 @@ 1.8 + + org.apache.maven.plugins + maven-shade-plugin + 3.1.0 + + + package + + shade + + + + + + + all-permissions + ${project.name} + ${project.version} + + + + true + shaded + + + + + maven-antrun-plugin 3.1.0 From 71dd202556dbd7deca07f9aff3602c53198ecc8d Mon Sep 17 00:00:00 2001 From: Bryan Andrews Date: Tue, 7 Nov 2023 19:23:26 -0600 Subject: [PATCH 04/24] fixed error that was always forbidding all edges within each tier --- tetrad-lib/src/main/java/edu/cmu/tetrad/data/Knowledge.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/Knowledge.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/Knowledge.java index d023bd4dc4..1dbef102a8 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/Knowledge.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/Knowledge.java @@ -717,7 +717,7 @@ public List getListOfForbiddenEdges() { } for (int i = this.tierSpecs.size() - 1; i >= 0; i--) { - for (int j = i; j >= 0; j--) { + for (int j = i - 1; j >= 0; j--) { Set tieri = this.tierSpecs.get(i); Set tierj = this.tierSpecs.get(j); From 12e2676028cf1ed85192e8767d9b9607a8ca4b53 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Wed, 8 Nov 2023 14:17:26 -0500 Subject: [PATCH 05/24] Added shading to the Tetrad lib jar. --- .../src/main/java/edu/cmu/tetradapp/workbench/LayoutUtils.java | 2 +- tetrad-lib/src/main/java/edu/cmu/tetrad/graph/LayoutUtil.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/LayoutUtils.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/LayoutUtils.java index 59f6984a47..d1a3252e2e 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/LayoutUtils.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/LayoutUtils.java @@ -499,7 +499,7 @@ public static void circleLayout(LayoutEditable layoutEditable) { int m = FastMath.min(r.width, r.height) / 2; - LayoutUtil.defaultLayout(graph); + LayoutUtil.circleLayout(graph); layoutEditable.layoutByGraph(graph); LayoutUtils.layout = Layout.circle; } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/LayoutUtil.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/LayoutUtil.java index c5bd1ac995..dbb476ca16 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/LayoutUtil.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/LayoutUtil.java @@ -45,7 +45,7 @@ public static void defaultLayout(Graph graph) { * Arranges the nodes in the graph in a circle. * @param graph the graph to be arranged. */ - private static void circleLayout(Graph graph) { + public static void circleLayout(Graph graph) { if (graph == null) { return; } From 2270906dced0f0d3cf9cfd0853af9448fa8c35d6 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Wed, 8 Nov 2023 14:22:46 -0500 Subject: [PATCH 06/24] Fixed the circle layout in the interface. --- .../src/main/java/edu/cmu/tetradapp/workbench/LayoutUtils.java | 2 +- tetrad-lib/src/main/java/edu/cmu/tetrad/graph/LayoutUtil.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/LayoutUtils.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/LayoutUtils.java index 59f6984a47..d1a3252e2e 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/LayoutUtils.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/LayoutUtils.java @@ -499,7 +499,7 @@ public static void circleLayout(LayoutEditable layoutEditable) { int m = FastMath.min(r.width, r.height) / 2; - LayoutUtil.defaultLayout(graph); + LayoutUtil.circleLayout(graph); layoutEditable.layoutByGraph(graph); LayoutUtils.layout = Layout.circle; } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/LayoutUtil.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/LayoutUtil.java index c5bd1ac995..dbb476ca16 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/LayoutUtil.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/LayoutUtil.java @@ -45,7 +45,7 @@ public static void defaultLayout(Graph graph) { * Arranges the nodes in the graph in a circle. * @param graph the graph to be arranged. */ - private static void circleLayout(Graph graph) { + public static void circleLayout(Graph graph) { if (graph == null) { return; } From b519b7309618b23bed985282410cfad489a61d07 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Wed, 8 Nov 2023 15:09:04 -0500 Subject: [PATCH 07/24] Updated version to 7.6.1-SNAPSHOT --- data-reader/pom.xml | 2 +- pom.xml | 2 +- tetrad-gui/pom.xml | 2 +- tetrad-lib/pom.xml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/data-reader/pom.xml b/data-reader/pom.xml index 9810fc78a3..89662ead25 100644 --- a/data-reader/pom.xml +++ b/data-reader/pom.xml @@ -5,7 +5,7 @@ io.github.cmu-phil tetrad - 7.6.0-SNAPSHOT + 7.6.1-SNAPSHOT data-reader diff --git a/pom.xml b/pom.xml index 64319b0487..05caf45039 100644 --- a/pom.xml +++ b/pom.xml @@ -4,7 +4,7 @@ 4.0.0 io.github.cmu-phil tetrad - 7.6.0-SNAPSHOT + 7.6.1-SNAPSHOT pom Tetrad Project diff --git a/tetrad-gui/pom.xml b/tetrad-gui/pom.xml index 91f56bf4f5..820d99b993 100644 --- a/tetrad-gui/pom.xml +++ b/tetrad-gui/pom.xml @@ -6,7 +6,7 @@ io.github.cmu-phil tetrad - 7.6.0-SNAPSHOT + 7.6.1-SNAPSHOT tetrad-gui diff --git a/tetrad-lib/pom.xml b/tetrad-lib/pom.xml index 79468aefd3..09043fe6a1 100644 --- a/tetrad-lib/pom.xml +++ b/tetrad-lib/pom.xml @@ -6,7 +6,7 @@ io.github.cmu-phil tetrad - 7.6.0-SNAPSHOT + 7.6.1-SNAPSHOT tetrad-lib From de5775795d177948509cc670bcdcccaa5366d0a6 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Wed, 8 Nov 2023 15:12:18 -0500 Subject: [PATCH 08/24] Fixed test. --- tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFci.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFci.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFci.java index 5141a62e56..61c99d1c38 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFci.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFci.java @@ -161,7 +161,7 @@ public void testSearch11() { knowledge.addToTier(2, "X3"); checkSearch("Latent(L1),Latent(L2),L1-->X1,L1-->X2,L2-->X2,L2-->X3", - "X1<->X2,X2<->X3", knowledge); + "X1o->X2,X2<->X3", knowledge); } @Test From 063d8e12342e28c892537c410cdfbe45a38fe127 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Thu, 9 Nov 2023 09:39:27 -0500 Subject: [PATCH 09/24] Added the demixer code. --- tetrad-lib/dependency-reduced-pom.xml | 62 +++ .../java/edu/cmu/tetrad/search/Demixer.java | 266 +++++++++++ .../edu/cmu/tetrad/search/DemixerMMLKun.java | 418 ++++++++++++++++++ .../edu/cmu/tetrad/search/MixtureModel.java | 207 +++++++++ 4 files changed, 953 insertions(+) create mode 100644 tetrad-lib/dependency-reduced-pom.xml create mode 100644 tetrad-lib/src/main/java/edu/cmu/tetrad/search/Demixer.java create mode 100644 tetrad-lib/src/main/java/edu/cmu/tetrad/search/DemixerMMLKun.java create mode 100644 tetrad-lib/src/main/java/edu/cmu/tetrad/search/MixtureModel.java diff --git a/tetrad-lib/dependency-reduced-pom.xml b/tetrad-lib/dependency-reduced-pom.xml new file mode 100644 index 0000000000..1a34c60792 --- /dev/null +++ b/tetrad-lib/dependency-reduced-pom.xml @@ -0,0 +1,62 @@ + + + + tetrad + io.github.cmu-phil + 7.6.0-SNAPSHOT + + 4.0.0 + tetrad-lib + + + + org.apache.maven.wagon + wagon-ssh + 2.10 + + + + + maven-compiler-plugin + 3.11.0 + + 1.8 + 1.8 + + + + maven-shade-plugin + 3.1.0 + + + package + + shade + + + + + + maven-antrun-plugin + 3.1.0 + + + compile + + run + + + + + + + + + + + + + UTF-8 + + + diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Demixer.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Demixer.java new file mode 100644 index 0000000000..b3854ed823 --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Demixer.java @@ -0,0 +1,266 @@ +package edu.cmu.tetrad.search; + +import edu.cmu.tetrad.data.DataSet; +import edu.cmu.tetrad.util.Matrix; +import edu.cmu.tetrad.util.Vector; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Random; + +/** + * Uses expectation-maximization to sort a a data set with data sampled from two or more multivariate Gaussian + * distributions into its component data sets. + * + * @author Madelyn Glymour + */ +public class Demixer { + + private final int numVars; + private final int numCases; + private final int numClusters; // number of clusters + private final DataSet data; + private final double[][] dataArray; // v-by-n data matrix + private final Matrix[] variances; + private final double[][] meansArray; // k-by-v matrix representing means for each variable for each of k models + private final Matrix[] variancesArray; // k-by-v-by-v matrix representing covariance matrix for each of k models + private final double[] weightsArray; // array of length k representing weights for each model + private final double[][] gammaArray; // k-by-n matrix representing gamma for each data case in each model + private boolean demixed = false; + + public Demixer(DataSet data, int k) { + this.numClusters = k; + this.data = data; + dataArray = data.getDoubleData().toArray(); + numVars = data.getNumColumns(); + numCases = data.getNumRows(); + meansArray = new double[k][numVars]; + weightsArray = new double[k]; + variancesArray = new Matrix[k]; + variances = new Matrix[k]; + gammaArray = new double[k][numCases]; + + Random rand = new Random(); + + // initialize the means array to the mean of each variable plus noise + for (int i = 0; i < numVars; i++) { + for (int j = 0; j < k; j++) { + meansArray[j][i] = calcMean(data.getDoubleData().getColumn(i)) + (rand.nextGaussian()); + } + } + + // initialize the weights array uniformly + for (int i = 0; i < k; i++) { + weightsArray[i] = Math.abs((1.0 / k)); + } + + // initialize the covariance matrix array to the actual covariance matrix + for (int i = 0; i < k; i++) { + variances[i] = data.getCovarianceMatrix(); + } + } + + /* + * Runs the E-M algorithm iteratively until the weights array converges. Returns a MixtureModel object containing + * the final values of the means, covariance matrices, weights, and gammas arrays. + */ + public MixtureModel demix() { + double[] tempWeights = new double[numClusters]; + + System.arraycopy(weightsArray, 0, tempWeights, 0, numClusters); + + boolean weightsUnequal = true; + ArrayList diffsList; + int iterCounter = 0; + + System.out.println("Weights: " + Arrays.toString(weightsArray)); + + // convergence check + while (weightsUnequal) { + expectation(); + maximization(); + + System.out.println("Weights: " + Arrays.toString(weightsArray)); + + diffsList = new ArrayList<>(); // list of differences between new weights and old weights + for (int i = 0; i < numClusters; i++) { + diffsList.add(Math.abs(weightsArray[i] - tempWeights[i])); + } + + Collections.sort(diffsList); // sort the list + + // if the largest difference is below the threshold, or we've passed 100 iterations, converge + if (diffsList.get(numClusters - 1) < 0.0001 || iterCounter > 100) { + weightsUnequal = false; + } + + // new weights are now the old weights + System.arraycopy(weightsArray, 0, tempWeights, 0, numClusters); + + iterCounter++; + } + + MixtureModel model = new MixtureModel(data, dataArray, meansArray, weightsArray, variancesArray, gammaArray); + demixed = true; + + return model; + + } + + /* + * Returns true if the algorithm has been run, and the gamma, mean, and covariance arrays are at their stable values + */ + public boolean isDemixed() { + return demixed; + } + + /* + * Computes the probability that each case belongs to each model (the gamma), given the current values of the mean, + * weight, and covariance arrays + */ + private void expectation() { + + double gamma; + double divisor; + + for (int i = 0; i < numClusters; i++) { + for (int j = 0; j < numCases; j++) { + gamma = weightsArray[i] * normalPDF(j, i); + divisor = gamma; + + for (int w = 0; w < numClusters; w++) { + if (w != i) { + divisor += (weightsArray[w] * normalPDF(j, w)); + } + } + gamma = gamma / divisor; + gammaArray[i][j] = gamma; + } + } + } + + /* + * Estimates the means, covariances, and weight of each model, given the current values of the gamma array + */ + private void maximization() { + + // the weight of each model is the sum of the gamma for each case in that model, divided by the number of cases + double weight; + + for (int i = 0; i < numClusters; i++) { + weight = 0; + for (int j = 0; j < numCases; j++) { + weight += gammaArray[i][j]; + } + weight = weight / numCases; + weightsArray[i] = weight; + } + + // the mean for each variable in each model is determined by the weighted mean of that variable in the model + // (where each case i in the variable in model k is weighted by the gamma(i, k) + double meanNumerator; + double meanDivisor; + double mean; + + for (int i = 0; i < numClusters; i++) { + for (int v = 0; v < numVars; v++) { + meanNumerator = 0; + meanDivisor = 0; + for (int j = 0; j < numCases; j++) { + + meanNumerator += gammaArray[i][j] * dataArray[j][v]; + meanDivisor += gammaArray[i][j]; + } + mean = meanNumerator / meanDivisor; + meansArray[i][v] = mean; + } + } + + // the covariance matrix for each model is determined by the covariance matrix of the data, weighted by the + // gamma values for that model + double var; + + for (int i = 0; i < numClusters; i++) { + for (int v = 0; v < numVars; v++) { + for (int v2 = v; v2 < numVars; v2++) { + var = getVar(i, v, v2, numCases, gammaArray, dataArray, meansArray); + // if(Math.abs(var) >= 0.5) { + variancesArray[i].set(v, v2, var); + variancesArray[i].set(v2, v, var); + + // Reset the variances if things start to go awry with the algorithm; turns out not to be necessary + // } else{ + // Random rand = new Random(); + // double temp = 0.5 + rand.nextDouble(); + // variancesArray[i][v][v2] = temp; + // variancesArray[i][v2][v] = temp; + // } + } + } + variances[i] = new Matrix(variancesArray[i]); + } + + } + + static double getVar(int i, int v, int v2, int numCases, double[][] gammaArray, double[][] dataArray, double[][] meansArray) { + double varNumerator; + double varDivisor; + double var; + varNumerator = 0; + varDivisor = 0; + + for (int j = 0; j < numCases; j++) { + varNumerator += gammaArray[i][j] * (dataArray[j][v] - meansArray[i][v]) * (dataArray[j][v2] - meansArray[i][v2]); + varDivisor += gammaArray[i][j]; + } + + var = varNumerator / varDivisor; + return var; + } + + /* + * For an input case and model, returns the value of the model's normal PDF for that case, using the current + * estimations of the means and covariance matrix + */ + private double normalPDF(int caseIndex, int weightIndex) { + Matrix cov = variances[weightIndex]; + + Matrix covIn = cov.inverse(); + double[] mu = meansArray[weightIndex]; + double[] thisCase = dataArray[caseIndex]; + + double[][] diffs = new double[1][numVars]; + + for (int i = 0; i < numVars; i++) { + diffs[0][i] = thisCase[i] - mu[i]; + } + + Matrix diffsMatrix = new Matrix(diffs); + Matrix diffsTranspose = diffsMatrix.transpose(); + + Matrix distance = covIn.times(diffsTranspose); // inverse of the covariance matrix * (x - mu) + + distance = diffsMatrix.times(distance); // squared + + double distanceScal = distance.get(0, 0); // distance is a scalar, but in matrix representation + distanceScal = distanceScal * (-.5); + distanceScal = Math.exp(distanceScal); + distanceScal = distanceScal / Math.sqrt(2 * Math.PI * cov.det()); // exp(-.5 * distance) / sqrt(2 * pi * cov) + + return distanceScal; + } + + /* + * Returns the mean of a variable, input as a Vector + */ + private double calcMean(Vector dataPoints) { + double sum = 0; + + for (int i = 0; i < dataPoints.size(); i++) { + sum += dataPoints.get(i); + } + + return sum / dataPoints.size(); + } +} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/DemixerMMLKun.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/DemixerMMLKun.java new file mode 100644 index 0000000000..f0d4a92ac2 --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/DemixerMMLKun.java @@ -0,0 +1,418 @@ +package edu.cmu.tetrad.search; + +import edu.cmu.tetrad.cluster.KMeans; +import edu.cmu.tetrad.data.BoxDataSet; +import edu.cmu.tetrad.data.DataSet; +import edu.cmu.tetrad.data.DoubleDataBox; +import edu.cmu.tetrad.data.SimpleDataLoader; +import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetrad.util.Matrix; +import edu.cmu.tetrad.util.MatrixUtils; +import edu.pitt.dbmi.data.reader.Delimiter; + +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * Created by user on 2/27/18. + */ +public class DemixerMMLKun { + + private final double minWeight; + + public DemixerMMLKun() { + minWeight = 1e-3; + } + + public static void main(String... args) { + DataSet dataSet; + + try { + dataSet = SimpleDataLoader.loadContinuousData(new File("/Users/user/Documents/Demix_Testing/NonGaussian/sub_1500_4var_3comp.txt"), + "//", '\"', "*", true, Delimiter.TAB, false); + } catch (IOException e) { + throw new RuntimeException(e); + } + + DemixerMMLKun pedro = new DemixerMMLKun(); + long startTime = System.currentTimeMillis(); + MixtureModel model = pedro.demix(dataSet, 25); + long elapsed = System.currentTimeMillis() - startTime; + + double[] weights = model.getWeights(); + for (double weight : weights) { + System.out.print(weight + "\t"); + } + + try { + FileWriter writer = new FileWriter("/Users/user/Documents/Demix_Testing/sub_1500_4var_3comp.txt"); + BufferedWriter bufferedWriter = new BufferedWriter(writer); + + for (int i = 0; i < dataSet.getNumRows(); i++) { + bufferedWriter.write(model.getDistribution(i) + "\n"); + } + bufferedWriter.flush(); + bufferedWriter.close(); + + DataSet[] dataSets = model.getDemixedData(); + + for (int i = 0; i < dataSets.length; i++) { + writer = new FileWriter("/Users/user/Documents/Demix_Testing/sub_1500_4var_3comp_demixed_" + (i + 1) + ".txt"); + bufferedWriter = new BufferedWriter(writer); + bufferedWriter.write(dataSets[i].toString()); + bufferedWriter.flush(); + bufferedWriter.close(); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + + System.out.println("Elapsed: " + elapsed / 1000); + } + + private MixtureModel demix(DataSet data, int k) { + double[][] dataArray = data.getDoubleData().toArray(); + int numVars = data.getNumColumns(); + int numCases = data.getNumRows(); + double lambda2 = Math.sqrt((Math.log(numCases) - Math.log(Math.log(numCases))) / 2.0); + double lambda = lambda2 * ((Math.pow(numVars, 2)) * 0.5 + 1.5 * numVars + 1); // part of the MML score + double epsilon = 1e-6; // part of the MML score + double threshold = 1e-8; // threshold for MML score convergence + + System.out.println("Lambda: " + lambda); + + // Initialize clusterings with kmeans + KMeans kMeans = KMeans.randomClusters(k); + kMeans.cluster(data.getDoubleData()); + + // Use initial clusterings to get initial means, variances, and gamma arrays + double[][] meansArray = new double[k][numVars]; + double[] weightsArray = new double[k]; + Matrix[] variancesArray = new Matrix[k]; + Matrix[] variances = new Matrix[k]; + double[][] gammaArray = new double[k][numCases]; + + List> clusters = kMeans.getClusters(); + List cluster; + int clusterSize; + double[] means; + + double[][] clusterMatrixArray; + + for (int i = 0; i < clusters.size(); i++) { + cluster = clusters.get(i); + clusterSize = cluster.size(); + means = new double[numVars]; + + for (int j = 0; j < numVars; j++) { + means[j] = 0; + } + + clusterMatrixArray = new double[clusterSize][numVars]; + + for (int j = 0; j < clusterSize; j++) { + // System.out.print(Integer.toString(cluster.get(j)) + "\t"); + MatrixUtils.sum(means, dataArray[cluster.get(j)]); + clusterMatrixArray[j] = dataArray[cluster.get(j)]; + } + + // Initial mean is mean of cluster + means = MatrixUtils.scalarProduct(1.0 / clusterSize, means); + meansArray[i] = means; + + // Initial weight is percentage of rows taken up by cluster + weightsArray[i] = ((double) clusterSize) / ((double) numCases); + + // Initial covariance matrix is cov matrix of cluster, unless cluster cov matrix has 0 determinant + DoubleDataBox box = new DoubleDataBox(clusterMatrixArray); + List variables = data.getVariables(); + BoxDataSet clusterData = new BoxDataSet(box, variables); + Matrix clusterCovMatrix = clusterData.getCovarianceMatrix(); + if (MatrixUtils.determinant(clusterCovMatrix.toArray()) == 0) { + variances[i] = MatrixUtils.cholesky(data.getCovarianceMatrix()); + variancesArray[i] = data.getCovarianceMatrix(); + } else { + variances[i] = MatrixUtils.cholesky(clusterCovMatrix); + variancesArray[i] = clusterCovMatrix; + } + + } + + double gamma; + double divisor; + + for (int z = 0; z < k; z++) { + for (int j = 0; j < numCases; j++) { + gamma = weightsArray[z] * normalPDF(j, z, variances, meansArray, dataArray, numVars); + divisor = gamma; + + for (int w = 0; w < k; w++) { + if (w != z) { + divisor += (weightsArray[w] * normalPDF(j, w, variances, meansArray, dataArray, numVars)); + } + } + + //Initial gamma is weighted probability for the case in cluster k, divided by the sum of weighted probabilities in all clusters + gamma = gamma / divisor; + gammaArray[z][j] = gamma; + } + } + + // Verbose debugging output + System.out.println("Clusters: " + k); + System.out.println("Weights: " + Arrays.toString(weightsArray)); + + // oldLogL and newLogL determine convergence + double oldLogL = Double.POSITIVE_INFINITY; + double newLogL; + + DeterminingStats stats; + + while (true) { + + // maximization step + stats = innerStep(data, dataArray, weightsArray, meansArray, variancesArray, variances, gammaArray, numCases, numVars, lambda); + meansArray = stats.getMeans(); + weightsArray = stats.getWeights(); + variancesArray = stats.getVariances(); + variances = stats.getVarMatrixArray(); + + k = weightsArray.length; + + // fail if there are no clusters + if (k == 0) { + break; + } + + // verbose debugging output + System.out.println("Clusters: " + k); + System.out.println("Weights: " + Arrays.toString(weightsArray)); + + // expectation step; gamma computed as above, I should probably make a separate method for it + for (int i = 0; i < k; i++) { + + for (int j = 0; j < numCases; j++) { + + double pdf = normalPDF(j, i, variances, meansArray, dataArray, numVars); + + gamma = weightsArray[i] * pdf; + + divisor = gamma; + + for (int w = 0; w < k; w++) { + if (w != i) { + divisor += (weightsArray[w] * normalPDF(j, w, variances, meansArray, dataArray, numVars)); + } + } + gamma = gamma / divisor; + + + gammaArray[i][j] = gamma; + } + } + + // check for convergence + double mml = 0; + double gammaMean; + for (int i = 0; i < weightsArray.length; i++) { + gammaMean = 0; + for (int j = 0; j < numCases; j++) { + gammaMean += gammaArray[i][j]; + } + gammaMean /= numCases; + mml += Math.log(gammaMean); + } + + mml /= weightsArray.length; + + double weightSum = 0; + + for (double v : weightsArray) { + weightSum += Math.log(v / epsilon + 1); + } + + weightSum *= lambda / numCases; + + newLogL = mml + weightSum; + + // if oldLogL and newLogL converge, end; otherwise, set oldLogL to newLogL + if (Math.abs(oldLogL / (newLogL) - 1) < threshold) { + break; + } else { + oldLogL = newLogL; + } + + } + + return new MixtureModel(data, dataArray, meansArray, weightsArray, variancesArray, gammaArray); + } + + /** + * Performs the maximization step + */ + private DeterminingStats innerStep(DataSet data, double[][] dataArray, double[] weightsArray, double[][] meansArray, Matrix[] variancesArray, Matrix[] variances, double[][] gammaArray, int numCases, int numVars, double lambda) { + + double weight; + double pSum; // sum of all gammas for a case + double meanNumerator; + double mean; + Matrix tempVar; + + ArrayList meansList = new ArrayList<>(); + ArrayList varsLilst = new ArrayList<>(); + ArrayList varMatList = new ArrayList<>(); + + for (int i = 0; i < weightsArray.length; i++) { + + // maximize weights + pSum = 0; + for (int j = 0; j < numCases; j++) { + pSum += gammaArray[i][j]; + } + + weight = (pSum - lambda) / (numCases - (lambda * weightsArray.length)); + weightsArray[i] = weight; + + // maximize covariance matrices + tempVar = new Matrix(numVars, numVars); + + for (int v = 0; v < numVars; v++) { + + // maximize means + meanNumerator = 0; + for (int j = 0; j < numCases; j++) { + + meanNumerator += gammaArray[i][j] * dataArray[j][v]; + } + mean = meanNumerator / pSum; + meansArray[i][v] = mean; + + for (int v2 = v; v2 < numVars; v2++) { + double var = Demixer.getVar(i, v, v2, numCases, gammaArray, dataArray, meansArray); + tempVar.set(v, v2, var); + tempVar.set(v2, v, var); + } + } + + Matrix varMatrix = new Matrix(tempVar); + if (varMatrix.det() != 0) { + variancesArray[i] = MatrixUtils.cholesky(tempVar); + variances[i] = MatrixUtils.cholesky(varMatrix); + } else { + variances[i] = MatrixUtils.cholesky(data.getCovarianceMatrix()); + variancesArray[i] = data.getCovarianceMatrix(); + } + + } + + System.out.println(); + + // check weights, and remove any clusters with weights below threshold + ArrayList weightsList = new ArrayList<>(); + + for (int i = 0; i < weightsArray.length; i++) { + + if (weightsArray[i] >= minWeight) { + weightsList.add(weightsArray[i]); + meansList.add(meansArray[i]); + varsLilst.add(variancesArray[i]); + varMatList.add(variances[i]); + } + } + + double[] tempWeightsArray = new double[weightsList.size()]; + double[][] tempMeansArray = new double[weightsList.size()][numVars]; + Matrix[] tempVarsArray = new Matrix[weightsList.size()]; + Matrix[] tempVariances = new Matrix[weightsList.size()]; + for (int i = 0; i < weightsList.size(); i++) { + tempWeightsArray[i] = weightsList.get(i); + tempMeansArray[i] = meansList.get(i); + tempVarsArray[i] = varsLilst.get(i); + tempVariances[i] = varMatList.get(i); + } + + weightsArray = tempWeightsArray; + meansArray = tempMeansArray; + variancesArray = tempVarsArray; + variances = tempVariances; + + return new DeterminingStats(meansArray, weightsArray, variancesArray, variances); + } + + /** + * Returns the value of the Normal PDF for a given case if it belongs to a given cluster + */ + private double normalPDF(int caseIndex, int weightIndex, Matrix[] variances, double[][] meansArray, double[][] dataArray, int numVars) { + Matrix cov = variances[weightIndex]; + cov = cov.transpose(); + + Matrix covIn = cov.inverse(); + double[] mu = meansArray[weightIndex]; + double[] thisCase = dataArray[caseIndex]; + + double[][] diffs = new double[1][numVars]; + + for (int i = 0; i < numVars; i++) { + diffs[0][i] = thisCase[i] - mu[i]; + } + + Matrix diffsMatrix = new Matrix(diffs); + + Matrix mah = diffsMatrix.times(covIn); + + double val; + double mahScal = 0; + for (int i = 0; i < mah.getNumRows(); i++) { + for (int j = 0; j < mah.getNumColumns(); j++) { + val = mah.get(i, j); + val = val * val; + mahScal += val; + mah.set(i, j, val); + } + } + + double distanceScal = Math.pow(2 * Math.PI, -(numVars) / 2.0); + distanceScal = distanceScal / cov.det(); + distanceScal = distanceScal * Math.exp(-.5 * mahScal); + + return distanceScal; + } + + /** + * Private wrapper class for statistics to be maximized + */ + private static class DeterminingStats { + private final double[][] meansArray; + private final double[] weightsArray; + private final Matrix[] variancesArray; + private final Matrix[] varMatrixArray; + + public DeterminingStats(double[][] meansArray, double[] weightsArray, Matrix[] variancesArray, Matrix[] varMatrixArray) { + this.meansArray = meansArray; + this.weightsArray = weightsArray; + this.variancesArray = variancesArray; + this.varMatrixArray = varMatrixArray; + } + + public double[] getWeights() { + return weightsArray; + } + + public double[][] getMeans() { + return meansArray; + } + + public Matrix[] getVariances() { + return variancesArray; + } + + public Matrix[] getVarMatrixArray() { + return varMatrixArray; + } + } +} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MixtureModel.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MixtureModel.java new file mode 100644 index 0000000000..8a1516ef38 --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MixtureModel.java @@ -0,0 +1,207 @@ +package edu.cmu.tetrad.search; + +import edu.cmu.tetrad.data.BoxDataSet; +import edu.cmu.tetrad.data.CovarianceMatrixOnTheFly; +import edu.cmu.tetrad.data.DataSet; +import edu.cmu.tetrad.data.DoubleDataBox; +import edu.cmu.tetrad.search.score.SemBicScore; +import edu.cmu.tetrad.util.Matrix; + +/** + * Represents a Gaussian mixture model -- a dataset with data sampled from two or more multivariate Gaussian + * distributions. + * + * @author Madelyn Glymour + */ +public class MixtureModel { + private final DataSet data; + private final int[] cases; + private final int[] caseCounts; + private final double[][] dataArray; // v-by-n data matrix + private final double[][] meansArray; // k-by-v matrix representing means for each variable for each of k models + private final double[] weightsArray; // array of length k representing weights for each model + private final double[][] gammaArray; // k-by-n matrix representing gamma for each data case in each model + private final Matrix[] variancesArray; // k-by-v-by-v matrix representing covariance matrix for each of k models + private final int numModels; // number of models in mixture + + public MixtureModel(DataSet data, double[][] dataArray, double[][] meansArray, double[] weightsArray, Matrix[] variancesArray, double[][] gammaArray) { + this.data = data; + this.dataArray = dataArray; + this.meansArray = meansArray; + this.weightsArray = weightsArray; + this.variancesArray = variancesArray; + this.numModels = weightsArray.length; + this.gammaArray = gammaArray; + this.cases = new int[data.getNumRows()]; + + // set the individual model for each case + for (int i = 0; i < cases.length; i++) { + cases[i] = getDistribution(i); + } + + this.caseCounts = new int[numModels]; + + // count the number of cases in each individual data set + for (int i = 0; i < numModels; i++) { + caseCounts[i] = 0; + } + + for (int aCase : cases) { + for (int j = 0; j < numModels; j++) { + if (aCase == j) { + caseCounts[j]++; + break; + } + } + } + } + + /** + * @return the mixed data set in array form + */ + public double[][] getData() { + return dataArray; + } + + /** + * @return the means matrix + */ + public double[][] getMeans() { + return meansArray; + } + + /** + * @return the weights array + */ + public double[] getWeights() { + return weightsArray; + } + + /** + * @return the variance matrix + */ + public Matrix[] getVariances() { + return variancesArray; + } + + /** + * @return an array assigning each case an integer corresponding to a model + */ + public int[] getCases() { + return cases; + } + + /** + * Classifies a given case into a model, based on which model has the highest gamma value for that case. + */ + public int getDistribution(int caseNum) { + + // hard classification + int dist = 0; + double highest = 0; + + for (int i = 0; i < numModels; i++) { + if (gammaArray[i][caseNum] > highest) { + highest = gammaArray[i][caseNum]; + dist = i; + } + + } + + return dist; + + // soft classification, deprecated because it doesn't classify as well + + /*int gammaSum = 0; + + for (int i = 0; i < k; i++) { + gammaSum += gammaArray[i][caseNum]; + } + + Random rand = new Random(); + double test = gammaSum * rand.nextDouble(); + + if(test < gammaArray[0][caseNum]){ + return 0; + } + + double sum = gammaArray[0][caseNum]; + + for (int i = 1; i < k; i++){ + sum = sum+gammaArray[i][caseNum]; + if(test < sum){ + return i; + } + } + + return k - 1; */ + } + + /* + * Sort the mixed data set into its component data sets. + * + * @return a list of data sets + */ + public DataSet[] getDemixedData() { + DoubleDataBox[] dataBoxes = new DoubleDataBox[numModels]; + int[] caseIndices = new int[numModels]; + + for (int i = 0; i < numModels; i++) { + dataBoxes[i] = new DoubleDataBox(caseCounts[i], data.getNumColumns()); + caseIndices[i] = 0; + } + + int index; + DoubleDataBox box; + int count; + for (int i = 0; i < cases.length; i++) { + + // get the correct data set and corresponding case count for this case + index = cases[i]; + box = dataBoxes[index]; + count = caseIndices[index]; + + // set the [count]th row of the given data set to the ith row of the mixed data set + for (int j = 0; j < data.getNumColumns(); j++) { + box.set(count, j, data.getDouble(i, j)); + } + + dataBoxes[index] = box; //make sure that the changes get carried to the next iteration of the loop + caseIndices[index] = count + 1; //increment case count of this data set + } + + // create list of data sets + DataSet[] dataSets = new DataSet[numModels]; + for (int i = 0; i < numModels; i++) { + dataSets[i] = new BoxDataSet(dataBoxes[i], data.getVariables()); + } + + return dataSets; + } + + /** + * Perform an FGES search on each of the demixed data sets. + * + * @return the BIC scores of the graphs returned by searches. + */ + public double[] searchDemixedData() { + DataSet[] dataSets = getDemixedData(); + SemBicScore score; + edu.cmu.tetrad.search.Fges fges; + DataSet dataSet; + double bic; + double[] bicScores = new double[numModels]; + + for (int i = 0; i < numModels; i++) { + dataSet = dataSets[i]; + score = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet)); + score.setPenaltyDiscount(2.0); + fges = new edu.cmu.tetrad.search.Fges(score); + fges.search(); + bic = fges.getModelScore(); + bicScores[i] = bic; + } + + return bicScores; + } +} From 14e2dd1335383461a3e065ffabf34d5de04031b4 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Thu, 9 Nov 2023 09:40:27 -0500 Subject: [PATCH 10/24] Added the demixer code. --- .../src/main/java/edu/cmu/tetrad/search/DemixerMMLKun.java | 1 - 1 file changed, 1 deletion(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/DemixerMMLKun.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/DemixerMMLKun.java index f0d4a92ac2..75461c4c7c 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/DemixerMMLKun.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/DemixerMMLKun.java @@ -362,7 +362,6 @@ private double normalPDF(int caseIndex, int weightIndex, Matrix[] variances, dou } Matrix diffsMatrix = new Matrix(diffs); - Matrix mah = diffsMatrix.times(covIn); double val; From f3e4fc21dacb9f9af2a92cdb6fbcb3cef9c23f6f Mon Sep 17 00:00:00 2001 From: jdramsey Date: Thu, 9 Nov 2023 12:35:48 -0500 Subject: [PATCH 11/24] Adding the ordered local check to the Markov Checker. --- .../tetradapp/editor/MarkovCheckEditor.java | 24 +++++++++++-------- .../model/MarkovCheckIndTestModel.java | 2 +- .../FractionDependentUnderAlternative.java | 2 +- .../statistic/FractionDependentUnderNull.java | 2 +- .../statistic/MarkovAdequacyScore.java | 2 +- .../statistic/PvalueDistanceToAlpha.java | 2 +- .../statistic/PvalueUniformityUnderNull.java | 2 +- .../edu/cmu/tetrad/search/MarkovCheck.java | 23 +++++++++++++++--- 8 files changed, 40 insertions(+), 19 deletions(-) diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/MarkovCheckEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/MarkovCheckEditor.java index 20da5dae5e..fe92a9dbec 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/MarkovCheckEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/MarkovCheckEditor.java @@ -99,20 +99,24 @@ public MarkovCheckEditor(MarkovCheckIndTestModel model) { throw new NullPointerException("Expecting a model"); } - conditioningSetTypeJComboBox.addItem("Parents(X)"); + conditioningSetTypeJComboBox.addItem("Parents(X) (Local Markov)"); + conditioningSetTypeJComboBox.addItem("Parents(X) for a Valid Order (Ordered Local Markov)"); conditioningSetTypeJComboBox.addItem("MarkovBlanket(X)"); - conditioningSetTypeJComboBox.addItem("All Subsets"); + conditioningSetTypeJComboBox.addItem("All Subsets (Global Markov)"); conditioningSetTypeJComboBox.addActionListener(e -> { switch ((String) Objects.requireNonNull(conditioningSetTypeJComboBox.getSelectedItem())) { - case "Parents(X)": - model.getMarkovCheck().setSetType(MarkovCheck.ConditioningSetType.PARENTS); + case "\"Parents(X) (\\\"Local Markov\\\")\"": + model.getMarkovCheck().setSetType(MarkovCheck.ConditioningSetType.LOCAL_MARKOV); + break; + case "Parents(X) for a Valid Order (\"Ordered Local Markov\")": + model.getMarkovCheck().setSetType(MarkovCheck.ConditioningSetType.ORDERED_LOCAL_MARKOV); break; case "MarkovBlanket(X)": model.getMarkovCheck().setSetType(MarkovCheck.ConditioningSetType.MARKOV_BLANKET); break; - case "All Subsets": - model.getMarkovCheck().setSetType(MarkovCheck.ConditioningSetType.ALL_SUBSETS); + case "All Subsets (\"Global Markov\")": + model.getMarkovCheck().setSetType(MarkovCheck.ConditioningSetType.GLOBAL_MARKOV); break; default: throw new IllegalArgumentException("Unknown conditioning set type: " + @@ -121,7 +125,7 @@ public MarkovCheckEditor(MarkovCheckIndTestModel model) { class MyWatchedProcess extends WatchedProcess { public void watch() { - if (model.getMarkovCheck().getSetType() == MarkovCheck.ConditioningSetType.ALL_SUBSETS && model.getVars().size() > 12) { + if (model.getMarkovCheck().getSetType() == MarkovCheck.ConditioningSetType.GLOBAL_MARKOV && model.getVars().size() > 12) { int ret = JOptionPane.showOptionDialog(MarkovCheckEditor.this, "The all subsets option is exponential and can become extremely slow beyond 12" + "\nvariables. You may possibly be required to force quit Tetrad. Continue?", "Warning", @@ -277,7 +281,7 @@ public void watch() { JTabbedPane pane = new JTabbedPane(); pane.addTab("Check Markov", indep); - pane.addTab("Check Faithfulness", dep); + pane.addTab("Check Dependent Distribution", dep); pane.addTab("Help", scroll); box.add(pane); @@ -310,13 +314,13 @@ public void watch() { @NotNull private static String getHelpMessage() { - return "This tool lets you plot statistics for independence tests of a pair of variables given some conditioning calculated for one of those variables, for a given graph and dataset. Two tables are made, one in which the independence facts predicted by the graph using these conditioning sets are tested in the data and the other in which the graph's predicted dependence facts are tested. The first of these sets is a test for \"Markov\" for the relevant conditioning sets; the is a test for \"Faithfulness.”\n" + + return "This tool lets you plot statistics for independence tests of a pair of variables given some conditioning calculated for one of those variables, for a given graph and dataset. Two tables are made, one in which the independence facts predicted by the graph using these conditioning sets are tested in the data and the other in which the graph's predicted dependence facts are tested. The first of these sets is a check for \"Markov\" (a check for implied independence facts) for the chosen conditioning sets; the is a check of the \"Dependent Distribution.\" (a check of implied dependence facts)”\n" + "\n" + "Each table gives columns for the independence fact being checked, its test result, and its statistic. This statistic is either a p-value, ranging from 0 to 1, where p-values above the alpha level of the test are judged as independent, or a score bump, where this bump is negative for independent judgments and positive for dependent judgments.\n" + "\n" + "If the independence test yields a p-value, as for instance, for the Fisher Z test (for the linear, Gaussian case) or else the Chi-Square test (for the multinomial case), then under the null hypothesis of independence and for a consistent test, these p-values should be distributed as Uniform(0, 1). That is, it should be just as likely to see p-values in any range of equal width. If the test is inconsistent or the graph is incorrect (i.e., the parents of some or all of the nodes in the graph are incorrect), then this distribution of p-values will not be Uniform. To visualize this, we display the histogram of the p-values with equally sized bins; the bars in this histogram, for this case, should ideally all be of equal height.\n" + "\n" + - "If the first bar in this histogram is especially high (for the p-value case), that means that many tests are being judged as dependent. For checking Faithfulness, one hopes that this list is non-empty, then this first bar will be especially high, since high p-values are for examples where the graph is unfaithful to the distribution. These are likely for for cases where paths in the graph cancel unfaithfully. But for checking Markov, one hopes that this first bar will be the same height as all of the other bars.\n" + + "If the first bar in this histogram is especially high (for the p-value case), that means that many tests are being judged as dependent. For checking the dependent distribution, one hopes that this list is non-empty, in which case this first bar will be especially high, since high p-values are for examples where the graph is unfaithful to the distribution. These are likely for for cases where paths in the graph cancel unfaithfully. But for checking Markov, one hopes that this first bar will be the same height as all of the other bars.\n" + "\n" + "To make it especially clear, we give two statistics in the interface. The first is the percentage of p-values judged dependent on the test. If an alpha level is used in the test, this number should be very close to the alpha level for the Local Markov check since the distribution of p-values under this condition is Uniform. For the second, we test the Uniformity of the p-values using a Kolmogorov-Smirnov test. The p-value returned by this test should be greater than the user’s preferred alpha level if the distribution of p-values is Uniform and less then this alpha level if the distribution of p-values is non-Uniform.\n" + "\n" + diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/MarkovCheckIndTestModel.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/MarkovCheckIndTestModel.java index 05ed27b7aa..d44d3f28e5 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/MarkovCheckIndTestModel.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/MarkovCheckIndTestModel.java @@ -64,7 +64,7 @@ public static Knowledge serializableInstance() { } public void setIndependenceTest(IndependenceTest test) { - this.markovCheck = new MarkovCheck(this.graph, test, this.markovCheck == null ? MarkovCheck.ConditioningSetType.PARENTS : this.markovCheck.getSetType()); + this.markovCheck = new MarkovCheck(this.graph, test, this.markovCheck == null ? MarkovCheck.ConditioningSetType.LOCAL_MARKOV : this.markovCheck.getSetType()); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/FractionDependentUnderAlternative.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/FractionDependentUnderAlternative.java index 176e4b0201..74277d48f0 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/FractionDependentUnderAlternative.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/FractionDependentUnderAlternative.java @@ -36,7 +36,7 @@ public String getDescription() { @Override public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, alpha), MarkovCheck.ConditioningSetType.PARENTS); + MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, alpha), MarkovCheck.ConditioningSetType.LOCAL_MARKOV); markovCheck.generateResults(); return markovCheck.getFractionDependent(false); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/FractionDependentUnderNull.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/FractionDependentUnderNull.java index 70f9299e1f..ab344826ac 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/FractionDependentUnderNull.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/FractionDependentUnderNull.java @@ -36,7 +36,7 @@ public String getDescription() { @Override public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, alpha), MarkovCheck.ConditioningSetType.PARENTS); + MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, alpha), MarkovCheck.ConditioningSetType.LOCAL_MARKOV); markovCheck.generateResults(); return markovCheck.getFractionDependent(true); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/MarkovAdequacyScore.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/MarkovAdequacyScore.java index 6833b6f6e7..a9b30523a9 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/MarkovAdequacyScore.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/MarkovAdequacyScore.java @@ -29,7 +29,7 @@ public String getDescription() { @Override public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, 0.01), MarkovCheck.ConditioningSetType.PARENTS); + MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, 0.01), MarkovCheck.ConditioningSetType.LOCAL_MARKOV); markovCheck.generateResults(); return markovCheck.getMarkovAdequacyScore(alpha); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/PvalueDistanceToAlpha.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/PvalueDistanceToAlpha.java index 8fef3f87ca..602cec72a5 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/PvalueDistanceToAlpha.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/PvalueDistanceToAlpha.java @@ -35,7 +35,7 @@ public String getDescription() { @Override public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, alpha), MarkovCheck.ConditioningSetType.PARENTS); + MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, alpha), MarkovCheck.ConditioningSetType.LOCAL_MARKOV); markovCheck.generateResults(); return abs(alpha - markovCheck.getKsPValue(true)); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/PvalueUniformityUnderNull.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/PvalueUniformityUnderNull.java index 1a23b18e6f..df4f0af418 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/PvalueUniformityUnderNull.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/PvalueUniformityUnderNull.java @@ -33,7 +33,7 @@ public String getDescription() { @Override public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, alpha), MarkovCheck.ConditioningSetType.PARENTS); + MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, alpha), MarkovCheck.ConditioningSetType.LOCAL_MARKOV); markovCheck.generateResults(); return markovCheck.getKsPValue(true); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java index fef052901c..bb8d5e7bd6 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java @@ -74,7 +74,7 @@ public void generateResults() { resultsIndep.clear(); resultsDep.clear(); - if (setType == ConditioningSetType.ALL_SUBSETS) { + if (setType == ConditioningSetType.GLOBAL_MARKOV) { AllSubsetsIndependenceFacts result = getAllSubsetsIndependenceFacts(graph); generateResultsAllSubsets(true, result.msep, result.mconn); generateResultsAllSubsets(false, result.msep, result.mconn); @@ -83,12 +83,29 @@ public void generateResults() { List nodes = new ArrayList<>(variables); Collections.sort(nodes); + List order = graph.paths().getValidOrder(graph.getNodes(), true); + for (Node x : nodes) { Set z; switch (setType) { - case PARENTS: + case LOCAL_MARKOV: + z = new HashSet<>(graph.getParents(x)); + break; + case ORDERED_LOCAL_MARKOV: + if (order == null) throw new IllegalArgumentException("No valid order found."); z = new HashSet<>(graph.getParents(x)); + + // Keep only the parents in Prefix(x). + for (Node w : new ArrayList<>(z)) { + int i1 = order.indexOf(x); + int i2 = order.indexOf(w); + + if (i2 >= i1) { + z.remove(w); + } + } + break; case MARKOV_BLANKET: z = GraphUtils.markovBlanket(x, graph); @@ -542,6 +559,6 @@ private List getResultsLocal(boolean indep) { * setting, and PAG_MB uses a Markov blanket of the target variable in a PAG setting. */ public enum ConditioningSetType { - PARENTS, MARKOV_BLANKET, ALL_SUBSETS + LOCAL_MARKOV, ORDERED_LOCAL_MARKOV, MARKOV_BLANKET, GLOBAL_MARKOV } } From e67267cdfdfc135e736377eb985b40b14fc507d4 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Thu, 9 Nov 2023 17:26:56 -0500 Subject: [PATCH 12/24] Adding the ordered local check to the Markov Checker. --- .../java/edu/cmu/tetradapp/editor/MarkovCheckEditor.java | 6 +++--- tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Paths.java | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/MarkovCheckEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/MarkovCheckEditor.java index fe92a9dbec..9706993450 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/MarkovCheckEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/MarkovCheckEditor.java @@ -106,16 +106,16 @@ public MarkovCheckEditor(MarkovCheckIndTestModel model) { conditioningSetTypeJComboBox.addActionListener(e -> { switch ((String) Objects.requireNonNull(conditioningSetTypeJComboBox.getSelectedItem())) { - case "\"Parents(X) (\\\"Local Markov\\\")\"": + case "Parents(X) (Local Markov)": model.getMarkovCheck().setSetType(MarkovCheck.ConditioningSetType.LOCAL_MARKOV); break; - case "Parents(X) for a Valid Order (\"Ordered Local Markov\")": + case "Parents(X) for a Valid Order (Ordered Local Markov)": model.getMarkovCheck().setSetType(MarkovCheck.ConditioningSetType.ORDERED_LOCAL_MARKOV); break; case "MarkovBlanket(X)": model.getMarkovCheck().setSetType(MarkovCheck.ConditioningSetType.MARKOV_BLANKET); break; - case "All Subsets (\"Global Markov\")": + case "All Subsets (Global Markov)": model.getMarkovCheck().setSetType(MarkovCheck.ConditioningSetType.GLOBAL_MARKOV); break; default: diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Paths.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Paths.java index 266b4b1022..112506f4aa 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Paths.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Paths.java @@ -30,7 +30,7 @@ private static void addToSet(Map> previous, Node b, Node c) { * * @param initialOrder Variables in the order will be kept as close to this initial order as possible, either the * forward order or the reverse order, depending on the next parameter. - * @param forward Whether the variable will be iterated over in forward or reverse direction. + * @param forward Whether the variables will be iterated over in forward or reverse direction. * @return The valid causal order found. */ public List getValidOrder(List initialOrder, boolean forward) { From f414cbba732a09c2a06045b6a6f82e75c28fc550 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Thu, 9 Nov 2023 17:42:39 -0500 Subject: [PATCH 13/24] Added shading to the lib jar. --- tetrad-lib/pom.xml | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tetrad-lib/pom.xml b/tetrad-lib/pom.xml index 09043fe6a1..2d07d70c0e 100644 --- a/tetrad-lib/pom.xml +++ b/tetrad-lib/pom.xml @@ -22,6 +22,35 @@ 1.8 + + org.apache.maven.plugins + maven-shade-plugin + 3.1.0 + + + package + + shade + + + + + + + all-permissions + ${project.name} + ${project.version} + + + + true + shaded + + + + + maven-antrun-plugin 3.1.0 From fc2bd9637cacdaa49e81daa448fa677dd6c83f90 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Mon, 13 Nov 2023 10:31:52 -0500 Subject: [PATCH 14/24] Stamping DirectLiNGAM and ICALiNGAM results with BIC. --- .../tetradapp/editor/MarkovCheckEditor.java | 12 ++++ .../continuous/dag/DirectLingam.java | 2 + .../algorithm/continuous/dag/IcaLingam.java | 2 + .../algorithm/oracle/cpdag/Boss.java | 7 ++- .../cpdag/{PcLingam.java => BossLingam.java} | 20 ++++--- .../algorithm/oracle/cpdag/Cpc.java | 6 +- .../algorithm/oracle/cpdag/Fges.java | 7 +++ .../algorithm/oracle/cpdag/Grasp.java | 5 +- .../algorithm/oracle/cpdag/Pc.java | 6 +- .../algorithm/oracle/cpdag/Sp.java | 7 ++- .../java/edu/cmu/tetrad/data/Knowledge.java | 47 ++++++++++++---- .../search/{PcLingam.java => BossLingam.java} | 9 +-- .../cmu/tetrad/search/score/SemBicScorer.java | 5 +- .../tetrad/search/utils/LogUtilsSearch.java | 55 ++++++++++++++++++- .../tetrad/search/utils/TeyssierScorer.java | 12 +++- .../search/work_in_progress/FasLofs.java | 2 +- 16 files changed, 161 insertions(+), 43 deletions(-) rename tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/{PcLingam.java => BossLingam.java} (85%) rename tetrad-lib/src/main/java/edu/cmu/tetrad/search/{PcLingam.java => BossLingam.java} (96%) diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/MarkovCheckEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/MarkovCheckEditor.java index 9706993450..db7f627aec 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/MarkovCheckEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/MarkovCheckEditor.java @@ -496,6 +496,12 @@ public void mouseClicked(MouseEvent e) { // scroll.setPreferredSize(new Dimension(400, 400)); b1.add(scroll); + Box b1a = Box.createHorizontalBox(); + JLabel label = new JLabel("Table contents can be selected and copied in to, e.g., Excel."); + b1a.add(label); + b1a.add(Box.createHorizontalGlue()); + b1.add(b1a); + Box b4 = Box.createHorizontalBox(); b4.add(Box.createGlue()); b4.add(Box.createHorizontalStrut(10)); @@ -682,6 +688,12 @@ public void mouseClicked(MouseEvent e) { // scroll.setPreferredSize(new Dimension(400, 400)); b1.add(scroll); + Box b1a = Box.createHorizontalBox(); + JLabel label = new JLabel("Table contents can be selected and copied in to, e.g., Excel."); + b1a.add(label); + b1a.add(Box.createHorizontalGlue()); + b1.add(b1a); + Box b4 = Box.createHorizontalBox(); b4.add(Box.createGlue()); b4.add(Box.createHorizontalStrut(10)); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/continuous/dag/DirectLingam.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/continuous/dag/DirectLingam.java index 42cefbfbe8..d30c33400d 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/continuous/dag/DirectLingam.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/continuous/dag/DirectLingam.java @@ -13,6 +13,7 @@ import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.search.score.Score; +import edu.cmu.tetrad.search.utils.LogUtilsSearch; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.util.Params; import edu.cmu.tetrad.util.TetradLogger; @@ -56,6 +57,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { Graph graph = search.search(); TetradLogger.getInstance().forceLogMessage(graph.toString()); + LogUtilsSearch.stampWithBic(graph, dataSet); return graph; } else { DirectLingam algorithm = new DirectLingam(); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/continuous/dag/IcaLingam.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/continuous/dag/IcaLingam.java index c97e7a0614..3734fff3b0 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/continuous/dag/IcaLingam.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/continuous/dag/IcaLingam.java @@ -11,6 +11,7 @@ import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.search.IcaLingD; +import edu.cmu.tetrad.search.utils.LogUtilsSearch; import edu.cmu.tetrad.util.Matrix; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.util.Params; @@ -56,6 +57,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { TetradLogger.getInstance().forceLogMessage(bHat.toString()); TetradLogger.getInstance().forceLogMessage(graph.toString()); + LogUtilsSearch.stampWithBic(graph, dataSet); return graph; } else { IcaLingam algorithm = new IcaLingam(); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Boss.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Boss.java index ec75e01630..a7c92da553 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Boss.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Boss.java @@ -15,6 +15,7 @@ import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.search.PermutationSearch; import edu.cmu.tetrad.search.score.Score; +import edu.cmu.tetrad.search.utils.LogUtilsSearch; import edu.cmu.tetrad.search.utils.TsUtils; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.util.Params; @@ -51,7 +52,6 @@ public Boss(ScoreWrapper score) { this.score = score; } - @Override public Graph search(DataModel dataModel, Parameters parameters) { if (parameters.getInt(Params.NUMBER_RESAMPLING) < 1) { @@ -75,8 +75,9 @@ public Graph search(DataModel dataModel, Parameters parameters) { boss.setVerbose(parameters.getBoolean(Params.VERBOSE)); PermutationSearch permutationSearch = new PermutationSearch(boss); permutationSearch.setKnowledge(this.knowledge); - - return permutationSearch.search(); + Graph graph = permutationSearch.search(); + LogUtilsSearch.stampWithScores(graph, dataModel, score); + return graph; } else { Boss algorithm = new Boss(this.score); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/PcLingam.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/BossLingam.java similarity index 85% rename from tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/PcLingam.java rename to tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/BossLingam.java index 04e2a4b0eb..4dc585c651 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/PcLingam.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/BossLingam.java @@ -23,23 +23,25 @@ import java.util.ArrayList; import java.util.List; +import static edu.cmu.tetrad.search.utils.LogUtilsSearch.stampWithBic; + /** * Peter/Clark algorithm (PC). * * @author josephramsey */ -@edu.cmu.tetrad.annotation.Algorithm(name = "PC-LiNGAM", command = "pc-lingam", algoType = AlgType.forbid_latent_common_causes) +@edu.cmu.tetrad.annotation.Algorithm(name = "BOSS-LiNGAM", command = "boss-lingam", algoType = AlgType.forbid_latent_common_causes) @Bootstrapping -public class PcLingam implements Algorithm, HasKnowledge, UsesScoreWrapper, ReturnsBootstrapGraphs { +public class BossLingam implements Algorithm, HasKnowledge, UsesScoreWrapper, ReturnsBootstrapGraphs { private static final long serialVersionUID = 23L; private ScoreWrapper score; private Knowledge knowledge = new Knowledge(); private List bootstrapGraphs = new ArrayList<>(); - public PcLingam() { + public BossLingam() { } - public PcLingam(ScoreWrapper scoreWrapper) { + public BossLingam(ScoreWrapper scoreWrapper) { this.score = scoreWrapper; } @@ -69,11 +71,13 @@ public Graph search(DataModel dataModel, Parameters parameters) { Graph cpdag = permutationSearch.search(); - edu.cmu.tetrad.search.PcLingam pcLingam = new edu.cmu.tetrad.search.PcLingam(cpdag, (DataSet) dataModel); + edu.cmu.tetrad.search.BossLingam bossLingam = new edu.cmu.tetrad.search.BossLingam(cpdag, (DataSet) dataModel); + Graph graph = bossLingam.search(); - return pcLingam.search(); + stampWithBic(graph, dataModel); + return graph; } else { - PcLingam pcAll = new PcLingam(this.score); + BossLingam pcAll = new BossLingam(this.score); DataSet data = (DataSet) dataModel; GeneralResamplingTest search = new GeneralResamplingTest(data, pcAll, parameters.getInt(Params.NUMBER_RESAMPLING), parameters.getDouble(Params.PERCENT_RESAMPLE_SIZE), parameters.getBoolean(Params.RESAMPLING_WITH_REPLACEMENT), parameters.getInt(Params.RESAMPLING_ENSEMBLE), parameters.getBoolean(Params.ADD_ORIGINAL_DATASET)); @@ -94,7 +98,7 @@ public Graph getComparisonGraph(Graph graph) { @Override public String getDescription() { - return "PC-LiNGAM using " + this.score.getDescription(); + return "BOSS-LiNGAM using " + this.score.getDescription(); } @Override diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Cpc.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Cpc.java index 8430b3b623..7c34c9f73b 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Cpc.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Cpc.java @@ -23,6 +23,8 @@ import java.util.ArrayList; import java.util.List; +import static edu.cmu.tetrad.search.utils.LogUtilsSearch.stampWithBic; + /** * Conservative PC (CPC). * @@ -107,7 +109,9 @@ public Graph search(DataModel dataModel, Parameters parameters) { search.setVerbose(parameters.getBoolean(Params.VERBOSE)); search.setKnowledge(knowledge); search.setConflictRule(conflictRule); - return search.search(); + Graph graph = search.search(); + stampWithBic(graph, dataModel); + return graph; } else { Cpc pcAll = new Cpc(this.test); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Fges.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Fges.java index ca9d0460f3..d3298a8528 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Fges.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Fges.java @@ -3,6 +3,7 @@ import edu.cmu.tetrad.algcomparison.algorithm.Algorithm; import edu.cmu.tetrad.algcomparison.algorithm.ReturnsBootstrapGraphs; import edu.cmu.tetrad.algcomparison.score.ScoreWrapper; +import edu.cmu.tetrad.algcomparison.statistic.BicEst; import edu.cmu.tetrad.algcomparison.utils.HasKnowledge; import edu.cmu.tetrad.algcomparison.utils.TakesExternalGraph; import edu.cmu.tetrad.algcomparison.utils.UsesScoreWrapper; @@ -15,6 +16,7 @@ import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.search.score.Score; +import edu.cmu.tetrad.search.utils.LogUtilsSearch; import edu.cmu.tetrad.search.utils.TsUtils; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.util.Params; @@ -98,6 +100,11 @@ public Graph search(DataModel dataModel, Parameters parameters) { graph = search.search(); + if (!graph.getAllAttributes().containsKey("BIC")) { + graph.addAttribute("BIC", new BicEst().getValue(null, graph, dataModel)); + } + + LogUtilsSearch.stampWithScores(graph, dataModel, score); return graph; } else { Fges fges = new Fges(this.score); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Grasp.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Grasp.java index 174d70ed85..287699409d 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Grasp.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Grasp.java @@ -17,6 +17,7 @@ import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.search.IndependenceTest; import edu.cmu.tetrad.search.score.Score; +import edu.cmu.tetrad.search.utils.LogUtilsSearch; import edu.cmu.tetrad.search.utils.TsUtils; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.util.Params; @@ -87,7 +88,9 @@ public Graph search(DataModel dataModel, Parameters parameters) { grasp.setNumStarts(parameters.getInt(Params.NUM_STARTS)); grasp.setKnowledge(this.knowledge); grasp.bestOrder(score.getVariables()); - return grasp.getGraph(parameters.getBoolean(Params.OUTPUT_CPDAG)); + Graph graph = grasp.getGraph(parameters.getBoolean(Params.OUTPUT_CPDAG)); + LogUtilsSearch.stampWithScores(graph, dataModel, score); + return graph; } else { Grasp algorithm = new Grasp(this.test, this.score); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Pc.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Pc.java index f836722598..decfad9141 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Pc.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Pc.java @@ -23,6 +23,8 @@ import java.util.ArrayList; import java.util.List; +import static edu.cmu.tetrad.search.utils.LogUtilsSearch.stampWithBic; + /** * Peter/Clark algorithm (PC). * @@ -106,7 +108,9 @@ public Graph search(DataModel dataModel, Parameters parameters) { search.setKnowledge(this.knowledge); search.setStable(parameters.getBoolean(Params.STABLE_FAS)); search.setConflictRule(conflictRule); - return search.search(); + Graph graph = search.search(); + stampWithBic(graph, dataModel); + return graph; } else { Pc pcAll = new Pc(this.test); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Sp.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Sp.java index d0813c6b81..d61d61584e 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Sp.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Sp.java @@ -16,6 +16,7 @@ import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.search.PermutationSearch; import edu.cmu.tetrad.search.score.Score; +import edu.cmu.tetrad.search.utils.LogUtilsSearch; import edu.cmu.tetrad.search.utils.TsUtils; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.util.Params; @@ -43,7 +44,6 @@ public class Sp implements Algorithm, UsesScoreWrapper, HasKnowledge, ReturnsBoo private Knowledge knowledge = new Knowledge(); private List bootstrapGraphs = new ArrayList<>(); - public Sp() { // Used in reflection; do not delete. } @@ -69,8 +69,9 @@ public Graph search(DataModel dataModel, Parameters parameters) { Score score = this.score.getScore(dataModel, parameters); PermutationSearch permutationSearch = new PermutationSearch(new edu.cmu.tetrad.search.Sp(score)); permutationSearch.setKnowledge(this.knowledge); - - return permutationSearch.search(); + Graph graph = permutationSearch.search(); + LogUtilsSearch.stampWithScores(graph, dataModel, score); + return graph; } else { Sp algorithm = new Sp(this.score); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/Knowledge.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/Knowledge.java index d023bd4dc4..47ce353c57 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/Knowledge.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/Knowledge.java @@ -54,9 +54,16 @@ public final class Knowledge implements TetradSerializable { // private static final Pattern VARNAME_PATTERN = Pattern.compile("[A-Za-z0-9:_\\-.]+"); // private static final Pattern SPEC_PATTERN = Pattern.compile("[A-Za-z0-9:-_,\\-.*]+"); private static final Pattern COMMAN_DELIM = Pattern.compile(","); - private final Set variables; - private final Set>> forbiddenRulesSpecs; - private final Set>> requiredRulesSpecs; + + private final Set variables; + + // This needs to be a list for backward compatibility. Need to check when adding + // a new spec whether it's already in the list. + private final List>> forbiddenRulesSpecs; + + // This needs to be a list for backward compatibility. Need to check when adding + // a new spec whether it's already in the list. + private final List>> requiredRulesSpecs; private final List> tierSpecs; // Legacy. private final List knowledgeGroups; @@ -65,8 +72,8 @@ public final class Knowledge implements TetradSerializable { public Knowledge() { this.variables = new HashSet<>(); - this.forbiddenRulesSpecs = new HashSet<>(); - this.requiredRulesSpecs = new HashSet<>(); + this.forbiddenRulesSpecs = new ArrayList<>(); + this.requiredRulesSpecs = new ArrayList<>(); this.tierSpecs = new ArrayList<>(); this.knowledgeGroups = new LinkedList<>(); this.knowledgeGroupRules = new HashMap<>(); @@ -265,9 +272,13 @@ public void addKnowledgeGroup(KnowledgeGroup group) { this.knowledgeGroupRules.put(group, o); if (group.getType() == KnowledgeGroup.FORBIDDEN) { - this.forbiddenRulesSpecs.add(o); + if (!forbiddenRulesSpecs.contains(o)) { + this.forbiddenRulesSpecs.add(o); + } } else if (group.getType() == KnowledgeGroup.REQUIRED) { - this.requiredRulesSpecs.add(o); + if (!requiredRulesSpecs.contains(o)) { + this.requiredRulesSpecs.add(o); + } } } @@ -536,7 +547,11 @@ public void setForbidden(String var1, String var2) { OrderedPair> o = new OrderedPair<>(f1, f2); - this.forbiddenRulesSpecs.add(o); + if (!forbiddenRulesSpecs.contains(o)) { + if (!forbiddenRulesSpecs.contains(o)) { + this.forbiddenRulesSpecs.add(o); + } + } } /** @@ -580,7 +595,9 @@ public void setRequired(String var1, String var2) { OrderedPair> o = new OrderedPair<>(f1, f2); - this.requiredRulesSpecs.add(o); + if (!requiredRulesSpecs.contains(o)) { + this.requiredRulesSpecs.add(o); + } } /** @@ -609,9 +626,13 @@ public void setKnowledgeGroup(int index, KnowledgeGroup group) { knowledgeGroupRules.put(group, o); if (group.getType() == KnowledgeGroup.FORBIDDEN) { - this.forbiddenRulesSpecs.add(o); + if (!forbiddenRulesSpecs.contains(o)) { + this.forbiddenRulesSpecs.add(o); + } } else if (group.getType() == KnowledgeGroup.REQUIRED) { - this.requiredRulesSpecs.add(o); + if (!requiredRulesSpecs.contains(o)) { + this.requiredRulesSpecs.add(o); + } } this.knowledgeGroups.set(index, group); @@ -640,7 +661,9 @@ public void setTierForbiddenWithin(int tier, boolean forbidden) { OrderedPair> o = new OrderedPair<>(varsInTier, varsInTier); if (forbidden) { - this.forbiddenRulesSpecs.add(o); + if (!forbiddenRulesSpecs.contains(o)) { + this.forbiddenRulesSpecs.add(o); + } } else { this.forbiddenRulesSpecs.remove(o); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/PcLingam.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BossLingam.java similarity index 96% rename from tetrad-lib/src/main/java/edu/cmu/tetrad/search/PcLingam.java rename to tetrad-lib/src/main/java/edu/cmu/tetrad/search/BossLingam.java index e5792a7fcb..ed9758a42a 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/PcLingam.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BossLingam.java @@ -30,19 +30,16 @@ import edu.cmu.tetrad.regression.Regression; import edu.cmu.tetrad.regression.RegressionDataset; import edu.cmu.tetrad.regression.RegressionResult; -import edu.cmu.tetrad.search.utils.MeekRules; import edu.cmu.tetrad.util.Matrix; import edu.cmu.tetrad.util.TetradLogger; import edu.cmu.tetrad.util.Vector; import org.apache.commons.math3.util.FastMath; -import java.text.DecimalFormat; -import java.text.NumberFormat; import java.util.ArrayList; import java.util.List; /** - *

Implements the PC-LiNGAM algorithm which first finds a CPDAG for the variables + *

Implements the BOSS-LiNGAM algorithm which first finds a CPDAG for the variables * and then uses a non-Gaussian orientation method to orient the undirected edges. The reference is as follows: * *

>Hoyer et al., "Causal discovery of linear acyclic models with arbitrary @@ -65,7 +62,7 @@ * @author patrickhoyer * @author josephramsey */ -public class PcLingam { +public class BossLingam { private final Graph cpdag; private final DataSet dataSet; private double[] pValues; @@ -78,7 +75,7 @@ public class PcLingam { * @param cpdag The CPDAG whose unoriented edges are to be oriented. * @param dataSet Teh dataset to use. */ - public PcLingam(Graph cpdag, DataSet dataSet) + public BossLingam(Graph cpdag, DataSet dataSet) throws IllegalArgumentException { if (cpdag == null) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/score/SemBicScorer.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/score/SemBicScorer.java index c6d78ba379..e7ac81d695 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/score/SemBicScorer.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/score/SemBicScorer.java @@ -69,7 +69,10 @@ public static double scoreDag(Graph dag, DataModel data, double penaltyDiscount, parentIndices[count++] = hashIndices.get(parent); } - _score += score.localScore(hashIndices.get(node), parentIndices); + double score1 = score.localScore(hashIndices.get(node), parentIndices); + if (!Double.isNaN(score1)) { + _score += score1; + } } return _score; diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/LogUtilsSearch.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/LogUtilsSearch.java index da798c4473..3bbd503f04 100755 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/LogUtilsSearch.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/LogUtilsSearch.java @@ -21,14 +21,18 @@ package edu.cmu.tetrad.search.utils; +import edu.cmu.tetrad.algcomparison.statistic.BicEst; +import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.graph.Edge; +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.graph.GraphTransforms; import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetrad.search.score.Score; import edu.cmu.tetrad.util.NumberFormatUtil; +import org.jetbrains.annotations.NotNull; import java.text.NumberFormat; -import java.util.Iterator; -import java.util.List; -import java.util.Set; +import java.util.*; /** * Contains utilities for logging search steps. @@ -139,6 +143,51 @@ public static String getScoreFact(Node i, List parents) { return fact.toString(); } + + public static Map buildIndexing(List nodes) { + Map hashIndices = new HashMap<>(); + + int i = -1; + + for (Node n : nodes) { + hashIndices.put(n, ++i); + } + + return hashIndices; + } + + @NotNull + public static void stampWithScores(Graph graph, DataModel dataModel, Score score) { + if (!graph.getAllAttributes().containsKey("Score")) { + Graph dag = GraphTransforms.dagFromCPDAG(graph); + Map hashIndices = buildIndexing(dag.getNodes()); + + double _score = 0.0; + + for (Node node : dag.getNodes()) { + List x = dag.getParents(node); + + int[] parentIndices = new int[x.size()]; + + int count = 0; + for (Node parent : x) { + parentIndices[count++] = hashIndices.get(parent); + } + + _score += score.localScore(hashIndices.get(node), parentIndices); + } + + graph.addAttribute("Score", _score); + } + + stampWithBic(graph, dataModel); + } + + public static void stampWithBic(Graph graph, DataModel dataModel) { + if (!graph.getAllAttributes().containsKey("BIC")) { + graph.addAttribute("BIC", new BicEst().getValue(null, graph, dataModel)); + } + } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/TeyssierScorer.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/TeyssierScorer.java index c1175ce252..0c6579c29d 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/TeyssierScorer.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/TeyssierScorer.java @@ -609,10 +609,13 @@ private void updateScores(int i1, int i2) { private void recalculate(int p) { if (this.prefixes.get(p) == null || !this.prefixes.get(p).containsAll(getPrefix(p))) { Pair p2 = getParentsInternal(p); - if (scores.get(p) == null) { + + double score1 = scores.get(p).score; + + if (scores.get(p) == null || Double.isNaN(score1)) { this.runningScore += p2.score; } else { - this.runningScore += p2.score - scores.get(p).score; + this.runningScore += p2.score - score1; } this.scores.set(p, p2); } @@ -626,7 +629,10 @@ private double sum() { recalculate(i); } double score1 = this.scores.get(i).getScore(); - score += score1; + +// if (!Double.isNaN(score1)) { + score += score1; +// } } return score; diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/work_in_progress/FasLofs.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/work_in_progress/FasLofs.java index 55584e6f7b..88db208200 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/work_in_progress/FasLofs.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/work_in_progress/FasLofs.java @@ -41,7 +41,7 @@ * generalized. Instead of hard-coding FAS, an arbitrary algorithm can be used to obtain adjacencies. Instead of * hard-coding robust skew, and arbitrary algorithm can be used to to pairwise orientation. Instead of orienting all * edges, an option can be given to just orient the edges that are unoriented in the input graph (see, e.g., PC LiNGAM). - * This was an early attempt at this. For PC-LiNGAM, see this paper:

+ * This was an early attempt at this. For BOSS-LiNGAM, see this paper:

* *

Hoyer, P. O., Hyvarinen, A., Scheines, R., Spirtes, P. L., Ramsey, J., Lacerda, G., * & Shimizu, S. (2012). Causal discovery of linear acyclic models with arbitrary distributions. arXiv preprint From 2290324d15cd13aa9179195a2f14cc8259c04831 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Mon, 13 Nov 2023 10:38:59 -0500 Subject: [PATCH 15/24] Testing a solution to a boss hanging problem. --- .../java/edu/cmu/tetrad/search/utils/GrowShrinkTree.java | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/GrowShrinkTree.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/GrowShrinkTree.java index 596f422052..8f2f46558a 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/GrowShrinkTree.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/GrowShrinkTree.java @@ -110,7 +110,11 @@ private GSTNode(GrowShrinkTree tree) { this.grow = new AtomicBoolean(false); this.shrink = new AtomicBoolean(false); - this.growScore = this.tree.localScore(); + Double localScore = this.tree.localScore(); + +// this.growScore = Double.isNaN(localScore) ? 0 : localScore; + + this.growScore = localScore; } private GSTNode(GrowShrinkTree tree, Node add, Set parents) { From 61359839a0809f31427a0f79ad8de59057a650ad Mon Sep 17 00:00:00 2001 From: jdramsey Date: Mon, 13 Nov 2023 11:49:55 -0500 Subject: [PATCH 16/24] Fixed hanging problem with GRaSP, BOSS, etc. --- .../tetrad/search/utils/GrowShrinkTree.java | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/GrowShrinkTree.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/GrowShrinkTree.java index 8f2f46558a..2c9cec9924 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/GrowShrinkTree.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/GrowShrinkTree.java @@ -56,12 +56,22 @@ public Integer getIndex(Node node) { return this.index.get(node); } +// public Double localScore() { +// return this.score.localScore(this.nodeIndex); +// } +// +// public Double localScore(int[] X) { +// return this.score.localScore(this.nodeIndex, X); +// } + public Double localScore() { - return this.score.localScore(this.nodeIndex); + double score = this.score.localScore(this.nodeIndex); + return Double.isNaN(score) ? 0 : score; } public Double localScore(int[] X) { - return this.score.localScore(this.nodeIndex, X); + double score = this.score.localScore(this.nodeIndex, X); + return Double.isNaN(score) ? Double.NEGATIVE_INFINITY : score; } public boolean isRequired(Node node) { @@ -110,11 +120,7 @@ private GSTNode(GrowShrinkTree tree) { this.grow = new AtomicBoolean(false); this.shrink = new AtomicBoolean(false); - Double localScore = this.tree.localScore(); - -// this.growScore = Double.isNaN(localScore) ? 0 : localScore; - - this.growScore = localScore; + this.growScore = this.tree.localScore(); } private GSTNode(GrowShrinkTree tree, Node add, Set parents) { From f87bf63b1fa12b2cc0817b710c34374f9d955a99 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Mon, 13 Nov 2023 11:53:13 -0500 Subject: [PATCH 17/24] Updated version to 7.6.1-SNAPSHOT --- data-reader/pom.xml | 2 +- pom.xml | 2 +- tetrad-gui/pom.xml | 2 +- tetrad-lib/pom.xml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/data-reader/pom.xml b/data-reader/pom.xml index 9810fc78a3..89662ead25 100644 --- a/data-reader/pom.xml +++ b/data-reader/pom.xml @@ -5,7 +5,7 @@ io.github.cmu-phil tetrad - 7.6.0-SNAPSHOT + 7.6.1-SNAPSHOT data-reader diff --git a/pom.xml b/pom.xml index 64319b0487..05caf45039 100644 --- a/pom.xml +++ b/pom.xml @@ -4,7 +4,7 @@ 4.0.0 io.github.cmu-phil tetrad - 7.6.0-SNAPSHOT + 7.6.1-SNAPSHOT pom Tetrad Project diff --git a/tetrad-gui/pom.xml b/tetrad-gui/pom.xml index 91f56bf4f5..820d99b993 100644 --- a/tetrad-gui/pom.xml +++ b/tetrad-gui/pom.xml @@ -6,7 +6,7 @@ io.github.cmu-phil tetrad - 7.6.0-SNAPSHOT + 7.6.1-SNAPSHOT tetrad-gui diff --git a/tetrad-lib/pom.xml b/tetrad-lib/pom.xml index e0fed958cc..2d07d70c0e 100644 --- a/tetrad-lib/pom.xml +++ b/tetrad-lib/pom.xml @@ -6,7 +6,7 @@ io.github.cmu-phil tetrad - 7.6.0-SNAPSHOT + 7.6.1-SNAPSHOT tetrad-lib From da1d47a4d9cfd988a8113330324cfb52bdd88639 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Mon, 13 Nov 2023 12:00:32 -0500 Subject: [PATCH 18/24] Updated version to 7.6.1-SNAPSHOT --- .../edu/cmu/tetrad/search/utils/TeyssierScorer.java | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/TeyssierScorer.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/TeyssierScorer.java index 0c6579c29d..79584d3823 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/TeyssierScorer.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/TeyssierScorer.java @@ -609,13 +609,10 @@ private void updateScores(int i1, int i2) { private void recalculate(int p) { if (this.prefixes.get(p) == null || !this.prefixes.get(p).containsAll(getPrefix(p))) { Pair p2 = getParentsInternal(p); - - double score1 = scores.get(p).score; - - if (scores.get(p) == null || Double.isNaN(score1)) { + if (scores.get(p) == null) { this.runningScore += p2.score; } else { - this.runningScore += p2.score - score1; + this.runningScore += p2.score - scores.get(p).score; } this.scores.set(p, p2); } @@ -628,11 +625,7 @@ private double sum() { if (this.scores.get(i) == null) { recalculate(i); } - double score1 = this.scores.get(i).getScore(); - -// if (!Double.isNaN(score1)) { - score += score1; -// } + score += this.scores.get(i).getScore(); } return score; From 170536f50c69629edf00e9e0bbdb99a4cbb2ed94 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Mon, 13 Nov 2023 16:53:55 -0500 Subject: [PATCH 19/24] Moved ConidtioningSetType to upper level because JPype was unable to parse the name as an inner enum of MarkovCheck. --- .../edu/cmu/tetradapp/editor/MarkovCheckEditor.java | 12 ++++++------ .../cmu/tetradapp/model/MarkovCheckIndTestModel.java | 3 ++- .../statistic/FractionDependentUnderAlternative.java | 3 ++- .../statistic/FractionDependentUnderNull.java | 3 ++- .../algcomparison/statistic/MarkovAdequacyScore.java | 3 ++- .../statistic/PvalueDistanceToAlpha.java | 3 ++- .../statistic/PvalueUniformityUnderNull.java | 3 ++- .../edu/cmu/tetrad/search/ConditioningSetType.java | 10 ++++++++++ .../main/java/edu/cmu/tetrad/search/MarkovCheck.java | 8 -------- 9 files changed, 28 insertions(+), 20 deletions(-) create mode 100644 tetrad-lib/src/main/java/edu/cmu/tetrad/search/ConditioningSetType.java diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/MarkovCheckEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/MarkovCheckEditor.java index db7f627aec..3e06f09bc4 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/MarkovCheckEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/MarkovCheckEditor.java @@ -27,8 +27,8 @@ import edu.cmu.tetrad.graph.GraphUtils; import edu.cmu.tetrad.graph.IndependenceFact; import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetrad.search.ConditioningSetType; import edu.cmu.tetrad.search.IndependenceTest; -import edu.cmu.tetrad.search.MarkovCheck; import edu.cmu.tetrad.search.test.IndependenceResult; import edu.cmu.tetrad.search.test.MsepTest; import edu.cmu.tetrad.util.NumberFormatUtil; @@ -107,16 +107,16 @@ public MarkovCheckEditor(MarkovCheckIndTestModel model) { conditioningSetTypeJComboBox.addActionListener(e -> { switch ((String) Objects.requireNonNull(conditioningSetTypeJComboBox.getSelectedItem())) { case "Parents(X) (Local Markov)": - model.getMarkovCheck().setSetType(MarkovCheck.ConditioningSetType.LOCAL_MARKOV); + model.getMarkovCheck().setSetType(ConditioningSetType.LOCAL_MARKOV); break; case "Parents(X) for a Valid Order (Ordered Local Markov)": - model.getMarkovCheck().setSetType(MarkovCheck.ConditioningSetType.ORDERED_LOCAL_MARKOV); + model.getMarkovCheck().setSetType(ConditioningSetType.ORDERED_LOCAL_MARKOV); break; case "MarkovBlanket(X)": - model.getMarkovCheck().setSetType(MarkovCheck.ConditioningSetType.MARKOV_BLANKET); + model.getMarkovCheck().setSetType(ConditioningSetType.MARKOV_BLANKET); break; case "All Subsets (Global Markov)": - model.getMarkovCheck().setSetType(MarkovCheck.ConditioningSetType.GLOBAL_MARKOV); + model.getMarkovCheck().setSetType(ConditioningSetType.GLOBAL_MARKOV); break; default: throw new IllegalArgumentException("Unknown conditioning set type: " + @@ -125,7 +125,7 @@ public MarkovCheckEditor(MarkovCheckIndTestModel model) { class MyWatchedProcess extends WatchedProcess { public void watch() { - if (model.getMarkovCheck().getSetType() == MarkovCheck.ConditioningSetType.GLOBAL_MARKOV && model.getVars().size() > 12) { + if (model.getMarkovCheck().getSetType() == ConditioningSetType.GLOBAL_MARKOV && model.getVars().size() > 12) { int ret = JOptionPane.showOptionDialog(MarkovCheckEditor.this, "The all subsets option is exponential and can become extremely slow beyond 12" + "\nvariables. You may possibly be required to force quit Tetrad. Continue?", "Warning", diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/MarkovCheckIndTestModel.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/MarkovCheckIndTestModel.java index d44d3f28e5..a0dcba9ee0 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/MarkovCheckIndTestModel.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/MarkovCheckIndTestModel.java @@ -24,6 +24,7 @@ import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.Knowledge; import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.search.ConditioningSetType; import edu.cmu.tetrad.search.IndependenceTest; import edu.cmu.tetrad.search.MarkovCheck; import edu.cmu.tetrad.search.test.IndependenceResult; @@ -64,7 +65,7 @@ public static Knowledge serializableInstance() { } public void setIndependenceTest(IndependenceTest test) { - this.markovCheck = new MarkovCheck(this.graph, test, this.markovCheck == null ? MarkovCheck.ConditioningSetType.LOCAL_MARKOV : this.markovCheck.getSetType()); + this.markovCheck = new MarkovCheck(this.graph, test, this.markovCheck == null ? ConditioningSetType.LOCAL_MARKOV : this.markovCheck.getSetType()); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/FractionDependentUnderAlternative.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/FractionDependentUnderAlternative.java index 74277d48f0..01579b610f 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/FractionDependentUnderAlternative.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/FractionDependentUnderAlternative.java @@ -3,6 +3,7 @@ import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.DataSet; import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.search.ConditioningSetType; import edu.cmu.tetrad.search.MarkovCheck; import edu.cmu.tetrad.search.test.IndTestFisherZ; @@ -36,7 +37,7 @@ public String getDescription() { @Override public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, alpha), MarkovCheck.ConditioningSetType.LOCAL_MARKOV); + MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, alpha), ConditioningSetType.LOCAL_MARKOV); markovCheck.generateResults(); return markovCheck.getFractionDependent(false); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/FractionDependentUnderNull.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/FractionDependentUnderNull.java index ab344826ac..737d298b61 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/FractionDependentUnderNull.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/FractionDependentUnderNull.java @@ -3,6 +3,7 @@ import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.DataSet; import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.search.ConditioningSetType; import edu.cmu.tetrad.search.MarkovCheck; import edu.cmu.tetrad.search.test.IndTestFisherZ; @@ -36,7 +37,7 @@ public String getDescription() { @Override public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, alpha), MarkovCheck.ConditioningSetType.LOCAL_MARKOV); + MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, alpha), ConditioningSetType.LOCAL_MARKOV); markovCheck.generateResults(); return markovCheck.getFractionDependent(true); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/MarkovAdequacyScore.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/MarkovAdequacyScore.java index a9b30523a9..5bafefbb19 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/MarkovAdequacyScore.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/MarkovAdequacyScore.java @@ -3,6 +3,7 @@ import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.DataSet; import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.search.ConditioningSetType; import edu.cmu.tetrad.search.MarkovCheck; import edu.cmu.tetrad.search.test.IndTestFisherZ; @@ -29,7 +30,7 @@ public String getDescription() { @Override public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, 0.01), MarkovCheck.ConditioningSetType.LOCAL_MARKOV); + MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, 0.01), ConditioningSetType.LOCAL_MARKOV); markovCheck.generateResults(); return markovCheck.getMarkovAdequacyScore(alpha); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/PvalueDistanceToAlpha.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/PvalueDistanceToAlpha.java index 602cec72a5..42fe7dfce6 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/PvalueDistanceToAlpha.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/PvalueDistanceToAlpha.java @@ -3,6 +3,7 @@ import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.DataSet; import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.search.ConditioningSetType; import edu.cmu.tetrad.search.MarkovCheck; import edu.cmu.tetrad.search.test.IndTestFisherZ; @@ -35,7 +36,7 @@ public String getDescription() { @Override public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, alpha), MarkovCheck.ConditioningSetType.LOCAL_MARKOV); + MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, alpha), ConditioningSetType.LOCAL_MARKOV); markovCheck.generateResults(); return abs(alpha - markovCheck.getKsPValue(true)); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/PvalueUniformityUnderNull.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/PvalueUniformityUnderNull.java index df4f0af418..cf152d6cb0 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/PvalueUniformityUnderNull.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/PvalueUniformityUnderNull.java @@ -3,6 +3,7 @@ import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.DataSet; import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.search.ConditioningSetType; import edu.cmu.tetrad.search.MarkovCheck; import edu.cmu.tetrad.search.test.IndTestFisherZ; @@ -33,7 +34,7 @@ public String getDescription() { @Override public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, alpha), MarkovCheck.ConditioningSetType.LOCAL_MARKOV); + MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, alpha), ConditioningSetType.LOCAL_MARKOV); markovCheck.generateResults(); return markovCheck.getKsPValue(true); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/ConditioningSetType.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/ConditioningSetType.java new file mode 100644 index 0000000000..7744e915be --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/ConditioningSetType.java @@ -0,0 +1,10 @@ +package edu.cmu.tetrad.search; + +/** + * The type of conditioning set to use for the Markov check. The default is PARENTS, which uses the parents of the + * target variable to predict the separation set. DAG_MB uses the Markov blanket of the target variable in a DAG + * setting, and PAG_MB uses a Markov blanket of the target variable in a PAG setting. + */ +public enum ConditioningSetType { + LOCAL_MARKOV, ORDERED_LOCAL_MARKOV, MARKOV_BLANKET, GLOBAL_MARKOV +} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java index bb8d5e7bd6..6b21542341 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java @@ -553,12 +553,4 @@ private List getResultsLocal(boolean indep) { } - /** - * The type of conditioning set to use for the Markov check. The default is PARENTS, which uses the parents of the - * target variable to predict the separation set. DAG_MB uses the Markov blanket of the target variable in a DAG - * setting, and PAG_MB uses a Markov blanket of the target variable in a PAG setting. - */ - public enum ConditioningSetType { - LOCAL_MARKOV, ORDERED_LOCAL_MARKOV, MARKOV_BLANKET, GLOBAL_MARKOV - } } From cdba2623cddd941b8814ae56da8d7fde80d3684c Mon Sep 17 00:00:00 2001 From: jdramsey Date: Tue, 14 Nov 2023 14:28:14 -0500 Subject: [PATCH 20/24] Reapplying the forbidden edges knowledge fix. --- tetrad-lib/src/main/java/edu/cmu/tetrad/data/Knowledge.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/Knowledge.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/Knowledge.java index 47ce353c57..eddfaf00f9 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/Knowledge.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/Knowledge.java @@ -740,7 +740,7 @@ public List getListOfForbiddenEdges() { } for (int i = this.tierSpecs.size() - 1; i >= 0; i--) { - for (int j = i; j >= 0; j--) { + for (int j = i - 1; j >= 0; j--) { Set tieri = this.tierSpecs.get(i); Set tierj = this.tierSpecs.get(j); From af0ecadff4b71074914a1c46ad87cc0b0e8af316 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Tue, 14 Nov 2023 15:01:31 -0500 Subject: [PATCH 21/24] Reapplying the forbidden edges knowledge fix. --- tetrad-lib/src/main/java/edu/cmu/tetrad/data/Knowledge.java | 3 +++ .../main/java/edu/cmu/tetrad/search/score/SemBicScorer.java | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/Knowledge.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/Knowledge.java index eddfaf00f9..320faccf17 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/Knowledge.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/Knowledge.java @@ -740,6 +740,9 @@ public List getListOfForbiddenEdges() { } for (int i = this.tierSpecs.size() - 1; i >= 0; i--) { + + // Make sure this iterates from i - 1 to 0 or else all directed edges will be + // forbidden within tiers! for (int j = i - 1; j >= 0; j--) { Set tieri = this.tierSpecs.get(i); Set tierj = this.tierSpecs.get(j); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/score/SemBicScorer.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/score/SemBicScorer.java index e7ac81d695..57c5979b6a 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/score/SemBicScorer.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/score/SemBicScorer.java @@ -44,7 +44,7 @@ public static double scoreDag(Graph dag, DataModel data, double penaltyDiscount, SemBicScore score; if (data instanceof ICovarianceMatrix) { - score = new SemBicScore((ICovarianceMatrix) dag); + score = new SemBicScore((ICovarianceMatrix) data); } else if (data instanceof DataSet) { score = new SemBicScore((DataSet) data, precomputeCovariances); } else { From 1d19e8b22c01184ed6d20ba1a49901d5798d6226 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Tue, 14 Nov 2023 15:14:53 -0500 Subject: [PATCH 22/24] Re-fixing this FCI test. --- tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFci.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFci.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFci.java index 5141a62e56..61c99d1c38 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFci.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFci.java @@ -161,7 +161,7 @@ public void testSearch11() { knowledge.addToTier(2, "X3"); checkSearch("Latent(L1),Latent(L2),L1-->X1,L1-->X2,L2-->X2,L2-->X3", - "X1<->X2,X2<->X3", knowledge); + "X1o->X2,X2<->X3", knowledge); } @Test From ae93285cf93b4e2dbf548bbc067ce9ba645e3398 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Wed, 15 Nov 2023 15:59:36 -0500 Subject: [PATCH 23/24] Exposed some methods for allow Markov Checker independencies to be printed in R. --- .../main/java/edu/cmu/tetrad/search/MarkovCheck.java | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java index 6b21542341..bca58d5fcc 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java @@ -178,8 +178,8 @@ public static AllSubsetsIndependenceFacts getAllSubsetsIndependenceFacts(Graph g } public static class AllSubsetsIndependenceFacts { - public final List msep; - public final List mconn; + private final List msep; + private final List mconn; public AllSubsetsIndependenceFacts(List msep, List mconn) { this.msep = msep; @@ -206,6 +206,14 @@ public String toStringDep() { return builder.toString(); } + + public List getMsep() { + return msep; + } + + public List getMconn() { + return mconn; + } } /** From 3a128e0b78106ca398a93a970fd40826cd21093f Mon Sep 17 00:00:00 2001 From: jdramsey Date: Wed, 15 Nov 2023 17:43:14 -0500 Subject: [PATCH 24/24] Updating version to 7.6.1. --- data-reader/pom.xml | 2 +- pom.xml | 2 +- tetrad-gui/pom.xml | 2 +- tetrad-lib/pom.xml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/data-reader/pom.xml b/data-reader/pom.xml index 89662ead25..b18df6845f 100644 --- a/data-reader/pom.xml +++ b/data-reader/pom.xml @@ -5,7 +5,7 @@ io.github.cmu-phil tetrad - 7.6.1-SNAPSHOT + 7.6.1 data-reader diff --git a/pom.xml b/pom.xml index 05caf45039..b03d15628e 100644 --- a/pom.xml +++ b/pom.xml @@ -4,7 +4,7 @@ 4.0.0 io.github.cmu-phil tetrad - 7.6.1-SNAPSHOT + 7.6.1 pom Tetrad Project diff --git a/tetrad-gui/pom.xml b/tetrad-gui/pom.xml index 820d99b993..0483b1ab23 100644 --- a/tetrad-gui/pom.xml +++ b/tetrad-gui/pom.xml @@ -6,7 +6,7 @@ io.github.cmu-phil tetrad - 7.6.1-SNAPSHOT + 7.6.1 tetrad-gui diff --git a/tetrad-lib/pom.xml b/tetrad-lib/pom.xml index 2d07d70c0e..b963ae99b4 100644 --- a/tetrad-lib/pom.xml +++ b/tetrad-lib/pom.xml @@ -6,7 +6,7 @@ io.github.cmu-phil tetrad - 7.6.1-SNAPSHOT + 7.6.1 tetrad-lib