In [5]:
# Imports
%matplotlib inline
import functools
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from jax import jit
import optax
import numpy as np
from os import path
from matplotlib import pyplot as plt
import torch
from jax import random as jrandom
from functools import partial
from jax import value_and_grad
from tqdm import tqdm

data_loc = 'C:/Users/arthu/OneDrive/CentraleSupelec/3A/Cours/Graphical Model/deepwalk-node2vec-comparison/BlogCatalog3/BlogCatalog-dataset/data/'


## Load Data

In [13]:
def load_data():
    with open(data_loc+'nodes.csv', 'r') as f:
        num_nodes = sum(1 for _ in f)
    with open(data_loc+'edges.csv', 'r') as f:
        num_edges = sum(1 for _ in f)
    
    adjacent_matrix = np.zeros((num_nodes,num_nodes))
    num_neighbors = np.zeros(num_nodes)
    with open(data_loc+'edges.csv', 'r') as f:
        for line in tqdm(f.readlines(),total=num_edges):
            i, j = line.strip().split(',')  # csv
            i,j = int(i) - 1,int(j) - 1
            adjacent_matrix[i,j] = num_neighbors[i] + 1
            adjacent_matrix[j,i] = num_neighbors[j] + 1

            num_neighbors[i] += 1
            num_neighbors[j] += 1
    adjacent_matrix = jnp.array(adjacent_matrix)
    return adjacent_matrix
adjacent_matrix = load_data()

100%|██████████| 333983/333983 [00:02<00:00, 143733.50it/s]


[0. 0. 0. ... 0. 0. 0.]


## Visualizing the Graph
To visualize the graph structure of the graph we created above, we will use the [`networkx`](networkx.org) library because it already has functions for drawing graphs.

We first convert the `jraph.GraphsTuple` to a `networkx.DiGraph`.

In [3]:
#The graph is too big to be displayed
#draw_jraph_graph_structure(bc_dataset)

In [88]:
rng_key = jrandom.key(1)

@jit
def get_next_node(params,_):
    current_node,rng_key,adjacent_matrix = params
    neighbors_mask = adjacent_matrix[current_node]
    n_neighbors = (jax.lax.sqrt(jnp.sum(neighbors_mask + 1)) - 1) / 2
    rng_key, subkey = jrandom.split(rng_key)
    next_neighbor_index = jrandom.randint(subkey, [1],0,n_neighbors)
    next_node = jnp.argwhere(neighbors_mask == next_neighbor_index,size=1)[0][0]
    return (next_node,rng_key,adjacent_matrix),next_node


In [91]:
get_next_node(params=(0,rng_key,adjacent_matrix),_=None)[1]

Array(2770, dtype=int32)

In [110]:
@partial(jit,static_argnames=['walk_length'])
def get_walk(initial_node,walk_length,rng_key,adjacent_matrix):
    walk = jax.lax.scan(f=get_next_node,init=(initial_node,rng_key,adjacent_matrix),xs=None,length=walk_length)[1]
    return walk

In [111]:
get_walk(0,5,rng_key,adjacent_matrix)

Array([2770,  745,  499, 3907, 2241], dtype=int32)

In [122]:
def gen_random_walk_tensor(initial_node, walk_length, num_walks,rng_key,adjacent_matrix ):
    walk_batch = np.zeros((num_walks,walk_length))
    for i in range(num_walks):
        walk_batch[i] = get_walk(initial_node=initial_node,walk_length=walk_length,rng_key=rng_key,adjacent_matrix=adjacent_matrix)
    return walk_batch

In [127]:
walk = gen_random_walk_tensor(0,10,10,rng_key,adjacent_matrix)
print(walk)

[[2770.  745.  499. ... 1680.  879. 3660.]
 [2770.  745.  499. ... 1680.  879. 3660.]
 [2770.  745.  499. ... 1680.  879. 3660.]
 ...
 [2770.  745.  499. ... 1680.  879. 3660.]
 [2770.  745.  499. ... 1680.  879. 3660.]
 [2770.  745.  499. ... 1680.  879. 3660.]]


In [None]:
@jit
def gen_batch_random_walk(batch_walk, graph, initial_nodes, length, num_walks, rng_key):
    n_nodes = initial_nodes.shape[0]
    walk = jnp.zeros((num_walks, length), dtype=jnp.int32)
    batch_walk = jnp.zeros((num_walks * n_nodes, length), dtype=jnp.int32)
    for i, n in enumerate(initial_nodes):
        n = n.item()
        rng_key, subkey = jrandom.split(rng_key)
        sub_walk, rng_key = gen_random_walk_tensor(walk, graph, n, length, num_walks, subkey)
        batch_walk = batch_walk.at[num_walks * i:num_walks * (i + 1)].set(sub_walk)
    return walk, batch_walk , rng_key

rng_key = jrandom.key(0)
last_walk, batch_walk, rng_key = gen_batch_random_walk(walk, bc_dataset, jnp.array([0, 1]), 5, 3, rng_key)
print(batch_walk)

TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function gen_batch_random_walk at C:\Users\arthu\AppData\Local\Temp\ipykernel_5292\309023732.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument num_walks.
The error occurred while tracing the function gen_batch_random_walk at C:\Users\arthu\AppData\Local\Temp\ipykernel_5292\309023732.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument length.

In [None]:
def generate_windows(windows, random_walk, window_size):
    num_walks, walk_length = random_walk.shape
    num_windows = walk_length + 1 - window_size
    windows = jnp.zeros((num_walks * num_windows, window_size), dtype=jnp.int32)
    for j in range(num_windows):
        windows = windows.at[num_walks * j:num_walks * (j + 1)].set(random_walk[:, j:j + window_size])
    return windows

windows = jnp.zeros((1,1), dtype=int)
windows = generate_windows(windows,batch_walk, 3)
print(batch_walk.shape)
print(windows.shape)

(6, 5)
(18, 3)


In [None]:
@jit
def get_windows_dotproduct(windows, embedding):
    embedding_size = embedding.shape[1]
    # get the embedding of the initial node repeated num_windows times
    first_emb = embedding[windows[:, 0]]
    first_emb = jnp.expand_dims(first_emb, axis=1)  # Ajouter une nouvelle dimension
    # get the embedding of the remaining nodes in each window
    others_emb = embedding[windows[:, 1:]]
    others_emb = others_emb.reshape(windows.shape[0], -1, embedding_size)
    # result has same shape as others
    # Each element is the dot product between the corresponding node embedding
    # and the embedding of the first node of that walk
    # that is, result_{i, j} for random walk i and element j is v_{W_{i, 0}} dot v_{W_{i, j}}
    result = jnp.sum(first_emb * others_emb, axis=-1)
    return result

vg_get_windows_dotproduct = value_and_grad(get_windows_dotproduct,argnums=1)

embedding_jax = jax.random.normal(rng_key, shape=(12000, 300))
get_windows_dotproduct(windows, embedding_jax)

Array([[  9.989287  ,  14.3554125 ],
       [  0.10503864,  -8.012433  ],
       [  8.171717  ,  22.485273  ],
       [-14.368267  ,  19.513103  ],
       [ -4.4057894 ,  -8.634466  ],
       [-10.255943  ,   3.440627  ],
       [-11.403229  , -12.996873  ],
       [ -1.2675219 ,  17.343773  ],
       [ -6.708728  ,  -7.2031827 ],
       [-19.548513  ,  23.787937  ],
       [  8.799166  ,  29.11709   ],
       [-23.854403  ,  -1.1407948 ],
       [ 37.785007  , -10.838486  ],
       [  9.711826  , -32.64926   ],
       [ -3.9425364 ,   1.50566   ],
       [-11.725664  ,  11.1334    ],
       [  5.2066326 , -15.35693   ],
       [ -1.85783   , -27.490053  ]], dtype=float32)

In [None]:
@jit
def compute_mean_log_sigmoid(windows, embedding):
    dot_product = get_windows_dotproduct(windows, embedding)
    sigmoid_output = 1 / (1 + jnp.exp(-dot_product))  # Sigmoid function
    log_sigmoid_output = jnp.log(sigmoid_output)  # Logarithm
    return jnp.mean(log_sigmoid_output)  # Mean

# Usage example:
mean_log_sigmoid = compute_mean_log_sigmoid(windows, embedding_jax)

In [None]:
def gen_negative_samples(amount, length, initial_node, number_of_nodes,rng_key):
  """Generates negative samples for a random walk process in JAX.

  Args:
    amount: Number of negative samples to generate.
    length: Length of each negative sample walk (number of nodes in the path).
    initial_node: Starting node for all negative samples.
    number_of_nodes: Total number of nodes in the graph.

  Returns:
    A JAX array of shape (amount, length) containing the negative samples.
  """
  negative_samples = jnp.zeros((amount, length), dtype=jnp.int32)  # Use jnp.int32 for node indices
  negative_samples = negative_samples.at[:, 0].set(initial_node)  # Set initial node efficiently
  rng_key, subkey = jrandom.split(rng_key)
  negative_samples = negative_samples.at[:, 1:].set(
      jrandom.randint(rng_key,shape=(amount, length - 1),minval = 0, maxval = number_of_nodes)
  )
  return negative_samples, rng_key

gen_negative_samples(amount=3, length=5, initial_node=0, number_of_nodes=2000,rng_key =rng_key)

(Array([[   0,  825,  911, 1631, 1314],
        [   0,   28, 1399, 1362,  613],
        [   0, 1806,  630, 1652,  360]], dtype=int32),
 Array((), dtype=key<fry>) overlaying:
 [3000548268 4272618543])

In [None]:
def gen_batch_negative_samples(amount, length, initial_nodes, number_of_nodes,rng_key):
  """Generates negative samples for a random walk process in JAX for a batch of initial nodes.

  Args:
    amount: Number of negative samples to generate per initial node.
    length: Length of each negative sample walk (number of nodes in the path).
    initial_nodes: A JAX array of shape (batch_size,) containing initial nodes for each sample.
    number_of_nodes: Total number of nodes in the graph.

  Returns:
    A JAX array of shape (amount * batch_size, length) containing the negative samples.
  """

  # Expand initial_nodes to match amount (amount, batch_size)
  #initial_nodes = jnp.expand_dims(initial_nodes, axis=0).repeat(amount, axis=0)
  n_nodes = initial_nodes.shape[0]
  num_walks = amount
  sub_negative_sample = jnp.zeros((num_walks, length), dtype=jnp.int32)
  batch_negative_sample = jnp.zeros((num_walks * n_nodes, length), dtype=jnp.int32)
  for i, n in enumerate(initial_nodes):
      n = n.item()
      rng_key, subkey = jrandom.split(rng_key)
      sub_negative_sample, rng_key = gen_negative_samples(amount, length, n, number_of_nodes, subkey)
      batch_negative_sample = batch_negative_sample.at[amount * i:amount * (i + 1)].set(sub_negative_sample)
  return batch_negative_sample , rng_key


# Example usage
initial_nodes = jnp.array([0, 1])
a, b =gen_batch_negative_samples(amount=3, length=5, initial_nodes=initial_nodes, number_of_nodes=2000,rng_key = rng_key)
print(a)

[[   0 1180 1940  427  957]
 [   0 1528 1469  905   90]
 [   0 1610  587  474  515]
 [   1 1402   71 1354 1003]
 [   1  205 1527 1098  918]
 [   1  859 1883 1862 1594]]


In [None]:
def generate_batches(array, batch_size):
    """Yield successive batches of size `batch_size` from `array`."""
    for i in range(0, len(array), batch_size):
        yield array[i:i + batch_size]

gen = generate_batches(list(range(101)), 20)
for batch in gen:
    print(batch)

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
[20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39]
[40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59]
[60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79]
[80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]
[100]


In [None]:
from tqdm import tqdm
eps = 1e-15

def deepWalk(graph, walks_per_vertex, walk_length, window_size, embedding_size, num_neg, lr, epochs, batch_size):
    number_of_nodes = jnp.array([jnp.unique(graph.senders).size])
    
    embedding = (torch.randn(size=(number_of_nodes, embedding_size)))
    embedding = jax.random.normal(rng_key, shape=(number_of_nodes, embedding_size))
    embedding.requires_grad = True
    optimizer = torch.optim.SGD([embedding], lr=lr)
    loss_history = {'pos': [], 'neg': [], 'total': []}

    for _ in range(epochs):
        nodes = torch.tensor(list(graph.nodes), dtype=int)
        random.shuffle(nodes)
        node_loader = generate_batches(nodes, batch_size)
        n_batches = int(number_of_nodes / batch_size)
        for n in tqdm(node_loader, total=n_batches):
            random_walk = gen_batch_random_walk(graph, n, walk_length, walks_per_vertex)
            num_windows = walk_length + 1 - window_size

            # Positive Sampling
            # each row of windows is one window, we have B = walks_per_vertex*num_windows windows
            windows = generate_windows(random_walk, window_size)
            batch_dotproduct = get_windows_dotproduct(windows, embedding)
            # takes the sigmoid of the dot product to get probability, then
            # takes the loglik and average through all elements
            pos_loss = -torch.log(torch.sigmoid(batch_dotproduct)+eps).mean()
            # Negative Sampling
            negative_samples = gen_batch_negative_samples(
                amount=num_neg*walks_per_vertex, 
                length=walk_length, 
                initial_nodes=n, 
                number_of_nodes=number_of_nodes
            )
            windows = generate_windows(negative_samples, window_size)
            batch_dotproduct = get_windows_dotproduct(windows, embedding)
            neg_loss = -torch.log(1-torch.sigmoid(batch_dotproduct)+eps).mean()

            loss = pos_loss + neg_loss
            # Optimization
            loss.backward()
            loss_history['total'].append(loss.detach().numpy())
            loss_history['pos'].append(pos_loss.detach().numpy())
            loss_history['neg'].append(neg_loss.detach().numpy())
            optimizer.step()
            optimizer.zero_grad()  

    return embedding, loss_history

In [None]:
import jax
import jax.numpy as jnp
from jax import random
import optax
from tqdm import tqdm


### TO FIX

def deepWalk(graph, walks_per_vertex, walk_length, window_size, embedding_size, num_neg, lr, epochs, batch_size,rng_key ,eps = 1e-15):
    
    number_of_nodes = int(jnp.unique(graph.senders).size)
    
    embedding = jax.random.normal(rng_key, shape=(number_of_nodes, embedding_size))
    optimizer = optax.sgd(learning_rate=lr)
    opt_state = optimizer.init(embedding)
    loss_history = {'pos': [], 'neg': [], 'total': []}

    def loss_fn(params, batch_walk, window_size, eps = 1e-8):
        windows = 0
        windows = generate_windows(windows, batch_walk, window_size)
        batch_dotproduct_positive = get_windows_dotproduct(windows, params)
        windows = generate_windows(windows,batch_negative_sample, window_size)
        batch_dotproduct_negative = get_windows_dotproduct(windows, params)
        pos_loss = -jnp.log(jax.nn.sigmoid(batch_dotproduct_positive) + eps).mean()
        neg_loss = -jnp.log(1 - jax.nn.sigmoid(batch_dotproduct_negative) + eps).mean()
        return pos_loss + neg_loss

    grad_fn = jax.value_and_grad(loss_fn)
    
    def update(params, opt_state, batch_walk, windows):
        

        loss, grads = grad_fn(params, batch_walk, window_size)


        updates, opt_state = optimizer.update(grads, opt_state)
        new_params = optax.apply_updates(params, updates)
        return new_params, opt_state, loss

    for _ in range(epochs):
        nodes = jnp.array(list(graph.nodes), dtype=jnp.int32)
        nodes = jax.random.permutation(rng_key, nodes)
        node_loader = generate_batches(nodes, batch_size)
        n_batches = int(number_of_nodes / batch_size)
        batch_walk = 0
        for n in tqdm(node_loader, total=n_batches):
            walk, batch_walk , rng_key = gen_batch_random_walk(batch_walk, graph, n, walk_length, walks_per_vertex, rng_key)
            batch_negative_sample , rng_key = gen_batch_negative_samples(
                amount=num_neg * walks_per_vertex, 
                length=walk_length, 
                initial_nodes=n, 
                number_of_nodes=number_of_nodes,
                rng_key = rng_key
            )

            embedding, opt_state, loss = update(embedding, opt_state, batch_walk, windows)
            loss_history['total'].append(loss)
            print(loss_history['total'][-1])

    return embedding, loss_history


In [None]:
%prun embedding, loss_history = deepWalk(graph=bc_dataset, walks_per_vertex=1, walk_length=2, window_size=2, embedding_size=128, num_neg=2, lr=0.011, epochs=1, batch_size=256, rng_key = rng_key)

  2%|▎         | 1/40 [00:09<06:28,  9.97s/it]

8.603537


  5%|▌         | 2/40 [00:20<06:35, 10.40s/it]

9.006628


  8%|▊         | 3/40 [00:29<06:03,  9.81s/it]

8.740193


 10%|█         | 4/40 [00:37<05:22,  8.95s/it]

9.136781


 12%|█▎        | 5/40 [00:45<04:57,  8.50s/it]

9.134063


 15%|█▌        | 6/40 [00:57<05:32,  9.78s/it]

8.916544


 18%|█▊        | 7/40 [01:11<06:06, 11.10s/it]

9.042809


 20%|██        | 8/40 [01:23<06:04, 11.39s/it]

8.608324


 22%|██▎       | 9/40 [01:34<05:48, 11.23s/it]

9.828885


 25%|██▌       | 10/40 [01:46<05:51, 11.70s/it]

9.197929


 28%|██▊       | 11/40 [01:55<05:15, 10.87s/it]

9.046522


 30%|███       | 12/40 [02:03<04:38,  9.93s/it]

9.22105


 32%|███▎      | 13/40 [02:11<04:09,  9.25s/it]

9.068176


 35%|███▌      | 14/40 [02:18<03:47,  8.76s/it]

8.764362


 38%|███▊      | 15/40 [02:27<03:39,  8.79s/it]

9.310879


 40%|████      | 16/40 [02:41<04:07, 10.32s/it]

8.77195


 42%|████▎     | 17/40 [02:49<03:41,  9.62s/it]

8.622083


 45%|████▌     | 18/40 [02:57<03:18,  9.03s/it]

8.551861


 48%|████▊     | 19/40 [03:04<03:00,  8.60s/it]

8.176603


 50%|█████     | 20/40 [03:13<02:51,  8.55s/it]

9.464855


 52%|█████▎    | 21/40 [03:21<02:38,  8.32s/it]

7.831616


 55%|█████▌    | 22/40 [03:28<02:26,  8.15s/it]

9.729718


 57%|█████▊    | 23/40 [03:36<02:15,  7.99s/it]

9.454603


 60%|██████    | 24/40 [03:44<02:08,  8.04s/it]

9.109565


 62%|██████▎   | 25/40 [03:52<01:58,  7.93s/it]

9.694206


 65%|██████▌   | 26/40 [03:59<01:49,  7.82s/it]

8.481419


 68%|██████▊   | 27/40 [04:07<01:41,  7.81s/it]

8.555294


 70%|███████   | 28/40 [04:15<01:33,  7.78s/it]

8.92296


 72%|███████▎  | 29/40 [04:22<01:24,  7.71s/it]

8.980722


 75%|███████▌  | 30/40 [04:30<01:16,  7.68s/it]

9.1042385


 78%|███████▊  | 31/40 [04:38<01:09,  7.67s/it]

8.751317


 80%|████████  | 32/40 [04:46<01:03,  7.93s/it]

8.542926


 82%|████████▎ | 33/40 [04:54<00:54,  7.85s/it]

9.545682


 85%|████████▌ | 34/40 [05:02<00:46,  7.81s/it]

8.396461


 88%|████████▊ | 35/40 [05:10<00:39,  7.90s/it]

9.5626135


 90%|█████████ | 36/40 [05:18<00:31,  7.87s/it]

8.836658


 92%|█████████▎| 37/40 [05:25<00:23,  7.84s/it]

9.063208


 95%|█████████▌| 38/40 [05:33<00:15,  7.84s/it]

9.413374


 98%|█████████▊| 39/40 [05:41<00:07,  7.91s/it]

9.2698145


100%|██████████| 40/40 [05:49<00:00,  7.88s/it]

8.934503


41it [05:51,  8.58s/it]                        

9.0179405
 




         145703093 function calls (145656720 primitive calls) in 355.210 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   394017   56.329    0.000   68.471    0.000 dispatch.py:79(apply_primitive)
    20624   52.507    0.003   52.641    0.003 reductions.py:667(cumulative_reduction)
    10312   12.601    0.001   12.671    0.001 reductions.py:222(sum)
  4537824   11.434    0.000   24.877    0.000 dtypes.py:329(issubdtype)
  2026607   11.177    0.000   27.618    0.000 dtypes.py:639(dtype)
    52050   10.528    0.000   10.777    0.000 array_methods.py:259(deferring_binary_op)
    72266   10.063    0.000   23.827    0.000 lax_numpy.py:885(squeeze)
    10312    9.914    0.001  126.030    0.012 lax_numpy.py:1163(bincount)
 15174010    8.296    0.000   11.855    0.000 dtypes.py:315(_issubclass)
    10313    7.851    0.001    7.851    0.001 lax_numpy.py:1211(broadcast_arrays)
    72266    7.251    0.000  178.475    0.002 scatter.p

In [None]:
cumsum_vec = np.cumsum(np.insert(loss_history['total'], 0, 0)) 
window_width = 10
ma_vec = (cumsum_vec[window_width:] - cumsum_vec[:-window_width]) / window_width
plt.plot(ma_vec)

NameError: name 'loss_history' is not defined

In [None]:
embedding

In [None]:
X_norm = (X-X.mean(axis=0)) / X.std(axis=0)

In [None]:
from sklearn.metrics import f1_score
from sklearn.linear_model import LogisticRegression

X = embedding.detach().numpy()
y = bc_dataset['labels']

clf = LogisticRegression(random_state=0, multi_class='ovr').fit(X_norm, y)
y_hat = clf.predict(X_norm)
f1_score(y, y_hat, average='micro')

In [None]:
from sklearn.decomposition import PCA
pca = PCA(n_components=2)
X_t = pca.fit_transform(X)


In [None]:
plt.scatter(X_t[:, 0], X_t[:, 1], c=y)