## Cleora: A Simple, Strong and Scalable Graph Embedding Scheme

https://arxiv.org/abs/2102.02302

https://github.com/Synerise/cleora

The aim of this notebook is to implement Cleora with torch. I've tested it on text data. The embedding computes in 15.7 sec on my Laptop.

In [1]:
import torch
import torch.nn.functional as F

import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics.pairwise import cosine_distances

from datasets import load_dataset
from transformers import GPT2Tokenizer

from tqdm import tqdm
import random

  from .autonotebook import tqdm as notebook_tqdm


### Prepare data

In [2]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

In [3]:
dataset = load_dataset('wikitext', 'wikitext-103-v1')

Found cached dataset wikitext (/Users/piotrgabrys/.cache/huggingface/datasets/wikitext/wikitext-103-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
100%|██████████| 3/3 [00:00<00:00,  4.34it/s]


In [4]:
dataset = [tokenizer(i)['input_ids'] for i in tqdm(dataset['train']['text']) if len(i) > 20]

 36%|███▌      | 647871/1801350 [04:37<08:23, 2290.55it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1059 > 1024). Running this sequence through the model will result in indexing errors
100%|██████████| 1801350/1801350 [12:25<00:00, 2415.97it/s]


In [5]:
dataset2 = []
for sentence in tqdm(dataset):
    if len(sentence) > 10:
        start = random.randint(0, len(sentence) - 10)
        end = start + 10
        dataset2.append(sentence[start:end])
    else:
        dataset2.append(sentence)

100%|██████████| 1081822/1081822 [00:04<00:00, 252803.61it/s]


In [6]:
df = pd.Series(dataset2).explode().to_frame().reset_index()
df.columns = ['sentence_id', 'token']
df.to_csv('sentences.csv', index=False)

In [7]:
df = pd.read_csv('sentences.csv')

In [8]:
le = LabelEncoder()
le.fit(df.token)
df.token = le.transform(df.token)

In [9]:
df = pd.merge(df, df, left_on='sentence_id', right_on='sentence_id')
df2 = df.groupby(['token_x', 'token_y']).count()
df2.reset_index(inplace=True)
df2.columns = ['token_x', 'token_y', 'cnt']
total_cnt = df2.groupby('token_x').sum()['cnt']
df2['total_cnt'] = df2.token_x.map(total_cnt)
df2['prob'] = df2.cnt / df2.total_cnt

In [10]:
df2.head()

Unnamed: 0,token_x,token_y,cnt,total_cnt,prob
0,0,0,4,40,0.1
1,0,2,1,40,0.025
2,0,15,1,40,0.025
3,0,51,1,40,0.025
4,0,55,2,40,0.05


### Initialize embeddings and create sparse adjacency matrix

In [11]:
emb_size = 100

indices = torch.tensor(df2.loc[:, ['token_x', 'token_y']].values).T
values = torch.tensor(df2.loc[:, 'prob'].values)
size = indices.max().item() + 1
transition_matrix = torch.sparse_coo_tensor(indices=indices, values=values, size=(size, size), dtype=torch.float32, requires_grad=False)
transition_matrix = transition_matrix.to_sparse_csr()

torch.random.manual_seed(42)
embedding = torch.randn((size, emb_size), dtype=torch.float32, requires_grad=False)
embedding = F.normalize(embedding, p=2, dim=1)
embedding = embedding.to_sparse_csr()

  transition_matrix = transition_matrix.to_sparse_csr()


In [12]:
del df, df2, total_cnt

### Compute embeddings

In [13]:
for i in range(5):
    print(i)
    embedding = transition_matrix @ embedding
    embedding = F.normalize(embedding.to_dense(), p=2, dim=1).to_sparse_csr()
embedding = embedding.to_dense()

0
1
2
3
4


### Explore embeddings

In [15]:
embedding = pd.DataFrame(embedding.numpy(), index=le.classes_)
words = [tokenizer.decode(i) for i in le.classes_]
embedding.index = words

In [16]:
embedding

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,90,91,92,93,94,95,96,97,98,99
',-0.078871,0.101386,-0.122783,-0.128704,0.013432,-0.090238,0.068177,0.072422,-0.012252,0.019230,...,0.006720,0.042587,0.086443,0.148699,0.060983,-0.105272,0.056365,0.101028,-0.084708,0.096045
",",-0.077256,0.111697,-0.114209,-0.130592,0.021396,-0.083489,0.059185,0.075306,-0.013261,0.027371,...,0.004992,0.045245,0.088574,0.153976,0.060455,-0.100197,0.057497,0.105100,-0.077266,0.095184
-,-0.078081,0.106146,-0.120282,-0.128122,0.017603,-0.087731,0.064212,0.076965,-0.015866,0.019608,...,0.006208,0.047714,0.088965,0.149447,0.057769,-0.102568,0.059099,0.106120,-0.082972,0.095846
.,-0.079771,0.102860,-0.121698,-0.126472,0.017008,-0.088676,0.064447,0.075584,-0.016992,0.018370,...,0.006672,0.049668,0.090899,0.147907,0.057339,-0.103476,0.059010,0.105894,-0.085030,0.097719
0,-0.079406,0.103333,-0.120526,-0.127985,0.017730,-0.087394,0.063860,0.074601,-0.014871,0.020180,...,0.006794,0.047096,0.090067,0.149544,0.059204,-0.103544,0.058020,0.103905,-0.082364,0.095828
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
amplification,-0.079749,0.097286,-0.127882,-0.123563,0.013753,-0.092153,0.068693,0.079774,-0.020657,0.009400,...,0.008155,0.052868,0.092107,0.143906,0.054667,-0.105253,0.060384,0.109164,-0.090111,0.097654
ominated,-0.078943,0.101755,-0.120527,-0.132504,0.013220,-0.090621,0.068843,0.071721,-0.007848,0.024456,...,0.004701,0.036491,0.082882,0.150605,0.065203,-0.105306,0.054782,0.098239,-0.081994,0.093821
regress,-0.078452,0.103818,-0.118927,-0.134839,0.012833,-0.090133,0.070644,0.069622,-0.003385,0.029162,...,0.003307,0.030045,0.078554,0.152287,0.068324,-0.104958,0.053035,0.095629,-0.079955,0.092081
Collider,-0.082631,0.086775,-0.138616,-0.106659,0.015481,-0.092726,0.064990,0.090261,-0.043482,-0.016299,...,0.014213,0.083852,0.110063,0.131640,0.033698,-0.105092,0.070027,0.124329,-0.103159,0.104426


In [17]:
dists = cosine_distances(embedding.loc[['king'], :].values - embedding.loc[['man'], :].values + embedding.loc[['woman'], :].values, embedding.values)

In [18]:
pd.Series(dists.ravel(), index=embedding.index).sort_values()[:50]

 code          0.000002
 plains        0.000002
inating        0.000003
 dem           0.000003
 tablets       0.000003
 kits          0.000004
 software      0.000004
 networking    0.000004
 affiliated    0.000004
 tracking      0.000004
inth           0.000004
 radio         0.000004
 Haitian       0.000004
 lact          0.000004
 lasers        0.000004
human          0.000004
inate          0.000004
 loop          0.000004
 bases         0.000005
 photo         0.000005
 theatrical    0.000005
verts          0.000005
path           0.000005
 telescopes    0.000005
uffs           0.000005
 text          0.000005
plane          0.000005
 multip        0.000005
 educ          0.000005
 corridor      0.000005
 secondary     0.000005
 lenses        0.000006
 targeting     0.000006
 Irish         0.000006
group          0.000006
analy          0.000006
 aph           0.000006
orders         0.000006
hers           0.000006
horse          0.000006
 languages     0.000006
etric          0