## GMSNetのモデルを作成する  train,eval,test のデータに分ける
### 損失関数の求め方、必要な値が正しく得られているかなど検証してみる
2024/5/17 時点での完成版

コードをすべて問題なく実行できるようにはなった。
最適化されたメッシュを取り出してくることもでき、しっかり更新できている。
しかし、損失の値が全く減少せず、むしろ増加していてほぼもともとのメッシュのままの出力が最適化メッシュとして出てきてしまう。
ゆえに、最適化できているとは到底言えない。






・各epoch後のメッシュの表示


・学習率を0.1倍から0.95倍にしてみる

・シフト切り捨ては推論時飲み使用する

・バッチサイズを変えてみる

・モデルの出力を0.5倍するなどして移動量を小さくする

・学習時はMetricLossで、最適メッシュを選定するときは違う関数なのはなんで

# シフト切り捨てをやめてみる
# トレーニング完了後のトレーニングメッシュを表示してみる
# 各epochの終了時にメッシュを可視化してみる

In [1]:
import copy
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.nn import GCNConv
from torch_geometric.nn.norm import GraphNorm
import torch_geometric.transforms as T
from torch.nn import Linear, InstanceNorm2d, InstanceNorm1d, Conv1d, ReLU, Tanh
from torch.optim.lr_scheduler import ReduceLROnPlateau
import random
import matplotlib.pyplot as plt
from torch_geometric.transforms import FaceToEdge
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from itertools import combinations
import vtk
import glob
from torch_scatter import scatter_mean
from tqdm import tqdm
# 計算を軽くするためのライブラリ
from torch.cuda import empty_cache
import gc               # メモリリークを防ぐ

from torch import nn
import os
import sys

import datetime

In [2]:
torch.cuda.is_available()

True

In [3]:
num_train_epoch = 20
num_test_epoch = 10
num_trial = 5

# Dataの準備

In [4]:
class Dataset(Dataset):
    def __init__(self, num_files):
        None

class Mesh(Dataset):
    def __init__(self):
        self.coordinates = None
        self.faces = None

class Polygon(Dataset):
    def __init__(self, num_node, num_face):
        self.parent_meshID = None
        self.coordinates = torch.zeros(num_node, 2)
        self.faces = torch.zeros(num_face, 3)
        self.edges = None
        self.d = None
        self.Cx = None
        self.Cy = None
        self.x_min = None
        self.y_min = None

class PolygonID(Dataset):
    def __init__(self, nodeID):
        self.nodeID = nodeID
        # self.parent_meshID = None

class Polygon_data(Dataset):
    def __init__(self, polygonID, meshID, nodeID):
        self.polygonID = polygonID
        self.meshID = meshID
        self.nodeID = nodeID

class Minibatch(Dataset):
    def __init__(self):
        self.x = None
        self.edge_index = None
        self.batch = None

In [5]:
def create_mesh_polygonID_data(vtk_file_path, polygonID_list, poly_count, polygon_dict, mesh_index):
    reader = vtk.vtkDataSetReader()
    reader.SetFileName(vtk_file_path)
    reader.Update()

    data = reader.GetOutput()
    
    mesh = Mesh()
    # 座標情報を取得
    points = data.GetPoints()
    num_points = points.GetNumberOfPoints()
    coordinates = torch.zeros(num_points, 3)
    for i in range(num_points):
        coordinates[i] = torch.tensor(points.GetPoint(i))

    mesh.coordinates = coordinates[:, :2]                        # mesh.coordinates を定義

    # 面情報を取得
    polys = data.GetPolys()
    num_polys = polys.GetNumberOfCells()
    mesh.faces = torch.zeros(num_polys, 3, dtype=int)           # mesh.faces を定義

    # 各三角形の情報を取得
    polys.InitTraversal()
    for i in range(num_polys):
        cell = vtk.vtkIdList()
        if polys.GetNextCell(cell) == 0:
            break
        mesh.faces[i] = torch.tensor([cell.GetId(0), cell.GetId(1), cell.GetId(2)])
        
# ------------ mesh のデータを取得完了 -------------------------


    # 各セルの各辺の隣接セル数を調べる
    edge_neighbors = {}
    num_cells = data.GetNumberOfCells()
    for cell_index in range(num_cells):
        cell = data.GetCell(cell_index)
        num_edges = cell.GetNumberOfEdges()

        for edge_index in range(num_edges):
            edge = cell.GetEdge(edge_index)
            edge_points = edge.GetPointIds()

            # 辺を構成する点のインデックスを取得
            point1_id = edge_points.GetId(0)
            point2_id = edge_points.GetId(1)

            # 辺を構成する点のインデックスを照準にソート
            edge_key = (min(point1_id, point2_id), max(point1_id, point2_id))

            # 辺の隣接セル数をカウント
            if edge_key in edge_neighbors:
                edge_neighbors[edge_key] += 1
            else:
                edge_neighbors[edge_key] = 1 

    boundary_edges = []
    # 境界上の辺を特定
    for edge_key, num_neighbors in edge_neighbors.items():
        if num_neighbors == 1:
            boundary_edges.append(edge_key)

    # 境界上の辺を構成する頂点の番号を取得
    boundary_points = set()     # 集合を表すデータ型、順番を持たず、重複した要素は取り除かれる
# ---------------- 自由点かどうかの判定完了 ------------------------
    

    for edge_key in boundary_edges:
        boundary_points.add(edge_key[0])
        boundary_points.add(edge_key[1])
    
    
    for pointId in range(num_points):       # pointId:自由点の頂点番号
        if pointId in boundary_points:
            continue
        else:
            poly_count += 1
            # print("pointId:", pointId)
        mask = (mesh.faces == pointId)
        if mask.any():
            count = torch.sum(mask).item()
        num_node = count + 1
        num_face = count
        polygon_number = poly_count - 1 



        
        polygon_i = f"polygon_{polygon_number}"
        # print(polygon_i)
        polygon_i = Polygon(num_node, num_face)
        
        element_to_check = pointId
        polygon_i.face = mesh.faces[(mesh.faces == element_to_check).any(dim=1)]
        # print(polygon_i.face)

        polygon_i.nodeId = set()
        for i in range(len(polygon_i.face)):
            polygon_i.nodeId.add(polygon_i.face[i, 0].item())
            polygon_i.nodeId.add(polygon_i.face[i, 1].item())
            polygon_i.nodeId.add(polygon_i.face[i, 2].item())
        sorted_nodeId = sorted(polygon_i.nodeId)
        polygon_i.nodeID = torch.tensor(list(sorted_nodeId))
        
        point_id_index = (polygon_i.nodeID == pointId).nonzero().item()

        value_to_move = polygon_i.nodeID[point_id_index]
        polygon_i.nodeID = torch.cat((value_to_move.unsqueeze(0), polygon_i.nodeID[polygon_i.nodeID != pointId]))
        # print(polygon_i.nodeID)
        setattr(polygon_i, "parent_meshID", mesh)
        polygonID_list.append(f"polygon_{polygon_number}")

        keyword = f"polygon_{polygon_number}"
        valiables = (f"mesh_{mesh_index}", polygon_i.nodeID)
        polygon_dict[keyword] = valiables

    # --------- polygon.nodeID の取得完了 -------------
    return mesh, polygonID_list, poly_count, polygon_dict

# Dataset の作成

In [6]:
def create_mesh_polygon_dataset(vtk_files):
    num_vtk_files = len(vtk_files)
    polygonID_list = []
    mesh_data_list = []
    poly_count = 0
    polygon_dict = {}
    # ファイルに順にアクセスする
    for i in range(num_vtk_files):
        # print("File Name:", vtk_files[i])
        mesh, polygonID_list, poly_count, polygon_dict = create_mesh_polygonID_data(vtk_files[i], polygonID_list, poly_count, polygon_dict, i)
        mesh_data_list.append(mesh)
    return mesh_data_list, polygonID_list, polygon_dict


In [7]:
# 以下、i はpolygon番号で座標と面情報を取得することができる
face_to_edge = T.FaceToEdge(remove_faces=False)
def data_getter(polygonID, num_mesh_data_list, mesh_data_list, polygon_data_list):
    
    polygon_meshID = int(polygon_data_list[polygonID].meshID.split("_")[-1])
    mesh = mesh_data_list[polygon_meshID]
    
    num_node = len(polygon_data_list[polygonID].nodeID)
    num_face = num_node - 1 
    polygon_i = Polygon(num_node, num_face)

    # print(polygon_data_list[polygonID].nodeID)      # polygon に属する頂点の番号

    polygon_i.coordinates = mesh.coordinates[polygon_data_list[polygonID].nodeID]     # polygonの座標
    # print(polygon_i.coordinates)

    # print(polygon_i.faces)

    # polygon_i.faces を取得するコード
    
    element_to_check = polygon_data_list[polygonID].nodeID[0]
    polygon_i.face = mesh.faces[(mesh.faces == element_to_check).any(dim=1)]

    indices = torch.nonzero(torch.isin(polygon_i.face, polygon_data_list[polygonID].nodeID))
    for idx in range(indices.size(0)):
        row_idx, col_idx = indices[idx]
        value_to_replace = polygon_i.face[row_idx, col_idx]
        polygon_i.face[row_idx, col_idx] = (polygon_data_list[polygonID].nodeID == value_to_replace).nonzero().item()
    polygon_i.faces = polygon_i.face.long()

    # 各行の三角形からエッジを抽出してedge_indexを構築
    edges = torch.cat([ polygon_i.faces[:, [0, 1]],
                        polygon_i.faces[:, [1, 2]],
                        polygon_i.faces[:, [2, 0]]], dim=0)

    # エッジのインデックスをソートして重複を削除
    edge_index = torch.sort(edges, dim=1).values
    edge_index = torch.tensor(sorted(edge_index.numpy().tolist())).unique(dim=0)
    polygon_i.edge_index = torch.transpose(edge_index, 0, 1)
    return polygon_i



# メッシュをプロットする関数

In [8]:
def plot_mesh(mesh, title):

    vertices = mesh.coordinates
    faces = mesh.faces
    fig = plt.figure()
    ax = fig.add_subplot(111, aspect="equal")

    # 描画するメッシュの頂点をプロット
    # ax.plot(vertices[:,0], vertices[:,1], 'bo')  # 頂点を青色の点でプロット
    # ax.plot(vertices[:,0], vertices[:,1], 'k-')  # 辺を黒色の線でプロット

    # 各三角形をプロット
    for face in faces:
        v0, v1, v2 = vertices[face]
        v0_np = v0.detach().numpy()
        v1_np = v1.detach().numpy()
        v2_np = v2.detach().numpy()
        ax.plot([v0_np[0], v1_np[0], v2_np[0], v0_np[0]], [v0_np[1], v1_np[1], v2_np[1], v0_np[1]], 'b-')  # 三角形を赤色の線でプロット

    ax.set_title(title)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.axhline(0, color="black", linewidth=0.001)
    ax.axvline(0, color="black", linewidth=0.001)

    # plt.xlim(-300, 150)
    # plt.ylim(-200, 1400)
    # plt.savefig(f"/mnt/{title}.png", format="png")
    plt.show()
    plt.close()


In [9]:
def save_mesh(mesh, title):

    vertices = mesh.coordinates
    faces = mesh.faces
    fig = plt.figure()
    ax = fig.add_subplot(111, aspect="equal")

    # 描画するメッシュの頂点をプロット
    # ax.plot(vertices[:,0], vertices[:,1], 'bo')  # 頂点を青色の点でプロット
    # ax.plot(vertices[:,0], vertices[:,1], 'k-')  # 辺を黒色の線でプロット

    # 各三角形をプロット
    for face in faces:
        v0, v1, v2 = vertices[face]
        v0_np = v0.detach().numpy()
        v1_np = v1.detach().numpy()
        v2_np = v2.detach().numpy()
        ax.plot([v0_np[0], v1_np[0], v2_np[0], v0_np[0]], [v0_np[1], v1_np[1], v2_np[1], v0_np[1]], 'b-')  # 三角形を赤色の線でプロット

    ax.set_title(title)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.axhline(0, color="black", linewidth=0.001)
    ax.axvline(0, color="black", linewidth=0.001)

    # plt.xlim(-300, 150)
    # plt.ylim(-200, 1400)
    plt.savefig(f"/mnt/Saved_mesh/{title}.png", format="png")
    # plt.show()
    plt.close()

# meshデータからvtkファイルを出力する関数

In [10]:
def vtk_output(mesh, title):
    vertices = mesh.coordinates
    faces = mesh.faces
    num_vertices = len(vertices)
    num_faces = len(faces)


    # vertices を３次元に戻す
    z_column = torch.zeros(vertices.shape[0], 1)
    vertices = torch.cat((vertices, z_column), dim=1)

    with open(f"/mnt/optimized_data/{title}.vtk", "w") as f:
        f.write("# vtk DataFile Version 2.0\n")
        f.write("FOR TEST\n")
        f.write("ASCII\n")
        f.write("DATASET POLYDATA\n")

        f.write("POINTS {} float\n".format(num_vertices))
        for vertex in vertices:
            f.write("{:.15f} {:.15f} {:.15f}\n".format(*vertex))

        f.write("\nPOLYGONS {} {}\n".format(num_faces, num_faces * 4))
        for face in faces:
            f.write("3 ")
            f.write(" ".join(str(idx.item()) for idx in face))
            f.write("\n")
       



# Normalization

In [11]:
def normalization(polygon):
    vertices = polygon.coordinates
    normalized_vertices = vertices.clone()
    centered_vertices = vertices.clone()
    # print(vertices)

    max_x = torch.max(vertices[:,0])
    min_x = torch.min(vertices[:,0])
    max_y = torch.max(vertices[:,1])
    min_y = torch.min(vertices[:,1])

    polygon.d = torch.max(max_x - min_x, max_y - min_y)
    polygon.x_min = min_x
    polygon.y_min = min_y

    normalized_vertices = (vertices - torch.tensor([polygon.x_min, polygon.y_min])) / polygon.d

    
    polygon.Cx = normalized_vertices[0,0].item()
    polygon.Cy = normalized_vertices[0,1].item()   

    centered_vertices = normalized_vertices - torch.tensor([polygon.Cx, polygon.Cy])
    polygon.coordinates = centered_vertices
    
    # print("Normalized polygon:", vertices)

    return polygon


    


# denormalization

In [12]:
def denormalization(polygon):
    vertices = polygon.coordinates
    shifted_vertices = vertices.clone()
    denormalized_vertices = vertices.clone()
    
    shifted_vertices = vertices + torch.tensor([polygon.Cx, polygon.Cy])
        

    denormalized_vertices = polygon.d * shifted_vertices + torch.tensor([polygon.x_min, polygon.y_min])
    polygon.coordinates = denormalized_vertices
    return polygon


# MetricLoss

In [13]:
import sys

import logging

logging.basicConfig(
    filename='metric_loss.log',
    level=logging.DEBUG, 
    format='%(message)s'
)

logger = logging.getLogger(__name__)

class MetricLoss:
    def select_vertices(self, vertices, face):
        v0 = vertices[face[0]].clone()
        v1 = vertices[face[1]].clone()
        v2 = vertices[face[2]].clone()
        return v0, v1, v2 

    def edge_length(self, v0, v1, v2):
        l1 = torch.sqrt((v0[0] - v1[0])**2 + (v0[1] - v1[1])**2)
        l2 = torch.sqrt((v1[0] - v2[0])**2 + (v1[1] - v2[1])**2)
        l3 = torch.sqrt((v2[0] - v0[0])**2 + (v2[1] - v0[1])**2)

        return l1, l2, l3

    def face_area(self, polygon, l1, l2, l3):
        
        s = 0.5*(l1 + l2 + l3)
        #print(round(l1.item(),5), round(l2.item(),5), round(l3.item(),5), round(s.item(),5), round((s*(s-l1)*(s-l2)*(s-l3)).item(),5))
        temp = s*(s-l1)*(s-l2)*(s-l3)
        #temp.register_hook(print_grad)
        logger.info("    s, in_sqrt: {}, {}".format(s.item(), temp.item()))
        
        try:
            
            face_area = torch.sqrt(temp)
            
        except Exception as e:
            
            SimplePolygonGenerator.polygon_visualizer(polygon)
            
            print("An error occurred")
            print("Value of temp:", temp)
            print(l1.item(), l2.item(), l3.item())
            
            raise
        
        #face_area = torch.sqrt(temp)
        #face_area.register_hook(print_grad)
        return face_area

    def compute_loss(self, polygon, vertices, face, dx):
        v0, v1, v2 = self.select_vertices(vertices, face)
        if dx is not None:
            if face[0]==0:
                v0 = v0 + dx
            elif face[1]==0:
                v1 += v1 + dx
            elif face[2]==0:
                v2 += v2 + dx
        #print(v0, v1, v2)
        logger.info("    v0: ({}, {})".format(v0[0].item(), v0[1].item()))
        logger.info("    v1: ({}, {})".format(v1[0].item(), v1[1].item()))
        logger.info("    v2: ({}, {})".format(v2[0].item(), v2[1].item()))
        l1, l2, l3 = self.edge_length(v0, v1, v2)
        logger.info("    l1, l2, l3:  {}, {}, {}".format(l1.item(), l2.item(), l3.item()))
        s = self.face_area(polygon, l1, l2, l3)
        #print(s.item(), l1.item(), l2.item(), l3.item())

        #q = (l1**2 + l2**2 + l3**2)/(4.0*torch.sqrt(torch.tensor(3.))*s+1.0)
        
        #loss = 1 - 1/q
        #print(q.item(), loss.item())
        
        #q = q.clone().detach().requires_grad_(True)
        loss = 1-(4.0*torch.sqrt(torch.tensor(3.))*s)/(l1**2 + l2**2 + l3**2)
        logger.info("    area, loss: {}, {}".format(s.item(), loss.item()))
        logger.info("")
        #print(loss.item())
        #print("")
        
        #loss.register_hook(print_grad)
        
        return loss

    def __call__(self, polygon, dx=None):
        vertices = polygon.coordinates
        faces = polygon.faces
        loss = 0 
        #print(vertices)
        #print(dx)
        for face in faces:
            loss = loss + self.compute_loss(polygon, vertices, face, dx)
        
        metric_loss = loss/(len(polygon.coordinates[:,0])-1) #.clone().detach().requires_grad_(True))
        return metric_loss
    
    
def print_grad(grad):
    print(grad)

# meshデータからq_hatを求める関数

In [14]:
def calculate_q_hat(mesh):
    vertices = mesh.coordinates
    faces = mesh.faces
    r_list = []
    alpha_list = []
    beta_list = []

    for face in faces:
        # a(最小角)と b(最大角)を求める

        angles = []
        v0, v1, v2 = m_loss.select_vertices(vertices, face)
        l1, l2, l3 = m_loss.edge_length(v0, v1, v2)

        # 余弦定理から各角度の余弦値を計算
        cos_alpha = (l2**2 + l3**2 - l1**2) / (2*l2*l3)
        cos_beta = (l1**2 + l3**2 - l2**2) / (2*l1*l3)
        cos_gamma = (l1**2 + l2**2 - l3**2) / (2*l1*l2)
        # 余弦値から角度を計算して個度法に変換
        alpha = torch.acos(cos_alpha) * 180 / np.pi
        beta = torch.acos(cos_beta) * 180 / np.pi
        gamma = torch.acos(cos_gamma) * 180 / np.pi

        angles.append(alpha)
        angles.append(beta)
        angles.append(gamma)

        min_angle = min(angles)
        max_angle = max(angles)

        alpha_list.append(min_angle)
        beta_list.append(max_angle)



        # 1/q = r を求める

        r = 1 - m_loss.compute_loss(vertices, face) 
        r_list.append(r)

    a_mean = sum(alpha_list) / len(alpha_list)
    a_min = min(alpha_list)
    b_mean = sum(beta_list) / len(beta_list)
    b_max = max(beta_list)
    r_mean = sum(r_list) / len(r_list)
    r_min = min(r_list)

    q_hat = (((a_mean + a_min + 120 - b_max - b_mean)/60) + r_mean + r_min) / 6

    return q_hat

    


    
    

スターポリゴンの中から外側に自由点が移動したときに自由点の移動量を半分にしてもう一度外に行っていないか検証する
自由点が外に行かないことを確認したあとのスターポリゴンを返す

In [15]:
def check(polygon, polygonID):
    # print("polygonID:", polygonID)
    vertices = polygon.coordinates
    
    edge_index = polygon.edge_index
    
    return_value = True
    while return_value == True:   

        for i in range(1, len(vertices[:,0])):
            point1 = torch.tensor([0.0, 0.0])
            point2 = vertices[0]
            point3 = vertices[i]
            
            pos_i = torch.where(edge_index[0] == i)
            pos_i = pos_i[0]
            # print("edge_index[0]", edge_index[0])
            
            for j in range(len(pos_i)):
                if edge_index[1, pos_i[j]] == 0:
                    continue
                else:
                    val_pos_i = edge_index[1, pos_i[j]]
                    break

            point4 = vertices[val_pos_i]


            a1 = 0
            b1 = 0
            a2 = 0
            b2 = 0
            check1 = 0
            check2 = 0
            check3 = 0
            check4 = 0
            x1 = point1[0]
            y1 = point1[1]
            x2 = point2[0]
            y2 = point2[1]
            x3 = point3[0]
            y3 = point3[1]
            x4 = point4[0]
            y4 = point4[1]
            a1 = (y1 - y2)/(x1-x2)
            b1 = y1 - (a1*x1)
            a2 = (y3 - y4)/(x3-x4)
            b2 = y3 - (a2*x3)
            check1 = (a1*x3) - y3 + b1 
            check2 = (a1*x4) - y4 + b1    # point1,2を通る直線に対してpoint3,4を結ぶ線分が交差しているか
            check3 = (a2*x1) - y1 + b2
            check4 = (a2*x2) - y2 + b2    # point3,4を通る直線に対してpoint1,2を結ぶ線分が交差しているか
            # print("1:",check1,"2:",check2,"3:",check3,"4:",check4)
            del a1, a2, b1, b2, x1, x2, x3, x4, y1, y2, y3, y4 

            if (check1*check2) <= 0 and (check3*check4) <= 0 :
                return_value = True
                # print("Out_of_StarPolygon")
                vertices[0] = 0.5*vertices[0]
                polygon.coordinates[0] = vertices[0]
                break
            else:
                return_value = False
                continue       
            
        
    # plot_mesh(polygon, "polygon_checked")
               
    return polygon

# Model 隠れ層のノード数は何にするか未定

In [16]:
class GMSNet(torch.nn.Module):
    
    def __init__(self, input_dim, feature_dim, hidden_channnels):
        
        super(GMSNet, self).__init__()
        torch.manual_seed(42)
        
        self.shared_mlp = Conv1d(input_dim, feature_dim, kernel_size=1)
        self.GNorm = GraphNorm(feature_dim, feature_dim)
        self.conv = GCNConv(feature_dim, feature_dim)
        self.fc1 = Linear(feature_dim, hidden_channnels)
        #self.ISNorm = InstanceNorm1d(hidden_channnels, affine=True)
        self.fc2 = Linear(hidden_channnels, input_dim)
        
        self.relu = ReLU()
        self.tanh = Tanh()

        # Weight initialization
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, (nn.Linear, nn.Conv1d)):
            nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    def forward(self, x, edge_index):       
        
        # feature
        x = torch.permute(x, (0, 2, 1))
        x = self.shared_mlp(x)
        feature = self.relu(x)
        feature = torch.permute(feature, (0, 2, 1))
        
        # GNN
        # x = self.GNorm(feature)
        x = self.relu(feature)
        gnn_feature = self.conv(x, edge_index) + feature
        
        # MLP
        target_feature = gnn_feature.mean(dim=1)
        mlp_midlayer = self.fc1(target_feature)
        x = self.relu(mlp_midlayer)
        x = self.fc2(x)
        x = self.tanh(x)
        
        x = 0.1*x
        
        return x

# Main

In [17]:
# フォルダ内のすべてのvtkファイルにアクセスする
train_vtk_files = glob.glob("/mnt/new_train_data/*.vtk")
eval_vtk_files = glob.glob("/mnt/Eval_Data/*.vtk")
test_vtk_files = glob.glob("/mnt/Test_Data/*vtk")

num_train_mesh = len(train_vtk_files)
num_test_mesh = len(test_vtk_files)
print("num_train_mesh:", num_train_mesh)
train_mesh_data_list, train_polygonID_list, train_polygon_dict = create_mesh_polygon_dataset(train_vtk_files)
eval_mesh_data_list, eval_polygonID_list, eval_polygon_dict = create_mesh_polygon_dataset(eval_vtk_files)
test_mesh_data_list, test_polygonID_list, test_polygon_dict = create_mesh_polygon_dataset(test_vtk_files)


    
# ポリゴンデータを格納するリストを作成
train_polygon_data_list = []
eval_polygon_data_list = []
test_polygon_data_list = []

for i in range(len(train_polygonID_list)):
    polygonID = f"polygon_{i}"
    meshID = train_polygon_dict[f"polygon_{i}"][0]
    nodeID = train_polygon_dict[f"polygon_{i}"][1]
    polygon_data = Polygon_data(polygonID, meshID, nodeID)
    train_polygon_data_list.append(polygon_data)

for i in range(len(eval_polygonID_list)):
    polygonID = f"polygon_{i}"
    meshID = eval_polygon_dict[f"polygon_{i}"][0]
    nodeID = eval_polygon_dict[f"polygon_{i}"][1]
    polygon_data = Polygon_data(polygonID, meshID, nodeID)
    eval_polygon_data_list.append(polygon_data)

for i in range(len(test_polygonID_list)):
    polygonID = f"polygon_{i}"
    meshID = test_polygon_dict[f"polygon_{i}"][0]
    nodeID = test_polygon_dict[f"polygon_{i}"][1]
    polygon_data = Polygon_data(polygonID, meshID, nodeID)
    test_polygon_data_list.append(polygon_data)

num_train_mesh: 3


In [18]:
type(train_polygon_data_list[0])

__main__.Polygon_data

In [19]:
train_data_loader = DataLoader(train_polygonID_list, batch_size=64*num_train_mesh, shuffle=True)
# eval_data_loader = DataLoader(eval_polygonID_list, batch_size=32, shuffle=True)
test_data_loader = DataLoader(test_polygonID_list, batch_size=64*num_test_mesh, shuffle=True)
# for step, data in enumerate(data_loader):
#     print(f"Step {step + 1}:")
#     print("==========")
#     print(data)
#     print(len(data))
#     for i in range(len(data)):
#         polygonID = int(data[i].split("_")[-1])
#         print("polygonID:",polygonID)
#         polygon = data_getter(polygonID, 0)
#         print("polygon.coordinates:",polygon.coordinates)
#         print("polygon.edge_index:", polygon.edge_index)
#         print("==========")
#         # plot_mesh(polygon, "title")

In [20]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = GMSNet(input_dim=2, feature_dim=64, hidden_channnels=64)
model.to(device)
print(model)
m_loss = MetricLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.0001)
criterion = MetricLoss()
#criterion = nn.MSELoss()

cuda
GMSNet(
  (shared_mlp): Conv1d(2, 64, kernel_size=(1,), stride=(1,))
  (GNorm): GraphNorm(64)
  (conv): GCNConv(64, 64)
  (fc1): Linear(in_features=64, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=2, bias=True)
  (relu): ReLU()
  (tanh): Tanh()
)


In [21]:
# model.load_state_dict(torch.load('mnt/model_weight.pth'))

In [25]:
# writer = SummaryWriter(log_dir="mnt/logs/" + datetime.datetime.now()strftime("%Y%m%d-%H%M%S"))
# 学習率を調整するスケジューラの設定
scheduler = ReduceLROnPlateau(optimizer, mode="min", patience=2, factor=0.95, verbose=True)

loss_list = []

def train_(device):
    model.train()
    loss_list = []
    for step, data in enumerate(train_data_loader):
        empty_cache()
        gc.collect()
        # print(f"Step {step + 1}:")
        # print("==========")
        # print(data)
        # print(len(data))
        minibatch = Minibatch()
        minibatch_coordinates = []
        all_edge_index_1 = []
        all_edge_index_2 = []
        batch_list = []
        num_dis = 0
        metric_loss_list = []
        for i in range(len(data)):
            # gc.collect()
            polygonID = int(data[i].split("_")[-1])
            # print("polygonID:",polygonID)
            polygon = data_getter(polygonID, 0, train_mesh_data_list, train_polygon_data_list)                                                                     # 270MiB
            # plot_mesh(polygon, "")      #############################################

            # polygon を正規化する
            polygon = normalization(polygon)                                                                        # 900MiB
            
            # plot_mesh(polygon, "")      #############################################

            # print("polygon.coordinates:",polygon.coordinates)
            edge_index = polygon.edge_index + num_dis                               
            # print("polygon.edge_index:", edge_index)
            all_edge_index_1.append(edge_index[0])
            all_edge_index_2.append(edge_index[1])
            num_dis = num_dis + len(polygon.coordinates)
            # print("==========")
            minibatch_coordinates.append(polygon.coordinates)
            batch_i = torch.tensor([i]*len(polygon.coordinates))                    
            batch_list.append(batch_i)
            
        # print(minibatch_coordinates)
        minibatch.x = torch.cat(minibatch_coordinates, dim=0)
        
        # print(minibatch.x)
        # print("minibatch.x.size:", minibatch.x.size())
        edge_index_1 = torch.cat(all_edge_index_1, dim=-1)
        edge_index_2 = torch.cat(all_edge_index_2, dim=-1)
        minibatch.edge_index = torch.cat([edge_index_1.unsqueeze(0), edge_index_2.unsqueeze(0)], dim=0)
        # print(minibatch.edge_index)
        # print("minibatch.edge_index.size:", minibatch.edge_index.size())
        minibatch.batch = torch.cat(batch_list, dim=0)
        # print(minibatch.batch)
        # print("minibatch.batch.size:", minibatch.batch.size())

        # すべてのデータをGPU上に移動する
        minibatch.x = minibatch.x.to(device)
        minibatch.edge_index = minibatch.edge_index.to(device)
        minibatch.batch = minibatch.batch.to(device)
        model.to(device)

        out = model(minibatch.x, minibatch.edge_index)                                     # 333MiB
        # print("out:", out)
        # print("out:", out.size())

        for i in range(len(data)):
            # gc.collect()
            polygonID = int(data[i].split("_")[-1])
            polygon = data_getter(polygonID, 0, train_mesh_data_list, train_polygon_data_list)                                                             # 225MiB
            polygon = normalization(polygon)                                                                # 932MiB

            metric_loss = criterion(polygon)
            metric_loss = criterion(polygon, out[i].cpu())
            
            polygon.coordinates[0] = out[i]

            metric_loss_list.append(metric_loss)

        loss = (sum(metric_loss_list) / len(metric_loss_list))      # .requires_grad_(True)
        print("Loss:", loss)
        writer.add_scalar("Loss/train", loss, epoch)
        # print("loss:", loss)

        loss.backward()
        optimizer.step()
        
        with torch.no_grad():
            for i in range(len(data)):
                # gc.collect()
                polygonID = int(data[i].split("_")[-1])
                polygon = data_getter(polygonID, 0, train_mesh_data_list, train_polygon_data_list)                                                             # 225MiB
                polygon = normalization(polygon)    
                
                polygon.coordinates[0] = polygon.coordinates[0] + out[i]
                
                polygon_meshID = int(train_polygon_data_list[polygonID].meshID.split("_")[-1])
                mesh = train_mesh_data_list[polygon_meshID]

                mesh.coordinates[train_polygon_data_list[polygonID].nodeID[0]] = polygon.coordinates[0]
        
        del out
        
        loss.detach()               # 計算グラフを切り離し、不要な計算グラフが保持されることを防ぐ
        optimizer.zero_grad()

        # ステップごとに損失をログに記録
        writer.add_scalar("/mnt/logs", loss.item(), global_step=len(train_data_loader)*epoch + step)

        loss_list.append(loss.item())
    
    val_loss = sum(loss_list)/ len(loss_list)
    scheduler.step(val_loss)


def train(device):
    model.train()
    model.to(device)
    
    temp = 0
    ddd = 0
    
    for step, data in enumerate(train_data_loader):
        
        optimizer.zero_grad()
        metric_loss = 0
        dx_list = []
        
        for i in range(len(data)):

            polygonID = int(data[i].split("_")[-1])
            polygon = data_getter(polygonID, 0, train_mesh_data_list, train_polygon_data_list)
            polygon = normalization(polygon) 

            edge_index = polygon.edge_index               
            
            x = polygon.coordinates.clone().unsqueeze(0).to(device)
            ei = edge_index.to(device)
            out = model(x, ei)
            
            logger.info("epoch: {:04}, polygonID: {:04}".format(epoch, polygonID))
            logger.info("before")
            with torch.no_grad():
                ml = criterion(polygon)
                logger.info("")
            logger.info("after")
            l = criterion(polygon, out[0].cpu())
            
            logger.info("")
            
            metric_loss += l
            dx_list.append(out[0].cpu())
                
            #prediction_list[polygonID].append(out[0].cpu())
            
        loss = metric_loss/len(data)
        ddd += len(data)
        temp += loss
        print("    Loss:", loss.item(), polygonID)
        
        loss.backward()
        optimizer.step()
        
        #if epoch>10:
        for i in range(len(data)):

            with torch.no_grad():
                polygonID = int(data[i].split("_")[-1])

                polygon = data_getter(polygonID, 0, train_mesh_data_list, train_polygon_data_list)
                polygon = normalization(polygon) 

                polygon.coordinates[0] = polygon.coordinates[0] + dx_list[i]
                polygon = denormalization(polygon)
                
                polygon_meshID = int(train_polygon_data_list[polygonID].meshID.split("_")[-1])
                mesh = train_mesh_data_list[polygon_meshID]
                mesh.coordinates[train_polygon_data_list[polygonID].nodeID[0]] = polygon.coordinates[0]

            #print()
    writer.add_scalar("loss", temp/ddd, epoch)        
    loss_list.append(temp/ddd)

# 最終的な最適化したメッシュを生成してvtkファイルで出力する

In [29]:
# TensorBoard用のログディレクトリを指定
writer = SummaryWriter(log_dir="/mnt/logs/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
epoch = 0

import torch
torch.autograd.set_detect_anomaly(True)

for j in range(num_train_mesh):
        save_mesh(train_mesh_data_list[j], f"{epoch}_{j}")

for epoch in tqdm(range(num_train_epoch)):
    #print("epoch:", epoch)
    train(device)

    # for i in range(num_train_mesh):
    #     plot_mesh(train_mesh_data_list[i], f"{i}")
    for j in range(num_train_mesh):
        save_mesh(train_mesh_data_list[j], f"{epoch+1}_{j}")
    
writer.close()



  0%|          | 0/20 [00:00<?, ?it/s]

    Loss: 0.04175681248307228 143
    Loss: 0.04238526523113251 333
    Loss: 0.038668349385261536 1126
    Loss: 0.040344804525375366 1101
    Loss: 0.041604939848184586 1303
    Loss: 0.04191674664616585 1424
    Loss: 0.03896954283118248 1007
    Loss: 0.03568774461746216 758


  5%|▌         | 1/20 [01:18<24:56, 78.75s/it]

    Loss: 0.0428156740963459 873
    Loss: 0.04394881799817085 896
    Loss: 0.039640843868255615 668
    Loss: 0.04190824553370476 827
    Loss: 0.04192792996764183 1092
    Loss: 0.04024863988161087 94
    Loss: 0.0423746258020401 466
    Loss: 0.04171696677803993 327


 10%|█         | 2/20 [02:37<23:36, 78.71s/it]

    Loss: 0.044323358684778214 917
    Loss: 0.042368095368146896 1436
    Loss: 0.04675270989537239 757
    Loss: 0.04548662528395653 752
    Loss: 0.04682115837931633 1051
    Loss: 0.044623490422964096 114
    Loss: 0.04044599086046219 614
    Loss: 0.0493045449256897 1025


 15%|█▌        | 3/20 [03:55<22:11, 78.33s/it]

    Loss: 0.046055883169174194 175
    Loss: 0.04889382794499397 200
    Loss: 0.046966105699539185 1024
    Loss: 0.04705004021525383 697
    Loss: 0.04711902141571045 802
    Loss: 0.046977534890174866 1483
    Loss: 0.04646129533648491 457
    Loss: 0.05063688009977341 1362


 20%|██        | 4/20 [05:13<20:50, 78.18s/it]

    Loss: 0.04574117437005043 1395
    Loss: 0.050673361867666245 326
    Loss: 0.04631051793694496 385
    Loss: 0.05341792106628418 1259
    Loss: 0.051485057920217514 1464
    Loss: 0.04865396395325661 351
    Loss: 0.0474836491048336 61
    Loss: 0.04818778112530708 870


 25%|██▌       | 5/20 [06:30<19:25, 77.69s/it]

    Loss: 0.0475563108921051 932
    Loss: 0.04984183982014656 1434
    Loss: 0.049684058874845505 983
    Loss: 0.05084561929106712 213
    Loss: 0.049634456634521484 465
    Loss: 0.04864880070090294 927
    Loss: 0.0465807281434536 712
    Loss: 0.05103706195950508 1151


 30%|███       | 6/20 [07:48<18:12, 78.01s/it]

    Loss: 0.051886752247810364 527
    Loss: 0.048116061836481094 612
    Loss: 0.04457961022853851 191
    Loss: 0.04859668388962746 297
    Loss: 0.0474238246679306 276
    Loss: 0.04698377847671509 536
    Loss: 0.045717377215623856 1465
    Loss: 0.04395141825079918 22


 35%|███▌      | 7/20 [09:07<16:57, 78.25s/it]

    Loss: 0.04607630893588066 846
    Loss: 0.04580572247505188 1020
    Loss: 0.04288668930530548 506
    Loss: 0.04519975557923317 789
    Loss: 0.0411691851913929 346
    Loss: 0.0437016487121582 1415
    Loss: 0.0420299731194973 900
    Loss: 0.04062485694885254 1163


 40%|████      | 8/20 [10:24<15:34, 77.91s/it]

    Loss: 0.04233631491661072 1103
    Loss: 0.038255758583545685 1093
    Loss: 0.037721071392297745 873
    Loss: 0.03603743389248848 343
    Loss: 0.03839532285928726 707
    Loss: 0.039037831127643585 1058
    Loss: 0.03541308268904686 412
    Loss: 0.038366351276636124 1379


 45%|████▌     | 9/20 [11:45<14:26, 78.74s/it]

    Loss: 0.035675615072250366 975
    Loss: 0.03419075533747673 302
    Loss: 0.034629445523023605 901
    Loss: 0.032403480261564255 1350
    Loss: 0.03371826186776161 293
    Loss: 0.031342171132564545 596
    Loss: 0.031205808743834496 252
    Loss: 0.03118315525352955 1216


 50%|█████     | 10/20 [13:06<13:15, 79.58s/it]

    Loss: 0.03048202581703663 377
    Loss: 0.02822781912982464 1317
    Loss: 0.031425464898347855 478
    Loss: 0.02820245362818241 736
    Loss: 0.02983812429010868 667
    Loss: 0.027293389663100243 1277
    Loss: 0.026985546573996544 119
    Loss: 0.02702704444527626 832


 55%|█████▌    | 11/20 [14:23<11:48, 78.69s/it]

    Loss: 0.02806713432073593 60
    Loss: 0.025075459852814674 696
    Loss: 0.026688462123274803 376
    Loss: 0.025178654119372368 427
    Loss: 0.02354380488395691 725
    Loss: 0.027005383744835854 1449
    Loss: 0.02584058605134487 1036
    Loss: 0.02355206198990345 326


 60%|██████    | 12/20 [15:38<10:20, 77.52s/it]

    Loss: 0.023981742560863495 486
    Loss: 0.024443626403808594 1400
    Loss: 0.022984229028224945 94
    Loss: 0.02535443939268589 892
    Loss: 0.02597607858479023 653
    Loss: 0.02221301943063736 767
    Loss: 0.02350631356239319 1144
    Loss: 0.021951740607619286 1476


 65%|██████▌   | 13/20 [16:58<09:09, 78.52s/it]

    Loss: 0.021773425862193108 1420
    Loss: 0.022963574156165123 1175
    Loss: 0.02383040077984333 1351
    Loss: 0.021677039563655853 1261
    Loss: 0.02522192895412445 857
    Loss: 0.02214350365102291 583
    Loss: 0.020491937175393105 480
    Loss: 0.02355988137423992 1326


 70%|███████   | 14/20 [18:15<07:46, 77.83s/it]

    Loss: 0.023366563022136688 456
    Loss: 0.020729107782244682 1021
    Loss: 0.021366924047470093 1095
    Loss: 0.022078434005379677 202
    Loss: 0.023478815332055092 419
    Loss: 0.02134416252374649 420
    Loss: 0.022082896903157234 83
    Loss: 0.02119002863764763 1256


 75%|███████▌  | 15/20 [19:37<06:36, 79.22s/it]

    Loss: 0.02158377878367901 107
    Loss: 0.0221440177410841 612
    Loss: 0.021294469013810158 227
    Loss: 0.023193947970867157 688
    Loss: 0.02094435691833496 840
    Loss: 0.02127321995794773 511
    Loss: 0.021403148770332336 769
    Loss: 0.0212516151368618 1465


 80%|████████  | 16/20 [20:54<05:14, 78.61s/it]

    Loss: 0.01928972266614437 37
    Loss: 0.02184154838323593 1222
    Loss: 0.020813291892409325 1220
    Loss: 0.022160368040204048 211
    Loss: 0.024582235142588615 123
    Loss: 0.02205674909055233 188
    Loss: 0.019367888569831848 250
    Loss: 0.021928558126091957 1264


 85%|████████▌ | 17/20 [22:12<03:55, 78.41s/it]

    Loss: 0.020331883803009987 919
    Loss: 0.021874485537409782 834
    Loss: 0.019039006903767586 284
    Loss: 0.020743979141116142 1268
    Loss: 0.02180296741425991 1161
    Loss: 0.020646555349230766 897
    Loss: 0.02406133897602558 1082
    Loss: 0.023127567023038864 1371


 90%|█████████ | 18/20 [23:30<02:36, 78.10s/it]

    Loss: 0.019980700686573982 1355
    Loss: 0.020837604999542236 420
    Loss: 0.021363260224461555 811
    Loss: 0.020654087886214256 1396
    Loss: 0.022064028307795525 272
    Loss: 0.022814972326159477 1417
    Loss: 0.022364402189850807 186
    Loss: 0.02093079313635826 925


 95%|█████████▌| 19/20 [24:55<01:20, 80.16s/it]

    Loss: 0.021123258396983147 787
    Loss: 0.02229459397494793 637
    Loss: 0.02010589838027954 1168
    Loss: 0.02162044681608677 1105
    Loss: 0.02133283019065857 1468
    Loss: 0.02094140462577343 222
    Loss: 0.02171563357114792 564
    Loss: 0.022077005356550217 452


100%|██████████| 20/20 [26:14<00:00, 78.70s/it]


In [None]:
polygon = data_getter(26, 0, train_mesh_data_list, train_polygon_data_list)
polygon.faces

In [None]:
polygon.coordinates

In [None]:
polygon_ = normalization(polygon)
polygon_.coordinates +torch.tensor([-0.08619694411754608, 0.08077329397201538])

In [None]:
plt.scatter(polygon.coordinates.T[0], polygon.coordinates.T[1])

In [None]:
for i in range(num_train_mesh):
    plot_mesh(train_mesh_data_list[i], f"{i}")

In [None]:
def test(device, trial, test_mesh_data_lists):
    model.eval()
    for step, data in enumerate(test_data_loader):
        empty_cache()
        gc.collect()
        # print(f"Step {step + 1}:")
        # print("==========")
        # print(data)
        # print(len(data))
        minibatch = Minibatch()
        minibatch_coordinates = []
        all_edge_index_1 = []
        all_edge_index_2 = []
        batch_list = []
        num_dis = 0
        metric_loss_list = []
        for i in range(len(data)):
            # gc.collect()
            polygonID = int(data[i].split("_")[-1])
            # print("polygonID:",polygonID)
            polygon = data_getter(polygonID, trial, test_mesh_data_list, test_polygon_data_list)
            # polygon を正規化する
            polygon = normalization(polygon)
            # print("polygon.coordinates:",polygon.coordinates)
            edge_index = polygon.edge_index + num_dis
            # print("polygon.edge_index:", edge_index)
            all_edge_index_1.append(edge_index[0])
            all_edge_index_2.append(edge_index[1])
            num_dis = num_dis + len(polygon.coordinates)
            # print("==========")
            minibatch_coordinates.append(polygon.coordinates)
            batch_i = torch.tensor([i]*len(polygon.coordinates))
            batch_list.append(batch_i)
            
        # print(minibatch_coordinates)
        minibatch.x = torch.cat(minibatch_coordinates, dim=0)
        
        # print(minibatch.x)
        # print("minibatch.x.size:", minibatch.x.size())
        edge_index_1 = torch.cat(all_edge_index_1, dim=-1)
        edge_index_2 = torch.cat(all_edge_index_2, dim=-1)
        minibatch.edge_index = torch.cat([edge_index_1.unsqueeze(0), edge_index_2.unsqueeze(0)], dim=0)
        # print(minibatch.edge_index)
        # print("minibatch.edge_index.size:", minibatch.edge_index.size())
        minibatch.batch = torch.cat(batch_list, dim=0)
        # print(minibatch.batch)
        # print("minibatch.batch.size:", minibatch.batch.size())

        # すべてのデータをGPU上に移動する
        minibatch.x = minibatch.x.to(device)
        minibatch.edge_index = minibatch.edge_index.to(device)
        minibatch.batch = minibatch.batch.to(device)
        model.to(device)

        out = model(minibatch.x, minibatch.edge_index, minibatch.batch)
        # print("out:", out)
        # print("out:", out.size())

        for i in range(len(data)):
            # gc.collect()
            polygonID = int(data[i].split("_")[-1])
            polygon = data_getter(polygonID, trial, test_mesh_data_list, test_polygon_data_list)
            # 正規化する
            polygon = normalization(polygon)

            polygon.coordinates[0] = out[i]
            
            # print("out_i:", out[i])
            polygon = check(polygon, polygonID)
            # print("fixed_out_i:", polygon.coordinates[0])
            # metric_loss = criterion(polygon)
            # metric_loss_list.append(metric_loss)
            
            # 非正規化する
            polygon = denormalization(polygon)

            # 予測したノードの座標をもとのメッシュに当てはめて更新する
            polygon_meshID = int(test_polygon_data_list[polygonID].meshID.split("_")[-1])
            mesh = test_mesh_data_lists[trial][polygon_meshID]

            mesh.coordinates[test_polygon_data_list[polygonID].nodeID[0]] = polygon.coordinates[0]
        
    return test_mesh_data_lists









            

In [None]:
# 100epochで最適化されたメッシュを10個生成する
test_mesh_data_lists = [copy.deepcopy(test_mesh_data_list) for _ in range(10)]

for trial in range(num_trial):
    for epoch in tqdm(range(num_test_epoch)):
        test_mesh_data_lists = test(device, trial, test_mesh_data_lists)


In [None]:
best_mesh_data_list = []
num_test_mesh = len(test_vtk_files)
    
for i in tqdm(range(num_test_mesh)):
    q_hat_list = []
    for j in range(10):
        mesh = test_mesh_data_lists[j][i]

        # q_hat を求めるコード
        q_hat = calculate_q_hat(mesh)
        
        q_hat_list.append(q_hat)
        best = q_hat_list.index(min(q_hat_list))
        best_mesh_data_list.append(test_mesh_data_lists[best + 1][i])
    


    # best_mesh_data_list[i] のデータをvtkファイルで出力する
    vtk_output(best_mesh_data_list[i], f"optimized_{i}")

## 更新したメッシュを表示してみる

In [None]:
for i in range(num_test_mesh):
    plot_mesh(best_mesh_data_list[i], f"optimized_{i}")

# 元のメッシュから座標が変わっているか確認する

In [None]:
格子点が滑らかに配置されているか、急激な変化や不連続がないかを評価します。# for i in range(num_test_mesh):
#     print(test_mesh_data_list[i].coordinates.size())
#     print(best_mesh_data_list[i].coordinates.size())


In [None]:
for i in range(num_test_mesh):
    if torch.equal(test_mesh_data_list[i].coordinates, best_mesh_data_list[i].coordinates):
        print("Not updated!!")
    
    else:
         print("Updated")

In [None]:
for i in range(num_test_mesh):
    plot_mesh(test_mesh_data_list[i], f"original_{i}")

In [None]:
a=torch.tensor([[[ 0.6004, -0.1381,  0.5997,  0.2204, -0.5977,  0.4884, -0.1954, -0.2755],
        [-0.5785, -0.7031,  0.2004, -0.1510,  0.2774, -0.5800,  0.5260, -0.5150],
        [-0.1189,  0.1473,  0.3666,  0.5725,  0.6517, -0.5593,  0.1796, -0.3000],
        [-0.0748, -0.5303,  0.6381, -0.5197,  0.3724,  0.2501,  0.2280, -0.3767]],
        [[0.6353,  0.1526,  0.1002, -0.6209,  0.2963, -0.1024, -0.3281,  0.6075],
        [0.1588, -0.3860, -0.3635, -0.0385,  0.3957, -0.1853, -0.4013, -0.2480],
        [-0.5284,  0.2546,  0.5519, -0.6643,  0.1631,  0.3607,  0.1233, -0.2522],
        [0.3676,  0.3681,  0.2616, -0.1182, -0.1916,  0.0747, -0.1194, -0.2073]]])

In [None]:
a[:,0]

In [None]:
a.unsqueeze(0)

In [None]:
import torch

# 3点の座標を定義（これらは微分の対象となる変数なのでrequires_grad=Trueを設定）
p1 = torch.tensor([-0.6721, -0.4008], requires_grad=True)
p2 = torch.tensor([-0.0425,  0.0179], requires_grad=True)
p3 = torch.tensor([-0.3361, -0.1763], requires_grad=True)

# 各辺の長さを計算
a = torch.sqrt(torch.sum((p2 - p1) ** 2))
b = torch.sqrt(torch.sum((p3 - p2) ** 2))
c = torch.sqrt(torch.sum((p1 - p3) ** 2))

# ヘロンの公式を用いて面積を計算
s = (a + b + c) / 2
area = torch.sqrt(s * (s - a) * (s - b) * (s - c))

loss = 1 - (4*torch.sqrt(torch.tensor(3))*area)/(a**2 + b**2 + c**2)

# 面積に対して自動微分を行う
loss.backward()

# 各点に対する勾配を表示
print(p1.grad)
print(p2.grad)
print(p3.grad)

In [None]:
torch.exp(torch.tensor(0))

In [None]:
a = [[0.0, 0.0], 
    [-0.3022986352443695, 0.7041022777557373],
    [-0.6732527613639832, 0.49798783659935],
    [-0.6732527613639832, -0.05284641683101654],
    [-0.3600771427154541, -0.19979676604270935],
    [0.1442829966545105, -0.10576343536376953],
    [0.32674723863601685, 0.25243768095970154]]

for i in range(1, len(a)):
    x = [a[0][0], a[i][0]]
    y = [a[0][1], a[i][1]]
    plt.plot(x, y, c="k")
    
for i in range(1, len(a)):
    x = [a[i][0], a[i%6+1][0]]
    y = [a[i][1], a[i%6+1][1]]
    plt.plot(x, y, c="k")
    
plt.scatter([x[0] for x in a], [x[1] for x in a], c="r")

In [None]:
a = [[-0.08619694411754608, 0.08077329397201538], 
    [-0.3022986352443695, 0.7041022777557373],
    [-0.6732527613639832, 0.49798783659935],
    [-0.6732527613639832, -0.05284641683101654],
    [-0.3600771427154541, -0.19979676604270935],
    [0.1442829966545105, -0.10576343536376953],
    [0.32674723863601685, 0.25243768095970154]]

for i in range(1, len(a)):
    x = [a[0][0], a[i][0]]
    y = [a[0][1], a[i][1]]
    plt.plot(x, y, c="k")
    
for i in range(1, len(a)):
    x = [a[i][0], a[i%6+1][0]]
    y = [a[i][1], a[i%6+1][1]]
    plt.plot(x, y, c="k")
    
plt.scatter([x[0] for x in a], [x[1] for x in a], c="r")