In [13]:
import torch
from torch_geometric.data import Data

def create_grid_graph(rows, cols, num_channels):
    num_nodes = rows * cols
    edge_index = []
    
    for i in range(rows):
        for j in range(cols):
            node = i * cols + j
            if j < cols - 1:
                edge_index.append([node, node + 1])
            if i < rows - 1:
                edge_index.append([node, node + cols])
    
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.rand((edge_index.size(1), num_channels), dtype=torch.float)
    
    x = torch.rand((num_nodes, num_channels), dtype=torch.float)
    grid_data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    
    return grid_data

def create_mesh_graph(num_nodes, num_channels):
    edge_index = []
    
    for i in range(num_nodes):
        for j in range(i + 1, num_nodes):
            edge_index.append([i, j])
    
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.rand((edge_index.size(1), num_channels), dtype=torch.float)
    
    x = torch.rand((num_nodes, num_channels), dtype=torch.float)
    mesh_data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    
    return mesh_data

def create_g2m_and_m2g_connections(grid_data, mesh_data, num_channels):
    grid_nodes = grid_data.num_nodes
    mesh_nodes = mesh_data.num_nodes
    
    g2m_edge_index = torch.stack([torch.randint(0, grid_nodes, (mesh_nodes,)),
                                  torch.arange(0, mesh_nodes)], dim=0)
    
    m2g_edge_index = torch.stack([torch.arange(0, mesh_nodes),
                                  torch.randint(0, grid_nodes, (mesh_nodes,))], dim=0)
    
    g2m_edge_attr = torch.rand((g2m_edge_index.size(1), num_channels), dtype=torch.float)
    m2g_edge_attr = torch.rand((m2g_edge_index.size(1), num_channels), dtype=torch.float)
    
    return g2m_edge_index, g2m_edge_attr, m2g_edge_index, m2g_edge_attr

def create_custom_graph_dataset(grid_rows, grid_cols, mesh_nodes, num_channels):
    grid_data = create_grid_graph(grid_rows, grid_cols, num_channels)
    mesh_data = create_mesh_graph(mesh_nodes, num_channels)
    
    g2m_edge_index, g2m_edge_attr, m2g_edge_index, m2g_edge_attr = create_g2m_and_m2g_connections(grid_data, mesh_data, num_channels)
    
    mesh_edge_index_shifted = mesh_data.edge_index + grid_data.num_nodes
    
    data = Data(
        grid_x=grid_data.x,
        mesh_x=mesh_data.x,
        g2m_edge_index=g2m_edge_index + torch.tensor([[0], [grid_data.num_nodes]], dtype=torch.long),
        g2m_edge_attr=g2m_edge_attr,
        m2g_edge_index=m2g_edge_index + torch.tensor([[grid_data.num_nodes], [0]], dtype=torch.long),
        m2g_edge_attr=m2g_edge_attr,
        m2m_edge_index=mesh_edge_index_shifted,
        m2m_edge_attr=mesh_data.edge_attr
    )
    
    return data

# Usage
dataset = create_custom_graph_dataset(grid_rows=4, grid_cols=4, mesh_nodes=5, num_channels=32)

for key, value in dataset:
    print(key, value.shape)
    print(value)






grid_x torch.Size([16, 32])
tensor([[0.3261, 0.3179, 0.5072, 0.4777, 0.6842, 0.6190, 0.1013, 0.9559, 0.9457,
         0.2734, 0.6823, 0.9296, 0.7511, 0.2911, 0.4444, 0.0738, 0.7200, 0.3564,
         0.3760, 0.9449, 0.5029, 0.5358, 0.9657, 0.4783, 0.8820, 0.3123, 0.3201,
         0.8298, 0.0740, 0.4648, 0.7669, 0.7674],
        [0.2911, 0.0727, 0.3978, 0.7582, 0.4081, 0.7202, 0.9174, 0.3543, 0.9415,
         0.1579, 0.7645, 0.6784, 0.9778, 0.8835, 0.9059, 0.3376, 0.1358, 0.0879,
         0.3105, 0.5423, 0.6341, 0.4012, 0.2445, 0.5959, 0.7440, 0.4685, 0.4367,
         0.9564, 0.9320, 0.9892, 0.9010, 0.3689],
        [0.2029, 0.5262, 0.0605, 0.0665, 0.2659, 0.7377, 0.6724, 0.1649, 0.7727,
         0.0134, 0.4235, 0.7464, 0.1772, 0.2922, 0.8698, 0.5286, 0.1255, 0.1407,
         0.2876, 0.8423, 0.2244, 0.1274, 0.6049, 0.3463, 0.7275, 0.0649, 0.4971,
         0.2676, 0.1221, 0.2268, 0.8006, 0.9192],
        [0.7727, 0.9963, 0.7038, 0.9460, 0.4630, 0.4543, 0.7652, 0.5123, 0.4688,
         0.0

In [14]:
%load_ext autoreload
%autoreload 2
import sys
import torch


sys.path.insert(0,'/home/aw1223/ip/agile')

from sdk.ample import Ample

from torch_geometric.datasets import FakeDataset #TODO remove
from sdk.models.models import MLP_Model,Interaction_Net_Model,GCN_Model

from torch_geometric.data import Data


class Graphcast(torch.nn.Module):
    def __init__(self, in_channels=32, out_channels=32, layer_count=1, hidden_dimension=32, precision = torch.float32):
        super().__init__()
        self.precision = precision
        self.layers = torch.nn.ModuleList()
       
        self.grid_mesh_embedder = MLP_Model(in_channels, hidden_dimension) 
        self.grid_mesh_embedder.name  = 'grid_mesh_embedder'
        self.layers.append(self.grid_mesh_embedder) 

        self.g2m_embedder = MLP_Model(in_channels, hidden_dimension) 
        self.g2m_embedder.name  = 'g2m_embedder'
        self.layers.append(self.g2m_embedder) 

        self.g2m_int_net = Interaction_Net_Model()
        self.g2m_int_net.name  = 'g2m_int_net'
        self.layers.append(self.g2m_int_net) 

        self.m2m_embedder = MLP_Model(in_channels, hidden_dimension) 
        self.m2m_embedder.name  = 'm2m_embedder'
        self.layers.append(self.m2m_embedder) 

        self.m2m_int_net = Interaction_Net_Model()
        self.m2m_int_net.name  = 'm2m_int_net'
        self.layers.append(self.m2m_int_net) 

        self.m2g_embedder = MLP_Model(in_channels, hidden_dimension) 
        self.m2g_embedder.name  = 'm2m_embedder'
        self.layers.append(self.m2g_embedder) 
 
        self.m2g_int_net = Interaction_Net_Model()
        self.m2g_int_net.name  = 'm2m_int_net'
        self.layers.append(self.m2g_int_net) 
        
        for layer in self.layers:
            layer.to(self.precision)

    def forward(
            self,
            g2m_edge_attr,
            m2m_edge_attr,
            m2g_edge_attr,
            g2m_edge_index,
            m2m_edge_index,
            m2g_edge_index,
            grid_mesh_rep):
            
        outputs_model = []
        
        outputs_sub_model1,grid_mesh_emb = self.grid_mesh_embedder(grid_mesh_rep)
        
        outputs_sub_model2,g2m_emb = self.g2m_embedder(g2m_edge_attr)

        outputs_sub_model3,grid_mesh_emb = self.g2m_int_net(grid_mesh_emb, g2m_edge_index, g2m_emb)
        
        outputs_sub_model4,m2m_emb = self.m2m_embedder(m2m_edge_attr)

        outputs_sub_model5,grid_mesh_emb = self.m2m_int_net(grid_mesh_emb, m2m_edge_index,m2m_emb)

        outputs_sub_model6,m2g_emb = self.m2g_embedder(m2g_edge_attr)

        outputs_sub_model7,grid_mesh_emb = self.m2g_int_net(grid_mesh_emb, m2g_edge_index,m2g_emb)

        outputs_model = outputs_sub_model1 + outputs_sub_model2 + outputs_sub_model3 + outputs_sub_model4 + outputs_sub_model5 + outputs_sub_model6 + outputs_sub_model7
        return outputs_model,grid_mesh_emb



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [15]:
import torch
from torch_geometric.data import Data

def create_grid_graph(rows, cols, num_channels):
    num_nodes = rows * cols
    edge_index = []
    
    for i in range(rows):
        for j in range(cols):
            node = i * cols + j
            if j < cols - 1:
                edge_index.append([node, node + 1])
            if i < rows - 1:
                edge_index.append([node, node + cols])
    
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.rand((edge_index.size(1), num_channels), dtype=torch.float)
    
    x = torch.rand((num_nodes, num_channels), dtype=torch.float)
    grid_data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    
    return grid_data

def create_mesh_graph(num_nodes, num_channels):
    edge_index = []
    
    for i in range(num_nodes):
        for j in range(i + 1, num_nodes):
            edge_index.append([i, j])
    
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.rand((edge_index.size(1), num_channels), dtype=torch.float)
    
    x = torch.rand((num_nodes, num_channels), dtype=torch.float)
    mesh_data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    
    return mesh_data

def create_g2m_and_m2g_connections(grid_data, mesh_data, num_channels):
    grid_nodes = grid_data.num_nodes
    mesh_nodes = mesh_data.num_nodes
    
    g2m_edge_index = torch.stack([torch.randint(0, grid_nodes, (mesh_nodes,)),
                                  torch.arange(0, mesh_nodes)], dim=0)
    
    m2g_edge_index = torch.stack([torch.arange(0, mesh_nodes),
                                  torch.randint(0, grid_nodes, (mesh_nodes,))], dim=0)
    
    g2m_edge_attr = torch.rand((g2m_edge_index.size(1), num_channels), dtype=torch.float)
    m2g_edge_attr = torch.rand((m2g_edge_index.size(1), num_channels), dtype=torch.float)
    
    return g2m_edge_index, g2m_edge_attr, m2g_edge_index, m2g_edge_attr

def create_custom_graph_dataset(grid_rows, grid_cols, mesh_nodes, num_channels, edge_dim):
    grid_data = create_grid_graph(grid_rows, grid_cols, num_channels)
    mesh_data = create_mesh_graph(mesh_nodes, num_channels)
    
    g2m_edge_index, g2m_edge_attr, m2g_edge_index, m2g_edge_attr = create_g2m_and_m2g_connections(grid_data, mesh_data, edge_dim)
    
    combined_x = torch.cat([grid_data.x, mesh_data.x], dim=0)
    
    mesh_edge_index_shifted = mesh_data.edge_index + grid_data.num_nodes
    
    data = Data(
        x=combined_x,
        g2m_edge_index=g2m_edge_index + torch.tensor([[0], [grid_data.num_nodes]], dtype=torch.long),
        g2m_edge_attr=g2m_edge_attr,
        m2m_edge_index=mesh_edge_index_shifted,
        m2m_edge_attr=mesh_data.edge_attr,
        m2g_edge_index=m2g_edge_index + torch.tensor([[grid_data.num_nodes], [0]], dtype=torch.long),
        m2g_edge_attr=m2g_edge_attr
    )
    
    return data

dataset = create_custom_graph_dataset(
    grid_rows=4,
    grid_cols=4,
    mesh_nodes=5,
    num_channels=32,  # Node feature dimension
    edge_dim=32  # Edge feature dimension
)



In [16]:
model = Graphcast()
grid_mesh_rep = dataset.x
g2m_edge_attr = dataset.g2m_edge_attr
m2m_edge_attr = dataset.m2m_edge_attr
m2g_edge_attr = dataset.m2g_edge_attr
g2m_edge_index = dataset.g2m_edge_index
m2m_edge_index = dataset.m2m_edge_index
m2g_edge_index = dataset.m2g_edge_index


outputs_model, grid_mesh_emb = model(
    g2m_edge_attr=g2m_edge_attr,
    m2m_edge_attr=m2m_edge_attr,
    m2g_edge_attr=m2g_edge_attr,
    g2m_edge_index=g2m_edge_index,
    m2m_edge_index=m2m_edge_index,
    m2g_edge_index=m2g_edge_index,
    grid_mesh_rep=grid_mesh_rep
)


RuntimeError: The size of tensor a (21) must match the size of tensor b (12) at non-singleton dimension 0