In [11]:
# Imports
%matplotlib inline
import functools
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from jax import jit
import optax
import numpy as np
from os import path
from matplotlib import pyplot as plt
import torch
from jax import random as jrandom
from functools import partial
from jax import value_and_grad
from tqdm import tqdm

data_loc = 'C:/Users/arthu/OneDrive/CentraleSupelec/3A/Cours/Graphical Model/deepwalk-node2vec-comparison/BlogCatalog3/BlogCatalog-dataset/data/'


## Load Data

In [12]:
def load_data():
    with open(data_loc+'nodes.csv', 'r') as f:
        num_nodes = sum(1 for _ in f)
    with open(data_loc+'edges.csv', 'r') as f:
        num_edges = sum(1 for _ in f)
    
    adjacent_matrix = np.zeros((num_nodes,num_nodes))
    num_neighbors = np.zeros(num_nodes)
    with open(data_loc+'edges.csv', 'r') as f:
        for line in tqdm(f.readlines(),total=num_edges):
            i, j = line.strip().split(',')  # csv
            i,j = int(i) - 1,int(j) - 1
            adjacent_matrix[i,j] = num_neighbors[i] + 1
            adjacent_matrix[j,i] = num_neighbors[j] + 1

            num_neighbors[i] += 1
            num_neighbors[j] += 1
    adjacent_matrix = jnp.array(adjacent_matrix)
    return adjacent_matrix
adjacent_matrix = load_data()

100%|██████████| 333983/333983 [00:00<00:00, 341696.27it/s]


## Visualizing the Graph
To visualize the graph structure of the graph we created above, we will use the [`networkx`](networkx.org) library because it already has functions for drawing graphs.

We first convert the `jraph.GraphsTuple` to a `networkx.DiGraph`.

In [13]:
#The graph is too big to be displayed
#draw_jraph_graph_structure(bc_dataset)

In [14]:
rng_key = jrandom.key(1)

@jit
def get_next_node(params,_):
    current_node,rng_key,adjacent_matrix = params
    neighbors_mask = adjacent_matrix[current_node]
    n_neighbors = (jax.lax.sqrt(jnp.sum(neighbors_mask + 1)) - 1) / 2
    rng_key, subkey = jrandom.split(rng_key)
    next_neighbor_index = jrandom.randint(subkey, [1],0,n_neighbors)
    next_node = jnp.argwhere(neighbors_mask == next_neighbor_index,size=1)[0][0]
    return (next_node,subkey,adjacent_matrix),next_node


In [15]:
get_next_node(params=(0,rng_key,adjacent_matrix),_=None)[1]

Array(2770, dtype=int32)

In [16]:
@partial(jit,static_argnames=['walk_length'])
def get_walk(initial_node,walk_length,rng_key,adjacent_matrix):
    walk = jax.lax.scan(f=get_next_node,init=(initial_node,rng_key,adjacent_matrix),xs=None,length=walk_length)[1]
    return walk

In [17]:
get_walk(0,5,rng_key,adjacent_matrix)

Array([2770, 1110,    0, 5333, 4146], dtype=int32)

In [18]:
@partial(jit,static_argnames=['num_walks','walk_length'])
def gen_random_walk_tensor(initial_node,num_walks, walk_length,rng_key,adjacent_matrix ):
    walk_list = jnp.zeros((num_walks,walk_length),dtype=jnp.int32)
    for i in range(num_walks):
        rng_key,_ = jrandom.split(rng_key)
        walk_list = walk_list.at[i].set(get_walk(initial_node=initial_node,walk_length=walk_length,rng_key=rng_key,adjacent_matrix=adjacent_matrix))
    return walk_list

In [19]:
random_walks = gen_random_walk_tensor(0,10,10,rng_key,adjacent_matrix)
print(random_walks)

[[2520  968 2596 3003 4297 3935 1644  418 3317 1451]
 [1262    0 2460 3055  309    0  752 1159 3124 2155]
 [ 282 1488 1586  637 2149   25 1886  666  294 1225]
 [3560 2452 2799 3322  305 3851 2336 3851 1856  569]
 [ 584  738 2384 2420  866  808 1100 1225 3127  858]
 [ 644  813 3322 2927 1386 3418 3239 2451  857 1932]
 [4838  581 3560  651 2562 5581 1800 1273 1120    0]
 [5203 2326 1702 1680  110    0 5310 6323 5624 1473]
 [ 370 4373  985 1931  425 4255 2442  177    0 5333]
 [4835 3251    0 3338  409 3607 1437  445 2339  127]]


In [20]:
@partial(jit,static_argnames=['window_size','num_walks','walk_length'])
def generate_windows(random_walks,window_size,num_walks,walk_length):
    num_windows = walk_length + 1 - window_size
    windows = jnp.zeros((num_walks * num_windows, window_size), dtype=jnp.int32)
    for j in range(num_windows):
        windows = windows.at[num_walks * j:num_walks * (j + 1)].set(random_walks[:, j:j + window_size])
    return windows

windows = generate_windows(random_walks, 3,10,10)
print(windows.shape)

(80, 3)


In [21]:
@jit
def get_windows_dotproduct(windows, embedding):
    embedding_size = embedding.shape[1]
    # get the embedding of the initial node repeated num_windows times
    first_emb = embedding[windows[:, 0]]
    first_emb = jnp.expand_dims(first_emb, axis=1)  # Ajouter une nouvelle dimension
    # get the embedding of the remaining nodes in each window
    others_emb = embedding[windows[:, 1:]]
    others_emb = others_emb.reshape(windows.shape[0], -1, embedding_size)
    # result has same shape as others
    # Each element is the dot product between the corresponding node embedding
    # and the embedding of the first node of that walk
    # that is, result_{i, j} for random walk i and element j is v_{W_{i, 0}} dot v_{W_{i, j}}
    result = jnp.sum(first_emb * others_emb, axis=-1)
    return result

vg_get_windows_dotproduct = value_and_grad(get_windows_dotproduct,argnums=1)

embedding_jax = jax.random.normal(rng_key, shape=(12000, 300))
get_windows_dotproduct(windows, embedding_jax)

Array([[-1.37977467e+01, -2.85341167e+01],
       [-6.18821239e+00,  7.70775938e+00],
       [-9.64875507e+00,  1.55383568e+01],
       [ 9.40449810e+00,  1.19765043e+01],
       [ 1.07187233e+01, -1.41470375e+01],
       [-2.41891346e+01,  2.50357399e+01],
       [-6.48633862e+00, -2.54224358e+01],
       [ 3.90500412e+01,  1.35591278e+01],
       [-9.41652489e+00, -5.67467213e-02],
       [-2.20587254e+00,  2.82458382e+01],
       [ 4.07399321e+00,  1.35076714e+01],
       [-2.63694935e+01, -2.19778252e+01],
       [ 1.33413086e+01,  1.46388206e+01],
       [ 1.26078262e+01, -1.87169228e+01],
       [ 1.39328594e+01,  4.81882572e+00],
       [ 4.52557468e+00, -5.49393749e+00],
       [-4.09537029e+00, -1.66300621e+01],
       [ 5.59790611e+01, -1.63674603e+01],
       [ 1.19582491e+01, -9.06642914e-01],
       [ 7.78726387e+00, -1.95134163e-01],
       [-2.03491998e+00, -1.70777645e+01],
       [ 3.18227673e+01,  2.15255032e+01],
       [ 8.85899830e+00, -4.25069475e+00],
       [-1.

In [22]:
@jit
def compute_mean_log_sigmoid(windows, embedding):
    dot_product = get_windows_dotproduct(windows, embedding)
    sigmoid_output = 1 / (1 + jnp.exp(-dot_product))  # Sigmoid function
    log_sigmoid_output = jnp.log(sigmoid_output)  # Logarithm
    return jnp.mean(log_sigmoid_output)  # Mean

# Usage example:
mean_log_sigmoid = compute_mean_log_sigmoid(windows, embedding_jax)

In [23]:
def gen_negative_samples(amount, length, initial_node, number_of_nodes,rng_key):
  """Generates negative samples for a random walk process in JAX.

  Args:
    amount: Number of negative samples to generate.
    length: Length of each negative sample walk (number of nodes in the path).
    initial_node: Starting node for all negative samples.
    number_of_nodes: Total number of nodes in the graph.

  Returns:
    A JAX array of shape (amount, length) containing the negative samples.
  """
  negative_samples = jnp.zeros((amount, length), dtype=jnp.int32)  # Use jnp.int32 for node indices
  negative_samples = negative_samples.at[:, 0].set(initial_node)  # Set initial node efficiently
  rng_key, subkey = jrandom.split(rng_key)
  negative_samples = negative_samples.at[:, 1:].set(
      jrandom.randint(rng_key,shape=(amount, length - 1),minval = 0, maxval = number_of_nodes)
  )
  return negative_samples, rng_key

gen_negative_samples(amount=3, length=5, initial_node=0, number_of_nodes=2000,rng_key =rng_key)

(Array([[   0,   80, 1554,  832,  799],
        [   0,  670, 1990,  305, 1162],
        [   0,  479,  644,  798,  695]], dtype=int32),
 Array((), dtype=key<fry>) overlaying:
 [2441914641 1384938218])

In [24]:
import jax
import jax.numpy as jnp
from jax import random
import optax
from tqdm import tqdm


### TO FIX

def deepWalk(adjacent_matrix, walks_per_vertex, walk_length, window_size, embedding_size, num_neg, lr, epochs,rng_key ,eps = 1e-15):
    
    number_of_nodes = adjacent_matrix.shape[0]
    
    embedding = jax.random.normal(rng_key, shape=(number_of_nodes, embedding_size))
    optimizer = optax.sgd(learning_rate=lr)
    opt_state = optimizer.init(embedding)
    loss_history = {'pos': [], 'neg': [], 'total': []}

    def loss_fn(embedding, random_walks,negative_samples,num_walks, window_size, eps = 1e-8):
        windows = generate_windows(random_walks,window_size,num_walks,walk_length)
        dotproduct_positive = get_windows_dotproduct(windows, embedding)
        windows = generate_windows(negative_samples,window_size,num_neg,walk_length)
        dotproduct_negative = get_windows_dotproduct(windows, embedding)
        pos_loss = -jnp.log(jax.nn.sigmoid(dotproduct_positive) + eps).mean()
        neg_loss = -jnp.log(1 - jax.nn.sigmoid(dotproduct_negative) + eps).mean()
        return pos_loss + neg_loss

    grad_fn = jax.value_and_grad(loss_fn)
    
    def update(params, opt_state, random_walks,negative_samples ):
        loss, grads = grad_fn(params, random_walks,negative_samples, walks_per_vertex,window_size)
        updates, opt_state = optimizer.update(grads, opt_state)
        new_params = optax.apply_updates(params, updates)
        return new_params, opt_state, loss

    for _ in range(epochs):
        nodes = np.arange(0,number_of_nodes)
        for n in tqdm(nodes):
            random_walks = gen_random_walk_tensor(n,walks_per_vertex,walk_length,rng_key,adjacent_matrix)
            negative_samples , rng_key = gen_negative_samples(
                num_neg,
                length=walk_length, 
                initial_node=n, 
                number_of_nodes=number_of_nodes,
                rng_key = rng_key
            )

            embedding, opt_state, loss = update(embedding, opt_state, random_walks,negative_samples)
            loss_history['total'].append(loss)

    return embedding, loss_history


In [25]:
deepWalk(adjacent_matrix, 5, 5, 5, 128, 2, 1e-5, 1,rng_key ,eps = 1e-15)

  4%|▎         | 378/10312 [00:16<07:24, 22.34it/s] 


KeyboardInterrupt: 

In [None]:
cumsum_vec = np.cumsum(np.insert(loss_history['total'], 0, 0)) 
window_width = 10
ma_vec = (cumsum_vec[window_width:] - cumsum_vec[:-window_width]) / window_width
plt.plot(ma_vec)

NameError: name 'loss_history' is not defined

In [None]:
embedding

In [None]:
X_norm = (X-X.mean(axis=0)) / X.std(axis=0)

In [None]:
from sklearn.metrics import f1_score
from sklearn.linear_model import LogisticRegression

X = embedding.detach().numpy()
y = bc_dataset['labels']

clf = LogisticRegression(random_state=0, multi_class='ovr').fit(X_norm, y)
y_hat = clf.predict(X_norm)
f1_score(y, y_hat, average='micro')

In [None]:
from sklearn.decomposition import PCA
pca = PCA(n_components=2)
X_t = pca.fit_transform(X)


In [None]:
plt.scatter(X_t[:, 0], X_t[:, 1], c=y)