## train copy.ipynb を改変し、GMSNetを三次元に拡張したい

各種の関数は２次元用に作られているため、３次元用に改変する必要がある

# シフト切り捨てをやめてみる
# トレーニング完了後のトレーニングメッシュを表示してみる
# 各epochの終了時にメッシュを可視化してみる

In [1]:
import copy
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.nn import GCNConv
from torch_geometric.nn import GraphNorm, LayerNorm
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch.nn import Linear, InstanceNorm2d, InstanceNorm1d, Conv1d, ReLU, Tanh
from torch.optim.lr_scheduler import ReduceLROnPlateau
import random
import matplotlib.pyplot as plt
from torch_geometric.transforms import FaceToEdge
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from itertools import combinations
import vtk
import glob
from tqdm import tqdm
# 計算を軽くするためのライブラリ
from torch.cuda import empty_cache
import gc               # メモリリークを防ぐ

from torch import nn
import os
import sys

import datetime
import pyvista as pv
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

In [2]:
torch.cuda.is_available()

True

In [3]:
num_train_epoch = 30
train_data_path = "/mnt/volume_data_folder/"
save_fig_path = "/mnt/volume_training/"
vtk_save_path = "/mnt/vtk_output/"

# Dataの準備

In [4]:
class Dataset(Dataset):
    def __init__(self, num_files):
        None

class Mesh(Dataset):
    def __init__(self):
        self.coordinates = None
        self.cells = None
        self.faces = None

class Polygon(Dataset):
    def __init__(self, num_node, num_cells):
        self.parent_meshID = None
        self.coordinates = torch.zeros(num_node, 3)
        self.cells = torch.zeros(num_cells, 4)
        self.edge_index = None
        self.edges = None
        self.d = None
        self.Cx = None
        self.Cy = None
        self.Cz = None
        self.x_min = None
        self.y_min = None
        self.z_min = None

    def to(self, device):
        # GPUに移動可能なtorch.Tensor変数を移動
        self.coordinates = self.coordinates.to(device)
        self.cells = self.cells.to(device)
        if self.edge_index is not None:
            self.edge_index = self.edge_index.to(device)
        if self.edges is not None:
            self.edges = self.edges.to(device)
        if self.d is not None:
            self.d = self.d.to(device)
        if self.Cx is not None:
            self.Cx = self.Cx.to(device)
        if self.Cy is not None:
            self.Cy = self.Cy.to(device)
        if self.Cz is not None:
            self.Cz = self.Cz.to(device)
        if self.x_min is not None:
            self.x_min = self.x_min.to(device)
        if self.y_min is not None:
            self.y_min = self.y_min.to(device)
        if self.z_min is not None:
            self.z_min = self.z_min.to(device)

class PolygonID(Dataset):
    def __init__(self, nodeID):
        self.nodeID = nodeID
        # self.parent_meshID = None

class Polygon_data(Dataset):
    def __init__(self, polygonID, meshID, nodeID, num_cells):
        self.polygonID = polygonID
        self.meshID = meshID
        self.nodeID = nodeID
        self.num_cells = num_cells

class Minibatch(Dataset):
    def __init__(self):
        self.x = None
        self.edge_index = None
        self.batch = None

class EarlyStopping:
    def __init__(self, patience, min_delta):
        self.patience = patience      # 損失が改善しないエポック数の上限
        self.min_delta = min_delta    # 損失が改善とみなす最小値
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

In [5]:
def create_mesh_polygonID_data(vtk_file_path, polygonID_list, poly_count, polygon_dict, mesh_index):
    data = pv.read(vtk_file_path)
    mesh = Mesh()
    mesh.coordinates = torch.tensor(data.points, dtype=torch.float32)
    cells = data.cells
    celltypes = data.celltypes
    tetra_indices = torch.where(torch.tensor(celltypes) == 10)[0]
    tetra_cells = []
    face_cells = []
    cell_offset = 0

    for idx in tetra_indices:
        num_points = cells[cell_offset]
        if num_points == 4:
            tetra_cells.append(cells[cell_offset + 1 : cell_offset + 1 + num_points])
        elif num_points == 3:
            face_cells.append(cells[cell_offset + 1 : cell_offset + 1 + num_points])
            
        cell_offset += num_points + 1

    mesh.cells = torch.tensor(tetra_cells, dtype=torch.long)
    # print("cells:", mesh.cells)

    mesh.faces = torch.tensor(face_cells, dtype=torch.long)
# ------------ mesh のデータを取得完了 -------------------------

# ---------------- 自由点かどうかの判定完了 ------------------------
    
    boundary_faces = data.extract_feature_edges(boundary_edges=True, manifold_edges=False)  #境界面の取得
    boundary_points_indices = boundary_faces.point_data["vtkOriginalPointIds"]              #境界面に存在する頂点番号の取得
    boundary_points = set(boundary_points_indices)                                        #境界面の頂点番号をリストに変換
    
    count= 0
    for pointId in range(mesh.coordinates.size(0)):       # pointId:自由点の頂点番号
        if pointId in boundary_points:
            continue
        else:
            poly_count = poly_count + 1
            # print("pointId:", pointId)
        mask = (mesh.cells == pointId)
        if mask.any():
            count = torch.sum(mask).item()
        num_node = count + 1
        num_tetra = count
        polygon_number = poly_count - 1 



        
        polygon_i = f"polygon_{polygon_number}"
        # print(polygon_i)
        polygon_i = Polygon(num_node, num_tetra)
        
        element_to_check = pointId
        polygon_i.cells = mesh.cells[(mesh.cells == element_to_check).any(dim=1)]
        # print(polygon_i.face)

        polygon_i.nodeId = set()
        for i in range(len(polygon_i.cells)):
            polygon_i.nodeId.add(polygon_i.cells[i, 0].item())
            polygon_i.nodeId.add(polygon_i.cells[i, 1].item())
            polygon_i.nodeId.add(polygon_i.cells[i, 2].item())
            polygon_i.nodeId.add(polygon_i.cells[i, 3].item())
        sorted_nodeId = sorted(polygon_i.nodeId)
        polygon_i.nodeID = torch.tensor(list(sorted_nodeId))
        
        point_id_index = (polygon_i.nodeID == pointId).nonzero().item()

        value_to_move = polygon_i.nodeID[point_id_index]
        polygon_i.nodeID = torch.cat((value_to_move.unsqueeze(0), polygon_i.nodeID[polygon_i.nodeID != pointId]))
        # print(polygon_i.nodeID)
        setattr(polygon_i, "parent_meshID", mesh)
        polygonID_list.append(f"polygon_{polygon_number}")

        keyword = f"polygon_{polygon_number}"
        valiables = (f"mesh_{mesh_index}", polygon_i.nodeID, len(polygon_i.cells))
        polygon_dict[keyword] = valiables


    # --------- polygon.nodeID の取得完了 -------------
    return mesh, polygonID_list, poly_count, polygon_dict

# Dataset の作成

In [6]:
def create_mesh_polygon_dataset(vtk_files):
    num_vtk_files = len(vtk_files)
    polygonID_list = []
    mesh_data_list = []
    poly_count = 0
    polygon_dict = {}
    # ファイルに順にアクセスする
    for i in range(num_vtk_files):
        # print("File Name:", vtk_files[i])
        mesh, polygonID_list, poly_count, polygon_dict = create_mesh_polygonID_data(vtk_files[i], polygonID_list, poly_count, polygon_dict, i)
        mesh_data_list.append(mesh)
    return mesh_data_list, polygonID_list, polygon_dict


In [7]:
# 以下、i はpolygon番号で座標とセル情報を取得することができる
def data_getter(polygonID, num_mesh_data_list, mesh_data_list, polygon_data_list):
    
    polygon_meshID = int(polygon_data_list[polygonID].meshID.split("_")[-1])
    mesh = mesh_data_list[polygon_meshID]
    
    num_node = len(polygon_data_list[polygonID].nodeID)
    num_tetra = polygon_data_list[polygonID].num_cells
    polygon_i = Polygon(num_node, num_tetra)

    polygon_i.coordinates = mesh.coordinates[polygon_data_list[polygonID].nodeID]     # polygonの座標

    
    element_to_check = polygon_data_list[polygonID].nodeID[0]
    polygon_i.tetra = mesh.cells[(mesh.cells == element_to_check).any(dim=1)]

    indices = torch.nonzero(torch.isin(polygon_i.tetra, polygon_data_list[polygonID].nodeID))
    for idx in range(indices.size(0)):
        row_idx, col_idx = indices[idx]
        value_to_replace = polygon_i.tetra[row_idx, col_idx]
        polygon_i.tetra[row_idx, col_idx] = (polygon_data_list[polygonID].nodeID == value_to_replace).nonzero().item()
    polygon_i.cells = polygon_i.tetra.long()

    # 各行の三角形からエッジを抽出してedge_indexを構築
    edges = torch.cat([ polygon_i.cells[:, [0, 1]],
                        polygon_i.cells[:, [0, 2]],
                        polygon_i.cells[:, [0, 3]],
                        polygon_i.cells[:, [1, 2]],
                        polygon_i.cells[:, [1, 3]],
                        polygon_i.cells[:, [2, 3]]], dim=0)

    # エッジのインデックスをソートして重複を削除
    edge_index = torch.sort(edges, dim=1).values
    edge_index = torch.tensor(sorted(edge_index.numpy().tolist())).unique(dim=0)
    polygon_i.edge_index = torch.transpose(edge_index, 0, 1)
    # print(polygon_i.edge_index)
    return polygon_i



# メッシュをプロットする関数

In [8]:
def plot_mesh(mesh, title):
    
    coordinates = mesh.coordinates.clone().detach()
    edge_index = mesh.edge_index.clone().detach()

    # 3Dプロットの設定
    fig = plt.figure(figsize=(10, 7))
    ax = fig.add_subplot(111, projection='3d')

    # 頂点をプロット
    ax.scatter(coordinates[:, 0], coordinates[:, 1], coordinates[:, 2], c='r', label='Vertices')

    # エッジをプロット
    for i in range(edge_index.shape[1]):  # 各エッジをループ
        start_idx = edge_index[0, i]
        end_idx = edge_index[1, i]
        start_point = coordinates[start_idx]
        end_point = coordinates[end_idx]
        ax.plot(
            [start_point[0], end_point[0]],
            [start_point[1], end_point[1]],
            [start_point[2], end_point[2]],
            c='b'
        )

    # 軸ラベルの設定
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')

    # タイトルと凡例
    ax.set_title('3D Mesh Visualization')
    ax.legend()

    # 表示
    plt.show()


In [9]:
def save_mesh(mesh, title):

    vertices = mesh.coordinates
    faces = mesh.cells
    fig = plt.figure()
    ax = fig.add_subplot(111, aspect="equal")

    # 描画するメッシュの頂点をプロット
    # ax.plot(vertices[:,0], vertices[:,1], 'bo')  # 頂点を青色の点でプロット
    # ax.plot(vertices[:,0], vertices[:,1], 'k-')  # 辺を黒色の線でプロット

    # 各三角形をプロット
    for face in faces:
        v0, v1, v2 = vertices[face]
        v0_np = v0.detach().numpy()
        v1_np = v1.detach().numpy()
        v2_np = v2.detach().numpy()
        ax.plot([v0_np[0], v1_np[0], v2_np[0], v0_np[0]], [v0_np[1], v1_np[1], v2_np[1], v0_np[1]], 'b-')  # 三角形を赤色の線でプロット

    ax.set_title(title)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.axhline(0, color="black", linewidth=0.001)
    ax.axvline(0, color="black", linewidth=0.001)

    # plt.xlim(-300, 150)
    # plt.ylim(-200, 1400)
    plt.savefig(f"{save_fig_path}{title}.png", format="png")
    # plt.show()
    plt.close()

# meshデータからvtkファイルを出力する関数

In [10]:
def vtk_output(mesh, title):
    vertices = mesh.coordinates  # Assumed to be an Nx3 array
    cells = mesh.cells  # Assumed to be an Mx3 array
    faces = mesh.faces
    num_vertices = len(vertices)
    num_cells = len(cells)
    num_faces = len(faces)

    filename = f"{vtk_save_path}{title}.vtk"
    with open(filename, "w") as f:
        # VTK Header
        f.write("# vtk DataFile Version 2.0\n")
        f.write(f"{title}\n")
        f.write("ASCII\n")
        f.write("DATASET POLYDATA\n")

        # Write vertices
        f.write(f"POINTS {num_vertices} float\n")
        for vertex in vertices:
            f.write(f"{vertex[0]} {vertex[1]} {vertex[2]}\n")

        # Write cells (polygons)
        f.write(f"\nPOLYGONS {num_cells} {num_cells * 5}\n")
        for cell in cells:
            f.write(f"4 {cell[0]} {cell[1]} {cell[2]} {cell[3]}\n")

        f.write(f"\nPOLYGONS {num_faces} {num_faces * 4}\n")
        for face in faces:
            f.write(f"3 {face[0]} {face[1]} {face[2]}\n")

        # Write cell types
        f.write(f"\nCELL_TYPES {num_cells + num_faces}\n")
        for _ in range(num_cells):
            f.write("10\n")  
        for _ in range(num_faces):
            f.write("5\n")

    print(f"VTK file saved as {filename}")
       



# Normalization

In [11]:
def normalization(polygon):
    vertices = polygon.coordinates
    normalized_vertices = vertices.clone()
    centered_vertices = vertices.clone()
    # print(vertices)

    max_x = torch.max(vertices[:,0])
    min_x = torch.min(vertices[:,0])
    max_y = torch.max(vertices[:,1])
    min_y = torch.min(vertices[:,1])
    max_z = torch.max(vertices[:,2])
    min_z = torch.min(vertices[:,2])

    diffs = vertices.unsqueeze(1) - vertices.unsqueeze(0)
    length = torch.norm(diffs, dim=2)
    polygon.d = torch.max(length)
    polygon.x_min = min_x
    polygon.y_min = min_y
    polygon.z_min = min_z

    normalized_vertices = (vertices - torch.tensor([polygon.x_min, polygon.y_min, polygon.z_min])) / polygon.d

    
    polygon.Cx = torch.tensor(normalized_vertices[0,0].item())
    polygon.Cy = torch.tensor(normalized_vertices[0,1].item())
    polygon.Cz = torch.tensor(normalized_vertices[0,2].item())

    centered_vertices = normalized_vertices - torch.tensor([polygon.Cx, polygon.Cy, polygon.Cz])
    polygon.coordinates = centered_vertices
    
    # print("Normalized polygon:", vertices)

    return polygon


    


# denormalization

In [12]:
def denormalization(polygon):
    vertices = polygon.coordinates.clone().to(device)
    shifted_vertices = vertices.clone().to(device)
    denormalized_vertices = vertices.clone().to(device)
    
    shifted_vertices = vertices + torch.tensor([polygon.Cx, polygon.Cy, polygon.Cz]).to(device)
        

    denormalized_vertices = polygon.d * shifted_vertices + torch.tensor([polygon.x_min, polygon.y_min, polygon.z_min]).to(device)
    polygon.coordinates = denormalized_vertices
    return polygon


# MetricLoss

In [13]:
# ログをファイルに保存するようにしている
import sys

import logging

logging.basicConfig(
    filename='metric_loss.log',
    level=logging.DEBUG, 
    format='%(message)s'
)

# logger = logging.getLogger(__name__)

class MetricLoss:
    def select_vertices(self, vertices, cell):
        v0 = vertices[cell[0]].clone().to(device)
        v1 = vertices[cell[1]].clone().to(device)
        v2 = vertices[cell[2]].clone().to(device)
        v3 = vertices[cell[3]].clone().to(device)
        if torch.isnan(v0).any() or torch.isnan(v1).any() or torch.isnan(v2).any() or torch.isnan(v3).any():
            print("Error in vertices")
            
        return v0, v1, v2, v3

    def edge_length(self, v0, v1, v2, v3):
        l1 = torch.sqrt(torch.sum((v0 - v1)**2))
        l2 = torch.sqrt(torch.sum((v0 - v2)**2))
        l3 = torch.sqrt(torch.sum((v0 - v3)**2))
        l4 = torch.sqrt(torch.sum((v1 - v2)**2))
        l5 = torch.sqrt(torch.sum((v1 - v3)**2))
        l6 = torch.sqrt(torch.sum((v2 - v3)**2))
        if torch.isnan(l1) or torch.isnan(l2) or torch.isnan(l3):
            print("Error in edge_length")
        return l1, l2, l3, l4, l5, l6

    def cell_volume(self, polygon, v0, v1, v2, v3):
        
        M = torch.tensor([
                [v1[0]-v0[0], v1[1]-v0[1], v1[2]-v0[2]],
                [v2[0]-v0[0], v2[1]-v0[1], v2[2]-v0[2]],
                [v3[0]-v0[0], v3[1]-v0[1], v3[2]-v0[2]]
            ], dtype=torch.float64)

        # 行列式を計算
        det_M = torch.det(M)

        # 体積を計算
        volume = torch.tensor(torch.abs(det_M)/ 6)
        return volume.item()

    def compute_loss(self, polygon, vertices, cell, dx):
        v0, v1, v2, v3 = self.select_vertices(vertices, cell)

        l1, l2, l3, l4, l5, l6 = self.edge_length(v0, v1, v2, v3)
        length_tensor = torch.tensor([l1, l2, l3, l4, l5, l6], dtype=torch.float32)
        v = self.cell_volume(polygon,  v0, v1, v2, v3)
        loss = 1-((36.0*torch.sqrt(torch.tensor(2.))*v)/(torch.sum(length_tensor ** 3)))
        
        return loss

    def __call__(self, polygon, dx=None):
        vertices = polygon.coordinates
        cells = polygon.cells
        loss = torch.tensor(0.0)
        for cell in cells:
            loss = loss + self.compute_loss(polygon, vertices, cell, dx)
        
        metric_loss = loss/len(cells)
        if torch.isnan(metric_loss):
            print("coordinates:", vertices)
            print("cells:", cells)  
            plot_mesh(polygon, "Current Polygon")
            print("Error: Loss value is NaN. Exiting program.")
            sys.exit(1)
        return metric_loss
    
    
def print_grad(grad):
    print(grad)

# meshデータからq_hatを求める関数

In [14]:
def calculate_q_hat(mesh):
    vertices = mesh.coordinates
    faces = mesh.cells
    r_list = []
    alpha_list = []
    beta_list = []

    for face in faces:
        # a(最小角)と b(最大角)を求める

        angles = []
        v0, v1, v2 = m_loss.select_vertices(vertices, face)
        l1, l2, l3 = m_loss.edge_length(v0, v1, v2)

        # 余弦定理から各角度の余弦値を計算
        cos_alpha = (l2**2 + l3**2 - l1**2) / (2*l2*l3)
        cos_beta = (l1**2 + l3**2 - l2**2) / (2*l1*l3)
        cos_gamma = (l1**2 + l2**2 - l3**2) / (2*l1*l2)
        # 余弦値から角度を計算して個度法に変換
        alpha = torch.acos(cos_alpha) * 180 / np.pi
        beta = torch.acos(cos_beta) * 180 / np.pi
        gamma = torch.acos(cos_gamma) * 180 / np.pi

        angles.append(alpha)
        angles.append(beta)
        angles.append(gamma)

        min_angle = min(angles)
        max_angle = max(angles)

        alpha_list.append(min_angle)
        beta_list.append(max_angle)

    # 1/q = r を求める
    for i in range(len(test_polygonID_list)):
        polygonID = i 
        polygon = data_getter(polygonID, 0, test_mesh_data_lists, test_polygon_data_list)

        r = 1 - m_loss(polygon) 
        r_list.append(r)

    a_mean = sum(alpha_list) / len(alpha_list)
    a_min = min(alpha_list)
    b_mean = sum(beta_list) / len(beta_list)
    b_max = max(beta_list)
    r_mean = sum(r_list) / len(r_list)
    r_min = min(r_list)

    q_hat = (((a_mean + a_min + 120 - b_max - b_mean)/60) + r_mean + r_min) / 6

    return q_hat

    

def calculate_qhat(mesh):
    vertices = mesh.coordinates
    faces = mesh.cells
    r_list = []
    alpha_list = []
    beta_list = []

    for face in faces:
        # a(最小角)と b(最大角)を求める

        angles = []
        v0, v1, v2 = m_loss.select_vertices(vertices, face)
        l1, l2, l3 = m_loss.edge_length(v0, v1, v2)

        # 余弦定理から各角度の余弦値を計算
        cos_alpha = (l2**2 + l3**2 - l1**2) / (2*l2*l3)
        cos_beta = (l1**2 + l3**2 - l2**2) / (2*l1*l3)
        cos_gamma = (l1**2 + l2**2 - l3**2) / (2*l1*l2)
        # 余弦値から角度を計算して個度法に変換
        alpha = torch.acos(cos_alpha) * 180 / np.pi
        beta = torch.acos(cos_beta) * 180 / np.pi
        gamma = torch.acos(cos_gamma) * 180 / np.pi

        angles.append(alpha)
        angles.append(beta)
        angles.append(gamma)

        min_angle = min(angles)
        max_angle = max(angles)

        alpha_list.append(min_angle)
        beta_list.append(max_angle)

        # 1/q = r を求める
        v0, v1, v2 = m_loss.select_vertices(vertices, face)
        l1, l2, l3 = m_loss.edge_length(v0, v1, v2)
        s = 0.5*(l1 + l2 + l3)
        temp = s*(s-l1)*(s-l2)*(s-l3)
        loss = 1-(4.0*torch.sqrt(torch.tensor(3.))*s)/(l1**2 + l2**2 + l3**2)
        r_list.append(1./loss)



    a_mean = sum(alpha_list) / len(alpha_list)
    a_min = min(alpha_list)
    b_mean = sum(beta_list) / len(beta_list)
    b_max = max(beta_list)
    r_mean = sum(r_list) / len(r_list)
    r_min = min(r_list)

    q_hat = (((a_mean + a_min + 120 - b_max - b_mean)/60) + r_mean + r_min) / 6

    return q_hat

    
    

スターポリゴンの中から外側に自由点が移動したときに自由点の移動量を半分にしてもう一度外に行っていないか検証する
自由点が外に行かないことを確認したあとのスターポリゴンを返す

In [15]:
def check(polygon):
    # print("polygonID:", polygonID)
    vertices = polygon.coordinates
    
    edge_index = polygon.edge_index
    
    return_value = True
    while return_value == True:   

        for i in range(1, len(vertices[:,0])):
            point1 = torch.tensor([0.0, 0.0])
            point2 = vertices[0]
            point3 = vertices[i]
            
            pos_i = torch.where(edge_index[0] == i)
            pos_i = pos_i[0]
            # print("edge_index[0]", edge_index[0])
            
            for j in range(len(pos_i)):
                if edge_index[1, pos_i[j]] == 0:
                    continue
                else:
                    val_pos_i = edge_index[1, pos_i[j]]
                    break

            point4 = vertices[val_pos_i]


            a1 = 0
            b1 = 0
            a2 = 0
            b2 = 0
            check1 = 0
            check2 = 0
            check3 = 0
            check4 = 0
            x1 = point1[0]
            y1 = point1[1]
            x2 = point2[0]
            y2 = point2[1]
            x3 = point3[0]
            y3 = point3[1]
            x4 = point4[0]
            y4 = point4[1]
            a1 = (y1 - y2)/(x1-x2)
            b1 = y1 - (a1*x1)
            a2 = (y3 - y4)/(x3-x4)
            b2 = y3 - (a2*x3)
            check1 = (a1*x3) - y3 + b1 
            check2 = (a1*x4) - y4 + b1    # point1,2を通る直線に対してpoint3,4を結ぶ線分が交差しているか
            check3 = (a2*x1) - y1 + b2
            check4 = (a2*x2) - y2 + b2    # point3,4を通る直線に対してpoint1,2を結ぶ線分が交差しているか
            # print("1:",check1,"2:",check2,"3:",check3,"4:",check4)
            del a1, a2, b1, b2, x1, x2, x3, x4, y1, y2, y3, y4 

            if (check1*check2) <= 0.1 and (check3*check4) <= 0.1 :
                return_value = True
                # print("Out_of_StarPolygon")
                vertices[0] = 0.5*vertices[0]
                polygon.coordinates[0] = vertices[0]
                break
            else:
                return_value = False
                continue       
            
        
    # plot_mesh(polygon, "polygon_checked")
               
    return polygon

# Model 隠れ層のノード数は何にするか未定

In [16]:
class GMSNet(torch.nn.Module):
    
    def __init__(self, input_dim, feature_dim, hidden_channels):   # モデルの初期化
        
        super(GMSNet, self).__init__()
        
        self.fc1 = Linear(input_dim, feature_dim)
        self.gcn = GCNConv(feature_dim, feature_dim)
        self.graph_norm = GraphNorm(feature_dim)
        self.instance_norm = nn.InstanceNorm1d(feature_dim)
        self.fc2 = nn.Linear(feature_dim, feature_dim)
        self.fc3 = nn.Linear(feature_dim, input_dim)


        # Weight initialization
        self.apply(self._init_weights)                                  # 重みの初期化

    def _init_weights(self, m):     # 線形層と畳み込み層の重みをKaiming正規化で初期化し、バイアスをゼロで初期化する
        if isinstance(m, (nn.Linear, nn.Conv1d)):
            nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    def forward(self, x, edge_index):       
        
        x = self.fc1(x)
        x = self.graph_norm(x)
        x = F.relu(x)
        x = self.gcn(x, edge_index) + x
        x = self.fc2(x)
        # x = self.instance_norm(x)
        x = self.fc3(x)
        x = 0.01 * x 
        x = x[0]

        
        
        return x

# Main

In [17]:
# フォルダ内のすべてのvtkファイルにアクセスする
train_vtk_files = glob.glob(f"{train_data_path}*.vtk")
train_vtk_filenames = [file.split('/')[-1].split('.')[0] for file in train_vtk_files]

num_train_mesh = len(train_vtk_files)
print("num_train_mesh:", num_train_mesh)
train_mesh_data_list, train_polygonID_list, train_polygon_dict = create_mesh_polygon_dataset(train_vtk_files)



    
# ポリゴンデータを格納するリストを作成
train_polygon_data_list = []

for i in range(len(train_polygonID_list)):
    polygonID = f"polygon_{i}"
    meshID = train_polygon_dict[f"polygon_{i}"][0]
    nodeID = train_polygon_dict[f"polygon_{i}"][1]
    num_cells = train_polygon_dict[f"polygon_{i}"][2]
    polygon_data = Polygon_data(polygonID, meshID, nodeID, num_cells)
    train_polygon_data_list.append(polygon_data)



num_train_mesh: 1


  mesh.cells = torch.tensor(tetra_cells, dtype=torch.long)


In [18]:
type(train_polygon_data_list[0])

__main__.Polygon_data

In [19]:
torch.manual_seed(42)
train_data_loader = DataLoader(train_polygonID_list, batch_size=64*num_train_mesh, shuffle=True)

In [20]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = GMSNet(input_dim=3, feature_dim=64, hidden_channels=64)
model.to(device)
print(model)
optimizer = torch.optim.Adam(model.parameters(),lr=0.01)
criterion = MetricLoss()
early_stopping = EarlyStopping(patience=5, min_delta=1e-6)

cuda
GMSNet(
  (fc1): Linear(in_features=3, out_features=64, bias=True)
  (gcn): GCNConv(64, 64)
  (graph_norm): GraphNorm(64)
  (instance_norm): InstanceNorm1d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  (fc2): Linear(in_features=64, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=3, bias=True)
)


In [21]:
# writer = SummaryWriter("logs")
# 学習率を調整するスケジューラの設定
scheduler = ReduceLROnPlateau(optimizer, mode="min", patience=2, factor=0.95, verbose=True)

loss_list = []

def train(device):
    model.train()

    # for name, param in model.GNorm.named_parameters():
    #     print(f'{name}:{param.data}')

    model.to(device)
    
    temp = 0
    ddd = 0
    
    for step, data in enumerate(train_data_loader):
        
        optimizer.zero_grad()
        metric_loss = torch.tensor(0.0, requires_grad=True)
        metric_loss.to(device)
        dx_list = []
        
        for i in range(len(data)):

            polygonID = int(data[i].split("_")[-1])
            # print("polygonID:", polygonID)
            polygon = data_getter(polygonID, 0, train_mesh_data_list, train_polygon_data_list)
            # plot_mesh(polygon, "corrent_polygon")
            # vtk_output(polygon, "corrent_polygon")
            # print("polygon_origin:", polygon.coordinates)
            polygon = normalization(polygon)           
            
            polygon.to(device)
            # print("normalized polygon:", polygon.coordinates)
            x = polygon.coordinates.clone().to(device)
            edge_index = polygon.edge_index.clone().to(device)
            out = model(x, edge_index)
            # print("out:", out)
            polygon.coordinates[0] = polygon.coordinates[0] + out
            # polygon = check(polygon)
            
            
            
            # logger.info("epoch: {:04}, polygonID: {:04}".format(epoch, polygonID))
            # logger.info("before")
            with torch.no_grad():
                l = criterion(polygon)
                # print("metric_loss:", l)

            
            # logger.info("")
            # print(l)
            metric_loss = metric_loss + l
            polygon = denormalization(polygon)
            # print("changed polygon:", polygon.coordinates)
                
            polygon_meshID = int(train_polygon_data_list[polygonID].meshID.split("_")[-1])
            mesh = train_mesh_data_list[polygon_meshID]
            mesh.coordinates[train_polygon_data_list[polygonID].nodeID[0]] = polygon.coordinates[0]
            
            
        loss = metric_loss/len(data)
        
        

            
        ddd = ddd + len(data)
        temp = temp + loss
        # print("    Loss:", loss.item(), "   PolygonID:", polygonID)
        
        loss.backward()
        optimizer.step()

    loss_ave = temp/ddd
    writer.add_scalar("loss", loss_ave, epoch)       
    print(loss_ave, epoch)
    loss_list.append(temp/ddd)

In [22]:
# for i in range(num_train_mesh):
    # plot_mesh(train_mesh_data_list[i], f"original_{train_vtk_filenames[i]}")

# 最終的な最適化したメッシュを生成してvtkファイルで出力する

In [23]:
# TensorBoard用のログディレクトリを指定
writer = SummaryWriter(log_dir="/mnt/log/test" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "train")

In [24]:
# for epoch in tqdm(range(num_train_epoch), position=1):
#     #print("epoch:", epoch)
#     train(device)

In [25]:
epoch = 0
import torch
torch.autograd.set_detect_anomaly(True)

# for j in range(num_train_mesh):                 # 元のメッシュを保存する
#         save_mesh(train_mesh_data_list[j], f"trained_{train_vtk_filenames[j]}_{epoch}")
# torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)            
for epoch in tqdm(range(num_train_epoch), position=1):
    #print("epoch:", epoch)
    train(device)
    # print(epoch)

    # for i in range(num_train_mesh):
    #     plot_mesh(train_mesh_data_list[i], f"{i}")
    for j in range(num_train_mesh):             # 最適化したメッシュを保存する
        vtk_output(train_mesh_data_list[j], f"trained_{train_vtk_filenames[j]}_{epoch+1}")

  volume = torch.tensor(torch.abs(det_M)/ 6)


tensor(0.0028, grad_fn=<DivBackward0>) 0




VTK file saved as /mnt/vtk_output/trained_volume_1.vtk
tensor(0.0028, grad_fn=<DivBackward0>) 1




VTK file saved as /mnt/vtk_output/trained_volume_2.vtk
tensor(0.0028, grad_fn=<DivBackward0>) 2




VTK file saved as /mnt/vtk_output/trained_volume_3.vtk
tensor(0.0028, grad_fn=<DivBackward0>) 3




VTK file saved as /mnt/vtk_output/trained_volume_4.vtk
tensor(0.0028, grad_fn=<DivBackward0>) 4




VTK file saved as /mnt/vtk_output/trained_volume_5.vtk
tensor(0.0028, grad_fn=<DivBackward0>) 5




VTK file saved as /mnt/vtk_output/trained_volume_6.vtk
tensor(0.0028, grad_fn=<DivBackward0>) 6




VTK file saved as /mnt/vtk_output/trained_volume_7.vtk
tensor(0.0028, grad_fn=<DivBackward0>) 7




VTK file saved as /mnt/vtk_output/trained_volume_8.vtk
tensor(0.0028, grad_fn=<DivBackward0>) 8




VTK file saved as /mnt/vtk_output/trained_volume_9.vtk
tensor(0.0028, grad_fn=<DivBackward0>) 9




VTK file saved as /mnt/vtk_output/trained_volume_10.vtk
tensor(0.0028, grad_fn=<DivBackward0>) 10




VTK file saved as /mnt/vtk_output/trained_volume_11.vtk
tensor(0.0028, grad_fn=<DivBackward0>) 11




VTK file saved as /mnt/vtk_output/trained_volume_12.vtk
tensor(0.0028, grad_fn=<DivBackward0>) 12




VTK file saved as /mnt/vtk_output/trained_volume_13.vtk
tensor(0.0028, grad_fn=<DivBackward0>) 13




VTK file saved as /mnt/vtk_output/trained_volume_14.vtk
tensor(0.0028, grad_fn=<DivBackward0>) 14




VTK file saved as /mnt/vtk_output/trained_volume_15.vtk
tensor(0.0028, grad_fn=<DivBackward0>) 15




VTK file saved as /mnt/vtk_output/trained_volume_16.vtk
tensor(0.0028, grad_fn=<DivBackward0>) 16




VTK file saved as /mnt/vtk_output/trained_volume_17.vtk
tensor(0.0028, grad_fn=<DivBackward0>) 17




VTK file saved as /mnt/vtk_output/trained_volume_18.vtk
tensor(0.0028, grad_fn=<DivBackward0>) 18




VTK file saved as /mnt/vtk_output/trained_volume_19.vtk
tensor(0.0028, grad_fn=<DivBackward0>) 19




VTK file saved as /mnt/vtk_output/trained_volume_20.vtk
tensor(0.0028, grad_fn=<DivBackward0>) 20




VTK file saved as /mnt/vtk_output/trained_volume_21.vtk
tensor(0.0028, grad_fn=<DivBackward0>) 21




VTK file saved as /mnt/vtk_output/trained_volume_22.vtk
tensor(0.0028, grad_fn=<DivBackward0>) 22




VTK file saved as /mnt/vtk_output/trained_volume_23.vtk
tensor(0.0028, grad_fn=<DivBackward0>) 23




VTK file saved as /mnt/vtk_output/trained_volume_24.vtk
tensor(0.0028, grad_fn=<DivBackward0>) 24




VTK file saved as /mnt/vtk_output/trained_volume_25.vtk
tensor(0.0029, grad_fn=<DivBackward0>) 25




VTK file saved as /mnt/vtk_output/trained_volume_26.vtk
tensor(0.0029, grad_fn=<DivBackward0>) 26




VTK file saved as /mnt/vtk_output/trained_volume_27.vtk
tensor(0.0029, grad_fn=<DivBackward0>) 27




VTK file saved as /mnt/vtk_output/trained_volume_28.vtk
tensor(0.0029, grad_fn=<DivBackward0>) 28




VTK file saved as /mnt/vtk_output/trained_volume_29.vtk
tensor(0.0029, grad_fn=<DivBackward0>) 29


100%|██████████| 30/30 [1:40:04<00:00, 200.15s/it]

VTK file saved as /mnt/vtk_output/trained_volume_30.vtk





In [26]:

writer.close()


In [27]:
print(model)

GMSNet(
  (fc1): Linear(in_features=3, out_features=64, bias=True)
  (gcn): GCNConv(64, 64)
  (graph_norm): GraphNorm(64)
  (instance_norm): InstanceNorm1d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  (fc2): Linear(in_features=64, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=3, bias=True)
)


メモリの開放をする

In [28]:
torch.save(model.state_dict(), 'model_weights.pth')

gc.collect()

if torch.cuda.is_available():
    torch.cuda.empty_cache()