In [None]:
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm

import numpy as np
import torch
import torch.nn
from transformers import BertModel, BertTokenizer, BertForMaskedLM
from transformers import GPT2TokenizerFast, GPT2LMHeadModel

import faiss
import scipy
from scipy.sparse import coo_array
from matplotlib import pyplot as plt
import torch
import scipy.sparse as sparse
import numpy as np
import time

from IPython.display import clear_output


In [None]:
gamma = 0.1

In [None]:
# load embeddings and get knn

# model = BertForMaskedLM.from_pretrained("bert-base-uncased")
# embeds = model.cls.predictions.decoder.weight#model.embeddings.word_embeddings.weight
embeds = BertModel.from_pretrained("bert-base-uncased").embeddings.word_embeddings.weight
embeds = embeds.detach().cpu().numpy()
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
vocab = np.array(list(tokenizer.get_vocab().keys()))

# tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
# lm_head = GPT2LMHeadModel.from_pretrained('gpt2').lm_head
# embeds = lm_head.weight.detach().cpu().numpy()
# vocab = np.array(list(tokenizer.get_vocab().keys()))

norms = np.linalg.norm(embeds, axis=1, keepdims=True)
embeds_normalized = embeds / norms#np.maximum(1, norms)

unused = np.array(['[unused' in key for key in vocab])
print("Constructing nearest neighbor matrix...")

english_alphabet = [
    'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
    'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
    'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
    'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'
]

k = 20
strong_masking = False
is_unused = ((strong_masking) * np.array([t.startswith("[") and t.endswith("]") for t in vocab])).astype(bool)
is_suffix = ((strong_masking) * np.array([t.startswith("##") for t in vocab])).astype(bool)
not_english = ((strong_masking) * (~np.array([all([x in english_alphabet for x in t]) for t in vocab]))).astype(bool)
is_number = (np.array([any([x in np.arange(10).astype(str) for x in t]) for t in vocab])).astype(bool)
is_normal = ~np.any([is_unused, is_suffix, is_number, not_english], axis=0)
masking = ([(is_number, is_number+is_normal), (is_normal, is_normal)] if not strong_masking
    else [(not_english, not_english), (is_unused, is_unused), (is_suffix, is_suffix), (is_number, is_number), (is_normal, is_normal)])
indices = np.empty([len(embeds), k])
distances = np.empty([len(embeds), k])
range_ = np.arange(len(embeds))
# for mask in tqdm([is_unused, is_suffix, is_number, is_normal]):
#     search_mask = mask
#     if mask.sum()>0:
#         index = faiss.IndexFlatIP(embeds.shape[1]) 
#         index.add(embeds_normalized[search_mask])
#         distances_temp, indices_temp = index.search(embeds_normalized[mask], k+1)
#         distances[mask] = distances_temp[:, 1:]
#         indices[mask] = range_[search_mask][indices_temp[:, 1:]]
for mask, search_mask in tqdm(masking):
    if mask.sum()>0:
        index = faiss.IndexFlatIP(embeds.shape[1]) 
        index.add(embeds_normalized[search_mask])
        print("added")
        distances_temp, indices_temp = index.search(embeds_normalized[mask], k+1)
        distances[mask] = distances_temp[:, 1:]
        indices[mask] = range_[search_mask][indices_temp[:, 1:]]
print([m.sum() for m in [is_unused, is_suffix, is_number, is_normal]])

In [None]:
len(vocab)

In [None]:
max_ = 1 # max_ = 2.1
# plot knn
plt.figure(figsize=[4, 4])
_, bins, _ = plt.hist(distances[:, 0][~unused], bins=100, color='red', alpha=0.5, label='closest');
_, bins, _ = plt.hist(distances[:, -1][~unused], bins=100, color='blue', alpha=0.5, label='furthest');
plt.hist(distances[:, -1][unused], bins=bins, color='black', label='unused tokens');
plt.ylabel("frequency")
plt.xlabel("neighbour similarity")
plt.xlim(0, max_**2)
plt.legend()

plt.figure(figsize=[6, 2])
plt.plot(distances.min(-1), color='black')
plt.xlabel("token number")
plt.ylabel("min neighbour sim")
plt.ylim(0, max_**2)

plt.figure(figsize=[6, 2])
plt.plot(norms, color='black')
plt.xlabel("token number")
plt.ylabel("min neighbour sim")
# plt.ylim(0, 1)

In [None]:
# examples
inds = np.random.randint(len(embeds), size=5)
for ind in inds:
    print(tokenizer.decode([ind]), ":", ' '.join(vocab[indices[ind].astype(int)]))
    print(distances[ind])

In [None]:
make_sym = False
bandwidth = 0.2 / np.log(2)
normalize = False

Say $K$ is a symmetric stochastic matrix, the vector of all ones is an eigenvector, so the generator for the uniform process commutes with $K$.
If $\mathcal L$ is the generator with $-1$ diagonal and we add $w\mathcal L$ to $K$ then we subtract $w\frac N {N-1}$ from all but the top eval.
Since $K$ is degenerate, this ends up being the eigenvalue gap.

In [None]:
def get_proc(gamma, normalize, make_sym, bandwidth, k):
    row_indices = np.repeat(np.arange(embeds.shape[0]), k) 
    col_indices = indices[:, :k].flatten()
    dot_products = distances[:, :k].flatten()
    # rates = distances.sum(-1)
    assert (dot_products > 0).all()
    assert (row_indices != col_indices).all()
    if make_sym:
        row_indices, col_indices = np.r_[row_indices, col_indices], np.r_[col_indices, row_indices]
        dot_products = np.r_[dot_products, dot_products]
    dot_products = np.exp((1 - dot_products) / bandwidth)
    
    sparse_matrix = coo_array((dot_products, (row_indices, col_indices)), shape=(embeds.shape[0], embeds.shape[0]))
    sparse_matrix.sum_duplicates()
    sparse_matrix_csr = sparse_matrix.tocsr()
    L_diag = - sparse_matrix_csr.sum(-1)
    L_off_diag = sparse_matrix_csr
    if normalize:
        L_off_diag = coo_array(L_off_diag / (-L_diag)[:, None])
        L_diag = -1 + 0 * L_diag
    rate = - (L_diag.min()) / (1-gamma) 
    L = L_off_diag / rate + scipy.sparse.diags(L_diag / rate)
    K = L_off_diag / rate + scipy.sparse.diags(L_diag / rate + 1)
    L = L.tocoo()
    K = K.tocoo()
    
    K_gpu = torch.sparse_coo_tensor((K.row, K.col), K.data, size=(embeds.shape[0], embeds.shape[0])).float().cuda()
    K_gpu = K_gpu.to_sparse_csr()
    L_gpu = torch.sparse_coo_tensor((L.row, L.col), L.data, size=(embeds.shape[0], embeds.shape[0])).float().cuda()
    L_gpu = L_gpu.to_sparse_csr()
    
    L = L.tocsr()
    K = K.tocsr()
    
    K_coo = K_gpu.to_sparse_coo()
    K_T = K_coo.transpose(0, 1).coalesce().to_sparse_csr()
    return L, K, K_gpu, L_gpu, K_coo, K_T

L, K, K_gpu, L_gpu, K_coo, K_T = get_proc(gamma, normalize, make_sym, bandwidth, k)

In [None]:
stationary = scipy.sparse.linalg.eigs(K.T, 1, which='LR')[1][:, 0]
stationary = torch.tensor(stationary).float().cuda()
    assert torch.isclose((stationary**2).sum(), torch.ones_like((stationary**2).sum()))


In [None]:
import gc

torch.cuda.empty_cache()
gc.collect()

In [None]:
plt.hist(np.log(-L.diagonal()))

In [None]:
x_0.sum()

In [None]:
x_0 = torch.tensor(np.ones([1, K.shape[0]]), device='cuda').float()
pbar = tqdm(range(30000))
for i in pbar:
    x_0_new = x_0 + (L_T @ x_0.T).T + 0.001 * (x_0.mean(1) - x_0)
    err = torch.sqrt(((x_0_new- x_0) ** 2).sum())
    if torch.allclose(err, torch.zeros_like(err)):
        break
    x_0 = x_0_new
    if i%1000 == 0:
        pbar.set_description(f"err:{err.item()}")
plt.plot(x_0.cpu().T)

In [None]:
#try mult
x_0 = torch.tensor(np.random.randn(16, 1024, K.shape[0]), device='cuda').float()
S = torch.tensor(np.random.randint(300, size=x_0.shape[:2]), device='cuda').long()

def K_power_mult(S, x_0, period=1):
    shape = x_0.shape
    x_0 = x_0.reshape(-1, x_0.shape[-1]).T
    curr_liks = x_0
    liks = torch.ones_like(x_0)
    curr_S = S.reshape(-1)
    pbar = tqdm(total=curr_S.max().item(), unit="iteration",
                position=0, leave=True)
    while torch.any(curr_S > 0):
        active = curr_S >= 0
        liks[:, curr_S == 0] = curr_liks[:, (curr_S == 0)[active]]
        if curr_liks.shape[-1] == 1:
            if not all((curr_S > 0)[active]):
                break
        else:
            curr_liks = curr_liks[:, (curr_S > 0)[active]]
        probs = K_gpu @ curr_liks
        # x_curr = sample_probs(probs)
        curr_S = curr_S - 1
        pbar.update(1)
    if curr_liks.shape[-1] > 0:
        liks[:, curr_S == 0] = curr_liks
    return liks.T.reshape(shape)

start_time = time.time()
K_power_mult(S, x_0)
torch.cuda.synchronize()
time.time() - start_time

In [None]:
import math

# try mi
mat_all = torch.eye(K.shape[0]).float().cuda()
p0 = torch.randn(K.shape[0]).float().cuda() ** 2
p0 = p0 / p0.sum()
ent_p0 = -torch.xlogy(p0, p0).sum()

batch_size = 2000
mis = torch.ones(K.shape[0])
for j in tqdm(range(math.ceil(K.shape[0] / batch_size))):
    mat = mat_all[:, j*batch_size:(j+1)*batch_size]
    for i in range(1000):
        # p = p0[:, None] * mat
        # p = torch.where(p < 0, 0, p)
        # p_sum = p.sum(0)
        # mi = (torch.xlogy(p, p).sum() - torch.xlogy(p_sum, p_sum).sum()) / ent_p0
        # mis[i] = mis[i] + mi
    
        # stat_part = stationary @ mat
        # diff = mat - stat_part * stationary[:, None]
        mat = K_T @ mat
        torch.cuda.synchronize()

In [None]:
plt.plot(torch.tensor(mis).cpu())

In [None]:
print(K_gpu.values().sum())
print((K_gpu @ torch.ones(K.shape[0]).cuda()).sum())

In [None]:
K_gpu.to_sparse_csr().float()@ torch.ones([K.shape[0], 1]).cuda()

In [None]:
print(l)
print(u.sum(0))

In [None]:
x_0 = torch.tensor(np.random.randint(len(embeds), size=[16, 1024]), device='cuda').long()
S = torch.tensor(np.random.randint(1000, size=x_0.shape), device='cuda').long()

def sample_probs(p):
    r, c = p.shape
    p[:, 0] = 0
    data_cs = p.data.cumsum()
    rows_cs = np.r_[[0], data_cs[p.indptr[1:] - 1]]
    data_cs = ((data_cs - np.repeat(rows_cs[:-1], np.diff(p.indptr)))
               / np.repeat(rows_cs[1:], np.diff(p.indptr))
               - np.repeat(np.random.rand(r)/rows_cs[1:], np.diff(p.indptr)))
    return p.indices[np.r_[[False], np.diff((data_cs >= 0).astype(int)) > 0]]

def f(S, x_0, period=1):
    shape = x_0.shape
    x_0 = x_0.flatten().cpu().numpy()
    x_curr = x_0
    x_t = np.ones_like(x_0)
    curr_S = S.flatten().cpu().numpy()
    pbar = tqdm(total=curr_S.max(), unit="iteration",
                position=0, leave=True)
    while any(curr_S > 0):
        active = curr_S >= 0
        x_t[curr_S == 0] = x_curr[(curr_S == 0)[active]]
        if len(x_curr) == 1:
            if not all((curr_S > 0)[active]):
                break
        else:
            x_curr = x_curr[(curr_S > 0)[active]]
        probs = K[x_curr]
        # print(x_curr.dtype, sample_probs(probs).dtype)
        x_curr = sample_probs(probs)
        curr_S = curr_S - 1
        pbar.update(1)
    if len(x_curr) > 0:
        x_t[curr_S == 0] = x_curr
    return x_t.reshape(shape)

In [None]:
x_t = f(torch.arange(2000).cuda(), 2000 * torch.ones(2000).cuda())

In [None]:
x_t#, x_0

In [None]:
A = L.to_dense()
print(A.dtype, A.device)

In [None]:
g = torch.randn([len(embeds), 10], device=device)
L.to_dense() @ g

In [None]:
g = torch.randn([len(embeds), 10], device=device)
%timeit L @ g

In [None]:
plt.imshow(L.toarray()[2000:2100, 2000:2100], vmin=-1, vmax=1, cmap='bwr')

In [None]:
scipy.sparse.linalg.eigs(K, k=6)

In [None]:
inds = np.random.randint(len(embeds), size=(32, 1024))
%timeit K[inds.ravel(), :].toarray().reshape(*inds.shape, K.shape[1])
%timeit K.T[inds.ravel(), :].toarray().reshape(*inds.shape, K.shape[1])

#### look at data

In [None]:
from omegaconf import OmegaConf
import data
cfg = OmegaConf.load('configs/basic_language.yaml')
train_dataloader, test_dataloader = data.get_dataloaders(cfg)

datum = next(iter(train_dataloader))
[tokenizer.decode(t) for t in datum['input_ids'][0].reshape(-1, 128)]

In [None]:
gamma = 0.99
normalize = True
make_sym = True
bandwidth = 0.3
k = 20
*_, K_coo, __ = get_proc(gamma, normalize, make_sym, bandwidth, k)
K_coo = K_coo.cpu()
N = K_T.shape[0]

In [None]:
up = 0.0

x = datum['input_ids'][0]
for i in range(1000):
    k_proc = K_coo.index_select(0, x).to_dense()
    x = (1-up) * k_proc + (up / N) * (1-gamma)
    x = torch.multinomial(x, num_samples=1, replacement=True).squeeze(-1)
    clear_output(wait=True)
    print(i * (1-gamma), '\n', tokenizer.decode(x[:20]))

In [None]:
inds = np.random.randint(len(embeds), size=(128, 1000))
%timeit (K@K[:, inds.ravel()]).toarray().reshape(K.shape[0], *inds.shape)
%timeit K@(K[:, inds.ravel()].toarray()).reshape(K.shape[0], *inds.shape)

In [None]:
num_powers = 20
current_prod = scipy.sparse.eye(K.shape[0])
K_powers = [current_prod]
for _ in range(num_powers):
    current_prod = current_prod @ K
    K_powers.append(current_prod)
for i in range(num_powers):
    scipy_coo = K_powers[i].tocoo()
    row = torch.from_numpy(scipy_coo.row.astype(np.int64))
    col = torch.from_numpy(scipy_coo.col.astype(np.int64))
    data = torch.from_numpy(scipy_coo.data)
    indices = torch.stack([row, col], dim=0)
    shape = scipy_coo.shape
    torch_sparse_tensor = torch.sparse_coo_tensor(indices, data, size=shape)
    K_powers[i] = torch_sparse_tensor

### other code

In [None]:
class all_ones(cola.ops.operator_base.LinearOperator):
    def _matmat(self, v):
        return v.sum(0, keepdim=True)

dtype = torch.float32
device = 'cpu'
N = len(embeds)
# weight = torch.tensor(1/20000, dtype=dtype, device=device)
# L = cola.ops.Sparse(torch.tensor(dot_products).to(dtype).to(device),
#                     torch.tensor(row_indices).to(dtype).to(device),
#                     torch.tensor(col_indices).to(dtype).to(device),
#                     shape=(N, N)) 
# L = cola.ops.Dense(L.to_dense())
# ones = all_ones(dtype, (N, N))
# ones.device = L.device
# L = L #+ weight * (ones - N * cola.ops.I_like(L))
# rate = (torch.tensor(rates / rates.max(), dtype=dtype, device=device) + (N-1) * weight).max() / (1-gamma)
# K = L / rate + cola.ops.I_like(L)

l, u = cola.linalg.eig(K, 1)
l2, u2 = cola.linalg.eig(K - cola.ops.Dense(u)@cola.ops.Dense(u.T), 1)
print(l, l2)