# Description

In this notebook, we process the raw simulation data and prepare center node features, deviation features, and curvature features.

# Import

In [1]:
import numpy as np
import random
import os
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation
from config.dataConfig import Config
from src.data import SimulationData

def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")

seed = 42
set_seed(seed)

Random seed set as 42


# Config

In [2]:
config = Config()

# Process data

In [3]:
data = SimulationData(config)
# data.load_features('data/features.pkl')
data.extract_section_parameters()
data.generate_coordinate_features()
data.generate_deviation_features()
data.generate_curvature_features()
data.features_to_numpy()
data.filter_reserve_capacities(config.min_abs_reserve_capacity)
data.build_edge_features()
#data.transform_disp_into_deformation()

Section parameters extracted.
Generating coordinate features.
Coordinate features were generated.
Generating deviation features.
Deviation features generated.
Generating curvature features.
Curvature features generated.


In [4]:
data.make_id_split(n_clusters=config.n_test, seed=seed)
data.make_ood_split_by_dt_w(n_test=config.n_test)

In-distribution split flags have been set.
Out-of-distribution split flags have been set.


In [5]:
data.save_features('data/features.pkl')
data.load_features('data/features.pkl')

# Visualize Train-Test Split

In [6]:
data.plot_splits_3d('figures')
data.plot_splits_2d('figures')

# Visualize features

In [7]:
data.features.keys()

dict_keys(['W16X100', 'W16X26', 'W16X31', 'W16X36', 'W16X40', 'W16X45', 'W16X50', 'W16X57', 'W16X67', 'W16X77', 'W16X89', 'W18X106', 'W18X119', 'W18X130', 'W18X35', 'W18X40', 'W18X46', 'W18X50', 'W18X55', 'W18X60', 'W18X65', 'W18X71', 'W18X76', 'W18X86', 'W18X97', 'W21X101', 'W21X111', 'W21X122', 'W21X132', 'W21X44', 'W21X48', 'W21X50', 'W21X55', 'W21X57', 'W21X62', 'W21X68', 'W21X73', 'W21X83', 'W21X93', 'W24X103', 'W24X104', 'W24X117', 'W24X131', 'W24X146', 'W24X55', 'W24X62', 'W24X68', 'W24X76', 'W24X84', 'W24X94', 'W27X102', 'W27X114', 'W27X129', 'W27X146', 'W27X161', 'W27X84', 'W27X94', 'W30X108', 'W30X116', 'W30X124', 'W30X132', 'W30X148', 'W30X173', 'W30X90', 'W30X99', 'W33X118', 'W33X130', 'W33X141', 'W33X152', 'W33X169', 'W33X201', 'W36X135', 'W36X150', 'W36X160', 'W36X170', 'W36X182', 'W36X194', 'W36X210', 'W40X149', 'W40X167', 'W40X183', 'W40X211'])

## Coordinate Features

In [8]:
# point_cloud = np.concatenate(data.features['W40X211']['coordinate_features'][0], axis=0)

# fig = plt.figure(figsize=(12, 12))
# ax = fig.add_subplot(111, projection='3d')
# ax.scatter(point_cloud[:, 0], point_cloud[:, 1], point_cloud[:, 2], alpha=0.6, s=50)

# ax.set_xlabel('X')
# ax.set_ylabel('Y')
# ax.set_zlabel('Z')
# plt.legend()
# plt.show()

In [9]:
# coords = data.features['W40X211']['coordinate_features'][-100:]
# n_frames = len(coords)

# def frame_points(i):
#     c = coords[i]
#     if isinstance(c, (list, tuple)):
#         return np.concatenate(c, axis=0)
#     c = np.asarray(c)
#     return c if c.ndim == 2 else c.reshape(-1, 3)

# # глобальные пределы осей
# mins = np.array([np.inf, np.inf, np.inf])
# maxs = -mins
# for i in range(n_frames):
#     p = frame_points(i)
#     mins = np.minimum(mins, p.min(axis=0))
#     maxs = np.maximum(maxs, p.max(axis=0))

# fig = plt.figure(figsize=(8, 8))
# ax = fig.add_subplot(111, projection="3d")
# scat = ax.scatter([], [], [], alpha=0.6, s=50)

# ax.set_xlim(mins[0], maxs[0])
# ax.set_ylim(mins[1], maxs[1])
# ax.set_zlim(mins[2], maxs[2])
# ax.set_xlabel("X"); ax.set_ylabel("Y"); ax.set_zlabel("Z")

# def init():
#     p0 = frame_points(0)
#     scat._offsets3d = (p0[:,0], p0[:,1], p0[:,2])
#     return scat,

# def update(i):
#     p = frame_points(i)
#     scat._offsets3d = (p[:,0], p[:,1], p[:,2])
#     ax.set_title(f"Frame {i+1}/{n_frames}")
#     return scat,

# ani = FuncAnimation(fig, update, frames=n_frames, init_func=init,
#                     interval=300, blit=False, repeat=True)

# # Сохранение в gif
# ani.save("figures/coordinate_features.gif", writer="pillow", fps=5)

# plt.close(fig)  # чтобы не открывалось окно

## Deviation Features

In [10]:
# point_cloud = np.concatenate(data.features['W40X211']['deviation_features'][10], axis=0)

# fig = plt.figure(figsize=(12, 12))
# ax = fig.add_subplot(111, projection='3d')
# ax.scatter(point_cloud[:, 0], point_cloud[:, 1], point_cloud[:, 2], alpha=0.6, s=50)

# ax.set_xlabel('X')
# ax.set_ylabel('Y')
# ax.set_zlabel('Z')
# plt.legend()
# plt.show()

In [11]:
# coords = data.features['W40X211']['deviation_features'][-100:]
# n_frames = len(coords)

# def frame_points(i):
#     c = coords[i]
#     if isinstance(c, (list, tuple)):
#         return np.concatenate(c, axis=0)
#     c = np.asarray(c)
#     return c if c.ndim == 2 else c.reshape(-1, 3)

# # глобальные пределы осей
# mins = np.array([np.inf, np.inf, np.inf])
# maxs = -mins
# for i in range(n_frames):
#     p = frame_points(i)
#     mins = np.minimum(mins, p.min(axis=0))
#     maxs = np.maximum(maxs, p.max(axis=0))

# fig = plt.figure(figsize=(8, 8))
# ax = fig.add_subplot(111, projection="3d")
# scat = ax.scatter([], [], [], alpha=0.6, s=50)

# ax.set_xlim(mins[0], maxs[0])
# ax.set_ylim(mins[1], maxs[1])
# ax.set_zlim(mins[2], maxs[2])
# ax.set_xlabel("X"); ax.set_ylabel("Y"); ax.set_zlabel("Z")

# def init():
#     p0 = frame_points(0)
#     scat._offsets3d = (p0[:,0], p0[:,1], p0[:,2])
#     return scat,

# def update(i):
#     p = frame_points(i)
#     scat._offsets3d = (p[:,0], p[:,1], p[:,2])
#     ax.set_title(f"Frame {i+1}/{n_frames}")
#     return scat,

# ani = FuncAnimation(fig, update, frames=n_frames, init_func=init,
#                     interval=300, blit=False, repeat=True)

# # Сохранение в gif
# ani.save("figures/deviation_features.gif", writer="pillow", fps=5)

# plt.close(fig)  # чтобы не открывалось окно

## Curvature Features

In [12]:
# coords = np.array(data.features['W40X211']['deviation_features'][-100:])
# curvatures = np.array(data.features['W40X211']['curvature_features'][-100:])
# n_frames = len(coords)

# n_frames = len(coords)

# def frame_points(i):
#     return coords[i].reshape(-1, 3)

# def frame_colors(i):
#     return curvatures[i].reshape(-1)

# # --- global axis limits ---
# mins = coords.reshape(-1, 3).min(axis=0)
# maxs = coords.reshape(-1, 3).max(axis=0)

# fig = plt.figure(figsize=(8, 8))
# ax = fig.add_subplot(111, projection="3d")
# scat = ax.scatter([], [], [], alpha=0.6, s=50, c=[], cmap="viridis")

# ax.set_xlim(mins[0], maxs[0])
# ax.set_ylim(mins[1], maxs[1])
# ax.set_zlim(mins[2], maxs[2])
# ax.set_xlabel("X"); ax.set_ylabel("Y"); ax.set_zlabel("Z")

# # add colorbar with fixed limits across all frames
# vmin, vmax = curvatures.min(), curvatures.max()
# scat.set_clim(vmin=vmin, vmax=vmax)
# cb = plt.colorbar(scat, ax=ax, shrink=0.6, pad=0.05)
# cb.set_label("Curvature")

# def init():
#     p0 = frame_points(0)
#     c0 = frame_colors(0)
#     scat._offsets3d = (p0[:, 0], p0[:, 1], p0[:, 2])
#     scat.set_array(c0)
#     return scat,

# def update(i):
#     p = frame_points(i)
#     c = frame_colors(i)
#     scat._offsets3d = (p[:, 0], p[:, 1], p[:, 2])
#     scat.set_array(c)
#     ax.set_title(f"Frame {i+1}/{n_frames}")
#     return scat,

# ani = FuncAnimation(fig, update, frames=n_frames, init_func=init,
#                     interval=300, blit=False, repeat=True)

# ani.save("figures/curvature_features.gif", writer="pillow", fps=5)
# plt.close(fig)


# Visualise Graph Edges

In [13]:
# def _to_nodes_xyz(sample_pos, num_points, num_cols):
#     """
#     Ensure shape (num_points, num_cols, 3) then flatten to (N,3)
#     with node index = z * num_cols + col, z in [0..num_points-1].
#     """
#     # sample_pos is either (num_cols, num_points, 3) or (num_points, num_cols, 3)
#     if sample_pos.shape[0] == num_cols and sample_pos.shape[1] == num_points:
#         sample_pos = np.transpose(sample_pos, (1, 0, 2))  # -> (num_points, num_cols, 3)
#     assert sample_pos.shape == (num_points, num_cols, 3), f"Got {sample_pos.shape}"
#     nodes = sample_pos.reshape(-1, 3)  # (num_points*num_cols, 3)
#     return nodes

# def _set_equal_aspect_3d(ax):
#     """Equal aspect ratio for 3D axes."""
#     xs, ys, zs = [getattr(ax, f'get_{a}lim')() for a in 'xyz']
#     xmid, ymid, zmid = (sum(xs)/2, sum(ys)/2, sum(zs)/2)
#     radius = max(xs[1]-xs[0], ys[1]-ys[0], zs[1]-zs[0]) / 2
#     ax.set_xlim(xmid - radius, xmid + radius)
#     ax.set_ylim(ymid - radius, ymid + radius)
#     ax.set_zlim(zmid - radius, zmid + radius)

# def visualize_edges_for_one_sample(
#     data,
#     section_group=None,
#     sample_idx=0,
#     save_path=None,
#     show=True,
#     edge_alpha=0.5,
#     edge_linewidth=1.0,
#     node_size=16
# ):
#     """
#     Draw a 3D scatter of nodes + edges for one sample of coordinate features.

#     Args:
#         data: your SimulationData instance (with build_edge_features() already run)
#         section_group (str or None): which section to plot; if None, use the first with data
#         sample_idx (int): index of the sample inside that section's coordinate_features
#         save_path (str or None): if set, saves a PDF/PNG/etc. based on extension
#         show (bool): whether to call plt.show()
#     """
#     # pick a section that has coordinate_features
#     if section_group is None:
#         for sg in data.section_groups:
#             if len(data.features[sg]['coordinate_features']) > 0:
#                 section_group = sg
#                 break
#         if section_group is None:
#             raise RuntimeError("No section has coordinate_features to visualize.")

#     feat = data.features[section_group]
#     coord = np.asarray(feat['coordinate_features'])
#     if coord.ndim != 4:
#         raise ValueError(f"Expected coordinate_features as (n_samples, num_cols, num_points, 3) "
#                          f"or (n_samples, num_points, num_cols, 3). Got {coord.shape}")

#     if sample_idx < 0 or sample_idx >= coord.shape[0]:
#         raise IndexError(f"sample_idx {sample_idx} out of range [0, {coord.shape[0]-1}]")

#     # fetch edge_index and positions
#     edge_index = feat.get('edge_index', None)
#     if edge_index is None:
#         raise RuntimeError(f"No edge_index stored for section '{section_group}'. "
#                            f"Run data.build_edge_features() first.")

#     num_cols   = data.cfg.num_stripes
#     num_points = data.cfg.num_points_in_stripe

#     sample_pos = coord[sample_idx]  # shape either (num_cols,num_points,3) or (num_points,num_cols,3)
#     nodes = _to_nodes_xyz(sample_pos, num_points=num_points, num_cols=num_cols)  # (N,3)
#     x, y, z = nodes[:, 0], nodes[:, 1], nodes[:, 2]

#     # 3D plot
#     fig = plt.figure(figsize=(8, 6))
#     ax = fig.add_subplot(111, projection="3d")

#     # color nodes by stripe index (optional, nice to see columns)
#     # reconstruct stripe id per node: idx -> col = idx % num_cols
#     cols_idx = np.arange(nodes.shape[0]) % num_cols
#     scatter = ax.scatter(x, y, z, c=cols_idx, s=node_size, depthshade=True)

#     # draw edges
#     row, col = edge_index
#     # iterate and plot short segments
#     for r, c in zip(row.tolist(), col.tolist()):
#         ax.plot([x[r], x[c]], [y[r], y[c]], [z[r], z[c]],
#                 alpha=edge_alpha, linewidth=edge_linewidth)

#     ax.set_xlabel(r"$x$")
#     ax.set_ylabel(r"$y$")
#     ax.set_zlabel(r"$z$")
#     _set_equal_aspect_3d(ax)
#     plt.tight_layout()

#     if save_path:
#         plt.savefig(save_path, bbox_inches="tight")
#     if show:
#         plt.show()
#     else:
#         plt.close(fig)

# visualize_edges_for_one_sample(
#     data,
#     section_group='W16X100',           # or pass a specific key like "W14X61"
#     sample_idx=0,                 # which sample within that section
#     save_path="figures/graph_edges_example.pdf",
#     show=True
# )