In [None]:
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm

import numpy as np
import torch
import torch.nn
from transformers import BertModel, BertTokenizer
import faiss
import scipy
from scipy.sparse import coo_matrix
from matplotlib import pyplot as plt

In [None]:
gamma = 0.1

In [None]:
# load embeddings and get knn

model = BertModel.from_pretrained("bert-base-uncased")
embeds = model.embeddings.word_embeddings.weight
embeds = embeds.detach().cpu().numpy()

norms = np.linalg.norm(embeds, axis=1, keepdims=True)
embeds_normalized = embeds / norms

print("Constructing nearest neighbor matrix...")

k = 10
index = faiss.IndexFlatIP(embeds.shape[1]) 
index.add(embeds_normalized)
distances, indices = index.search(embeds_normalized, k+1)
distances = distances[:, 1:]
indices = indices[:, 1:]

In [None]:
# get vocab
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
vocab = np.array(list(tokenizer.get_vocab().keys()))
unused = np.array(['[unused' in key for key in vocab])

In [None]:
# plot knn
plt.figure(figsize=[4, 4])
_, bins, _ = plt.hist(distances[:, 0][~unused], bins=100, color='red', alpha=0.5, label='closest');
_, bins, _ = plt.hist(distances[:, -1][~unused], bins=100, color='blue', alpha=0.5, label='furthest');
plt.hist(distances[:, -1][unused], bins=bins, color='black', label='unused tokens');
plt.ylabel("frequency")
plt.xlabel("neighbour similarity")
plt.xlim(0, 1)
plt.legend()

plt.figure(figsize=[6, 2])
plt.plot(distances.min(-1), color='black')
plt.xlabel("token number")
plt.ylabel("min neighbour sim")
plt.ylim(0, 1)

In [None]:
row_indices = np.repeat(np.arange(embeds.shape[0]), k) 
col_indices = indices.flatten()
dot_products = distances.flatten()
assert (dot_products > 0).all()
assert (row_indices != col_indices).all()
row_indices = np.r_[row_indices, np.arange(len(embeds))]
col_indices = np.r_[col_indices, np.arange(len(embeds))]
rates = distances.sum(-1)
dot_products = np.r_[dot_products, -rates] / rates.max()

sparse_matrix = coo_matrix((dot_products, (row_indices, col_indices)), shape=(embeds.shape[0], embeds.shape[0]))
sparse_matrix_csr = sparse_matrix.tocsr()

print("Finished constructing nearest neighbor matrix...")

L = sparse_matrix_csr
rate = - (L.diagonal().min()) / (1-gamma) 
K = L / rate + scipy.sparse.eye(L.shape[0])

In [None]:
g = np.random.randn(len(embeds))
%timeit K @ (K @ (K @ (K @ g)))

K_power = K @ K @ K @ K
%timeit K_power @ g

#### data

In [None]:
from omegaconf import OmegaConf
import data
cfg = OmegaConf.load('configs/basic_language.yaml')
train_dataloader, test_dataloader = data.get_dataloaders(cfg)

datum = next(iter(train_dataloader))
tokenizer.decode(datum['input_ids'][0])

In [None]:
inds = np.random.randint(len(embeds), size=(128, 1000))
%timeit K[inds.ravel(), :].toarray().reshape(*inds.shape, K.shape[1])

In [None]:
inds = np.random.randint(len(embeds), size=(128, 1000))
%timeit (K@K[:, inds.ravel()]).toarray().reshape(K.shape[0], *inds.shape)
%timeit K@(K[:, inds.ravel()].toarray()).reshape(K.shape[0], *inds.shape)

In [None]:
num_powers = 20
current_prod = scipy.sparse.eye(K.shape[0])
K_powers = [current_prod]
for _ in range(num_powers):
    current_prod = current_prod @ K
    K_powers.append(current_prod)
for i in range(num_powers):
    scipy_coo = K_powers[i].tocoo()
    row = torch.from_numpy(scipy_coo.row.astype(np.int64))
    col = torch.from_numpy(scipy_coo.col.astype(np.int64))
    data = torch.from_numpy(scipy_coo.data)
    indices = torch.stack([row, col], dim=0)
    shape = scipy_coo.shape
    torch_sparse_tensor = torch.sparse_coo_tensor(indices, data, size=shape)
    K_powers[i] = torch_sparse_tensor