# Sketching MT Framework

In [1]:
from sketchMT import *
import readMergeTree as rmt
import os
import networkx as nx
from weighted_hierarchy_pos import *

## Data Input 

This is an example to read our preset datasets. Our merge tree data is split into two files: ``treeEdges_monoMesh_*.txt`` and ``treeNodes_monoMesh_*.txt``, representing the data for tree edges and tree nodes, respectively.

To read your own dataset, your tree data should be saved in a ``nx.Graph`` object, in which each node has properties for its spacial coordinates (e.g., "x", "y"), its scalar value (e.g., "height"), and its critical type (0: minimum, 1: saddle, 2: maximum. This setting cannot be changed). 
Besides, you also need to provide the node id for the root node of the tree.

The tree data should be stored in a ``GWMergeTree`` object.

==========================================================

Our tree edge data format:

a_0, b_0  
a_1, b_1  
...  
a_{|E|-1}, b_{|E|-1}

Each row describes the indices of two nodes that the edge connecting in between. Edges are undirected.

==========================================================

Our tree node data format:

x_0, y_0, z_0, scalar_0, type_0  
x_1, y_1, z_1, scalar_1, type_1  
...  
x_{|V|-1}, y_{|V|-1}, z_{|V|-1}, scalar_{|V|-1}, type_{|V|-1}

Each row has five components: the "x", "y", "z" coordinates, the scalar value, and the critical point type for the node.


In [2]:
# dataset choices: ["HeatedCylinder", "CornerFlow", "VortexStreet", "VortexSlice", "RedSea", "HeatedFlowEnsemble", "MovingGaussian"]
dataset = "CornerFlow"
dataset_path = os.path.join("data", dataset)


mt_list = []
root_list = []
ranges_list = []
for _, _, files in os.walk(dataset_path):

    def key(s):
        try:
            int(s)
            return int(s)
        except ValueError:
            return len(files) + 1

    def endsWithTxt(s: str):
        if s.endswith(".txt"):
            return True
        return False

    txt_files = list(filter(endsWithTxt, files))
    if dataset == "HeatedFlowEnsemble":
        txt_files.sort(key=lambda x: (key(x.split("_")[1].replace("monoMesh","")), key(x.split(".")[0].split("_")[-1])))
    else:
        txt_files.sort(key=lambda x: key(x.split(".")[0].split("_")[-1]))
    
    # You need to specify the root node type. Choices: ["minimum", "maximum"]
    # (Avoid specifying merge tree type to avoid confusion between split tree and join tree in different contexts)
    for file in txt_files:
        trees, roots, ranges = rmt.get_trees(os.path.join(dataset_path, file), root_type="minimum")
        mt_list.extend(trees)
        root_list.extend(roots)
        ranges_list.extend(ranges)

assert (len(root_list) == len(mt_list))
gwmt_list = [GWMergeTree(mt_list[i], root_list[i]) for i in range(len(mt_list))]

## Parameter Initialization

We now specify the parameters to be passed to the GW Sketching framework, including the following:

*scalar_name*: the name of the scalar field in GWMergeTree objects.

*edge_weight_name*: the name of the weight of edges in GWMergeTree objects.

*weight_mode*: the strategy to encode $W$. Choices: ["shortestpath", "lca"].

*prob_distribution*: the strategy to encode $p$. Choices: ["uniform", "ancestor"]

In [3]:
scalar_name_dict = {
    "HeatedCylinder": "height",
    "HeatedFlowEnsemble": "height",
    "RedSea": "height",
    "VortexStreet": "height",
    "VortexSlice": "height",
    "CornerFlow": "height",
    "MovingGaussian": "height",
}

edge_weight_name_dict = {
    "HeatedCylinder": "weight",
    "HeatedFlowEnsemble": "weight",
    "RedSea": "weight",
    "VortexStreet": "weight",
    "VortexSlice": "weight",
    "CornerFlow": "weight",
    "MovingGaussian": "weight",
}

budget_dict = {
    "HeatedCylinder": 2,
    "HeatedFlowEnsemble": 3,
    "RedSea": 3,
    "VortexStreet": 3,
    "VortexSlice": 3,
    "CornerFlow": 3,
    "MovingGaussian": 2,
}
lambda_pers_dict = {
    "HeatedCylinder": 0.06,
}

In [4]:
# weight_modes = ["shortestpath", "lca"]
weight_mode = "shortestpath"
# weight_mode = "lca"

# prob_distributions = ["uniform", "ancestor"]
prob_distribution = "uniform"
# coupling_reordering = True

scalar_name = scalar_name_dict[dataset]
edge_weight_name = edge_weight_name_dict[dataset]
budget = budget_dict[dataset]
lambda_pers = lambda_pers_dict[dataset] * max(ranges_list)

In [5]:
# Validify the GWMergeTree object. This is not mandatory, but recommended to check whether your data input format is correct
for tree in gwmt_list:
    tree.label_validation([], scalar_name, edge_weight_name)

In [6]:
trees_path = os.path.join("tree-instances", dataset)
os.makedirs(trees_path, exist_ok=True)

if not os.path.exists(os.path.join(trees_path, "tree-000.jpg")):
    for i, t in enumerate(gwmt_list):
        plt.figure(figsize=(10, 10))
        nx.draw_networkx(t.tree, pos=weighted_hierarchy_pos(t.tree, root=t.root))
        plt.title("Tree-{} (root={})".format(str(i), str(t.root)))
        plt.savefig(fname= os.path.join(trees_path, "tree-" + str(i).zfill(3) + ".jpg"))
        plt.close()

## Framework Initialization

Initializing the SketchMT framework with given parameters and data input. 

The following items are completed in this step:
- Compute Frechet Mean
- Compute blowup matrix for each instance
- Vectorize blowup matrix

In [7]:
retest=False

In [8]:
sketchmt = SketchMT(
    gwmt_list,
    dataset, 
    scalar_name,
    edge_weight_name,
    weight_mode,
    prob_distribution,
    budget,
    retest=retest,
    GWIteration=50,
    lambda_persistence=lambda_pers,
)

Max Tree Size, ID = 18, 30
Computing Frechet Mean...
budget = 36
Computing Frechet Mean Done
Computing Blowups...
Computing Blowups Done
Getting Trees Vectors...
Getting Trees Vectors Done


## Applying Sketching Techniques

In our framework, We focus on the column subset selection (CSS) techniques for sketching. The selected column corresponds to the basis merge tree.

In [9]:
sketching_modes = ["CSS-IFS", "CSS-LSS"]
num_basis = list(range(2, 10))

In [10]:
colors = {"CSS-IFS": "green", "CSS-LSS": "blue"} 

solve_marks_1 = []
solve_names_1 = []
solve_marks_2 = []
solve_names_2 = []

plt.figure(figsize=(12, 6))

for mode in sketching_modes:
    sketch_losses, H0s, selected_indices, GW_losses = sketchmt.sketching(mode, num_basis)
    sum_sketch_losses = [np.sum(x) for x in sketch_losses]
    sum_GW_losses = [np.sum(x) for x in GW_losses]
    
    # we can plot a curve for the sum of the sketch loss. The elbow point of the curve indicates the optimal num of basis
    ax1 = plt.subplot(121)
    plt.xlabel("Number of Basis")
    plt.ylabel("Sketch error")
    mark, = plt.plot(num_basis, sum_sketch_losses, color=colors[mode], linewidth=4)
    solve_marks_1.append(mark)
    solve_names_1.append(mode)
    
    ax2 = plt.subplot(122)
    plt.xlabel("Number of Basis")
    plt.ylabel("GW loss")
    mark, = plt.plot(num_basis, sum_GW_losses, color=colors[mode], linewidth=4)
    solve_marks_2.append(mark)
    solve_names_2.append(mode)
    
    # save selected indices for each setting
    for e, num in enumerate(num_basis):
        selected_idx_settings = ["selected-basis", dataset, weight_mode, mode, str(num)]
        selected_idx_path = os.path.join(".", "/".join(selected_idx_settings))
        os.makedirs(selected_idx_path, exist_ok=True)
        
        np.savetxt(os.path.join(selected_idx_path, "selected-idx.txt"), selected_indices[e], delimiter=",")
        for idx in selected_indices[e]:
            plt.figure(figsize=(10, 10))
            nx.draw_networkx(sketchmt.trees[idx], pos=weighted_hierarchy_pos(sketchmt.trees[idx], root=sketchmt.roots[idx]))
            plt.title("Tree-{} (root={})".format(str(idx), str(sketchmt.roots[idx])))
            plt.savefig(fname= os.path.join(selected_idx_path, "tree-" + str(idx).zfill(3) + ".jpg"))
            plt.close()
        
        draw_H(H0s[e], selected_idx_path)
        draw_H(sketch_losses[e].reshape(1, -1), selected_idx_path, False)
        np.savetxt(os.path.join(selected_idx_path, "coefficient-matrix.txt"), H0s[e])
    
ax1.legend(solve_marks_1, solve_names_1)
ax2.legend(solve_marks_2, solve_names_2)
fig = plt.gcf()
fig.show()

loss_curve_settings = ["loss-curve", dataset]
loss_curve_path = os.path.join(".", "/".join(loss_curve_settings))
os.makedirs(loss_curve_path, exist_ok=True)
fig.savefig(os.path.join(loss_curve_path, "sketch-error-{}".format(weight_mode)))
plt.close()

Sketching:  2
Sketching:  3
Sketching:  4
Sketching:  5
Sketching:  6
Sketching:  7
Sketching:  8
Sketching:  9
Sketching:  2
Sketching:  3
Sketching:  4
Sketching:  5
Sketching:  6
Sketching:  7
Sketching:  8
Sketching:  9


  fig.show()
