In [None]:
import cola
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm

import numpy as np
import torch
import torch.nn
from transformers import BertModel, BertTokenizer, BertForMaskedLM
import faiss
import scipy
from scipy.sparse import coo_matrix, coo_array
from matplotlib import pyplot as plt

In [None]:
gamma = 0.1

In [None]:
# load embeddings and get knn

model = BertForMaskedLM.from_pretrained("bert-base-uncased")
embeds = model.cls.predictions.decoder.weight#model.embeddings.word_embeddings.weight
embeds = embeds.detach().cpu().numpy()

norms = np.linalg.norm(embeds, axis=1, keepdims=True)
embeds_normalized = embeds / norms#np.maximum(1, norms)

# get vocab
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
vocab = np.array(list(tokenizer.get_vocab().keys()))
unused = np.array(['[unused' in key for key in vocab])
print("Constructing nearest neighbor matrix...")

k = 10
is_unused = np.array([t.startswith("[") and t.endswith("]") for t in vocab])
is_suffix = np.array([t.startswith("##") for t in vocab])
is_number = np.array([all([x in np.arange(10).astype(str) for x in t]) for t in vocab])
is_normal = ~np.any([is_unused, is_suffix, is_number], axis=0)
indices = np.empty([len(embeds), k])
distances = np.empty([len(embeds), k])
range_ = np.arange(len(embeds))
for mask in tqdm([is_unused, is_suffix, is_number, is_normal]):
    index = faiss.IndexFlatIP(embeds.shape[1]) 
    index.add(embeds_normalized[mask])
    distances_temp, indices_temp = index.search(embeds_normalized[mask], k+1)
    distances[mask] = distances_temp[:, 1:]
    indices[mask] = range_[mask][indices_temp[:, 1:]]

In [None]:
max_ = 1 # max_ = 2.1
# plot knn
plt.figure(figsize=[4, 4])
_, bins, _ = plt.hist(distances[:, 0][~unused], bins=100, color='red', alpha=0.5, label='closest');
_, bins, _ = plt.hist(distances[:, -1][~unused], bins=100, color='blue', alpha=0.5, label='furthest');
plt.hist(distances[:, -1][unused], bins=bins, color='black', label='unused tokens');
plt.ylabel("frequency")
plt.xlabel("neighbour similarity")
plt.xlim(0, max_**2)
plt.legend()

plt.figure(figsize=[6, 2])
plt.plot(distances.min(-1), color='black')
plt.xlabel("token number")
plt.ylabel("min neighbour sim")
plt.ylim(0, max_**2)

plt.figure(figsize=[6, 2])
plt.plot(norms, color='black')
plt.xlabel("token number")
plt.ylabel("min neighbour sim")
# plt.ylim(0, 1)

In [None]:
# examples
inds = np.random.randint(len(embeds), size=5)
for ind in inds:
    print(tokenizer.decode([ind]), ":", ' '.join(vocab[indices[ind].astype(int)]))

In [None]:
row_indices = np.repeat(np.arange(embeds.shape[0]), k) 
col_indices = indices.flatten()
dot_products = distances.flatten()
assert (dot_products > 0).all()
assert (row_indices != col_indices).all()
row_indices = np.r_[row_indices, np.arange(len(embeds))]
col_indices = np.r_[col_indices, np.arange(len(embeds))]
rates = distances.sum(-1)
dot_products = np.r_[dot_products, -rates] / rates.max()

class all_ones(cola.ops.operator_base.LinearOperator):
    def _matmat(self, v):
        return v.sum(0, keepdim=True)
    
dtype = torch.float32
device = 'cpu'
N = len(embeds)
# weight = torch.tensor(1/20000, dtype=dtype, device=device)
# L = cola.ops.Sparse(torch.tensor(dot_products).to(dtype).to(device),
#                     torch.tensor(row_indices).to(dtype).to(device),
#                     torch.tensor(col_indices).to(dtype).to(device),
#                     shape=(N, N)) 
# L = cola.ops.Dense(L.to_dense())
# ones = all_ones(dtype, (N, N))
# ones.device = L.device
# L = L #+ weight * (ones - N * cola.ops.I_like(L))
# rate = (torch.tensor(rates / rates.max(), dtype=dtype, device=device) + (N-1) * weight).max() / (1-gamma)
# K = L / rate + cola.ops.I_like(L)

sparse_matrix = coo_array((dot_products, (row_indices, col_indices)), shape=(embeds.shape[0], embeds.shape[0]))
sparse_matrix_csr =sparse_matrix.tocsr()
L = sparse_matrix_csr
rate = - (L.diagonal().min()) / (1-gamma) 
K = L / rate + scipy.sparse.eye(L.shape[0])



In [None]:
l, u = cola.linalg.eig(K, 1)
l2, u2 = cola.linalg.eig(K - cola.ops.Dense(u)@cola.ops.Dense(u.T), 1)
print(l, l2)

In [None]:
x_0 = torch.tensor(np.random.randint(len(embeds), size=[16, 1024]), dtype=torch.int32, device=device)
S = torch.tensor(np.random.randint(1000, size=x_0.shape), dtype=torch.int32, device=device)

def sample_probs(probs):
    n_rows, n_cols = probs.shape
    u = np.random.random(n_rows)
    csc = probs.tocsc()
    cumsum = csc.cumsum(axis=1)
    samples = (u <= cumsum.toarray()).argmax(axis=1)
    return samples


def f(S, x_0, period=1):
    shape = x_0.shape
    x_0 = x_0.flatten().numpy()
    x_curr = x_0
    x_t = np.ones_like(x_0)
    curr_S = S.flatten().numpy()
    pbar = tqdm(total=curr_S.sum(), unit="iteration",
                position=0, leave=True)
    while any(curr_S > 0):
        active = curr_S >= 0
        x_t[curr_S == 0] = x_curr[(curr_S == 0)[active]]
        if len(x_curr) == 1:
            if not all((curr_S > 0)[active]):
                break
        else:
            x_curr = x_curr[(curr_S > 0)[active]]
        probs = K[x_curr]
        # x_curr = sample_probs(probs)
        curr_S = curr_S - 1
        pbar.update(int(np.array((curr_S >= 0).sum())))
    if len(x_curr) > 0:
        x_t[curr_S == 0] = x_curr
    return x_t.reshape(shape)

In [None]:
x_t = f(S, x_0)

In [None]:
x_t, x_0

In [None]:
A = L.to_dense()
print(A.dtype, A.device)

In [None]:
g = torch.randn([len(embeds), 10], device=device)
L.to_dense() @ g

In [None]:
g = torch.randn([len(embeds), 10], device=device)
%timeit L @ g

In [None]:
plt.imshow(L.toarray()[2000:2100, 2000:2100], vmin=-1, vmax=1, cmap='bwr')

In [None]:
scipy.sparse.linalg.eigs(K, k=6)

In [None]:
inds = np.random.randint(len(embeds), size=(32, 1024))
%timeit K[inds.ravel(), :].toarray().reshape(*inds.shape, K.shape[1])
%timeit K.T[inds.ravel(), :].toarray().reshape(*inds.shape, K.shape[1])

#### look at data

In [None]:
from omegaconf import OmegaConf
import data
cfg = OmegaConf.load('configs/basic_language.yaml')
train_dataloader, test_dataloader = data.get_dataloaders(cfg)

datum = next(iter(train_dataloader))
[tokenizer.decode(t) for t in datum['input_ids'][0].reshape(-1, 128)]

In [None]:
inds = np.random.randint(len(embeds), size=(128, 1000))
%timeit (K@K[:, inds.ravel()]).toarray().reshape(K.shape[0], *inds.shape)
%timeit K@(K[:, inds.ravel()].toarray()).reshape(K.shape[0], *inds.shape)

In [None]:
num_powers = 20
current_prod = scipy.sparse.eye(K.shape[0])
K_powers = [current_prod]
for _ in range(num_powers):
    current_prod = current_prod @ K
    K_powers.append(current_prod)
for i in range(num_powers):
    scipy_coo = K_powers[i].tocoo()
    row = torch.from_numpy(scipy_coo.row.astype(np.int64))
    col = torch.from_numpy(scipy_coo.col.astype(np.int64))
    data = torch.from_numpy(scipy_coo.data)
    indices = torch.stack([row, col], dim=0)
    shape = scipy_coo.shape
    torch_sparse_tensor = torch.sparse_coo_tensor(indices, data, size=shape)
    K_powers[i] = torch_sparse_tensor