In [None]:
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt

def compute_cosine_similarity(graph_sequence, lookback=10):
    """
    Computes cosine similarity between consecutive graph snapshots.

    Parameters:
        graph_sequence (list of np.ndarray): List of adjacency matrices.

    Returns:
        similarities (list of lists): Cosine similarities for each time step.
    """
    num_steps = len(graph_sequence)
    similarities = []
    n  = lookback - 1

    for t in range(n, num_steps):  # Start from t9 to t0
        similarities_t = []
        for prev_t in range(t-n, t+1):  # Compare G_t with G_t-1, ..., G_t-9
            vec_t = graph_sequence[t].flatten()  # Flatten adjacency matrix to vector
            vec_prev_t = graph_sequence[prev_t].flatten()
            cos_sim = cosine_similarity(vec_t.reshape(1, -1), vec_prev_t.reshape(1, -1))[0, 0]
            similarities_t.append(cos_sim)
        
        similarities.append(similarities_t)

    return np.array(similarities)

def plotting_cosine_similarity(G, weighted = False):
    if weighted:
        cosine_similarities = compute_cosine_similarity(G)
        plt.figure(figsize = [5,4], dpi = 300)
        for i in cosine_similarities:
            plt.plot(i, 'orange', alpha = 0.5)
            plt.xticks(range(10), ['t-9', 't-8', 't-7', 't-6', 't-5', 't-4', 't-3', 't-2', 't-1', 't'])
            plt.ylim(0, 1)
            plt.ylabel("Cosine Similarity") 
        plt.plot(np.mean(cosine_similarities,axis=0), 'red', alpha = 0.75)
        
    else:
        plt.figure(figsize = [10, 6], dpi = 300)
        for k in range(len(G)):
            # Compute cosine similarities
            cosine_similarities = compute_cosine_similarity(G[k])

            # Print sample results
            print("Cosine Similarity Matrix Shape:", cosine_similarities.shape)
            # print("Sample Similarities for t=9:", cosine_similarities[0])
        
            plt.subplot(2,3, k+1)
            # plt.suptitle('Cosine Similarity between Consecutive Graph Snapshots\n (Subject 1 and Session 1: stride = 10, window size = 100)')
            for i in cosine_similarities:
                plt.plot(i, 'orange', alpha = 0.5)
                plt.xticks(range(10), ['t-9', 't-8', 't-7', 't-6', 't-5', 't-4', 't-3', 't-2', 't-1', 't'])
                plt.ylim(0, 1)
                plt.ylabel("Cosine Similarity") 
            plt.plot(np.mean(cosine_similarities,axis=0), 'red', alpha = 0.75)
            
            plt.title("Threshold = {:.1f}".format(threshold[k]))
        plt.tight_layout()

plotting_cosine_similarity(G)
plotting_cosine_similarity(WG, weighted = True)