In [19]:
import torch
from torch.utils.data import Dataset
from torch_geometric.nn import GCNConv
from torch_geometric.nn.norm import GraphNorm
from torch.nn import Linear, InstanceNorm2d
import random
import matplotlib.pyplot as plt
from torch_geometric.transforms import FaceToEdge
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from itertools import combinations
import vtk
import glob

In [20]:
class Dataset(Dataset):
    def __init__(self, num_files):
        None

class Mesh(Dataset):
    def __init__(self):
        self.coordinates = None
        self.faces = None

class Polygon(Dataset):
    def __init__(self, num_node, num_face):
        self.parent_meshID = None
        self.coordinates = torch.zeros(num_node, 2)
        self.faces = torch.zeros(num_face, 3)
        self.d = None
        self.Cx = None
        self.Cy = None
        self.x_min = None
        self.y_min = None

class PolygonID(Dataset):
    def __init__(self, nodeID):
        self.nodeID = nodeID
        # self.parent_meshID = None

In [21]:
class Polygon_data(Dataset):
    def __init__(self, polygonID, meshID, nodeID):
        self.polygonID = polygonID
        self.meshID = meshID
        self.nodeID = nodeID

In [22]:
def create_mesh_polygonID_data(vtk_file_path, polygonID_list, poly_count, polygon_dict, mesh_index):
    reader = vtk.vtkDataSetReader()
    reader.SetFileName(vtk_file_path)
    reader.Update()

    data = reader.GetOutput()
    
    mesh = Mesh()
    # 座標情報を取得
    points = data.GetPoints()
    num_points = points.GetNumberOfPoints()
    coordinates = torch.zeros(num_points, 3)
    for i in range(num_points):
        coordinates[i] = torch.tensor(points.GetPoint(i))

    mesh.coordinates = coordinates[:, :2]                        # mesh.coordinates を定義

    # 面情報を取得
    polys = data.GetPolys()
    num_polys = polys.GetNumberOfCells()
    mesh.faces = torch.zeros(num_polys, 3, dtype=int)           # mesh.faces を定義

    # 各三角形の情報を取得
    polys.InitTraversal()
    for i in range(num_polys):
        cell = vtk.vtkIdList()
        if polys.GetNextCell(cell) == 0:
            break
        mesh.faces[i] = torch.tensor([cell.GetId(0), cell.GetId(1), cell.GetId(2)])
        
# ------------ mesh のデータを取得完了 -------------------------


    # 各セルの各辺の隣接セル数を調べる
    edge_neighbors = {}
    num_cells = data.GetNumberOfCells()
    for cell_index in range(num_cells):
        cell = data.GetCell(cell_index)
        num_edges = cell.GetNumberOfEdges()

        for edge_index in range(num_edges):
            edge = cell.GetEdge(edge_index)
            edge_points = edge.GetPointIds()

            # 辺を構成する点のインデックスを取得
            point1_id = edge_points.GetId(0)
            point2_id = edge_points.GetId(1)

            # 辺を構成する点のインデックスを照準にソート
            edge_key = (min(point1_id, point2_id), max(point1_id, point2_id))

            # 辺の隣接セル数をカウント
            if edge_key in edge_neighbors:
                edge_neighbors[edge_key] += 1
            else:
                edge_neighbors[edge_key] = 1 

    boundary_edges = []
    # 境界上の辺を特定
    for edge_key, num_neighbors in edge_neighbors.items():
        if num_neighbors == 1:
            boundary_edges.append(edge_key)

    # 境界上の辺を構成する頂点の番号を取得
    boundary_points = set()     # 集合を表すデータ型、順番を持たず、重複した要素は取り除かれる
# ---------------- 自由点かどうかの判定完了 ------------------------
    

    for edge_key in boundary_edges:
        boundary_points.add(edge_key[0])
        boundary_points.add(edge_key[1])
    
    
    for pointId in range(num_points):       # pointId:自由点の頂点番号
        if pointId in boundary_points:
            continue
        else:
            poly_count += 1
            # print("pointId:", pointId)
        mask = (mesh.faces == pointId)
        if mask.any():
            count = torch.sum(mask).item()
        num_node = count + 1
        num_face = count
        polygon_number = poly_count - 1 



        
        polygon_i = f"polygon_{polygon_number}"
        # print(polygon_i)
        polygon_i = Polygon(num_node, num_face)
        
        element_to_check = pointId
        polygon_i.face = mesh.faces[(mesh.faces == element_to_check).any(dim=1)]
        # print(polygon_i.face)

        polygon_i.nodeId = set()
        for i in range(len(polygon_i.face)):
            polygon_i.nodeId.add(polygon_i.face[i, 0].item())
            polygon_i.nodeId.add(polygon_i.face[i, 1].item())
            polygon_i.nodeId.add(polygon_i.face[i, 2].item())
        sorted_nodeId = sorted(polygon_i.nodeId)
        polygon_i.nodeID = torch.tensor(list(sorted_nodeId))
        
        point_id_index = (polygon_i.nodeID == pointId).nonzero().item()

        value_to_move = polygon_i.nodeID[point_id_index]
        polygon_i.nodeID = torch.cat((value_to_move.unsqueeze(0), polygon_i.nodeID[polygon_i.nodeID != pointId]))
        # print(polygon_i.nodeID)
        setattr(polygon_i, "parent_meshID", mesh)
        polygonID_list.append(f"polygon_{polygon_number}")

        keyword = f"polygon_{polygon_number}"
        valiables = (f"mesh_{mesh_index}", polygon_i.nodeID)
        polygon_dict[keyword] = valiables

    # --------- polygon.nodeID の取得完了 -------------
    return mesh, polygonID_list, poly_count, polygon_dict

In [27]:
def create_mesh_polygon_dataset(vtk_files):
    num_vtk_files = len(vtk_files)
    polygonID_list = []
    mesh_data_list = []
    poly_count = 0
    polygon_dict = {}
    # ファイルに順にアクセスする
    for i in range(num_vtk_files):
        print("File Name:", vtk_files[i])
        mesh, polygonID_list, poly_count, polygon_dict = create_mesh_polygonID_data(vtk_files[i], polygonID_list, poly_count, polygon_dict, i)
        mesh_data_list.append(mesh)
    return mesh_data_list, polygonID_list, polygon_dict

# フォルダ内のすべてのvtkファイルにアクセスする
vtk_files = glob.glob("./*.vtk")
num_mesh = len(vtk_files)
print(num_mesh)
mesh_data_list, polygonID_list, polygon_dict = create_mesh_polygon_dataset(vtk_files)
print(polygonID_list)
print(polygon_dict)
print(mesh_data_list)

2
File Name: ./men.vtk
File Name: ./men_re.vtk
['polygon_0', 'polygon_1', 'polygon_2', 'polygon_3', 'polygon_4', 'polygon_5', 'polygon_6', 'polygon_7', 'polygon_8', 'polygon_9', 'polygon_10', 'polygon_11', 'polygon_12', 'polygon_13', 'polygon_14', 'polygon_15', 'polygon_16', 'polygon_17', 'polygon_18', 'polygon_19', 'polygon_20', 'polygon_21', 'polygon_22', 'polygon_23', 'polygon_24', 'polygon_25', 'polygon_26', 'polygon_27', 'polygon_28', 'polygon_29', 'polygon_30', 'polygon_31', 'polygon_32', 'polygon_33', 'polygon_34', 'polygon_35', 'polygon_36', 'polygon_37', 'polygon_38', 'polygon_39', 'polygon_40', 'polygon_41', 'polygon_42', 'polygon_43', 'polygon_44', 'polygon_45', 'polygon_46', 'polygon_47', 'polygon_48', 'polygon_49', 'polygon_50', 'polygon_51', 'polygon_52', 'polygon_53', 'polygon_54', 'polygon_55', 'polygon_56', 'polygon_57', 'polygon_58', 'polygon_59', 'polygon_60', 'polygon_61', 'polygon_62', 'polygon_63', 'polygon_64', 'polygon_65', 'polygon_66', 'polygon_67', 'polygon_6

In [30]:
# ポリゴンデータを格納するリストを作成
polygon_data_list = []

for i in range(len(polygonID_list)):
    polygonID = f"polygon_{i}"
    meshID = polygon_dict[f"polygon_{i}"][0]
    nodeID = polygon_dict[f"polygon_{i}"][1]
    polygon_data = Polygon_data(polygonID, meshID, nodeID)
    polygon_data_list.append(polygon_data)


# 以下、i はpolygon番号で座標と面情報を取得することができる
# for i in range(len(polygon_data_list)):
def data_getter(i):
    polygon_meshID = int(polygon_data_list[i].meshID.split("_")[-1])
    mesh = mesh_data_list[polygon_meshID]
    
    num_node = len(polygon_data_list[i].nodeID)
    num_face = num_node - 1 
    polygon_i = Polygon(num_node, num_face)
    # polygonID = int(polygon_data_list[i].polygonID.split("_")[-1])

    # print(polygon_data_list[i].nodeID)      # polygon に属する頂点の番号

    for j in range(len(polygon_data_list[i].nodeID)):
        polygon_i.coordinates[j] = mesh.coordinates[polygon_data_list[i].nodeID[j]]     # polygonの座標
    # polygon_i.faces = mesh.faces[polygonID]
    # print(polygon_i.coordinates)

    # print(polygon_i.faces)

    # polygon_i.faces を取得するコード
    
    element_to_check = polygon_data_list[i].nodeID[0]
    polygon_i.face = mesh.faces[(mesh.faces == element_to_check).any(dim=1)]

    indices = torch.nonzero(torch.isin(polygon_i.face, polygon_data_list[i].nodeID))
    for idx in range(indices.size(0)):
        row_idx, col_idx = indices[idx]
        value_to_replace = polygon_i.face[row_idx, col_idx]
        polygon_i.face[row_idx, col_idx] = (polygon_data_list[i].nodeID == value_to_replace).nonzero().item()
    polygon_i.faces = polygon_i.face.long()

    return polygon_i



In [31]:
data_loader = DataLoader(polygonID_list, batch_size=32, shuffle=True)

for step, data in enumerate(data_loader):
    print(f"Step {step + 1}:")
    print("==========")
    print(data)
    print(len(data))
    for i in range(len(data)):
        polygonID = int(data[i].split("_")[-1])
        print(polygonID)
        polygon = data_getter(polygonID)
        print(polygon.coordinates)
        print(polygon.faces)
        print("==========")

Step 1:
['polygon_1808', 'polygon_2914', 'polygon_6489', 'polygon_6378', 'polygon_1036', 'polygon_2071', 'polygon_3789', 'polygon_6186', 'polygon_4083', 'polygon_4032', 'polygon_238', 'polygon_1618', 'polygon_3007', 'polygon_2339', 'polygon_1154', 'polygon_4051', 'polygon_2080', 'polygon_176', 'polygon_2744', 'polygon_5827', 'polygon_1263', 'polygon_4790', 'polygon_5676', 'polygon_528', 'polygon_3201', 'polygon_2995', 'polygon_5291', 'polygon_6250', 'polygon_5139', 'polygon_4486', 'polygon_4572', 'polygon_6255']
32
1808
tensor([[-0.2205, -0.1025],
        [-0.2166, -0.1218],
        [-0.2026, -0.1111],
        [-0.2412, -0.1049],
        [-0.2310, -0.1163],
        [-0.2034, -0.0924],
        [-0.2195, -0.0826],
        [-0.2347, -0.0906]])
tensor([[1, 2, 0],
        [2, 5, 0],
        [0, 5, 6],
        [3, 4, 0],
        [0, 4, 1],
        [0, 6, 7],
        [7, 3, 0]])
2914
tensor([[-0.2694, -0.0642],
        [-0.2719, -0.0815],
        [-0.2867, -0.0707],
        [-0.2659, -0.0474]