Skip to content

Commit a906954

Browse files
committed
Enhance DPG functionality by adding community visualization support and updating requirements
1 parent 8bb3583 commit a906954

11 files changed

Lines changed: 265 additions & 106 deletions

README.md

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ The graph-based nature of DPG provides significant enhancements in the direction
4646
To install DPG locally, first clone the repository:
4747

4848
```bash
49-
git clone https://github.com/LeonardoArrighi/DPG.git
49+
git clone https://github.com/Meta-Group/DPG.git
5050
cd DPG
5151
```
5252

@@ -57,7 +57,7 @@ pip install -e .
5757

5858
Alternatively, if using `pip directly`:
5959
```bash
60-
pip install git+https://github.com/LeonardoArrighi/DPG.git
60+
pip install git+https://github.com/Meta-Group/DPG.git
6161
```
6262
**Troubleshooting:** If you encounter dependency conflicts, we recommend using a virtual environment:
6363

@@ -132,11 +132,11 @@ dpg_model, nodes_list = dpg.to_networkx(dot)
132132
# Extract and visualize
133133
dpg_metrics = GraphMetrics.extract_graph_metrics(dpg_model, nodes_list,target_names=np.unique(y_train).astype(str).tolist())
134134
df = NodeMetrics.extract_node_metrics(dpg_model, nodes_list)
135-
plot_dpg("dpg_output.png", dot, df_nodes, dpg_metrics, save_dir="datasets", communities=True, class_flag=True)
135+
plot_dpg_communities("dpg_output", dot, df, dpg_metrics, save_dir="datasets", class_flag=True, export_pdf=True)
136136
```
137137
#### Output:
138138
<p align="center">
139-
<img src="https://github.com/LeonardoArrighi/DPG/blob/main/dpg_image_examples/dpg_output.png_communities.png?raw=true" width="600" />
139+
<img src="https://github.com/LeonardoArrighi/DPG/blob/main/dpg_image_examples/dpg_output_communities.png?raw=true" width="600" />
140140
</p>
141141

142142
#### CLI scripts
@@ -157,7 +157,7 @@ The DPG output, through `run_dpg_standard.py` or `run_dpg_custom.py`, produces s
157157
- a `.txt` file containing the Random Forest statistics (accuracy, confusion matrix, classification report)
158158

159159
## Easy usage
160-
Usage: `python run_dpg_standard.py --dataset <dataset_name> --n_learners <integer_number> --pv <threshold_value> --t <integer_number> --model_name <str_model_name> --dir <save_dir_path> --plot --save_plot_dir <save_plot_dir_path> --attribute <attribute> --communities --class_flag`
160+
Usage: `python run_dpg_standard.py --dataset <dataset_name> --n_learners <integer_number> --pv <threshold_value> --t <integer_number> --model_name <str_model_name> --dir <save_dir_path> --plot --save_plot_dir <save_plot_dir_path> --attribute <attribute> --communities --clusters --threshold_clusters <float> --class_flag --seed <int>`
161161
Where:
162162
- `dataset` is the name of the standard classification `sklearn` dataset to be analyzed;
163163
- `n_learners` is the number of base learners for the Random Forest;
@@ -169,9 +169,12 @@ Where:
169169
- `save_plot_dir` is the path of the directory to save the plot image;
170170
- `attribute` is the specific node metric which can be visualized on the DPG;
171171
- `communities` is a store_true variable which can be added to visualize communities on the DPG;
172-
- `class_flag` is a store_true variable which can be added to highlight class nodes.
172+
- `clusters` is a store_true variable which can be added to visualize clusters on the DPG;
173+
- `threshold_clusters` is the threshold used to detect ambiguous nodes in clusters;
174+
- `class_flag` is a store_true variable which can be added to highlight class nodes;
175+
- `seed` controls the random split.
173176

174-
Disclaimer: `attribute` and `communities` can not be added together, since DPG supports just one of the two visualizations.
177+
Disclaimer: `attribute`, `communities`, and `clusters` are mutually exclusive: DPG supports just one visualization mode at a time.
175178

176179
The usage of `run_dpg_custom.py` is similar, but it requires another parameter:
177180
- `target_column`, which is the name of the column to be used as the target variable;

dpg/sklearn_dpg.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from ucimlrepo import fetch_ucirepo
1717

1818
from .core import DecisionPredicateGraph
19-
from .visualizer import plot_dpg
19+
from .visualizer import plot_dpg, plot_dpg_communities
2020
from .utils import get_dpg_edge_metrics, clustering
2121
from metrics.nodes import NodeMetrics
2222
from metrics.graph import GraphMetrics
@@ -208,18 +208,27 @@ def test_dpg(datasets: str,
208208
)
209209
plot_name += f"_{model_name}_l{n_learners}_pv{perc_var}_t{decimal_threshold}_{seed}"
210210

211-
plot_dpg(
212-
plot_name,
213-
dot,
214-
df,
215-
df_edges,
216-
df_dpg,
217-
save_dir=save_plot_dir,
218-
attribute=attribute,
219-
communities=communities,
220-
clusters=clusters,
221-
threshold_clusters=threshold_clusters,
222-
class_flag=class_flag
223-
)
224-
225-
return df, df_edges, df_dpg, clusters, node_prob, confidence
211+
if communities:
212+
plot_dpg_communities(
213+
plot_name,
214+
dot,
215+
df,
216+
df_dpg,
217+
save_dir=save_plot_dir,
218+
class_flag=class_flag,
219+
df_edges=df_edges,
220+
)
221+
else:
222+
plot_dpg(
223+
plot_name,
224+
dot,
225+
df,
226+
df_edges,
227+
save_dir=save_plot_dir,
228+
attribute=attribute,
229+
clusters=clusters,
230+
threshold_clusters=threshold_clusters,
231+
class_flag=class_flag,
232+
)
233+
234+
return df, df_edges, df_dpg, clusters, node_prob, confidence

dpg/utils.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import re
23
import shutil
34
import yaml
45
from graphviz import Digraph
@@ -45,22 +46,19 @@ def highlight_class_node(dot, dpg_config=None):
4546

4647
# Iterate over each line in the dot body
4748
for i, line in enumerate(dot.body):
48-
# Extract the node identifier from the line
49-
line_id = line.split(' ')[1].replace("\t", "")
50-
# Check if the node identifier contains "Class"
51-
if "Class" in line_id:
49+
# Check for class labels in the node attributes
50+
if 'label="Class' in line:
5251
new_attrs = f'fillcolor="{fillcolor}" shape={shape} style="{style}"'
53-
# If node already has attributes, modify them
5452
if '[' in line:
55-
parts = line.split('[')
56-
attrs = parts[1].rstrip(']')
57-
# Remove existing attributes we're replacing
58-
for attr in ['fillcolor', 'shape', 'style']:
59-
attrs = ' '.join([a for a in attrs.split() if not a.startswith(attr)])
60-
# Add new attributes
61-
dot.body[i] = f"{parts[0]}[{attrs} {new_attrs}]"
53+
pre, rest = line.split('[', 1)
54+
attrs = rest.rsplit(']', 1)[0]
55+
# Remove existing attributes we're replacing (quoted or unquoted)
56+
attrs = re.sub(r'\b(fillcolor|shape|style)=(".*?"|[^ \]]+)', '', attrs)
57+
attrs = re.sub(r'\s+', ' ', attrs).strip()
58+
if attrs:
59+
attrs = attrs + ' '
60+
dot.body[i] = f"{pre}[{attrs}{new_attrs}]"
6261
else:
63-
# Node has no attributes yet
6462
node_id = line.split(' ')[0]
6563
dot.body[i] = f'{node_id} [{new_attrs}]'
6664

@@ -312,4 +310,4 @@ def clustering(dpg_model, class_nodes, threshold = None):
312310
clusters['Ambiguous'].append(node)
313311

314312

315-
return clusters, node_probs, confidence
313+
return clusters, node_probs, confidence

0 commit comments

Comments
 (0)