In [2]:
import numpy as np 
from sklearn.mixture import GaussianMixture
import itertools
import pandas as pd 
import torch
import networkx as nx 

In [67]:
class GLMGraph:
    """
    Generate a Graph with conditional edge distribution modeled via GLM's.
    Graph should have: x, y: Node features and target  
    edge_index: all edges in the graph 
    pos_edge_label_index: A subset of edge index (typically half)
    neg_edge_label_index: A subset of all non-existing edges (same length as pos_edge_label_index)
    No Edge Attributes considered. 
    """
    def __init__(self, n_nodes:int):
        
        self.Graph = None # an NetworkX representation 
        self.n_nodes = n_nodes
        
        self.num_features = None 
        self.cat_features = None 
        self.x = None
        self.y = None
        self.long_data = None
        
        self.edge_index = None
        self.edge_pos_edge_label_index = None
        self.edge_neg_edge_label_index= None
        
    def add_num_features(self, n_c:int, mu:list, sigma:list, w:list):
        """
        :param n_c: number of cluster components 
        :param mu: list of tuples corresponding to 
        num of features (tuple-length) and number of components (list-length) 
        :param sigma: Covariance matrix 
        :param w: mixture weights 
        :return: None 
        
        Add numeric features. Only used within the class. 
        """

        gmm = GaussianMixture(n_components=n_c, covariance_type='full') 
        # 'full': each component has its own general covariance matrix.
        gmm.means_ = np.array(mu)
        gmm.covariances_ = np.array(sigma)
        gmm.weights_ = np.array(w)
        gmm.precisions_cholesky_ = np.linalg.cholesky(np.linalg.inv(covariances))
        
        data, _ = gmm.sample(self.n_nodes)
        
        self.num_features = data
        
        
    def add_cat_features(self,k:list, p:list):
        """
        :param k: k classes (per features) --> like: [2,3,2]
        :param p: proportions (per feature) --> like: [[0.5,0.5], [0.1,0.4,0.5], [0.3,0.7]]
        :return: None
        pd.get_dummies(df['category'])
        """
        n = self.n_nodes
        m = len(k)
        data = np.zeros((n, m)) 
        
        for i in range(m):
            col_i = np.random.choice(np.arange(k[i]),
                                     size=n,
                                     p=p[i])
            data[:, i] = col_i
        self.cat_features = data
    
    def set_x(self):
        if self.cat_features is not None:
            self.x =  np.concatenate((self.num_features, self.cat_features), axis=1)
        else:
            self.x = self.num_features
        
    def set_y(self,theta:list, eps:float, binary=False, treshhold=0.5):
        """
        :param theta: model params (no intercept)
        :param eps: variance of linear relationship 
        :param binary: if y should be binary target
        :return: None
        """
        lin_rel = np.dot(self.x, theta) + np.random.normal(0,eps,self.n_nodes)
        if binary:
            y = (1 / (1 + np.exp(-lin_rel))) >= treshhold
        else: 
            y = lin_rel
        self.y = y
        

    def set_connections(self, intercept, theta, eps, treshhold):
        """
        :param intercept: Base connectivity, should center the log_odds around 0
        :param theta: Model params (similar nodes should have higher coefficients)
        :param eps: model variance 
        :param treshhold: either a number or an array of uniform (0,1) distributed numbers of len(n_nodes) 
        """
        n_rows = int((self.n_nodes * (self.n_nodes-1)) / 2) # all permutations
        if self.cat_features is not None:
            m_cols = self.cat_features.shape[1] + self.num_features.shape[1]
        
        else:
            m_cols = self.num_features.shape[1]
            long = np.zeros((n_rows, m_cols)) 
        
            idx = list(range(self.n_nodes))
            idx_pairs = list(itertools.combinations(idx, 2))

            for i in range(m_cols):
                xi_pairs = list(itertools.combinations(self.x[:,i], 2))
                xi_delta = [abs(pair[0] - pair[1]) for pair in xi_pairs]
                long[:,i] = xi_delta
            
            colnames = [f"x{i+1}_delta" for i in range(m_cols)]
            df = pd.DataFrame(long,columns=colnames)
            df.index = idx_pairs
            
        log_odds = intercept + np.dot(df, theta) + np.random.normal(0,eps,self.n_nodes)
        df["Connections"] = 1 / (1 + np.exp(-log_odds)) >= treshhold
        
        self.long_data = df

['idx', 'x1_delta', 'x2_delta', 'x3_delta']

In [23]:
means = [(-2, -2),
         (0, 0),
         (3, 3)]
    
covariances = [np.array([[0.5, 0.1],
                         [0.1, 0.5]]),
               np.array([[0.8, -0.2],
                         [-0.2, 0.8]]),
               np.array([[0.3, 0.1],
                         [0.1, 0.3]])]
    
weights = [0.3, 0.4, 0.3]

k = [2,3,2]
p = [[0.5,0.5], [0.1,0.4,0.5], [0.3,0.7]]

In [69]:
gglm = GLMGraph(200)
print(gglm.x)
print(gglm.n_nodes)
gglm.add_num_features(n_c=3,mu=means,sigma=covariances,w=weights)
#gglm.add_cat_features(k, p)
gglm.set_x()
gglm.x
gglm.set_y([1.5,3], 3)
gglm.y
gglm.set_long_data()
gglm.long_data

None
200


Unnamed: 0,x1_delta,x2_delta
"(0, 1)",0.287287,0.904842
"(0, 2)",0.080666,0.543363
"(0, 3)",1.153828,2.548837
"(0, 4)",1.806862,0.888759
"(0, 5)",1.567002,0.991068
...,...,...
"(196, 198)",0.111928,0.910076
"(196, 199)",0.029981,0.049185
"(197, 198)",0.095352,0.333752
"(197, 199)",0.046558,0.527138


In [75]:
np.random.rand(100, 1) + 10


array([[10.07495222],
       [10.60603541],
       [10.41095196],
       [10.24719507],
       [10.99194585],
       [10.01317074],
       [10.72772855],
       [10.59346321],
       [10.81425776],
       [10.62147479],
       [10.9041581 ],
       [10.80165151],
       [10.39554083],
       [10.00366738],
       [10.53116226],
       [10.74152619],
       [10.26447715],
       [10.41329595],
       [10.15711723],
       [10.0427239 ],
       [10.04245299],
       [10.70145842],
       [10.01929727],
       [10.62639541],
       [10.25434541],
       [10.53406624],
       [10.96856425],
       [10.71122887],
       [10.08516683],
       [10.74254962],
       [10.40829581],
       [10.84644035],
       [10.17328457],
       [10.26356332],
       [10.70583841],
       [10.91114059],
       [10.95856984],
       [10.19627601],
       [10.26792431],
       [10.82269163],
       [10.71372588],
       [10.34994062],
       [10.75745653],
       [10.05445128],
       [10.80194172],
       [10