In [None]:
import torch
from torch_geometric.nn import GraphSAGE
# from Meshpoolscr import MeshPooling
# def cnntogra(pos,batch,datatomap):
#     imageindex=batch
#     posindexT=torch.round(pos/20*(datatomap.shape[-1]-1)).long()
#     posindex=torch.stack((posindexT[:,1],posindexT[:,0]),dim=-1)
#     # Select the right "image" using imageindex
#     selected_images = datatomap[imageindex]
#     # Now use posindex to index into the last two dimensions
#     result = selected_images[torch.arange(selected_images.shape[0]), :, posindex[:, 0], posindex[:, 1]]
#     return result

def cnntogra(pos, batch, datatomap):
    # this is a function to convert the batch of CNN data to the corresponding graph data based on the node position
    # the graph data is a batched graph therefore the "batch" is needed to distinguish the different sub graphs
    # datatomap is the CNN data with the shape of [batch_size, 1, height, width]
    # pos have the shape of [node_number, 2]
    # batch have the shape of [node_number]
    unique_indices, inverse_indices = torch.unique(batch, return_inverse=True)
    posindexT = torch.round(pos / 20 * (datatomap.shape[-1] - 1)).long()
    posindex = torch.stack((posindexT[:, 1], posindexT[:, 0]), dim=-1)
    selected_images = datatomap[unique_indices]
    result = selected_images[inverse_indices, :, posindex[:, 0], posindex[:, 1]]
    return result


def gratocnn(resultfromgnn,pos,batch,oldcnn):
    imageindex=batch
    # Example tensors for demonstration
    batchedimage =  torch.ones(oldcnn.shape[0], resultfromgnn.shape[-1], oldcnn.shape[2], oldcnn.shape[3]).to(oldcnn.device)*-10 # This can be of different sizes now
    pixelvalue = resultfromgnn  # Random values
    imageindex = imageindex.unsqueeze(dim=-1)  # Adjusted to batch size
    posindexT=torch.round(pos/20*(oldcnn.shape[-1]-1)).long()
    posindex=torch.stack((posindexT[:,1],posindexT[:,0]),dim=-1)
    # Flatten the batchedimage tensor
    batchedimage_flat = batchedimage.view(batchedimage.size(0), batchedimage.size(1), -1)
    # Get image dimensions
    img_height, img_width = batchedimage.size(2), batchedimage.size(3)
    # Calculate the linear indices
    linear_indices = posindex[:, 0] * img_width + posindex[:, 1]
    # Use advanced indexin to assign values
    batchedimage_flat[imageindex[:, 0], :, linear_indices] = pixelvalue
    # Reshape back to original shape
    batchedimage = batchedimage_flat.view(batchedimage.size(0), batchedimage.size(1), img_height, img_width)
    return batchedimage

def cnntogra_pad(pos_nopad,batch,datatomap):
    pos=pos_nopad+torch.tensor([20/8,20/8])
    imageindex=batch
    posindexT=torch.round(pos/(20*5/4)*(datatomap.shape[-1]-1)).long()
    posindex=torch.stack((posindexT[:,1],posindexT[:,0]),dim=-1)
    # Select the right "image" using imageindex
    selected_images = datatomap[imageindex]
    # Now use posindex to index into the last two dimensions
    result = selected_images[torch.arange(selected_images.shape[0]), :, posindex[:, 0], posindex[:, 1]]
    return result
def gratocnn_pad(resultfromgnn,pos_nopad,batch,oldcnn):
    pos=pos_nopad+torch.tensor([20/8,20/8])
    imageindex=batch
    # Example tensors for demonstration
    batchedimage =  torch.ones(oldcnn.shape[0], resultfromgnn.shape[-1], oldcnn.shape[2], oldcnn.shape[3])*-10 # This can be of different sizes now
    pixelvalue = resultfromgnn  # Random values
    imageindex = imageindex.unsqueeze(dim=-1)  # Adjusted to batch size
    posindexT=torch.round(pos/(20*5/4)*(oldcnn.shape[-1]-1)).long()
    posindex=torch.stack((posindexT[:,1],posindexT[:,0]),dim=-1)
    # Flatten the batchedimage tensor
    batchedimage_flat = batchedimage.view(batchedimage.size(0), batchedimage.size(1), -1)
    # Get image dimensions
    img_height, img_width = batchedimage.size(2), batchedimage.size(3)
    # Calculate the linear indices
    linear_indices = posindex[:, 0] * img_width + posindex[:, 1]
    # Use advanced indexin to assign values
    batchedimage_flat[imageindex[:, 0], :, linear_indices] = pixelvalue
    # Reshape back to original shape
    batchedimage = batchedimage_flat.view(batchedimage.size(0), batchedimage.size(1), img_height, img_width)
    return batchedimage

In [None]:
import matplotlib.pyplot as plt
import numpy as np
def plotedges(pos,edge_index,figindex):
    fig=plt.figure()
    posss=np.array(pos.cpu().detach())
    print(pos.shape)
    print(edge_index.shape)
    d = dict(enumerate(posss, 0))
    edgess=np.array(edge_index.cpu().detach()).T
    for i in edgess:
        plt.plot([d[i[0]][0],d[i[1]][0]],[d[i[0]][1],d[i[1]][1]],linewidth=0.1,c='b',alpha=0.5)
        # print(d[i[0]],d[i[1]])
    plt.axis('equal')
    plt.savefig(str(figindex)+'boundpoo.svg')
    plt.close()


In [None]:

import torch
import torch.nn.functional as F
from torch_scatter import scatter_add

class DenseGraphSAGELayerV3(torch.nn.Module):
    def __init__(self, in_channels, out_channels, aggregation='sum', bias=True):
        super(DenseGraphSAGELayerV3, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.aggregation = aggregation  # 'sum' 或 'mean'

        self.weight_self = torch.nn.Linear(in_channels, out_channels, bias=False)
        self.weight_neighbor = torch.nn.Linear(in_channels, out_channels, bias=False)

        if bias:
            self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.weight_self.weight)
        torch.nn.init.xavier_uniform_(self.weight_neighbor.weight)
        if self.bias is not None:
            torch.nn.init.zeros_(self.bias)

    def forward(self, x, edge_index):
        """
        x: (N, in_channels)  # 节点特征矩阵
        edge_index: (2, E)  # 边索引，表示图的连接关系
        """
        src, dst = edge_index  # 提取边的起点和终点

        out_self = self.weight_self(x)  # 计算每个节点的自特征变换 (N, out_channels)

        # 计算邻居的特征差值
        diff = x[src] - x[dst]  # (E, in_channels)

        # 对差值应用变换
        diff_transformed = self.weight_neighbor(diff)  # (E, out_channels)

        # 聚合邻居消息
        out_neighbors = scatter_add(diff_transformed, dst, dim=0, dim_size=x.size(0))  # (N, out_channels)

        if self.aggregation == 'mean':
            deg = scatter_add(torch.ones_like(dst, dtype=torch.float), dst, dim=0, dim_size=x.size(0)).unsqueeze(-1)
            deg[deg == 0] = 1  # 避免除零
            out_neighbors /= deg

        out = out_self + out_neighbors  # (N, out_channels)

        if self.bias is not None:
            out = out + self.bias

        return out



class GraphSAGEV3(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, num_layers, out_channels, dropout=0.0, aggregation='sum'):
        super(GraphSAGEV3, self).__init__()

        self.layers = torch.nn.ModuleList()
        self.num_layers = num_layers
        self.dropout = torch.nn.Dropout(p=dropout)

        # Input layer
        self.layers.append(DenseGraphSAGELayerV3(in_channels, hidden_channels, aggregation))

        # Hidden layers
        for _ in range(num_layers - 2):
            self.layers.append(DenseGraphSAGELayerV3(hidden_channels, hidden_channels, aggregation))

        # Output layer
        self.layers.append(DenseGraphSAGELayerV3(hidden_channels, out_channels, aggregation))

    def forward(self, x, adj, mask=None):
        for i in range(self.num_layers):
            x = self.layers[i](x, adj)
            if i < self.num_layers - 1:
                x = self.dropout(torch.nn.functional.relu(x))

        return x

    def __repr__(self):
        return '{}(in_channels={}, hidden_channels={}, num_layers={}, out_channels={}, dropout={})'.format(
            self.__class__.__name__, self.layers[0].in_channels, self.layers[1].in_channels, self.num_layers, self.layers[-1].out_channels, self.dropout.p)



In [None]:
from Stressnet import Modfiedunet4shrink,Modfiedunet_CMAME,Modfiedunet3shrink_CMAME, Modfiedunet3shrink,StressNetori,StressNetgit,StressNetgit4shrink,Modfiedunet
from torch_geometric.nn import GraphSAGE
import torch.nn as nn


class CNNadgnn(nn.Module): ###https://doi.org/10.1016/j.mechmat.2021.104191
    def __init__(self,chlist=[8,32,64,128,128,128,128,128,64,32,9]):
        super().__init__()
        self.c1=Modfiedunet4shrink()
        self.g1=GraphSAGE(in_channels=int(8),hidden_channels=int(128), 
                                num_layers=4, out_channels=1)
    def forward(self, igra):
        x=igra.xdata128.float()
        gx,edge_index,posnew,pollinfor,batchinfo=igra.x,igra.edge_index,igra.pos.float(),igra.pollinfor,igra.batch
        gx=torch.cat((igra.x[:,0:2],igra.x[:,-3:],posnew),dim=-1).float()
        x1=self.c1(x)
        cnng=cnntogra(posnew,batchinfo,x1)
        # print(gx.shape,cnng.shape)
        gout=self.g1(torch.cat((gx,cnng),dim=-1),edge_index)
        return x1,gout
    
class CNNadgnn_cnn1ch(nn.Module): ###https://doi.org/10.1016/j.mechmat.2021.104191
    def __init__(self,chlist=[8,32,64,128,128,128,128,128,64,32,9]):
        super().__init__()
        self.c1=Modfiedunet4shrink()
        self.g1=GraphSAGE(in_channels=int(8),hidden_channels=int(128), 
                                num_layers=4, out_channels=1)
    def forward(self, igra):
        x=igra.xdata128.float()
        first_channel = x[:, 0:1, :, :]
        x= first_channel.repeat(1, 6, 1, 1)
        gx,edge_index,posnew,pollinfor,batchinfo=igra.x,igra.edge_index,igra.pos.float(),igra.pollinfor,igra.batch
        gx=torch.cat((igra.x[:,0:2],igra.x[:,-3:],posnew),dim=-1).float()
        print(x.shape)
        x1=self.c1(x)
        print(x1.shape)
        # print the number of parameters in c1
        print(sum(p.numel() for p in self.c1.parameters()))
        cnng=cnntogra(posnew,batchinfo,x1)
        # print(gx.shape,cnng.shape)
        gout=self.g1(torch.cat((gx,cnng),dim=-1),edge_index)
        return x1,gout

import torch
import torch.nn as nn
import torchvision.models.segmentation as segmentation
import torchvision.transforms.functional as TF

class CNNadgnn_cnn1ch(nn.Module):
    def __init__(self, chlist=[8, 32, 64, 128, 128, 128, 128, 128, 64, 32, 9]):
        super().__init__()
        self.segmentation_model = segmentation.lraspp_mobilenet_v3_large(num_classes=1)
        self.segmentation_model.train()  # Make sure to set to training mode if required
        self.g1 = GraphSAGE(in_channels=int(8), hidden_channels=int(128), num_layers=4, out_channels=1)

    def forward(self, igra):
        x = igra.xdata128.float()
        original_size = x.shape[-2:]  # Save the original size for resizing later

        # Process the first channel and resize to 224x224
        first_channel = x[:, 0:1, :, :]
        resized_input = TF.resize(first_channel.repeat(1, 3, 1, 1), size=(224, 224))

        # Feed into the segmentation model
        segmentation_output = self.segmentation_model(resized_input)['out']
        
        ### Resize output back to original size
        resized_output = TF.resize(segmentation_output, size=original_size)
        ### print the range of the output
        # print(resized_output.min(), resized_output.max())
        # print(resized_output.shape)
        # print(sum(p.numel() for p in self.segmentation_model.parameters()))

        gx, edge_index, posnew, pollinfor, batchinfo = igra.x, igra.edge_index, igra.pos.float(), igra.pollinfor, igra.batch
        gx = torch.cat((igra.x[:, 0:2], igra.x[:, -3:], posnew), dim=-1).float()

        cnng = cnntogra(posnew, batchinfo, resized_output)  # Assuming cnntogra can handle the resized output

        # Combine CNN features with graph features
        gout = self.g1(torch.cat((gx, cnng), dim=-1), edge_index)

        return resized_output, gout



class CNNadgnn_cnn1ch_cnnonly(nn.Module):
    def __init__(self, chlist=[8, 32, 64, 128, 128, 128, 128, 128, 64, 32, 9]):
        super().__init__()
        self.segmentation_model = segmentation.lraspp_mobilenet_v3_large(num_classes=1)
        self.segmentation_model.train()  # Make sure to set to training mode if required
        self.linear = nn.Linear(8, 1)
    def forward(self, igra):
        x = igra.xdata128.float()
        original_size = x.shape[-2:]  # Save the original size for resizing later

        # Process the first channel and resize to 224x224
        first_channel = x[:, 0:1, :, :]
        resized_input = TF.resize(first_channel.repeat(1, 3, 1, 1), size=(224, 224))

        # Feed into the segmentation model
        segmentation_output = self.segmentation_model(resized_input)['out']
        
        ### Resize output back to original size
        resized_output = TF.resize(segmentation_output, size=original_size)
        ### print the range of the output
        # print(resized_output.min(), resized_output.max())
        # print(resized_output.shape)
        # print(sum(p.numel() for p in self.segmentation_model.parameters()))

        gx, edge_index, posnew, pollinfor, batchinfo = igra.x, igra.edge_index, igra.pos.float(), igra.pollinfor, igra.batch
        gx = torch.cat((igra.x[:, 0:2], igra.x[:, -3:], posnew), dim=-1).float()

        cnng = cnntogra(posnew, batchinfo, resized_output)  # Assuming cnntogra can handle the resized output

        # Combine CNN features with graph features
        # gout = self.g1(torch.cat((gx, cnng), dim=-1), edge_index)
        gout = self.linear(torch.cat((gx, cnng), dim=-1))

        return resized_output, gout




device='cuda'
class CNNadgnn_cnn1chv2(nn.Module): ###https://doi.org/10.1016/j.mechmat.2021.104191
    def __init__(self,chlist=[8,32,64,128,128,128,128,128,64,32,9]):
        super().__init__()
        self.c1=Modfiedunet4shrink(chin=1)
        self.g1=GraphSAGE(in_channels=int(8),hidden_channels=int(128), 
                                num_layers=4, out_channels=1)
    def forward(self, igra):
        x=igra.xdata128.float()
        first_channel = x[:, 0:1, :, :]
        x= first_channel
        gx,edge_index,posnew,pollinfor,batchinfo=igra.x,igra.edge_index,igra.pos.float(),igra.pollinfor,igra.batch
        gx=torch.cat((igra.x[:,0:2],igra.x[:,-3:],posnew),dim=-1).float()
        x1=self.c1(x)
        cnng=cnntogra(posnew,batchinfo,x1)
        gout=self.g1(torch.cat((gx,cnng),dim=-1),edge_index)
        return x1,gout
 
class CNNadgnnIdent(nn.Module): ###https://doi.org/10.1016/j.mechmat.2021.104191
    def __init__(self,chlist=[8,32,64,128,128,128,128,128,64,32,9]):
        super().__init__()
        self.c1=Modfiedunet4shrink(chin=1)
        # self.g1=GraphSAGE(in_channels=int(8),hidden_channels=int(128), 
        #                         num_layers=4, out_channels=1)
        #create a mlp (8,128,128,128,1)
        self.mlp=nn.Sequential(
            nn.Linear(8,128),
            nn.LeakyReLU(),
            nn.Linear(128,128),
            nn.LeakyReLU(),
            nn.Linear(128,128),
            nn.LeakyReLU(),
            nn.Linear(128,128),
            nn.LeakyReLU(),
            nn.Linear(128,1)
        )
    def forward(self, igra):
        x=igra.xdata128.float()
        first_channel = x[:, 0:1, :, :]
        x= first_channel

        gx,edge_index,posnew,pollinfor,batchinfo=igra.x,igra.edge_index,igra.pos.float(),igra.pollinfor,igra.batch
        gx=torch.cat((igra.x[:,0:2],igra.x[:,-3:],posnew),dim=-1).float()
        x1=self.c1(x)
        cnng=cnntogra(posnew,batchinfo,x1)
        # print(gx.shape,cnng.shape)
        # gout=self.g1(torch.cat((gx,cnng),dim=-1),edge_index)
        gout = self.mlp(torch.cat((gx, cnng), dim=-1))
        return x1,gout

class CNNadgnnIdent_GNN(nn.Module): ###https://doi.org/10.1016/j.mechmat.2021.104191
    def __init__(self,chlist=[8,32,64,128,128,128,128,128,64,32,9]):
        super().__init__()
        self.c1=Modfiedunet4shrink(chin=1)
        self.g1=GraphSAGE(in_channels=int(8),hidden_channels=int(128), 
                                num_layers=4, out_channels=1)
        


    def forward(self, igra):
        x=igra.xdata128.float()
        first_channel = x[:, 0:1, :, :]
        x= first_channel

        gx,edge_index,posnew,pollinfor,batchinfo=igra.x,igra.edge_index,igra.pos.float(),igra.pollinfor,igra.batch
        gx=torch.cat((igra.x[:,0:2],igra.x[:,-3:],posnew),dim=-1).float()
        x1=self.c1(x)
        cnng=cnntogra(posnew,batchinfo,x1)
        # print(gx.shape,cnng.shape)
        gout=self.g1(torch.cat((gx,cnng),dim=-1),edge_index)
        return x1,gout


class CNNadgnnIdent_SA_GNN(nn.Module): ###https://doi.org/10.1016/j.mechmat.2021.104191
    def __init__(self,chlist=[8,32,64,128,128,128,128,128,64,32,9]):
        super().__init__()
        self.c1=Modfiedunet4shrink(chin=1)
        # self.g1=GraphSAGE(in_channels=int(8),hidden_channels=int(128), 
        #                         num_layers=4, out_channels=1)
        
        self.g1=GraphSAGEV3(in_channels=int(8),hidden_channels=int(128), 
                                num_layers=4, out_channels=1)

    def forward(self, igra):
        x=igra.xdata128.float()
        first_channel = x[:, 0:1, :, :]
        x= first_channel

        gx,edge_index,posnew,pollinfor,batchinfo=igra.x,igra.edge_index,igra.pos.float(),igra.pollinfor,igra.batch
        gx=torch.cat((igra.x[:,0:2],igra.x[:,-3:],posnew),dim=-1).float()
        x1=self.c1(x)
        cnng=cnntogra(posnew,batchinfo,x1)
        # print(gx.shape,cnng.shape)
        gout=self.g1(torch.cat((gx,cnng),dim=-1),edge_index)
        return x1,gout


class CNNadgnnIdent_small(nn.Module): ###https://doi.org/10.1016/j.mechmat.2021.104191
    def __init__(self,chlist=[8,32,64,128,128,128,128,128,64,32,9]):
        super().__init__()
        self.c1=Modfiedunet4shrink(chin=1,k=4)
        # self.c1=Modfiedunet_CMAME(chin=1)
        # self.c1=Modfiedunet3shrink_CMAME(chin=1,k=8)
        # self.g1=GraphSAGE(in_channels=int(8),hidden_channels=int(128), 
        #                         num_layers=4, out_channels=1)
        #create a mlp (8,128,128,128,1)
        hidden=32
        self.mlp=nn.Sequential(
            nn.Linear(8,hidden),
            nn.LeakyReLU(),
            nn.Linear(hidden,hidden),
            nn.LeakyReLU(),
            nn.Linear(hidden,hidden),
            nn.LeakyReLU(),
            nn.Linear(hidden,hidden),
            nn.LeakyReLU(),
            nn.Linear(hidden,1)
        )
    def forward(self, igra):
        x=igra.xdata128.float()
        first_channel = x[:, 0:1, :, :]
        x= first_channel

        gx,edge_index,posnew,pollinfor,batchinfo=igra.x,igra.edge_index,igra.pos.float(),igra.pollinfor,igra.batch
        gx=torch.cat((igra.x[:,0:2],igra.x[:,-3:],posnew),dim=-1).float()
        x1=self.c1(x)
        cnng=cnntogra(posnew,batchinfo,x1)
        # print(gx.shape,cnng.shape)
        # gout=self.g1(torch.cat((gx,cnng),dim=-1),edge_index)
        gout = self.mlp(torch.cat((gx, cnng), dim=-1))
        return x1,gout

class CNNadgnnIdent_GNN_small(nn.Module): ###https://doi.org/10.1016/j.mechmat.2021.104191
    def __init__(self,chlist=[8,32,64,128,128,128,128,128,64,32,9]):
        super().__init__()
        self.c1=Modfiedunet4shrink(chin=1,k=4)
        # self.c1=Modfiedunet_CMAME(chin=1)
        # self.c1=Modfiedunet3shrink_CMAME(chin=1,k=8)
        self.g1=GraphSAGE(in_channels=int(8),hidden_channels=int(32), 
                                num_layers=4, out_channels=1)
    def forward(self, igra):
        x=igra.xdata128.float()
        first_channel = x[:, 0:1, :, :]
        x= first_channel

        gx,edge_index,posnew,pollinfor,batchinfo=igra.x,igra.edge_index,igra.pos.float(),igra.pollinfor,igra.batch
        gx=torch.cat((igra.x[:,0:2],igra.x[:,-3:],posnew),dim=-1).float()
        x1=self.c1(x)
        cnng=cnntogra(posnew,batchinfo,x1)
        # print(gx.shape,cnng.shape)
        gout=self.g1(torch.cat((gx,cnng),dim=-1),edge_index)
        return x1,gout




class FlexibleCNN(nn.Module):
    def __init__(self, num_layers=1, in_channels=1, out_channels=1, mid_channels=16, activation_fn=nn.ReLU):
        super(FlexibleCNN, self).__init__()
        self.layers = nn.ModuleList()
        self.activation = activation_fn()
        if num_layers > 1:
            self.layers.append(nn.Sequential(
                nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, stride=1),
                self.activation
            ))
        else:
            self.layers.append(nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1),
                self.activation
            ))
        for _ in range(1, num_layers-1):
            self.layers.append(nn.Sequential(
                nn.Conv2d(mid_channels, mid_channels, kernel_size=3, padding=1, stride=1),
                self.activation
            ))
        if num_layers > 1:
            self.layers.append(nn.Sequential(
                nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, stride=1),
                nn.Identity()
            ))

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x



class CNNadgnnIdent_extrCNN(nn.Module): ###https://doi.org/10.1016/j.mechmat.2021.104191
    def __init__(self,chlist=[8,32,64,128,128,128,128,128,64,32,9]):
        super().__init__()
        self.c1=Modfiedunet4shrink()
        self.cextra=FlexibleCNN(num_layers=4, mid_channels=64)
        # self.g1=GraphSAGE(in_channels=int(8),hidden_channels=int(128), 
        #                         num_layers=4, out_channels=1)
    def forward(self, igra):
        x=igra.xdata128.float()
        gx,edge_index,posnew,pollinfor,batchinfo=igra.x,igra.edge_index,igra.pos.float(),igra.pollinfor,igra.batch
        gx=torch.cat((igra.x[:,0:2],igra.x[:,-3:],posnew),dim=-1).float()
        x1=self.c1(x)
        x1=self.cextra(x1)
        cnng=cnntogra(posnew,batchinfo,x1)
        # print(gx.shape,cnng.shape)
        # gout=self.g1(torch.cat((gx,cnng),dim=-1),edge_index)
        return x1,cnng
    


def getcnnloss(Xtrain,predicted,Ytrain,highvalue):
    mask=Xtrain[:,0,:,:]==1
    predictedT=torch.transpose(predicted,1,3)
    predictedT=torch.transpose(predictedT,2,1)
    YtrainT=torch.transpose(Ytrain,1,3)
    YtrainT=torch.transpose(YtrainT,2,1)
    midpre=predictedT[mask]
    YtrainTmask=YtrainT[mask]
    errors = torch.abs(midpre-YtrainTmask)
    lossdif = torch.mean(errors)

    #####get error for highstress area
    highmask=YtrainTmask>=highvalue
    YtrainThighmask=YtrainTmask[highmask]
    midprehighmask=midpre[highmask]
    losshighstress=torch.mean(torch.abs(YtrainThighmask-midprehighmask))


    return losshighstress+lossdif, losshighstress

import torch

def getcnnloss(Xtrain, predicted, Ytrain,highvalue, percentile=80):
    """
    计算损失函数，包括整体损失和基于百分位数确定的高值区域损失。

    参数:
    Xtrain -- 输入数据其中第一个通道用于生成掩码标记Ytrain中的非负值
    predicted -- 预测值
    Ytrain -- 实际值
    percentile -- 用于确定高值区域的百分位数 (默认是80即高于80%的值)

    返回:
    total_loss -- 整体损失和高值区域损失的总和
    high_stress_loss -- 高值区域的损失
    """
    # 确保 Xtrain 第一个通道中值为1的位置, 这些是非负数的位置
    non_negative_mask = Xtrain[:, 0, :, :] == 1
    
    # 应用掩码，调整掩码的维度以匹配Ytrain和predicted的维度
    Ytrain_non_negative = Ytrain[:, 0, :, :][non_negative_mask]
    predicted_non_negative = predicted[:, 0, :, :][non_negative_mask]
    
    # 计算整体平均绝对误差
    overall_errors = torch.abs(predicted_non_negative - Ytrain_non_negative)
    overall_mean_error = torch.mean(overall_errors)
    
    # 确定高值区域的阈值
    k = int(len(Ytrain_non_negative) * (1 - percentile / 100))
    threshold = Ytrain_non_negative.kthvalue(k).values if k > 0 else Ytrain_non_negative.min()
    
    # 高值掩码
    high_value_mask = Ytrain_non_negative >= threshold
    
    # 高值区域的数据
    Ytrain_high = Ytrain_non_negative[high_value_mask]
    predicted_high = predicted_non_negative[high_value_mask]
    
    # 高值区域的平均绝对误差
    high_stress_errors = torch.abs(predicted_high - Ytrain_high)
    high_stress_mean_error = torch.mean(high_stress_errors)
    
    # 计算总损失
    total_loss = overall_mean_error+high_stress_mean_error 
    
    return overall_mean_error, high_stress_mean_error



def getgnnhighloss_value(igra_svon, gx5final, igra_batch, threshold_value):
    # igra_svon shape [node_number, 1] # the stress value of the nodes
    # gx5final shape [node_number, 1] # the predicted stress value of the nodes
    # igra_batch shape [node_number] # the batch of the nodes to each design

    # Flatten to 1D array for subsequent operations
    igra_svon = igra_svon.view(-1)
    gx5final = gx5final.view(-1)

    # Select nodes with stress value greater than the threshold
    high_stress_mask = (igra_svon >= threshold_value)

    # Extract the relevant stress values
    real_high_stress = igra_svon[high_stress_mask]
    pred_high_stress = gx5final[high_stress_mask]

    # Ensure there are nodes exceeding the threshold
    if real_high_stress.numel() > 0:
        # Calculate MAPE, adding a small epsilon to avoid division by zero
        epsilon = 1e-8
        percentage_errors = torch.abs((real_high_stress - pred_high_stress) / (real_high_stress + epsilon))
        avg_high_mape = torch.mean(percentage_errors) * 100  # Convert to percentage
    else:
        avg_high_mape = torch.tensor(0.0)

    return avg_high_mape


In [None]:

def getgnnhighloss(igra_svon, gx5final, igra_batch, percentile=80):

    # igra_svon shape [node_number, 1] # the stress value of the nodes
    # gx5final shape [node_number, 1] # the predicted stress value of the nodes
    # igra_batch shape [node_number] # the batch of the nodes to each design

    # 展平为一维数组，便于后续操作
    igra_svon = igra_svon.view(-1)
    gx5final = gx5final.view(-1)

    # 按照 batch 分组计算每个设计的百分位数
    unique_batches = torch.unique(igra_batch)
    percentiles = torch.zeros_like(unique_batches, dtype=torch.float32)

    # 计算每个设计的 80% 百分位数
    for i, batch in enumerate(unique_batches):
        mask = (igra_batch == batch)
        percentiles[i] = torch.quantile(igra_svon[mask], percentile / 100.0)

    # 将 percentile 映射回节点对应的批次
    batch_percentiles = percentiles[igra_batch]

    # 筛选出应力值大于百分位数阈值的节点
    high_stress_mask = (igra_svon >= batch_percentiles)

    # 计算筛选出的节点的 MAE
    real_high_stress = igra_svon[high_stress_mask]
    pred_high_stress = gx5final[high_stress_mask]

    # 确保存在超过阈值的节点，计算 MAE
    if real_high_stress.numel() > 0:
        avg_high_mae = torch.mean(torch.abs(real_high_stress - pred_high_stress))
    else:
        avg_high_mae = torch.tensor(0.0)

    return avg_high_mae




In [11]:
from GraphUnet_define import meshpoolresSAGEmodelUnet241012

model=meshpoolresSAGEmodelUnet241012(in_channels=8,
        hidden_channels=11,
        out_channels=1,
        depth=4,
        messpnum=4)

print(sum(p.numel() for p in model.parameters()))

class CNNadgnnIdent_GNNmeshpo(nn.Module): ###https://doi.org/10.1016/j.mechmat.2021.104191
    def __init__(self,chlist=[8,32,64,128,128,128,128,128,64,32,9]):
        super().__init__()
        # self.c1=Modfiedunet4shrink(chin=1,k=4)
        self.c1=Modfiedunet4shrink(chin=1)
        # self.c1=Modfiedunet_CMAME(chin=1)
        # self.c1=Modfiedunet3shrink_CMAME(chin=1,k=8)
        self.g1=meshpoolresSAGEmodelUnet241012(in_channels=6,
        hidden_channels=16,
        out_channels=1,
        depth=4,
        messpnum=2)

    def forward(self, igra):
        x=igra.xdata128.float()
        first_channel = x[:, 0:1, :, :]
        x= first_channel

        gx,edge_index,posnew,pollinfor,batchinfo=igra.x,igra.edge_index,igra.pos.float(),igra.pollinfor,igra.batch
        gx=torch.cat((igra.x,posnew),dim=-1).float()
        x1=self.c1(x)
        cnng=cnntogra(posnew,batchinfo,x1)
        # print(gx.shape,cnng.shape)
        gout=self.g1(torch.cat((cnng,gx),dim=-1),edge_index,posnew,pollinfor,batchinfo)
 
        return x1,gout
    

class CNNadgnnIdent_GNNmeshpo_7M(nn.Module): ###https://doi.org/10.1016/j.mechmat.2021.104191
    def __init__(self,chlist=[8,32,64,128,128,128,128,128,64,32,9]):
        super().__init__()
        # self.c1=Modfiedunet4shrink(chin=1,k=4)
        self.c1=Modfiedunet4shrink(chin=1)
        # self.c1=Modfiedunet_CMAME(chin=1)
        # self.c1=Modfiedunet3shrink_CMAME(chin=1,k=8)
        self.g1=meshpoolresSAGEmodelUnet241012(in_channels=6,
        hidden_channels=4,
        out_channels=1,
        depth=4,
        messpnum=2)

    def forward(self, igra):
        x=igra.xdata128.float()
        first_channel = x[:, 0:1, :, :]
        x= first_channel

        gx,edge_index,posnew,pollinfor,batchinfo=igra.x,igra.edge_index,igra.pos.float(),igra.pollinfor,igra.batch
        gx=torch.cat((igra.x,posnew),dim=-1).float()
        x1=self.c1(x)
        cnng=cnntogra(posnew,batchinfo,x1)
        # print(gx.shape,cnng.shape)
        gout=self.g1(torch.cat((cnng,gx),dim=-1),edge_index,posnew,pollinfor,batchinfo)
 
        return x1,gout
    



class CNNadgnnIdent_GNNmeshpo_small(nn.Module): ###https://doi.org/10.1016/j.mechmat.2021.104191
    def __init__(self,chlist=[8,32,64,128,128,128,128,128,64,32,9]):
        super().__init__()
        self.c1=Modfiedunet4shrink(chin=1,k=4)
        # self.c1=Modfiedunet_CMAME(chin=1)
        # self.c1=Modfiedunet3shrink_CMAME(chin=1,k=8)
        self.g1=meshpoolresSAGEmodelUnet241012(in_channels=6,
        hidden_channels=16,
        out_channels=1,
        depth=4,
        messpnum=2)

    def forward(self, igra):
        x=igra.xdata128.float()
        first_channel = x[:, 0:1, :, :]
        x= first_channel

        gx,edge_index,posnew,pollinfor,batchinfo=igra.x,igra.edge_index,igra.pos.float(),igra.pollinfor,igra.batch
        gx=torch.cat((igra.x,posnew),dim=-1).float()
        x1=self.c1(x)
        cnng=cnntogra(posnew,batchinfo,x1)
        # print(gx.shape,cnng.shape)
        gout=self.g1(torch.cat((cnng,gx),dim=-1),edge_index,posnew,pollinfor,batchinfo)
 
        return x1,gout
    


class CNNadgnnIdent_GNNmeshpo_GNNonly_small(nn.Module): ###https://doi.org/10.1016/j.mechmat.2021.104191
    def __init__(self,chlist=[8,32,64,128,128,128,128,128,64,32,9]):
        super().__init__()
        # self.c1=Modfiedunet4shrink(chin=1,k=4)
        # self.c1=Modfiedunet_CMAME(chin=1)
        # self.c1=Modfiedunet3shrink_CMAME(chin=1,k=8)
        self.g1=meshpoolresSAGEmodelUnet241012(in_channels=5,
        hidden_channels=16,
        out_channels=1,
        depth=4,
        messpnum=2)

    def forward(self, igra):
        x=igra.xdata128.float()
        first_channel = x[:, 0:1, :, :]
        x= first_channel

        gx,edge_index,posnew,pollinfor,batchinfo=igra.x,igra.edge_index,igra.pos.float(),igra.pollinfor,igra.batch
        gx=torch.cat((igra.x,posnew),dim=-1).float()
        x1=x
        # cnng=cnntogra(posnew,batchinfo,x1)
        # print(gx.shape,cnng.shape)
        gout=self.g1(gx,edge_index,posnew,pollinfor,batchinfo)
 
        return x1,gout


class CNNadgnnIdent_GNNmeshpo_GNNonly_checksize(nn.Module): ###https://doi.org/10.1016/j.mechmat.2021.104191
    def __init__(self,hidden_channels,depth,messpnum,chlist=[8,32,64,128,128,128,128,128,64,32,9]):
        super().__init__()
        # self.c1=Modfiedunet4shrink(chin=1,k=4)
        # self.c1=Modfiedunet_CMAME(chin=1)
        # self.c1=Modfiedunet3shrink_CMAME(chin=1,k=8)
        self.g1=meshpoolresSAGEmodelUnet241012(in_channels=5,
        hidden_channels=hidden_channels,
        out_channels=1,
        depth=depth,
        messpnum=messpnum)

    def forward(self, igra):
        x=igra.xdata128.float()
        first_channel = x[:, 0:1, :, :]
        x= first_channel

        gx,edge_index,posnew,pollinfor,batchinfo=igra.x,igra.edge_index,igra.pos.float(),igra.pollinfor,igra.batch
        gx=torch.cat((igra.x,posnew),dim=-1).float()
        x1=x
        # cnng=cnntogra(posnew,batchinfo,x1)
        # print(gx.shape,cnng.shape)
        gout=self.g1(gx,edge_index,posnew,pollinfor,batchinfo)
 
        return x1,gout



class CNNadgnnIdent_GNNmeshpo_GNNonly_7M(nn.Module): ###https://doi.org/10.1016/j.mechmat.2021.104191
    def __init__(self,chlist=[8,32,64,128,128,128,128,128,64,32,9]):
        super().__init__()
        # self.c1=Modfiedunet4shrink(chin=1,k=4)
        # self.c1=Modfiedunet_CMAME(chin=1)
        # self.c1=Modfiedunet3shrink_CMAME(chin=1,k=8)
        self.g1=meshpoolresSAGEmodelUnet241012(in_channels=5,
        hidden_channels=11,
        out_channels=1,
        depth=4,
        messpnum=4)

    def forward(self, igra):
        x=igra.xdata128.float()
        first_channel = x[:, 0:1, :, :]
        x= first_channel

        gx,edge_index,posnew,pollinfor,batchinfo=igra.x,igra.edge_index,igra.pos.float(),igra.pollinfor,igra.batch
        gx=torch.cat((igra.x,posnew),dim=-1).float()
        x1=x
        # cnng=cnntogra(posnew,batchinfo,x1)
        # print(gx.shape,cnng.shape)
        gout=self.g1(gx,edge_index,posnew,pollinfor,batchinfo)
 
        return x1,gout




7587294


NameError: name 'nn' is not defined

In [4]:
import torch
import torch.nn as nn
from torch_geometric.nn import GraphSAGE

# Define GraphSAGE model
class GraphSAGEModel(nn.Module):
    def __init__(self):
        super(GraphSAGEModel, self).__init__()
        self.g1 = GraphSAGE(in_channels=8, hidden_channels=128, num_layers=4, out_channels=1)
    
    def forward(self, x, edge_index):
        return self.g1(x, edge_index)

# Define MLP model
class MLPModel(nn.Module):
    def __init__(self):
        super(MLPModel, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(8, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 1)
        )
    
    def forward(self, x):
        return self.mlp(x)

# Instantiate the models
graph_sage_model = GraphSAGEModel()
mlp_model = MLPModel()

# Calculate the number of parameters in each model
graph_sage_params = sum(p.numel() for p in graph_sage_model.parameters())
mlp_params = sum(p.numel() for p in mlp_model.parameters())

print(f"Number of parameters in GraphSAGE model: {graph_sage_params}")
print(f"Number of parameters in MLP model: {mlp_params}")


Number of parameters in GraphSAGE model: 68225
Number of parameters in MLP model: 67329


In [None]:

def getgnnhighloss_mape(igra_svon, gx5final, igra_batch, percentile=80):

    # igra_svon shape [node_number, 1] # the stress value of the nodes
    # gx5final shape [node_number, 1] # the predicted stress value of the nodes
    # igra_batch shape [node_number] # the batch of the nodes to each design

    # 展平为一维数组，便于后续操作
    igra_svon = igra_svon.view(-1)
    gx5final = gx5final.view(-1)

    # 按照 batch 分组计算每个设计的百分位数
    unique_batches = torch.unique(igra_batch)
    percentiles = torch.zeros_like(unique_batches, dtype=torch.float32)

    # 计算每个设计的 80% 百分位数
    for i, batch in enumerate(unique_batches):
        mask = (igra_batch == batch)
        percentiles[i] = torch.quantile(igra_svon[mask], percentile / 100.0)

    # 将 percentile 映射回节点对应的批次
    batch_percentiles = percentiles[igra_batch]

    # 筛选出应力值大于百分位数阈值的节点
    high_stress_mask = (igra_svon >= batch_percentiles)

    # 计算筛选出的节点的 MAE
    real_high_stress = igra_svon[high_stress_mask]
    pred_high_stress = gx5final[high_stress_mask]

    # 确保存在超过阈值的节点，计算 MAE
    if real_high_stress.numel() > 0:
        avg_high_mae = torch.mean(torch.abs(real_high_stress - pred_high_stress)/real_high_stress)
    else:
        avg_high_mae = torch.tensor(0.0)

    return avg_high_mae


def getcnnhighloss_mape(Xtrain, predicted, Ytrain,percentile=80):
    """
    计算损失函数，包括整体损失和基于百分位数确定的高值区域损失。

    参数:
    Xtrain -- 输入数据其中第一个通道用于生成掩码标记Ytrain中的非负值
    predicted -- 预测值
    Ytrain -- 实际值
    percentile -- 用于确定高值区域的百分位数 (默认是80即高于80%的值)

    返回:
    total_loss -- 整体损失和高值区域损失的总和
    high_stress_loss -- 高值区域的损失
    """
    # 确保 Xtrain 第一个通道中值为1的位置, 这些是非负数的位置
    non_negative_mask = Ytrain[:, 0, :, :] != -10
    
    # 应用掩码，调整掩码的维度以匹配Ytrain和predicted的维度
    Ytrain_non_negative = Ytrain[:, 0, :, :][non_negative_mask]
    predicted_non_negative = predicted[:, 0, :, :][non_negative_mask]
    
    # 确定高值区域的阈值
    k = int(len(Ytrain_non_negative) * (1 - percentile / 100))
    threshold = Ytrain_non_negative.kthvalue(k).values if k > 0 else Ytrain_non_negative.min()
    
    # 高值掩码
    high_value_mask = Ytrain_non_negative >= threshold
    
    # 高值区域的数据
    Ytrain_high = Ytrain_non_negative[high_value_mask]
    predicted_high = predicted_non_negative[high_value_mask]
    
    # 高值区域的平均绝对误差
    high_stress_errors = torch.abs(predicted_high - Ytrain_high)/Ytrain_high
    high_stress_mean_error = torch.mean(high_stress_errors)
    
  
    
    return high_stress_mean_error




def getcnnhighloss(Xtrain, predicted, Ytrain,percentile=80):
    """
    计算损失函数，包括整体损失和基于百分位数确定的高值区域损失。

    参数:
    Xtrain -- 输入数据其中第一个通道用于生成掩码标记Ytrain中的非负值
    predicted -- 预测值
    Ytrain -- 实际值
    percentile -- 用于确定高值区域的百分位数 (默认是80即高于80%的值)

    返回:
    total_loss -- 整体损失和高值区域损失的总和
    high_stress_loss -- 高值区域的损失
    """
    # 确保 Xtrain 第一个通道中值为1的位置, 这些是非负数的位置
    non_negative_mask = Ytrain[:, 0, :, :] != -10
    
    # 应用掩码，调整掩码的维度以匹配Ytrain和predicted的维度
    Ytrain_non_negative = Ytrain[:, 0, :, :][non_negative_mask]
    predicted_non_negative = predicted[:, 0, :, :][non_negative_mask]
    
    # 确定高值区域的阈值
    k = int(len(Ytrain_non_negative) * (1 - percentile / 100))
    threshold = Ytrain_non_negative.kthvalue(k).values if k > 0 else Ytrain_non_negative.min()
    
    # 高值掩码
    high_value_mask = Ytrain_non_negative >= threshold
    
    # 高值区域的数据
    Ytrain_high = Ytrain_non_negative[high_value_mask]
    predicted_high = predicted_non_negative[high_value_mask]
    
    # 高值区域的平均绝对误差
    high_stress_errors = torch.abs(predicted_high - Ytrain_high)
    high_stress_mean_error = torch.mean(high_stress_errors)
    
  
    
    return high_stress_mean_error



class CNNadgnnIdent_modifiedunet(nn.Module): ###https://doi.org/10.1016/j.mechmat.2021.104191
    def __init__(self,chlist=[8,32,64,128,128,128,128,128,64,32,9]):
        super().__init__()
        self.c1=Modfiedunet4shrink(chin=4)
        # self.g1=GraphSAGE(in_channels=int(8),hidden_channels=int(128), 
        #                         num_layers=4, out_channels=1)

    def forward(self, igra):
        x=igra.xdata128.float()
        # first_channel = x[:, 0:1, :, :]
        # x= first_channel

        # gx,edge_index,posnew,pollinfor,batchinfo=igra.x,igra.edge_index,igra.pos.float(),igra.pollinfor,igra.batch
        # gx=torch.cat((igra.x[:,0:2],igra.x[:,-3:],posnew),dim=-1).float()
        x1=self.c1(x)
        # cnng=cnntogra(posnew,batchinfo,x1)
        # print(gx.shape,cnng.shape)
        # gout=self.g1(torch.cat((gx,cnng),dim=-1),edge_index)
        gout = torch.zeros_like(igra.y)
        return x1,gout