In [None]:
import numpy as np
from attack_utils import get_attack_indices, SWAT_SUB_MAP

def separate_attacks_by_stage():
    # Get attack indices and true labels for SWAT
    attacks, true_labels = get_attack_indices("SWAT")
    
    # Create a dictionary to map components to their subsystem
    component_to_subsystem = {}
    for subsystem, components in SWAT_SUB_MAP.items():
        for component in components:
            component_to_subsystem[component] = subsystem
    
    # Initialize dictionary to store attacks by subsystem
    attacks_by_stage = {
        '1_Raw_Water_Tank': [],
        '2_Chemical': [],
        '3_UltraFilt': [],
        '4_DeChloro': [],
        '5_RO': [],
        '6_Return': []
    }
    
    # Group attacks by subsystem
    print("Grouping attacks by subsystem:")
    for i, (attack_indices, components) in enumerate(zip(attacks, true_labels)):
        affected_subsystems = set()
        
        # Find which subsystem(s) this attack affects
        for component in components:
            if component in component_to_subsystem:
                affected_subsystems.add(component_to_subsystem[component])
                print(f"Attack {i}: Component {component} belongs to {component_to_subsystem[component]}")
            else:
                print(f"Warning: Component {component} not found in any subsystem!")
        
        # Add attack index to all affected subsystems
        for subsystem in affected_subsystems:
            attacks_by_stage[subsystem].append((i, attack_indices, components))
    
    # Print summary
    print("\nSummary of attacks by stage:")
    for stage, stage_attacks in attacks_by_stage.items():
        print(f"\n{stage}: {len(stage_attacks)} attacks")
        for attack_num, indices, components in stage_attacks:
            print(f"  Attack {attack_num}: {components} (Indices {indices[0]}-{indices[-1]}, Duration: {len(indices)} samples)")
    
    return attacks_by_stage

# Run the function
attacks_by_stage = separate_attacks_by_stage()

# Additional analysis - count attacks per stage
stage_attack_counts = {stage: len(attacks) for stage, attacks in attacks_by_stage.items()}
print("\nNumber of attacks per stage:")
for stage, count in stage_attack_counts.items():
    print(f"{stage}: {count} attacks")

# Create a plot to visualize attack distribution by stage
import matplotlib.pyplot as plt

stages = list(stage_attack_counts.keys())
counts = list(stage_attack_counts.values())

plt.figure(figsize=(12, 6))
plt.bar(stages, counts)
plt.title('Distribution of SWAT Attacks by Subsystem')
plt.xlabel('Subsystem')
plt.ylabel('Number of Attacks')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

# Plot attack durations by stage
plt.figure(figsize=(12, 6))
for i, (stage, stage_attacks) in enumerate(attacks_by_stage.items()):
    if not stage_attacks:
        continue
    
    durations = [len(indices) for _, indices, _ in stage_attacks]
    plt.scatter([i] * len(durations), durations, label=stage)
    
plt.title('Attack Durations by Subsystem')
plt.xlabel('Subsystem')
plt.ylabel('Duration (samples)')
plt.xticks(range(len(stages)), stages, rotation=45)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
self.complex.add_cell(['MV-101', 'FIT-101'], rank=1) # valve affects flow
# Flow affects tank level
self.complex.add_cell(['FIT-101', 'LIT-101'], rank=1) # flow affects level


# Primary pump control
self.complex.add_cell(['LIT-101', 'P-101'], rank=1) # tank level to pump
self.complex.add_cell(['P-101', 'FIT-201'], rank=1) # pump to stage 2 flow meter
# Backup pump relationships
self.complex.add_cell(['P-101', 'P-102'], rank=1) # primary-backup relationship
self.complex.add_cell(['P-102', 'FIT-201'], rank=1) # backup pump to stage 2
self.complex.add_cell(['FIT-201', 'AIT-202'], rank=1) # flow affects pH reading
self.complex.add_cell(['AIT-202', 'P-203'], rank=1) # pH reading controls HCl dosing
self.complex.add_cell(['P-203', 'P-204'], rank=1) # HCl pump and its backup

# ORP monitoring path
self.complex.add_cell(['FIT-201', 'AIT-203'], rank=1) # flow affects ORP reading
self.complex.add_cell(['AIT-203', 'P-205'], rank=1) # ORP reading controls NaOCl dosing
self.complex.add_cell(['P-205', 'P-206'], rank=1) # NaOCl pump and its backup
```
# conductivity control (NaCl)
self.complex.add_cell(['AIT-201', 'P-201'], rank=1) # conductivity controls NaCl dosing
self.complex.add_cell(['P-201', 'P-202'], rank=1) # primary-backup NaCl pumps
# pH control (HCl)
self.complex.add_cell(['AIT-202', 'P-203'], rank=1) # pH controls HCl dosing
self.complex.add_cell(['P-203', 'P-204'], rank=1) # primary-backup HCl pumps
# ORP control (NaOCl)
self.complex.add_cell(['AIT-203', 'P-205'], rank=1) # ORP controls NaOCl dosing
self.complex.add_cell(['P-205', 'P-206'], rank=1) # primary-backup NaOCl pumps

# tank level to pump control
self.complex.add_cell(['LIT-301', 'P-301'], rank=1) # T301 level controls pump
# primary pump flow
self.complex.add_cell(['P-301', 'FIT-301'], rank=1) # pump affects flow rate
self.complex.add_cell(['FIT-301', 'LIT-401'], rank=1) # flow affects T401 level
# backup pump relationship
self.complex.add_cell(['P-301', 'P-302'], rank=1) # primary-backup pair

# MV-201 controls flow to next stage
self.complex.add_cell(['MV-201', 'FIT-201'], rank=1) # valve affects flow
self.complex.add_cell(['MV-201', 'LIT-301'], rank=1) # valve affects T301 level
# chemical mixing must pass through MV-201
self.complex.add_cell(['P-201', 'MV-201'], rank=1) # NaCl flow
self.complex.add_cell(['P-203', 'MV-201'], rank=1) # HCl flow
self.complex.add_cell(['P-205', 'MV-201'], rank=1) # NaOCl flow


# pressure monitoring
self.complex.add_cell(['PIT-301', 'MV-301'], rank=1) # inlet pressure affects valve

# backwash control timing and relationships
self.complex.add_cell(['DPIT-301', 'MV-301'], rank=1) # diff pressure affects UF inlet valve
self.complex.add_cell(['DPIT-301', 'MV-302'], rank=1) # diff pressure affects outlet valve
# critical thresholds for membrane protection
self.complex.add_cell(['DPIT-301', 'P-301'], rank=1) # high diff pressure might need flow adjustment
# backwash triggering
self.complex.add_cell(['DPIT-301', 'MV-303'], rank=1) # diff pressure triggers backwash
# monitoring T401 level
self.complex.add_cell(['LIT-401', 'P-401'], rank=1) # tank level controls dechlorinisation
self.complex.add_cell(['P-401', 'P-402'], rank=1) # primary-backup pump pair
# flow monitoring
self.complex.add_cell(['FIT-401', 'P-401'], rank=1) # flow affects control dosing
# pressure monitoring for RO feed
self.complex.add_cell(['PIT-401', 'P-401'], rank=1) # pressure affects pump operation
# T402 NaHSO3 storage and dosing
self.complex.add_cell(['LIT-402', 'P-401'], rank=1) # T402 level affects dosing pump
self.complex.add_cell(['P-401', 'P-402'], rank=1) # primary-backup NaHSO3 pumps
self.complex.add_cell(['AIT-401', 'P-401'], rank=1) # residual chlorine analyzer controls dosing

# pumps (primary/backup)
self.complex.add_cell(['P-501', 'P-502'], rank=1) # primary-backup pump pair
# pressure monitoring chain
self.complex.add_cell(['PIT-501', 'P-501'], rank=1) # feed pressure affects pump
self.complex.add_cell(['PIT-502', 'P-501'], rank=1) # permeate pressure affects pump
self.complex.add_cell(['PIT-503', 'P-501'], rank=1) # reject pressure affects pump
# chemical monitoring
self.complex.add_cell(['AIT-501', 'P-501'], rank=1) # pH monitoring
self.complex.add_cell(['AIT-502', 'P-501'], rank=1) # ORP monitoring

self.complex.add_cell(['AIT-503', 'P-501'], rank=1) # feed conductivity
self.complex.add_cell(['AIT-504', 'P-501'], rank=1)

self.complex.add_cell(['FIT-501', 'P-501'], rank=1) # inlet flow affects pump
self.complex.add_cell(['FIT-502', 'P-501'], rank=1) # permeate flow affects pump
self.complex.add_cell(['FIT-503', 'P-501'], rank=1) # reject flow affects pump
self.complex.add_cell(['FIT-504', 'P-501'], rank=1) # recirculation flow affects pump

self.complex.add_cell(['AIT-503', 'FIT-502'], rank=1) # feed conductivity affects permeate flow
self.complex.add_cell(['AIT-504', 'FIT-503'], rank=1) # permeate conductivity affects reject flow


self.complex.add_cell(['PIT-301', 'P-601'], rank=1) # pressure affects backwash pump
self.complex.add_cell(['P-601', 'MV-301'], rank=1) # pump controls backwash flow
self.complex.add_cell(['P-602', 'MV-301'], rank=1) # backup pump relationship

# backwash triggers (every 30 mins OR high differential pressure)
self.complex.add_cell(['DPIT-301', 'P-601'], rank=1) # pressure triggers backwash
self.complex.add_cell(['P-601', 'P-602'], rank=1) # primary-backup pump pair

# using reject water from T602
self.complex.add_cell(['LIT-602', 'P-601'], rank=1) # reject tank level affects backwash
self.complex.add_cell(['FIT-601', 'P-601'], rank=1) # backwash flow monitoring


self.complex.add_cell(['FIT-601', 'MV-301'], rank=1) # backwash flow affects valve control
# membrane protection during backwash
self.complex.add_cell(['MV-303', 'FIT-301'], rank=1) # backwash drain valve affects flow
self.complex.add_cell(['MV-304', 'FIT-301'], rank=1) # UF drain valve affects flow


# UV dechlorinator control
self.complex.add_cell(['UV-401', 'AIT-401'], rank=1) # UV affects hardness
self.complex.add_cell(['FIT-401', 'UV-401'], rank=1) # flow affects UV operation

self.complex.add_cell(['AIT-402', 'P-403'], rank=1) # ORP controls NaHSO3 dosing
self.complex.add_cell(['P-403', 'P-404'], rank=1) # primary-backup NaHSO3 pumps

# UF feed pumps (was reversed in list)
self.complex.add_cell(['P-301', 'P-302'],rank = 1)
# additional RO hardness monitoring
self.complex.add_cell(['AIT-401', 'UV-401'], rank=1) # hardness affects UV control
self.complex.add_cell(['AIT-402', 'P-203'], rank=1) # ORP affects NaHSO3 dosing

self.complex.add_cell(['AIT-402', 'P-205'], rank=1) # ORP affects NaOCl dosing
# UV dechlorinator process chain
self.complex.add_cell(['LIT-401', 'UV-401'], rank=1) # tank level affects UV operation

self.complex.add_cell(['UV-401', 'P-501'], rank=1) # UV affects RO feed pump


In [2]:
import sys
import numpy as np
from IPython.display import display, HTML
import networkx as nx
import matplotlib.pyplot as plt

# Import the enhanced SWATComplex class
from swat_topology import SWATComplex

# Create a bold section heading function for nicer output
def section_heading(title):
    display(HTML(f"<h3 style='background-color:#f0f0f0; padding:10px;'>{title}</h3>"))

# Create an instance of SWATComplex (this will build the complex)
section_heading("Building SWAT Combinatorial Complex")
swat_complex = SWATComplex()

# Display basic information about the complex
section_heading("Complex Information")
print(swat_complex.get_complex())



Building SWAT combinatorial complex...
Adding 51 components as rank 0 cells
Adding 86 specific component relationships as rank 1 cells
  Added 1-cell: [MV101, FIT101] (valve affects flow)
  Added 1-cell: [FIT101, LIT101] (flow affects level)
  Added 1-cell: [LIT101, P101] (tank level controls pump)
  Added 1-cell: [P101, FIT201] (pump to stage 2 flow meter)
  Added 1-cell: [P102, FIT201] (backup pump to stage 2)
  Added 1-cell: [FIT201, AIT202] (flow affects pH reading)
  Added 1-cell: [FIT201, AIT201] (flow affects conductivity reading)
  Added 1-cell: [FIT201, AIT203] (flow affects ORP reading)
  Added 1-cell: [AIT201, P201] (conductivity controls NaCl dosing)
  Added 1-cell: [AIT201, P202] (conductivity reading controls backup NaCl dosing)
  Added 1-cell: [AIT202, P203] (pH reading controls HCl dosing)
  Added 1-cell: [AIT202, P204] (pH reading controls backup HCl dosing)
  Added 1-cell: [AIT203, P205] (ORP reading controls NaOCl dosing)
  Added 1-cell: [AIT203, P206] (ORP reading c

Combinatorial Complex with 51 nodes and cells with ranks [0, 1, 2] and sizes (51, 74, 5) 


In [3]:
from swat_topology import SWATComplex

# Create the complex
swat_complex = SWATComplex()



# Get the underlying toponetx complex for further analysis
complex = swat_complex.get_complex()


Building SWAT combinatorial complex...
Adding 51 components as rank 0 cells
Adding 86 specific component relationships as rank 1 cells
  Added 1-cell: [MV101, FIT101] (valve affects flow)
  Added 1-cell: [FIT101, LIT101] (flow affects level)
  Added 1-cell: [LIT101, P101] (tank level controls pump)
  Added 1-cell: [P101, FIT201] (pump to stage 2 flow meter)
  Added 1-cell: [P102, FIT201] (backup pump to stage 2)
  Added 1-cell: [FIT201, AIT202] (flow affects pH reading)
  Added 1-cell: [FIT201, AIT201] (flow affects conductivity reading)
  Added 1-cell: [FIT201, AIT203] (flow affects ORP reading)
  Added 1-cell: [AIT201, P201] (conductivity controls NaCl dosing)
  Added 1-cell: [AIT201, P202] (conductivity reading controls backup NaCl dosing)
  Added 1-cell: [AIT202, P203] (pH reading controls HCl dosing)
  Added 1-cell: [AIT202, P204] (pH reading controls backup HCl dosing)
  Added 1-cell: [AIT203, P205] (ORP reading controls NaOCl dosing)
  Added 1-cell: [AIT203, P206] (ORP reading c

In [6]:
row, column, B1 = complex.incidence_matrix(0, 1, index=True)
row1, column1, B2 = complex.incidence_matrix(1, 2, index=True)
print("rank 0:")
print(row)
print("rank 1:")
print(column)
print("rank 2:")
print(column1)

print("B1:")
print(B1)
print("B2:")
print(B2)

rank 0:
OrderedDict([(frozenset({'MV101'}), 0), (frozenset({'LIT101'}), 1), (frozenset({'FIT101'}), 2), (frozenset({'P101'}), 3), (frozenset({'P102'}), 4), (frozenset({'P201'}), 5), (frozenset({'P202'}), 6), (frozenset({'P203'}), 7), (frozenset({'P204'}), 8), (frozenset({'P205'}), 9), (frozenset({'P206'}), 10), (frozenset({'FIT201'}), 11), (frozenset({'AIT201'}), 12), (frozenset({'AIT202'}), 13), (frozenset({'AIT203'}), 14), (frozenset({'MV201'}), 15), (frozenset({'FIT301'}), 16), (frozenset({'LIT301'}), 17), (frozenset({'DPIT301'}), 18), (frozenset({'P301'}), 19), (frozenset({'P302'}), 20), (frozenset({'MV301'}), 21), (frozenset({'MV302'}), 22), (frozenset({'MV303'}), 23), (frozenset({'MV304'}), 24), (frozenset({'UV401'}), 25), (frozenset({'P401'}), 26), (frozenset({'P402'}), 27), (frozenset({'P403'}), 28), (frozenset({'P404'}), 29), (frozenset({'AIT401'}), 30), (frozenset({'AIT402'}), 31), (frozenset({'FIT401'}), 32), (frozenset({'LIT401'}), 33), (frozenset({'AIT501'}), 34), (frozens

In [9]:
A01=complex.adjacency_matrix(0, 1).todense()
A01_sparse=complex.adjacency_matrix(0, 1)
print("A01:")
print(A01)
print("A01_sparse:")
print(A01_sparse)
A02=complex.adjacency_matrix(0, 2).todense()
print("A02:")
print(A02)
A02_sparse=complex.adjacency_matrix(0, 2)
print("A02_sparse:")
print(A02_sparse)


A01:
[[0 0 1 ... 0 0 0]
 [0 0 1 ... 0 0 0]
 [1 1 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
A01_sparse:
  (2, 0)	1
  (0, 0)	0
  (3, 1)	1
  (2, 1)	1
  (1, 1)	0
  (1, 2)	1
  (2, 2)	0
  (0, 2)	1
  (11, 3)	1
  (3, 3)	0
  (1, 3)	1
  (11, 4)	1
  (4, 4)	0
  (15, 5)	1
  (12, 5)	1
  (5, 5)	0
  (12, 6)	1
  (6, 6)	0
  (15, 7)	1
  (13, 7)	1
  (7, 7)	0
  (13, 8)	1
  (8, 8)	0
  (31, 9)	1
  (15, 9)	1
  :	:
  (35, 44)	1
  (34, 44)	1
  (43, 44)	1
  (42, 44)	1
  (29, 44)	1
  (44, 44)	0
  (28, 44)	1
  (39, 45)	1
  (45, 45)	0
  (37, 45)	1
  (40, 46)	1
  (48, 46)	1
  (46, 46)	0
  (50, 47)	1
  (47, 47)	0
  (21, 48)	1
  (23, 48)	1
  (48, 48)	0
  (46, 48)	1
  (49, 49)	0
  (18, 50)	1
  (47, 50)	1
  (23, 50)	1
  (50, 50)	0
  (21, 50)	1
A02:
[[0 1 1 ... 0 0 0]
 [1 0 1 ... 0 0 0]
 [1 1 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
A02_sparse:
  (3, 0)	1
  (2, 0)	1
  (1, 0)	1
  (0, 0)	0
  (3, 1)	1
  (2, 1)	1
  (1, 1)	0
  (0, 1)	1
  (3, 2)	1
  (2, 2)	0


  self._set_arrayXarray(i, j, x)


In [12]:
CA10=complex.coadjacency_matrix(1, 0).todense()
print("CA10:")
print(CA10)
CA10_sparse=complex.coadjacency_matrix(1, 0)
print("CA10_sparse:")
print(CA10_sparse)

CA20=complex.coadjacency_matrix(2, 0).todense()
print("CA20:")
print(CA20)
CA20_sparse=complex.coadjacency_matrix(2, 0)
print("CA20_sparse:")
print(CA20_sparse)

CA21=complex.coadjacency_matrix(2, 1).todense()
print("CA21:")
print(CA21)
CA21_sparse=complex.coadjacency_matrix(2, 1)
print("CA21_sparse:")

CA10:
[[0 1 0 ... 0 0 0]
 [1 0 1 ... 0 0 0]
 [0 1 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 1]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 1 0 0]]
CA10_sparse:
  (1, 0)	1
  (0, 0)	0
  (0, 1)	1
  (2, 1)	1
  (1, 1)	0
  (3, 2)	1
  (2, 2)	0
  (1, 2)	1
  (7, 3)	1
  (6, 3)	1
  (5, 3)	1
  (4, 3)	1
  (3, 3)	0
  (2, 3)	1
  (7, 4)	1
  (6, 4)	1
  (5, 4)	1
  (3, 4)	1
  (4, 4)	0
  (11, 5)	1
  (10, 5)	1
  (7, 5)	1
  (6, 5)	1
  (5, 5)	0
  (4, 5)	1
  :	:
  (72, 70)	1
  (70, 70)	0
  (67, 70)	1
  (26, 70)	1
  (24, 70)	1
  (73, 71)	1
  (70, 71)	1
  (69, 71)	1
  (71, 71)	0
  (70, 72)	1
  (67, 72)	1
  (26, 72)	1
  (24, 72)	1
  (72, 72)	0
  (69, 72)	1
  (68, 72)	1
  (22, 72)	1
  (71, 73)	1
  (70, 73)	1
  (69, 73)	1
  (73, 73)	0
  (25, 73)	1
  (24, 73)	1
  (23, 73)	1
  (22, 73)	1
CA20:
[[0 0 0 0 0]
 [0 0 1 0 0]
 [0 1 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]]
CA20_sparse:
  (0, 0)	0
  (2, 1)	1
  (1, 1)	0
  (2, 2)	0
  (1, 2)	1
  (3, 3)	0
  (4, 4)	0
CA21:
[[0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]]
CA21_sparse:


In [16]:


B01=complex.incidence_matrix(0, 1).todense()
print("B01:")
print(B01)
B01_sparse=complex.incidence_matrix(0, 1)
print("B01_sparse:")
print(B01_sparse)

B02=complex.incidence_matrix(0, 2).todense()
print("B02:")
print(B02)
B02_sparse=complex.incidence_matrix(0, 2)
print("B02_sparse:")
print(B02_sparse)

B12=complex.incidence_matrix(1, 2).todense()
print("B12:")
print(B12)
B12_sparse=complex.incidence_matrix(1, 2)
print("B12_sparse:")
print(B12_sparse)





B01:
[[1 0 0 ... 0 0 0]
 [0 1 1 ... 0 0 0]
 [1 1 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 1 0 1]]
B01_sparse:
  (0, 0)	1
  (1, 1)	1
  (1, 2)	1
  (2, 0)	1
  (2, 1)	1
  (3, 2)	1
  (3, 3)	1
  (4, 4)	1
  (5, 8)	1
  (5, 16)	1
  (6, 9)	1
  (7, 10)	1
  (7, 17)	1
  (8, 11)	1
  (9, 12)	1
  (9, 18)	1
  (9, 40)	1
  (10, 13)	1
  (11, 3)	1
  (11, 4)	1
  (11, 5)	1
  (11, 6)	1
  (11, 7)	1
  (12, 6)	1
  (12, 8)	1
  :	:
  (42, 54)	1
  (42, 55)	1
  (42, 56)	1
  (42, 57)	1
  (43, 43)	1
  (43, 50)	1
  (44, 44)	1
  (44, 45)	1
  (44, 49)	1
  (44, 50)	1
  (44, 63)	1
  (44, 64)	1
  (44, 65)	1
  (45, 58)	1
  (45, 62)	1
  (46, 59)	1
  (46, 60)	1
  (47, 71)	1
  (48, 59)	1
  (48, 67)	1
  (48, 68)	1
  (50, 69)	1
  (50, 70)	1
  (50, 71)	1
  (50, 73)	1
B02:
[[1 0 0 0 0]
 [1 0 0 0 0]
 [1 0 0 0 0]
 [1 0 0 0 0]
 [0 0 0 0 0]
 [0 1 0 0 0]
 [0 1 0 0 0]
 [0 1 0 0 0]
 [0 1 0 0 0]
 [0 1 0 0 0]
 [0 1 0 0 0]
 [0 1 0 0 0]
 [0 1 0 0 0]
 [0 1 0 0 0]
 [0 1 0 0 0]
 [0 1 1 0 0]
 [0 0 1 0 0]
 [0 0 1 0 0]
 [

In [None]:
# Get incidence matrix between rank 1 and 2
row_dict, col_dict, matrix = complex.incidence_matrix(1, 2, index=True)
print("Rank 2 cells:")
for i, cell in enumerate(col_dict.keys()):
    print(f"  {i}: {cell}")

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import toponetx as tnx
import topoembedx as tex
from swat_topology import SWATComplex

# Create the SWAT complex
swat_complex = SWATComplex()
complex = swat_complex.get_complex()

# Create a DeepCell model for embedding
model = tex.DeepCell(dimensions=2)

# Create several different embeddings to visualize different aspects of the complex
embeddings = {}

# 1. Component relationships (rank 0 through rank 1 cells)
print("Creating component relationship embedding...")
model.fit(complex, neighborhood_type="adj", neighborhood_dim={"rank": 0, "via_rank": 1})
embeddings["components"] = model.get_embedding(get_dict=True)

# 2. PLC structure (rank 0 to rank 2)
print("Creating PLC structure embedding...")
model.fit(complex, neighborhood_type="adj", neighborhood_dim={"rank": 0, "via_rank": 2})
embeddings["plc"] = model.get_embedding(get_dict=True)




# Helper function to create visualizations
def visualize_embedding(embedding_dict, title, color_by=None, annotate=True):
    plt.figure(figsize=(12, 10))
    
    # Extract points
    components = list(embedding_dict.keys())
    x = [embedding_dict[comp][0] for comp in components]
    y = [embedding_dict[comp][1] for comp in components]
    
    
    # Create scatter plot
    scatter = plt.scatter(x, y, alpha=0.7, s=100)
    
    # Add annotations
    if annotate:
        for i, comp in enumerate(components):
            plt.annotate(str(comp), (x[i], y[i]), 
                         fontsize=8, ha='center', va='center',
                         bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.7))
    
    
    plt.title(title)
    plt.xlabel("Dimension 1")
    plt.ylabel("Dimension 2")
    plt.grid(alpha=0.3)
    plt.tight_layout()
    return plt

# Generate visualizations
print("Generating visualizations...")

# 1. Component relationships by type
viz1 = visualize_embedding(embeddings["components"], 
                          "SWAT Components - Colored by Type")
viz1.savefig("swat_components_by_type.png")


# 3. PLC influence structure
viz3 = visualize_embedding(embeddings["plc"], 
                          "SWAT Components by PLC Control", 
                          annotate=False)
viz3.savefig("swat_plc_influence.png")




print("Visualizations complete!")

In [None]:
import matplotlib.pyplot as plt
import toponetx as tnx

import topoembedx as tex

#

# Create a model
model = tex.DeepCell()

# Fit the model to the cell complex
model.fit(complex, neighborhood_type="coadj", neighborhood_dim={"rank": 1, "via_rank": 0})

# Get the embeddings
embedded_points = model.get_embedding(get_dict=True)

# Prepare data for plotting
x = [embedded_points[cell][0] for cell in embedded_points]
y = [embedded_points[cell][1] for cell in embedded_points]
cell_labels = [f"Cell {cell}" for cell in embedded_points]

# Plotting
plt.figure(figsize=(10, 8))
plt.scatter(x, y, c="blue", label="Projected Points")

# Annotate the points to correspond with cells
for i, label in enumerate(cell_labels):
    plt.annotate(
        label, (x[i], y[i]), textcoords="offset points", xytext=(0, 10), ha="center"
    )

# Label axes and add title
plt.xlabel("X-axis")
plt.ylabel("Y-axis")
plt.title("Projection of 2-dim Combinatorial Complex in 2D")

# Display the plot
plt.legend()
plt.grid(True)
plt.show()

In [19]:
import torch 
print(torch.version.cuda)

None
