forked from argriffing/xgcode
-
Notifications
You must be signed in to change notification settings - Fork 0
/
20080910a.py
62 lines (56 loc) · 2.21 KB
/
20080910a.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""Evaluate a loss function that compares a query tree to a reference tree.
Note that the distance functions used here are not commutative.
The weighted split distance gives more importance to deep partitions implied by the reference tree.
"""
from SnippetUtil import HandlingError
import NewickIO
import FelTree
import TreeComparison
import Form
import FormOut
def get_form():
"""
@return: the body of a form
"""
# define the default tree string
default_tree_string = '((A:1, B:1):1, (C:1, D:1):1, (E:1, F:1):1);'
# define the form objects
form_objects = [
Form.MultiLine('query', 'query tree', default_tree_string),
Form.MultiLine('reference', 'reference tree', default_tree_string),
Form.RadioGroup('loss', 'loss function', [
Form.RadioItem('uniform', 'split distance'),
Form.RadioItem('weighted', 'weighted split distance', True)]),
Form.CheckGroup('options', 'normalization options', [
Form.CheckItem('normalize',
'compute the normalized loss', True)])]
return form_objects
def get_form_out():
return FormOut.Report()
def get_response_content(fs):
# read the query tree
query_tree = NewickIO.parse(fs.query, FelTree.NewickTree)
# read the reference tree
reference_tree = NewickIO.parse(fs.reference, FelTree.NewickTree)
# calculate the loss using the requested loss function
if fs.uniform:
loss_numerator = TreeComparison.get_split_distance(
query_tree, reference_tree)
elif fs.weighted:
loss_numerator = TreeComparison.get_weighted_split_distance(
query_tree, reference_tree)
# do the normalization if requested
if fs.normalize:
if fs.uniform:
loss_denominator = float(
TreeComparison.get_nontrivial_split_count(reference_tree))
elif fs.weighted:
loss_denominator = float(
TreeComparison.get_weighted_split_count(reference_tree))
else:
loss_denominator = 1
# return the response
if loss_denominator:
return str(loss_numerator / loss_denominator) + '\n'
else:
return 'normalization failed\n'