In [None]:
import torch 
import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as plt

def generate_vectors(T, d, M, max_iter=1000, tol=1e-6):
    """
    Generate T unit-normed vectors in d dimensions with mutual coherence M.

    Parameters:
    T (int): Number of vectors to generate.
    d (int): Dimension of each vector.
    M (float): Desired mutual coherence.
    max_iter (int): Maximum number of iterations to adjust vectors.
    tol (float): Tolerance for mutual coherence convergence.

    Returns:
    numpy.ndarray: An array of shape (T, d) containing the unit-normed vectors.
    """
    # Initialize T random unit vectors
    vectors = np.random.randn(T, d)
    vectors /= np.linalg.norm(vectors, axis=1, keepdims=True)
    
    for iteration in tqdm(range(max_iter)):
        # Compute the Gram matrix (inner products between vectors)
        G = np.dot(vectors, vectors.T)
        np.fill_diagonal(G, 0)  # Exclude self-inner products
        
        # Find the maximum absolute inner product (mutual coherence)
        max_coherence = np.max(np.abs(G))
        
        if max_coherence <= M + tol:
            print(f"Converged at iteration {iteration}, mutual coherence: {max_coherence}")
            return vectors
        else:
            # Adjust pairs of vectors that exceed the mutual coherence M
            for i in range(T):
                for j in range(i + 1, T):
                    inner_product = np.dot(vectors[i], vectors[j])
                    if np.abs(inner_product) > M:
                        # Compute the correction factor
                        correction = (inner_product - np.sign(inner_product) * M) / 2
                        # Adjust the vectors
                        vectors[i] -= correction * vectors[j]
                        vectors[j] -= correction * vectors[i]
                        # Re-normalize the vectors to unit length
                        vectors[i] /= np.linalg.norm(vectors[i])
                        vectors[j] /= np.linalg.norm(vectors[j])
    print(f"Did not converge after {max_iter} iterations, mutual coherence: {max_coherence}")
    return vectors


def get_mutual_coherence(vectors):
    """
    Compute the mutual coherence of a set of vectors.

    Parameters:
    vectors (numpy.ndarray): An array of shape (T, d) containing the unit-normed vectors.

    Returns:
    float: The mutual coherence of the vectors.
    """
    # Compute the Gram matrix (inner products between vectors)
    G = np.dot(vectors, vectors.T)
    np.fill_diagonal(G, 0)  # Exclude self-inner products
    return np.max(np.abs(G))


# Each vector represnets a word (a number) in the dictionary. Given a sequence of numbers, we can represent it as a sequence of vectors.
def get_embeddings(sequence, vectors):
    """
    Get the embeddings of a sequence of numbers using a set of vectors.

    Parameters:
    sequence (list): A list of numbers representing the sequence.
    vectors (numpy.ndarray): An array of shape (T, d) containing the unit-normed vectors.

    Returns:
    numpy.ndarray: An array of shape (len(sequence), d) containing the embeddings of the sequence.
    """
    return vectors[sequence-1]


def welch_bound(T, d):
    return np.sqrt((T-d)/(d*(T-1)))
    

def get_condition(T,L,alpha):
    lower_bound = (T*((2*L-3)/(1-(2*L-4)*alpha**2))**2)/(T-1 + ((2*L-3)/(1-(2*L-4)*alpha**2))**2)
    # take the largest integer closest to the lower bound
    return int(np.round(lower_bound))+1



if __name__ == '__main__':
    
    L = 10
    T = 80
    alpha = 1e-2

    # key result: Welch bound + consistency condition
    minimal_d = get_condition(T,L,alpha)
    print(f"Condition d: {minimal_d}")

    ds = np.arange(int(minimal_d), T,1)
    print('ds:', ds)

    # create a dataset of sequences to loop over
    print('creating dataset')
    N = 1000000
    sequences = []
    for i in tqdm(range(N)):
        K=L
        T_ = [i for i in range(0, T)]
        sequence = [0 for _ in range(L)]
        while K > 1:
            # sample integer k between 1 and K
            k = np.random.randint(1, K+1)
            # sample a token between 1 and T_
            try:
                t = int(np.random.choice(T_))
            except ValueError:
                breakpoint()
            # set x_i = t for i = k,..,K
            sequence[k-1:K] = [t for i in range(K-k+1)]
            # delete t from T_ and K = k
            T_.remove(t)
            K = k
        # shuffle the sequence
        np.random.shuffle(sequence)
        sequences.append(sequence)
    sequences = np.stack(sequences)+1
    print()
    print()



    Ms = []
    logits_distributions = []
    is_within = []  
    is_margin_ok = []
    for  d in ds:
        #Theory condition of M: M < (1/(2*L-3))-((2*L-4)/(2*L-3))*alpha**2
        M_max = (1/(2*L-3))-((2*L-4)/(2*L-3))*alpha**2
        M_min  = welch_bound(T, d-1)
        print('sampling M between:', M_min, M_max)
        M = np.random.uniform(M_min, M_max)
        print('desired M:', M)


        vecs = generate_vectors(T, d-1, M)
        M_est = get_mutual_coherence(vecs)
        vecs = np.concatenate((vecs, alpha*np.ones((T,1))), axis=1)
        # Check that the inner product of two equal vectors is 1 + alpha^2
        assert np.allclose(np.dot(vecs[0], vecs[0]), 1 + alpha**2)
        # Check that the absolute inner product of two different vectors is less than M + alpha^2
        to_check = np.abs(np.dot(vecs, vecs.T) - np.eye(T))
        wh = np.where(to_check > M_est + alpha**2)
        if len(wh[0]) != 0:
            breakpoint()
        # if not np.isclose(M, M_est, atol=1e-2):
        #     print('Empirical mutual coherence {} is not close to the desired value {}'.format(M_est, M))
        #     print()
        #     print()
        #     continue
        if M_est > M_max:
            print('Empirical mutual coherence {} is greater than theoretical max value {}'.format(M_est, M_max))
            Ms.append(M_est)
            is_within.append(False)
        elif M_est < M_min:
            print('Empirical mutual coherence {} is smaller than the Welch bound {}'.format(M_est, M_min))
            Ms.append(M_est)
            is_within.append(False)
        else:
            print('Empirical mutual coherence {} is within the desired range'.format(M_est))
            Ms.append(M)
            is_within.append(True)

        # for each draw, keep track of the distribution of logits for each possible count number
        e_cnt = np.zeros((d))
        e_cnt[-1] = 1/alpha
        logits = {k: [] for k in range(1, L+1)}

        for sequence in tqdm(sequences):
            assert np.max(sequence) <= T
            assert len(sequence) == L

            tokens_un, counts_un = np.unique(sequence, return_counts=True)
            counts = np.zeros(L)
            for i, t in enumerate(tokens_un):
                counts[np.where(sequence == t)] = counts_un[i]
            
            #model
            embeddings = get_embeddings(sequence, vecs)
            dot_product = np.dot(embeddings, embeddings.T)
            mixed_dot_product = np.dot(dot_product, embeddings) + embeddings
            projected = np.dot(mixed_dot_product, e_cnt)
            projected = np.maximum(projected, 0)
            
            #keep track of the logits
            for i, k in enumerate(counts):
                logits[int(k)].append(projected[i])

        #check that the logits are well separated
        logits_distributions.append(logits)
        ok = True
        for k in range(1, L-1):
            try:
                if np.max(logits[k]) < np.min(logits[k+1]):
                    print('logits are well separated for count:', k)
                else:
                    print('logits are not well separated for count:', k)
                    ok = False
            except ValueError:
                pass
        is_margin_ok.append(ok)
        print('is_margin_ok:', ok, 'is_within:', is_within[-1])
        print()
        print()


    # sort Ms and logits_distributions and is_within and is_margin_ok according to Ms
    Ms = np.array(Ms)
    is_within = np.array(is_within)
    is_margin_ok = np.array(is_margin_ok)
    sort_idx = np.argsort(Ms)
    Ms = Ms[sort_idx]
    is_within = is_within[sort_idx]
    is_margin_ok = is_margin_ok[sort_idx]
    logits_distributions = [logits_distributions[i] for i in sort_idx]



    #plot the distribution of logits for M
    for i, logits in enumerate(logits_distributions):
        plt.figure()
        for k in range(1, L-1):
            plt.hist(logits[k], bins=100, alpha=0.5, label=str(k))
        plt.title('M = {}, is_within: {}, is_margin_ok: {}'.format(Ms[i], is_within[i], is_margin_ok[i]))
        plt.legend()
        plt.show()



            


        






