Skip to content

Commit

Permalink
Merge 9135e81 into 549b369
Browse files Browse the repository at this point in the history
  • Loading branch information
Thanathan-k committed Jan 12, 2018
2 parents 549b369 + 9135e81 commit 115a4bc
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 38 deletions.
13 changes: 13 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,19 @@
<version>1.3</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>javax.validation</groupId>
<artifactId>validation-api</artifactId>
<version>2.0.0.Final</version>
</dependency>

<!-- Logging via log4j2 -->
<dependency>
<groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-api</artifactId>
<version>2.10.0</version>
</dependency>

</dependencies>

<build>
Expand Down
120 changes: 84 additions & 36 deletions src/main/java/com/formulasearchengine/mathmlsim/Distances.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,23 @@
import com.formulasearchengine.mathmlsim.distances.earthmover.Signature;
import com.formulasearchengine.mathmltools.mml.CMMLInfo;
import com.formulasearchengine.mathmltools.xmlhelper.XMLHelper;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;

import javax.xml.xpath.XPathExpressionException;
import java.text.DecimalFormat;
import java.util.*;

/**
* Created by Felix Hamborg <felixhamborg@gmail.com> on 05.12.16.
*/
public class Distances {

private static final Log LOG = LogFactory.getLog(Distances.class);
private static final Logger LOG = LogManager.getLogger(Distances.class.getName());

private static final java.text.DecimalFormat DECIMAL_FORMAT = new java.text.DecimalFormat("#.###");
private static final DecimalFormat decimalFormat = new DecimalFormat("#.###");

/**
* probably only makes sense to compute this on CI
Expand All @@ -24,25 +30,28 @@ public class Distances {
* @param h2
* @return
*/
public static double computeEarthMoverAbsoluteDistance(java.util.Map<String, Integer> h1, java.util.Map<String, Integer> h2) {
public static double computeEarthMoverAbsoluteDistance(Map<String, Double> h1, Map<String, Double> h2) {
Signature s1 = EarthMoverDistanceWrapper.histogramToSignature(h1);
Signature s2 = EarthMoverDistanceWrapper.histogramToSignature(h2);

return JFastEMD.distance(s1, s2, 0.0);
}

public static double computeRelativeDistance(java.util.Map<String, Integer> h1, java.util.Map<String, Integer> h2) {
final java.util.Set<String> keySet = new java.util.HashSet();
keySet.addAll(h1.keySet());
keySet.addAll(h2.keySet());
final double numberOfUniqueElements = keySet.size();
if (numberOfUniqueElements == 0.0) {
public static double computeRelativeDistance(Map<String, Double> h1, Map<String, Double> h2) {
int totalNumberOfElements = 0;
for (Double frequency : h1.values()) {
totalNumberOfElements += frequency;
}
for (Double frequency : h2.values()) {
totalNumberOfElements += frequency;
}
if (totalNumberOfElements == 0) {
return 0.0;
}

final double absoluteDistance = computeAbsoluteDistance(h1, h2);

return absoluteDistance / numberOfUniqueElements;
return absoluteDistance / totalNumberOfElements;
}


Expand All @@ -53,10 +62,10 @@ public static double computeRelativeDistance(java.util.Map<String, Integer> h1,
* @param h2
* @return
*/
public static double computeAbsoluteDistance(java.util.Map<String, Integer> h1, java.util.Map<String, Integer> h2) {
public static double computeAbsoluteDistance(Map<String, Double> h1, Map<String, Double> h2) {
double distance = 0;

final java.util.Set<String> keySet = new java.util.HashSet();
final Set<String> keySet = new HashSet();
keySet.addAll(h1.keySet());
keySet.addAll(h2.keySet());

Expand All @@ -82,37 +91,56 @@ public static double computeAbsoluteDistance(java.util.Map<String, Integer> h1,
* @param nodes
* @return
*/
protected static java.util.HashMap<String, Integer> contentElementsToHistogram(org.w3c.dom.NodeList nodes) {
final java.util.HashMap<String, Integer> histogram = new java.util.HashMap<>();
protected static HashMap<String, Double> contentElementsToHistogram(NodeList nodes) {
final HashMap<String, Double> histogram = new HashMap<>();

for (int i = 0; i < nodes.getLength(); i++) {
org.w3c.dom.Node node = nodes.item(i);
Node node = nodes.item(i);
String contentElementName = node.getTextContent().trim();
// increment frequency by 1
histogram.put(contentElementName, histogram.getOrDefault(contentElementName, 0) + 1);
histogram.put(contentElementName, histogram.getOrDefault(contentElementName, 0.0) + 1.0);
}

return histogram;
}

/**
* Adds all elements from one histogram to the other
* Adds all elements from all histogram
*
* @param h1
* @param h2
* @return
*/
protected static java.util.HashMap<String, Integer> histogramPlus(java.util.HashMap<String, Integer> h1, java.util.HashMap<String, Integer> h2) {
final java.util.Set<String> mergedKeys = new java.util.HashSet<>(h1.keySet());
mergedKeys.addAll(h2.keySet());
final java.util.HashMap<String, Integer> mergedHistogram = new java.util.HashMap<>();
public static Map<String, Double> histogramsPlus(List<Map<String, Double>> histograms) {
return histogramsPlus(histograms.toArray(new HashMap[histograms.size()]));
}

/**
* Adds all elements from all histogram
*
* @return
*/
@SafeVarargs
public static Map<String, Double> histogramsPlus(Map<String, Double>... histograms) {
switch (histograms.length) {
case 0:
throw new IllegalArgumentException("histograms.length=" + histograms.length + "; needs to be >= 2");
// return null;
case 1:
return histograms[0];
}


final Set<String> mergedKeys = new HashSet<>();
for (Map<String, Double> histogram : histograms) {
mergedKeys.addAll(histogram.keySet());
}
final HashMap<String, Double> mergedHistogram = new HashMap<>();

for (String key : mergedKeys) {
mergedHistogram.put(
key,
h1.getOrDefault(key, 0)
+ h2.getOrDefault(key, 0)
);
double value = 0.0;
for (Map<String, Double> histogram : histograms) {
value += histogram.getOrDefault(key, 0.0);
}
mergedHistogram.put(key, value);
}

return mergedHistogram;
Expand All @@ -125,20 +153,22 @@ protected static java.util.HashMap<String, Integer> histogramPlus(java.util.Hash
* @param tagName
* @return
*/
private static java.util.HashMap<String, Integer> strictCmmlInfoToHistogram(CMMLInfo strictCmml, String tagName) {
final org.w3c.dom.NodeList elements = strictCmml.getElementsByTagName(tagName);
private static HashMap<String, Double> strictCmmlInfoToHistogram(CMMLInfo strictCmml, String tagName) {
final NodeList elements = strictCmml.getElementsByTagName(tagName);
return contentElementsToHistogram(elements);
}



/**
* converts content math ml to a histogram for the given tagname, e.g., cn
*
* @param node
* @param tagName
* @return
*/
private static java.util.HashMap<String, Integer> cmmlNodeToHistrogram(org.w3c.dom.Node node, String tagName) throws javax.xml.xpath.XPathExpressionException {
final org.w3c.dom.NodeList elements = XMLHelper.getElementsB(node, "*//*:" + tagName);
private static HashMap<String, Double> cmmlNodeToHistrogram(Node node, String tagName) throws XPathExpressionException {
final NodeList elements = XMLHelper.getElementsB(node, "*//*:" + tagName);
return contentElementsToHistogram(elements);
}

Expand All @@ -149,15 +179,33 @@ private static java.util.HashMap<String, Integer> cmmlNodeToHistrogram(org.w3c.d
* @param tagName
* @param histogram
*/
private static void cleanupHistogram(String tagName, java.util.HashMap<String, Integer> histogram) {
private static void cleanupHistogram(String tagName, Map<String, Double> histogram) {
switch (tagName) {
case "csymbol":
histogram.remove("based_integer");
for (String key : ValidCSymbols.VALID_CSYMBOLS) {
histogram.remove(key);
}
break;
case "ci":
histogram.remove("integer");
break;
default:
case "cn":
Set<String> toberemovedKeys = new HashSet<>();
for (String key : histogram.keySet()) {
if (!isNumeric(key)) {
toberemovedKeys.add(key);
}
}
// now we can remove the keys
for (String key : toberemovedKeys) {
histogram.remove(key);
}
break;
}
}

private static boolean isNumeric(String str) {
return str.matches("-?\\d+(\\.\\d+)?"); //match a number with optional '-' and decimal.
}
}
Loading

0 comments on commit 115a4bc

Please sign in to comment.