In [3]:
import itertools
from functools import reduce
from operator import mul
import torch

conditioner sur la dimension pour rollout probablement

In [7]:
def compute_num_products(num_variables, degree):
    """
    Compute the number of polynomial products for a given number of variables and degree.
    
    Args:
        num_variables (int): The number of variables.
        degree (int): The maximum degree of the polynomial terms.

    Returns:
        int: The total number of polynomial products.
    """
    num_products = 0
    for d in range(1, degree + 1):
        num_products += sum(1 for _ in itertools.combinations_with_replacement(range(num_variables), d))
    return num_products

def compute_products(tensor, degree):
    """
    Compute all polynomial products up to a given degree for the input tensor using PyTorch operations.
    
    Args:
        tensor (torch.Tensor): The input tensor with variables.
        degree (int): The maximum degree of the polynomial terms.

    Returns:
        torch.Tensor: A tensor containing all polynomial products up to the given degree.
    """
    num_variables = tensor.shape[0]
    num_products = compute_num_products(num_variables, degree)
    products = torch.empty(num_products)

    idx = 0
    for d in range(1, degree + 1):
        for combo in itertools.combinations_with_replacement(range(num_variables), d):
            product = torch.prod(tensor[list(combo)])
            products[idx] = product
            idx += 1

    return products

In [9]:
a = torch.tensor([4, 1, 2, 5, 5])

print(compute_products(a, 2))

tensor([ 4.,  1.,  2.,  5.,  5., 16.,  4.,  8., 20., 20.,  1.,  2.,  5.,  5.,
         4., 10., 10., 25., 25., 25.])


In [82]:
import torch
import itertools

def compute_num_products(num_variables, degree):
    """
    Compute the number of polynomial products for a given number of variables and degree.
    
    Args:
        num_variables (int): The number of variables.
        degree (int): The maximum degree of the polynomial terms.

    Returns:
        int: The total number of polynomial products.
    """
    num_products = 0
    for d in range(1, degree + 1):
        num_products += sum(1 for _ in itertools.combinations_with_replacement(range(num_variables), d))
    return num_products

def compute_products_batch(tensor, degree):
    """
    Compute all polynomial products up to a given degree for a batch of input tensors using PyTorch operations.
    
    Args:
        tensor (torch.Tensor): The input tensor with shape [batch_size, num_variables].
        degree (int): The maximum degree of the polynomial terms.

    Returns:
        torch.Tensor: A tensor containing all polynomial products for the batch, shape [batch_size, num_products].
    """
    batch_size, num_variables = tensor.shape
    num_products = compute_num_products(num_variables, degree)
    products = torch.empty(batch_size, num_products, device=tensor.device)

    idx = 0
    for d in range(1, degree + 1):
        for combo in itertools.combinations_with_replacement(range(num_variables), d):
            product = torch.prod(tensor[:, list(combo)], dim=1)
            products[:, idx] = product
            idx += 1

    return products

def weighted_sum_batch(products, weights):
    """
    Compute the weighted sum of the polynomial products for a batch.
    
    Args:
        products (torch.Tensor): The tensor containing polynomial products, shape [batch_size, num_products].
        weights (torch.Tensor): The corresponding weights for each polynomial product, shape [num_products].

    Returns:
        torch.Tensor: The weighted sum of polynomial products for the batch, shape [batch_size].
    """
    return torch.matmul(products, weights)

# Example usage
variables = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], requires_grad=True)  # Example batch of variables
degree = 3
products = compute_products_batch(variables, degree)

print(products)

num_products = products.shape[1]
weights = torch.randn(num_products, requires_grad=True)  # Random weights for each polynomial product

result = weighted_sum_batch(products, weights)
print(result)

# To check the computation graph, we can compute gradients
result.sum().backward()  # Summing results to compute a single scalar gradient
print(variables.grad)
print(weights.grad)


tensor([[  1.,   2.,   3.,   1.,   2.,   3.,   4.,   6.,   9.,   1.,   2.,   3.,
           4.,   6.,   9.,   8.,  12.,  18.,  27.],
        [  4.,   5.,   6.,  16.,  20.,  24.,  25.,  30.,  36.,  64.,  80.,  96.,
         100., 120., 144., 125., 150., 180., 216.]], grad_fn=<CopySlices>)
tensor([  4.0998, 136.5760], grad_fn=<MvBackward0>)
tensor([[12.1050,  6.1875, -8.6075],
        [78.0966, 10.9784,  4.1306]])
tensor([  5.,   7.,   9.,  17.,  22.,  27.,  29.,  36.,  45.,  65.,  82.,  99.,
        104., 126., 153., 133., 162., 198., 243.])


In [158]:
import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing
import itertools


LATENT_SHAPE = 128

# GNN related parameters
EDGES_SHAPE = 5
MESSAGE_SHAPE = 128
HIDDEN_NN_SHAPE = 128


# output
OUTPUT_SHAPE= 2


class MLP(nn.Module):
    """ 
    linearly growing size
    """
    def __init__(self, inputShape:int, outputShape:int, dropout:float = 0.3):
        super(MLP, self).__init__()

        self.inputShape = inputShape
        self.outputShape = outputShape

        self.delta = (inputShape - outputShape) // 3
        dim1 = inputShape - self.delta
        dim2 = dim1 - self.delta

        self.mlp = nn.Sequential(
            nn.Linear(inputShape, dim1),
            nn.LeakyReLU(),
            nn.Linear(dim1, dim2),
            #nn.ELU(),
            nn.LeakyReLU(),
            nn.Linear(dim2, outputShape),
        )
        
        self.init_weights()
    
    def forward(self, x):
        x = self.mlp(x)
        return x
    
    
    def init_weights(self):
        for layer in self.mlp.children():
            if isinstance(layer, nn.Linear):
                nn.init.kaiming_normal_(layer.weight, nonlinearity='leaky_relu')
                #nn.init.xavier_normal_(layer.weight)
                #nn.init.zeros_(layer.bias)
                layer.bias.data.fill_(0.)
                
 

class MLP2(nn.Module):
    """
    constant size
    """
    def __init__(self, inputShape:int, latentShape:int, outputShape:int, dropout:float = 0.3):
        super(MLP2, self).__init__()

        self.inputShape = inputShape
        self.latentShape = latentShape
        self.outputShape = outputShape

        self.mlp = nn.Sequential(
            nn.Linear(inputShape, latentShape),
            nn.LeakyReLU(),
            nn.Linear(latentShape, latentShape),
            nn.LeakyReLU(),
            nn.Linear(latentShape, outputShape),
        )
        
        self.init_weights()
    
    def forward(self, x):
        x = self.mlp(x)
        return x
    
    def init_weights(self):
        for layer in self.mlp.children():
            if isinstance(layer, nn.Linear):
                nn.init.kaiming_normal_(layer.weight, nonlinearity='leaky_relu')
                #nn.init.xavier_normal_(layer.weight)
                #nn.init.zeros_(layer.bias)
                layer.bias.data.fill_(0.)
                

def compute_num_products(num_variables, degree):
    """
    Compute the number of polynomial products for a given number of variables and degree.
    
    Args:
        num_variables (int): The number of variables.
        degree (int): The maximum degree of the polynomial terms.

    Returns:
        int: The total number of polynomial products.
    """
    num_products = 0
    for d in range(1, degree + 1):
        num_products += sum(1 for _ in itertools.combinations_with_replacement(range(num_variables), d))
    return num_products

def compute_products_batch(tensor, degree):
    """
    Compute all polynomial products up to a given degree for a batch of input tensors using PyTorch operations.
    
    Args:
        tensor (torch.Tensor): The input tensor with shape [batch_size, num_variables].
        degree (int): The maximum degree of the polynomial terms.

    Returns:
        torch.Tensor: A tensor containing all polynomial products for the batch, shape [batch_size, num_products].
    """
    batch_size, num_variables = tensor.shape
    num_products = compute_num_products(num_variables, degree)
    products = torch.empty(batch_size, num_products, device=tensor.device)

    idx = 0
    for d in range(1, degree + 1):
        for combo in itertools.combinations_with_replacement(range(num_variables), d):
            product = torch.prod(tensor[:, list(combo)], dim=1)
            products[:, idx] = product
            idx += 1

    return products


                
class GN_edge_GAM(MessagePassing):
    """ 
    Message passing neural network in which the message passing
    only considers the edges features
    """
    def __init__(self, inputShape:int, outputShape:int, degreePoly = 2, nbDim =2,  shapeEdges:int = 7, hiddenShape:int=64, aggr:str='add'):
        super(GN_edge_GAM, self).__init__(aggr=aggr)

        self.inputShape = inputShape
        self.degreePoly = degreePoly
        self.nbDim = nbDim
        self.messageShape = compute_num_products(shapeEdges, degreePoly) * nbDim
        self.outputShape = outputShape
        self.hiddenShape = hiddenShape
        
        self.messageMLP = MLP2(shapeEdges, hiddenShape, self.messageShape)
        self.norm = torch.nn.LayerNorm(self.nbDim+ inputShape)
        
        self.updateMLP = MLP(self.nbDim + inputShape, outputShape)
    
    
    def forward(self, x:torch.tensor, edge_index:torch.tensor, edge_attr: torch.tensor):
        # Propagate messages
        out = self.propagate(edge_index, size=(x.size(0), x.size(0)), edge_attr=edge_attr, x=x)
        
        return out
      
    def message(self, x_i:torch.tensor, x_j:torch.tensor, edge_attr: torch.tensor):
        """ 
        Perfomrs the message passing in the graph neural network
        
        Args:
        -----
        - `x_i`: tensor associated to node i
        - `x_j`: tensor associated to node j
        """

        # compute the products
        prods = compute_products_batch(edge_attr, self.degreePoly)
        shape = prods.shape
        prods = prods.repeat(1, self.nbDim).reshape(2*shape[0], shape[1])
        # needs to have 
        #print(prods)
        #print(prods)
        # repeat along some axis

        # obtain the weights

        weights = self.messageMLP(edge_attr).reshape(prods.shape)
        print(weights.shape)

        res = torch.sum(weights * prods, axis = -1)

        print(res.shape)
        print(res.shape[0]/self.nbDim)

        res = res.reshape(int(res.shape[0]/self.nbDim), self.nbDim)

        print(res.shape)

        return res
    
    def update(self, aggr_out:torch.tensor, x:torch.tensor):
        """ 
        Function to update all the nodes after the aggregation
        
        Args:
        -----
        - `aggr_out`: result after the aggregation [# Nodes, messageShape]
        - `x`: current node [1, inputShape]
        """
        print(aggr_out.shape)
        xVal = self.norm(torch.cat([x, aggr_out], dim=-1))
        return self.updateMLP(xVal) 

 

class GAM_GNN(nn.Module):
    def __init__(self, inShape:int, latentShape:int, outShape:int, nbMessages:int, basis, edge_shape:int = EDGES_SHAPE, hiddenGN:int = 128):
        """ 
        Neural network to combine everything
        
        Args:
        -----
        - `inShape`: shape of the input vector
        - `latentShape`: shape of the latent space
        - `outShape`: shape of the output vector
        - `messageShape`: shape of the message in the GN
        - `hiddenGN`: shape of the hidden layers in the MLP of the GN
        """
        
        super().__init__()
        
        self.inShape = inShape
        self.latentShape = latentShape
        self.outShape = outShape
        
        
        # GNN
        self.GNN = GN_edge_GAM(self.inShape, self.outShape, 2, 2, edge_shape, hiddenGN)
        
    def forward(self, graph):
        """ 
        
        Args:
        -----
        - `x`: value for the nodes [# Nodes, #Timesteps x inShape]
        - `edge_index`
        - `edge_attr`
        """

        x = graph.x
        edge_index = graph.edge_index
        edge_attr = graph.edge_attr
        
        
        # encoder part
        #y = self.applyEnc(x)
        
        # gnn part
        y = self.GNN(x, edge_index, edge_attr)                                     # [#Nodes, outSHpae]

        return y


    def applyEnc(self, x):                                                    

        return x
    

    def L1Reg(self, graph):
        atr = graph.edge_attr

        messages = self.GNN.message(None, None, atr)

        loss = 0.01 * torch.sum(torch.abs(messages)) / graph.edge_index[0, :].shape[0]

        return loss


def loadNetwork(inputShape, edge_shape = EDGES_SHAPE):
    print('>>>> loading simplest')
    print('INFO >>> with NO encoder')
    print('INFO >>> with NO dropout')
    net = GAM_GNN(inputShape, LATENT_SHAPE, OUTPUT_SHAPE, MESSAGE_SHAPE,edge_shape, HIDDEN_NN_SHAPE)

    return net

In [132]:
basis = compute_products_batch(attr[0], 2)

In [133]:
print(basis.shape)

torch.Size([134, 20])


In [134]:
import numpy as np

lim = 0.85 * 100

xPos = np.linspace(-lim, lim, 10)
yPos = np.linspace(-lim, lim, 10)
gridX, gridY = np.meshgrid(xPos, yPos)
delta = np.random.uniform(0, 7, gridX.shape + (2,))

gridX2 = gridX + delta[:, :, 0]
gridY2 = gridY + delta[:, :, 1]

pos = np.stack([gridX.ravel(), gridY.ravel()], axis=1)
pos_perturbed = np.stack([gridX2.ravel(), gridY2.ravel()], axis=1)

pos = np.concatenate([pos, pos_perturbed], axis=0)

angles = np.random.rand(pos.shape[0]) * 2 * np.pi

In [135]:
import sys

def path_link(path:str):
    sys.path.append(path)

path_link('master/code/lib')

import simulation_v2 as sim2
import features as ft

In [136]:
data = sim2.compute_main(200, (60, 3.5, 70, 0.5), 10, T = 1000, initialization = (pos, angles), dt = 0.001, seed = 42)[0]

v0:60, tau:3.5, k:70, epsilon:0.5


100%|██████████| 999/999 [00:10<00:00, 99.39it/s] 


In [137]:
x, y, attr, inds = ft.processSimulation(data)

In [138]:
from torch_geometric.data import Data

In [139]:
g = Data(x = x[0][:, 2:], y = y[0], edge_attr = attr[0], edge_index = inds[0])

In [153]:
print(g.x.shape)

torch.Size([200, 8])


In [148]:
print(g.edge_attr.shape)

torch.Size([98, 5])


In [159]:
mod = GAM_GNN(8, 128, 2, 2, compute_products_batch, 5)

In [160]:
print(mod.GNN.messageShape)

40


In [161]:
mod(g)

torch.Size([196, 20])
torch.Size([196])
98.0
torch.Size([98, 2])
torch.Size([200, 2])


tensor([[ 0.1424,  0.7505],
        [ 0.1416,  0.7423],
        [ 0.3497,  2.2736],
        [ 0.1413,  0.7326],
        [ 0.2291,  1.6039],
        [ 0.5131,  2.9059],
        [ 0.5609,  1.8078],
        [ 0.6948,  2.1609],
        [ 0.3780,  1.8856],
        [ 0.2508,  2.5541],
        [ 0.6593,  2.0728],
        [ 0.1407,  0.7295],
        [ 0.1414,  0.7368],
        [ 0.0191,  0.6627],
        [ 0.1421,  0.7506],
        [ 0.1188,  1.3606],
        [ 0.1440,  0.7734],
        [ 0.1196,  1.3624],
        [ 0.1401,  0.7175],
        [ 0.1389,  0.6987],
        [ 0.6889,  2.1468],
        [ 0.6722,  2.1057],
        [ 0.1409,  0.7256],
        [ 0.1393,  0.7103],
        [ 0.1410,  0.7259],
        [ 0.1419,  0.7467],
        [ 0.1430,  0.7609],
        [ 0.1420,  0.7514],
        [ 0.1427,  0.7605],
        [ 0.5520,  1.6455],
        [ 0.1437,  0.7802],
        [ 0.3345,  1.8098],
        [ 0.1997,  2.7438],
        [ 0.1378,  0.3816],
        [ 0.1886,  1.4646],
        [ 0.6457,  2

In [24]:
import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing
import itertools


GAM_CFG = {
    "input_shape": 8,
    "edges_shape": 5,
    "output_shape": 2,
    "MLP_message": {
        "hidden_shape": 128,
        "dropout": "no"
    },
    "MLP_update": {
        "hidden_shape": 128,
        "dropout": "no"
    },
    "Basis": {
        "basis": "poly",
        "degree": 2,
        "nDim": 2,
    },
    "regularization": {
        "name": "l1",
        "scaler": 0.001
    }
}


def compute_num_products(num_variables, degree):
    """
    Compute the number of polynomial products for a given number of variables and degree.
    
    Args:
        num_variables (int): The number of variables.
        degree (int): The maximum degree of the polynomial terms.

    Returns:
        int: The total number of polynomial products.
    """
    num_products = 0
    for d in range(1, degree + 1):
        num_products += sum(1 for _ in itertools.combinations_with_replacement(range(num_variables), d))
    return num_products


def compute_products_batch(tensor, degree):
    """
    Compute all polynomial products up to a given degree for a batch of input tensors using PyTorch operations.
    
    Args:
        tensor (torch.Tensor): The input tensor with shape [batch_size, num_variables].
        degree (int): The maximum degree of the polynomial terms.

    Returns:
        torch.Tensor: A tensor containing all polynomial products for the batch, shape [batch_size, num_products].
    """
    batch_size, num_variables = tensor.shape
    num_products = compute_num_products(num_variables, degree)
    products = torch.empty(batch_size, num_products, device=tensor.device)

    idx = 0
    for d in range(1, degree + 1):
        for combo in itertools.combinations_with_replacement(range(num_variables), d):
            product = torch.prod(tensor[:, list(combo)], dim=1)
            products[:, idx] = product
            idx += 1

    return products


class MLP(nn.Module):
    """ 
    linearly growing size
    """
    def __init__(self, inputShape:int, outputShape:int, dropout:float = 0.3):
        super(MLP, self).__init__()

        print(dropout)
        assert dropout == 'no'

        self.inputShape = inputShape
        self.outputShape = outputShape

        self.delta = (inputShape - outputShape) // 3
        dim1 = inputShape - self.delta
        dim2 = dim1 - self.delta

        self.mlp = nn.Sequential(
            nn.Linear(inputShape, dim1),
            nn.LeakyReLU(),
            nn.Linear(dim1, dim2),
            nn.LeakyReLU(),
            nn.Linear(dim2, outputShape),
        )
        
        self.init_weights()
    
    def forward(self, x):
        x = self.mlp(x)
        return x
    
    
    def init_weights(self):
        for layer in self.mlp.children():
            if isinstance(layer, nn.Linear):
                nn.init.kaiming_normal_(layer.weight, nonlinearity='leaky_relu')
                layer.bias.data.fill_(0.)
                

class MLP2(nn.Module):
    """
    constant size
    """
    def __init__(self, inputShape:int, latentShape:int, outputShape:int, dropout:float = 0.3):
        super(MLP2, self).__init__()

        self.inputShape = inputShape
        self.latentShape = latentShape
        self.outputShape = outputShape

        print(dropout)
        assert dropout == 'no'

        self.mlp = nn.Sequential(
            nn.Linear(inputShape, latentShape),
            nn.LeakyReLU(),
            nn.Linear(latentShape, latentShape),
            nn.LeakyReLU(),
            nn.Linear(latentShape, outputShape),
        )
        
        self.init_weights()
    
    def forward(self, x):
        x = self.mlp(x)
        return x
    
    def init_weights(self):
        for layer in self.mlp.children():
            if isinstance(layer, nn.Linear):
                nn.init.kaiming_normal_(layer.weight, nonlinearity='leaky_relu')
                layer.bias.data.fill_(0.)
                              
                
class GN_edge_GAM(MessagePassing):
    """ 
    Message passing neural network in which the message passing
    only considers the edges features
    """
    def __init__(self, d, aggr:str='add'):
        
        super(GN_edge_GAM, self).__init__(aggr=aggr)

        self.inputShape = d['input_shape']
        self.edgeShape = d['edges_shape']
        self.outputShape = d['output_shape']
        
        self.degreePoly = d['Basis']['degree']
        self.basis = d['Basis']['basis']
        self.nbDim = d['Basis']['nDim']
        self.messageShape = compute_num_products(d['edges_shape'], d['Basis']['degree']) * d['Basis']['nDim']
        
        self.hiddenShape = d['MLP_message']['hidden_shape']
        
        self.messageMLP = MLP2(inputShape = self.edgeShape, 
                               latentShape = self.hiddenShape, 
                               outputShape = self.messageShape, 
                               dropout = d['MLP_message']['dropout'])
        
        self.norm = torch.nn.LayerNorm(self.nbDim + self.inputShape)
        
        self.updateMLP = MLP(inputShape = self.nbDim + self.inputShape, 
                             outputShape = self.outputShape,
                             dropout = d['MLP_update']['dropout'])

    
    
    def forward(self, x:torch.tensor, edge_index:torch.tensor, edge_attr: torch.tensor):

        out = self.propagate(edge_index, size=(x.size(0), x.size(0)), edge_attr=edge_attr, x=x)
        
        return out
      
    def message(self, x_i:torch.tensor, x_j:torch.tensor, edge_attr: torch.tensor):
        """ 
        Perfomrs the message passing in the graph neural network
        
        Args:
        -----
        - `x_i`: tensor associated to node i
        - `x_j`: tensor associated to node j
        """

        # compute the products
        prods = compute_products_batch(edge_attr, self.degreePoly)

        shape = prods.shape
        prods = prods.repeat(1, self.nbDim).reshape(self.nbDim * shape[0], shape[1])

        # obtain the weights
        weights = self.messageMLP(edge_attr).reshape(prods.shape)

        res = torch.sum(weights * prods, axis = -1)

        res = res.reshape(int(res.shape[0]/self.nbDim), self.nbDim)

        return res
    
    def update(self, aggr_out:torch.tensor, x:torch.tensor):
        """ 
        Function to update all the nodes after the aggregation
        
        Args:
        -----
        - `aggr_out`: result after the aggregation [# Nodes, messageShape]
        - `x`: current node [1, inputShape]
        """

        xVal = self.norm(torch.cat([x, aggr_out], dim=-1))
        return self.updateMLP(xVal) 

 
class GAM_GNN(nn.Module):
    def __init__(self, d):
        """ 
        Neural network to combine everything
        
        Args:
        -----
        - `inShape`: shape of the input vector
        - `latentShape`: shape of the latent space
        - `outShape`: shape of the output vector
        - `messageShape`: shape of the message in the GN
        - `hiddenGN`: shape of the hidden layers in the MLP of the GN
        """
        
        super().__init__()
        
        self.inShape = d['input_shape']
        self.outShape = d['output_shape']

        self.regu = d['regularization']['name']
        self.scaler_regu = d['regularization']['scaler']
        
        
        # GNN
        self.GNN = GN_edge_GAM(d)
        
    def forward(self, graph):
        """ 
        
        Args:
        -----
        - `x`: value for the nodes [# Nodes, #Timesteps x inShape]
        - `edge_index`
        - `edge_attr`
        """

        x = graph.x
        edge_index = graph.edge_index
        edge_attr = graph.edge_attr
        
        
        # encoder part
        #y = self.applyEnc(x)
        
        # gnn part
        y = self.GNN(x, edge_index, edge_attr)                                     # [#Nodes, outSHpae]

        return y


    def applyEnc(self, x):                                                    

        return x
    

    def L1Reg(self, graph):
        atr = graph.edge_attr

        messages = self.GNN.message(None, None, atr)

        loss = self.scaler_regu * torch.sum(torch.abs(messages)) / graph.edge_index[0, :].shape[0]

        return loss


def loadNetwork(d = None):
    if d is None:
        d = GAM_CFG
    print('>>>> loading simplest')
    print('INFO >>> with NO encoder')
    print('INFO >>> with NO dropout')
    net = GAM_GNN(d)

    return net


In [25]:
import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing
import itertools


GAM_CFG = {
    "input_shape": 8,
    "edges_shape": 5,
    "output_shape": 2,
    "MLP_message": {
        "hidden_shape": 128,
        "dropout": "no"
    },
    "MLP_update": {
        "hidden_shape": 128,
        "dropout": "no"
    },
    "Basis": {
        "basis": "poly",
        "degree": 2,
        "nDim": 2,
    },
    "regularization": {
        "name": "l1",
        "scaler": 0.001
    }
}


def compute_num_products(num_variables, degree):
    """
    Compute the number of polynomial products for a given number of variables and degree.
    
    Args:
        num_variables (int): The number of variables.
        degree (int): The maximum degree of the polynomial terms.

    Returns:
        int: The total number of polynomial products.
    """
    num_products = 0
    for d in range(1, degree + 1):
        num_products += sum(1 for _ in itertools.combinations_with_replacement(range(num_variables), d))
    return num_products


def compute_products_batch(tensor, degree):
    """
    Compute all polynomial products up to a given degree for a batch of input tensors using PyTorch operations.
    
    Args:
        tensor (torch.Tensor): The input tensor with shape [batch_size, num_variables].
        degree (int): The maximum degree of the polynomial terms.

    Returns:
        torch.Tensor: A tensor containing all polynomial products for the batch, shape [batch_size, num_products].
    """
    batch_size, num_variables = tensor.shape
    num_products = compute_num_products(num_variables, degree)
    products = torch.empty(batch_size, num_products, device=tensor.device)

    idx = 0
    for d in range(1, degree + 1):
        for combo in itertools.combinations_with_replacement(range(num_variables), d):
            product = torch.prod(tensor[:, list(combo)], dim=1)
            products[:, idx] = product
            idx += 1

    return products


class MLP(nn.Module):
    """ 
    linearly growing size
    """
    def __init__(self, inputShape:int, outputShape:int, dropout:float = 0.3):
        super(MLP, self).__init__()

        assert dropout == 'no'

        self.inputShape = inputShape
        self.outputShape = outputShape

        self.delta = (inputShape - outputShape) // 3
        dim1 = inputShape - self.delta
        dim2 = dim1 - self.delta

        self.mlp = nn.Sequential(
            nn.Linear(inputShape, dim1),
            nn.LeakyReLU(),
            nn.Linear(dim1, dim2),
            nn.LeakyReLU(),
            nn.Linear(dim2, outputShape),
        )
        
        self.init_weights()
    
    def forward(self, x):
        x = self.mlp(x)
        return x
    
    
    def init_weights(self):
        for layer in self.mlp.children():
            if isinstance(layer, nn.Linear):
                nn.init.kaiming_normal_(layer.weight, nonlinearity='leaky_relu')
                layer.bias.data.fill_(0.)
                

class MLP2(nn.Module):
    """
    constant size
    """
    def __init__(self, inputShape:int, latentShape:int, outputShape:int, dropout:float = 0.3):
        super(MLP2, self).__init__()

        self.inputShape = inputShape
        self.latentShape = latentShape
        self.outputShape = outputShape

        assert dropout == 'no'

        self.mlp = nn.Sequential(
            nn.Linear(inputShape, latentShape),
            nn.LeakyReLU(),
            nn.Linear(latentShape, latentShape),
            nn.LeakyReLU(),
            nn.Linear(latentShape, outputShape),
        )
        
        self.init_weights()
    
    def forward(self, x):
        x = self.mlp(x)
        return x
    
    def init_weights(self):
        for layer in self.mlp.children():
            if isinstance(layer, nn.Linear):
                nn.init.kaiming_normal_(layer.weight, nonlinearity='leaky_relu')
                layer.bias.data.fill_(0.)
                              
                
class GN_edge_GAM(MessagePassing):
    """ 
    Message passing neural network in which the message passing
    only considers the edges features
    """
    def __init__(self, d, aggr:str='add'):
        
        super(GN_edge_GAM, self).__init__(aggr=aggr)

        self.inputShape = d['input_shape']
        self.edgeShape = d['edges_shape']
        self.outputShape = d['output_shape']
        
        self.degreePoly = d['Basis']['degree']
        self.basis = d['Basis']['basis']
        self.nbDim = d['Basis']['nDim']
        self.messageShape = compute_num_products(d['edges_shape'], d['Basis']['degree']) * d['Basis']['nDim']
        
        self.hiddenShape = d['MLP_message']['hidden_shape']
        
        self.messageMLP = MLP2(inputShape = self.edgeShape, 
                               latentShape = self.hiddenShape, 
                               outputShape = self.messageShape, 
                               dropout = d['MLP_message']['dropout'])
        
        self.norm = torch.nn.LayerNorm(self.nbDim + self.inputShape)
        
        self.updateMLP = MLP(inputShape = self.nbDim + self.inputShape, 
                             outputShape = self.outputShape,
                             dropout = d['MLP_update']['dropout'])

    
    
    def forward(self, x:torch.tensor, edge_index:torch.tensor, edge_attr: torch.tensor):

        out = self.propagate(edge_index, size=(x.size(0), x.size(0)), edge_attr=edge_attr, x=x)
        
        return out
      
    def message(self, x_i:torch.tensor, x_j:torch.tensor, edge_attr: torch.tensor):
        """ 
        Perfomrs the message passing in the graph neural network
        
        Args:
        -----
        - `x_i`: tensor associated to node i
        - `x_j`: tensor associated to node j
        """

        # compute the products
        prods = compute_products_batch(edge_attr, self.degreePoly)
        shape = prods.shape
        prods = prods.repeat(1, self.nbDim).reshape(self.nbDim * shape[0], shape[1])

        # obtain the weights
        weights = self.messageMLP(edge_attr).reshape(prods.shape)
        print(weights.shape)

        res = torch.sum(weights * prods, axis = -1)

        res = res.reshape(int(res.shape[0]/self.nbDim), self.nbDim)

        return res
    
    def update(self, aggr_out:torch.tensor, x:torch.tensor):
        """ 
        Function to update all the nodes after the aggregation
        
        Args:
        -----
        - `aggr_out`: result after the aggregation [# Nodes, messageShape]
        - `x`: current node [1, inputShape]
        """

        xVal = self.norm(torch.cat([x, aggr_out], dim=-1))
        return self.updateMLP(xVal) 

 
class GAM_GNN(nn.Module):
    def __init__(self, d):
        """ 
        Neural network to combine everything
        
        Args:
        -----
        - `inShape`: shape of the input vector
        - `latentShape`: shape of the latent space
        - `outShape`: shape of the output vector
        - `messageShape`: shape of the message in the GN
        - `hiddenGN`: shape of the hidden layers in the MLP of the GN
        """
        
        super().__init__()
        
        self.inShape = d['input_shape']
        self.outShape = d['output_shape']

        self.regu = d['regularization']['name']
        self.scaler_regu = d['regularization']['scaler']
        
        
        # GNN
        self.GNN = GN_edge_GAM(d)
        
    def forward(self, graph):
        """ 
        
        Args:
        -----
        - `x`: value for the nodes [# Nodes, #Timesteps x inShape]
        - `edge_index`
        - `edge_attr`
        """

        x = graph.x
        edge_index = graph.edge_index
        edge_attr = graph.edge_attr
        
        
        # encoder part
        #y = self.applyEnc(x)
        
        # gnn part
        y = self.GNN(x, edge_index, edge_attr)                                     # [#Nodes, outSHpae]

        return y


    def applyEnc(self, x):                                                    

        return x
    

    def L1Reg(self, graph):
        atr = graph.edge_attr

        messages = self.GNN.message(None, None, atr)

        loss = self.scaler_regu * torch.sum(torch.abs(messages)) / graph.edge_index[0, :].shape[0]

        return loss


def loadNetwork(d = None):
    if d is None:
        d = GAM_CFG
    print('>>>> loading simplest')
    print('INFO >>> with NO encoder')
    print('INFO >>> with NO dropout')
    net = GAM_GNN(d)

    return net


In [2]:
import sys

def path_link(path:str):
    sys.path.append(path)

path_link('/master/code/lib')

import utils.testing_gen as gen

yessss


In [3]:
NB_SIM = 10

params = gen.Parameters_Simulation()
data = gen.get_mult_data(params, NB_SIM)

v0:60, tau:3.5, k:70, epsilon:0.5


100%|██████████| 149/149 [00:01<00:00, 96.43it/s] 


v0:60, tau:3.5, k:70, epsilon:0.5


100%|██████████| 149/149 [00:01<00:00, 111.03it/s]


v0:60, tau:3.5, k:70, epsilon:0.5


100%|██████████| 149/149 [00:01<00:00, 109.63it/s]


v0:60, tau:3.5, k:70, epsilon:0.5


100%|██████████| 149/149 [00:01<00:00, 80.62it/s]


v0:60, tau:3.5, k:70, epsilon:0.5


100%|██████████| 149/149 [00:01<00:00, 86.93it/s]


v0:60, tau:3.5, k:70, epsilon:0.5


100%|██████████| 149/149 [00:01<00:00, 101.42it/s]


v0:60, tau:3.5, k:70, epsilon:0.5


100%|██████████| 149/149 [00:01<00:00, 78.17it/s]


v0:60, tau:3.5, k:70, epsilon:0.5


100%|██████████| 149/149 [00:01<00:00, 93.36it/s]


v0:60, tau:3.5, k:70, epsilon:0.5


100%|██████████| 149/149 [00:01<00:00, 103.17it/s]


v0:60, tau:3.5, k:70, epsilon:0.5


100%|██████████| 149/149 [00:01<00:00, 107.32it/s]


In [4]:
graphs = gen.sims2Graphs(data)

In [5]:
print(data.shape)

(10, 150, 200, 2)


In [26]:
nn = loadNetwork()


>>>> loading simplest
INFO >>> with NO encoder
INFO >>> with NO dropout


In [27]:
a = nn(graphs[0])

torch.Size([228, 20])
