In [None]:
import numpy as np

# Load the exact Julia-saved X
# X = np.loadtxt("../data/X_data.csv", delimiter=",")
# print("Loaded X shape:", X.shape)

# coord = np.loadtxt("../data/coord_data.csv", delimiter=",")


In [None]:
import torch

In [None]:
# from sklearn.metrics import pairwise_distances

# C = pairwise_distances(coord.reshape(-1, 1), metric="sqeuclidean")
# C = C / C.mean()  # same normalization as the notebook

# eps = 0.25
# K = np.exp(-C / eps)

In [None]:
import numpy as np
import pytest
from sklearn.metrics import pairwise_distances

import sys
from pathlib import Path

# Insert "src" into sys.path so we can import the local wassnmf package
sys.path.insert(0, "../src")
from wassnmf.wassnmf import WassersteinNMF

In [None]:
def f(x, mu, sigma=1.0):
    """Mimic the Julia f(coord, μ, σ): exp.(-(x .- μ).^2)."""
    return np.exp(-(x - mu)**2 / (2 * sigma**2))  # Gaussian bump

In [None]:
# np.random.seed(42)
# n_features = 100
# n_samples = 100
# coord = np.linspace(-12, 12, n_features)
# X = np.zeros((n_features, n_samples), dtype=np.float64)

# # Generate data as sums of 3 random Gaussian bumps per column
# sigma = 1.0
# for j in range(n_samples):
#     bump1 = np.random.rand() * f(coord, sigma * np.random.randn() + 6, sigma=1.0)
#     bump2 = np.random.rand() * f(coord, sigma * np.random.randn(), sigma=1.0)
#     bump3 = np.random.rand() * f(coord, sigma * np.random.randn() - 6, sigma=1.0)
#     X[:, j] = bump1 + bump2 + bump3

# # Normalize columns to sum to 1 (probability simplex)
# X /= X.sum(axis=0, keepdims=True)

# print("X shape:", X.shape)


In [None]:
# # Build cost matrix C from the same coordinate range [-12, 12]
# C = pairwise_distances(coord.reshape(-1, 1), metric='sqeuclidean')
# C /= C.mean()

# # C = np.ones((coord.size, coord.size))

# # Convert cost matrix to kernel
# eps = 0.025
# K = np.exp(-C / eps)

# print("C shape:", C.shape, "  K shape:", K.shape)


In [None]:
import pandas as pd 
spot_data = pd.read_csv("../data/DLBCL_spotdata.csv", index_col=0)
X = spot_data.values
coord = np.linspace(-12, 12, X.shape[0])

In [None]:
cross_corr = spot_data.T.corr()

In [None]:
spot_data

In [None]:
from scipy.spatial.distance import pdist, squareform

In [None]:
# Distance between features (rows are features)
D_square = squareform(pdist(spot_data.values, metric='sqeuclidean'))

# Normalize distance scale
D_square /= np.mean(D_square)

# Choose a smooth ε
eps = 1.0  # tune this as needed

# Compute Gibbs kernel
K = np.exp(-D_square / eps)

# Sanity check
assert np.all(K > 0), "K has zero entries!"

In [None]:
plt.figure(figsize=(8, 8))
sns.heatmap(D_df, annot=True, fmt=".2f", cmap="coolwarm", square=True, cbar_kws={"shrink": 0.8})
plt.title("Euclidian Distance Between Spots")
plt.tight_layout()
plt.show()

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
sns.heatmap(K, cmap="viridis", square=True, cbar_kws={"shrink": 0.8},
            # xticklabels=spot_data.index, yticklabels=spot_data.index
            )
plt.title('Kernel matrix for the WNMF (log)');

In [None]:
sns.heatmap(np.log(K), cmap="viridis", square=True, cbar_kws={"shrink": 0.8},
            # xticklabels=spot_data.index, yticklabels=spot_data.index
            )
plt.title('Kernel matrix for the WNMF (log)');

In [None]:
%load_ext autoreload

%autoreload 2

In [None]:
X -= np.min(X)

In [None]:
# # Build cost matrix C from the same coordinate range [-12, 12]
# # C = pairwise_distances(coord.reshape(-1, 1), metric='sqeuclidean')
# # C /= C.mean()

# C = np.ones((coord.size, coord.size))

# # Convert cost matrix to kernel
# eps = 0.025
# K = np.exp(-C / eps)

# print("C shape:", C.shape, "  K shape:", K.shape)


In [None]:
# Instantiate and run WassersteinNMF with the same parameters as the Julia notebook
wnmf = WassersteinNMF(
    n_components=5,
    epsilon=eps,
    rho1=0.1,
    rho2=0.1,
    n_iter=30,
    verbose=True
)

D, Lambda = wnmf.fit_transform(X, K)
print("D shape:", D.shape)
print("Lambda shape:", Lambda.shape)


In [None]:
torch.save(D, "../results/D_250325.pt")
torch.save(Lambda, "../results/Lambda_250325.pt") 

In [None]:
sum(sum(X - D.numpy() @ Lambda.numpy()))

In [None]:
from sklearn.decomposition import NMF

In [None]:
from sklearn.preprocessing import normalize
from kneed import KneeLocator

def find_best_n_components(X, k_range=range(2, 20), plot=True):
    errors = []
    for k in k_range:
        model = NMF(n_components=k, init='nndsvda', random_state=0, max_iter=1000)
        W = model.fit_transform(X)
        H = model.components_
        err = model.reconstruction_err_
        errors.append(err)

    if plot:
        plt.figure(figsize=(8, 5))
        plt.plot(k_range, errors, marker='o')
        plt.xlabel('Number of components (k)')
        plt.ylabel('Reconstruction Error')
        plt.title('NMF Model Selection')
        plt.grid(True)
        plt.show()

    return errors


def find_elbow_k(errors, k_range):
    kneedle = KneeLocator(k_range, errors, curve="convex", direction="decreasing")
    return kneedle.knee



In [None]:
errors = find_best_n_components(X, k_range=range(2, 10))

In [None]:
k_range = list(range(2, 10))
elbow_k = find_elbow_k(errors, k_range)
print(f"Elbow at k = {elbow_k}")

In [None]:
model = NMF(n_components=5, init='nndsvda', random_state=0)
W = model.fit_transform(X)  # shape: (100, 5)
H = model.components_       # shape: (5, 50)

# Reconstruction
X_reconstructed = W @ H

In [None]:
datasets = [x.split('.')[0] for x in spot_data.columns]
labels = pd.Series(datasets)

In [None]:
spot_data.index

In [None]:
unique_labels = labels.unique()
palette = sns.color_palette("hls", len(unique_labels))
label_colors = dict(zip(unique_labels, palette))

# Map labels to colors
label_color_values = labels.map(label_colors)

In [None]:
from matplotlib.patches import Patch
import matplotlib.pyplot as plt

In [None]:
sns.set(style="white")

g = sns.clustermap(
    X,
    # row_colors=label_color_values.to_numpy(),
    col_colors=label_color_values.to_numpy(),
    cmap="viridis",
    xticklabels=False,
    yticklabels=spot_data.index,
    figsize=(8, 4),
    row_cluster=False,   
    col_cluster=False
)
g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_yticklabels(), rotation=0)

# Add legend
handles = [Patch(color=color, label=label) for label, color in label_colors.items()]
plt.legend(
    handles=handles,
    title="Labels",
    bbox_to_anchor=(1.2, 1),
    bbox_transform=plt.gcf().transFigure
)


In [None]:
created_matrix = torch.mm(D, Lambda)
created_matrix

In [None]:
sns.set(style="white")

g = sns.clustermap(
    created_matrix,
    # row_colors=label_color_values.to_numpy(),
    col_colors=label_color_values.to_numpy(),
    cmap="viridis",
    xticklabels=False,
    yticklabels=spot_data.index,
    figsize=(8, 4),
    row_cluster=False,   
    col_cluster=False
)
g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_yticklabels(), rotation=0)

# Add legend
handles = [Patch(color=color, label=label) for label, color in label_colors.items()]
plt.legend(
    handles=handles,
    title="Labels",
    bbox_to_anchor=(1.2, 1),
    bbox_transform=plt.gcf().transFigure
)


In [None]:
spot_sign_names = [f'SPOT_WGC_{i}' for i in range(1, 6)]

In [None]:
sns.set(style="white")

g = sns.clustermap(
    D,
    # row_colors=label_color_values.to_numpy(),
    # col_colors=label_color_values.to_numpy(),
    cmap="viridis",
    xticklabels=spot_sign_names,
    yticklabels=spot_data.index,
    figsize=(8, 4),
    row_cluster=False,
    col_cluster=False
)
g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_yticklabels(), rotation=0);



In [None]:
Lambda.save()

In [None]:
sns.set(style="white")

g = sns.clustermap(
    Lambda,
    # row_colors=label_color_values.to_numpy(),
    col_colors=label_color_values.to_numpy(),
    cmap="viridis",
    xticklabels=False,
    yticklabels=spot_sign_names,
    figsize=(8, 4),
    row_cluster=False,
    col_cluster=False
)
g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_yticklabels(), rotation=0);



In [None]:
spot_data

In [None]:
sns.set(style="white")

g = sns.clustermap(
    Lambda,
    # row_colors=label_color_values.to_numpy(),
    col_colors=label_color_values.to_numpy(),
    cmap="viridis",
    xticklabels=False,
    yticklabels=spot_sign_names,
    figsize=(8, 4),
    row_cluster=False,
    col_cluster=True,
    # try other methods
    method="ward"
)
g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_yticklabels(), rotation=0);



In [None]:
# Lambda shape: features x samples
Lambda_df = pd.DataFrame(Lambda)  # optional, for indexing convenience
Lambda_df.columns = labels  # or another label-aligned list

# Transpose to samples x features
Lambda_samples = Lambda_df.T

# Group by label, compute centroids
centroids = Lambda_samples.groupby(Lambda_samples.index).mean()

# Compute distance matrix between centroids
distance_matrix = pd.DataFrame(
    squareform(pdist(centroids, metric="euclidean")),
    index=centroids.index,
    columns=centroids.index
)

In [None]:
distance_matrix

In [None]:
sns.set(style="white")
g = sns.clustermap(
    distance_matrix,
    row_colors=distance_matrix.columns.map(label_colors),
    col_colors=distance_matrix.columns.map(label_colors),
    cmap="viridis",
    xticklabels=False,
    yticklabels=False,
    figsize=(8, 8),
    row_cluster=True,
    col_cluster=True,
    # try other methods
    method="ward",
)
g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_yticklabels(), rotation=0);



In [None]:
true_labels = labels

In [None]:
from sklearn.neighbors import NearestNeighbors
import pandas as pd

# Transpose Lambda: samples x features
X = Lambda.T
labels = pd.Series(true_labels, name="label")

# Fit KNN
k = 20
knn = NearestNeighbors(n_neighbors=k + 1).fit(X)
distances, indices = knn.kneighbors(X)

# For each sample, get label of its neighbors
rows = []
for i, neigh in enumerate(indices):
    source_label = labels.iloc[i]
    for j in neigh[1:]:  # skip self
        target_label = labels.iloc[j]
        rows.append((source_label, target_label))

transitions = pd.DataFrame(rows, columns=["from", "to"])

# Count transitions and normalize
transition_matrix = transitions.value_counts().unstack(fill_value=0)
transition_matrix = transition_matrix.div(transition_matrix.sum(axis=1), axis=0)  # row-normalize

# Visualize
sns.heatmap(transition_matrix, cmap="magma", annot=True, fmt=".2f")
plt.title(f"Label-to-label transition likelihoods (k={k})")
plt.show()


In [None]:
transition_counts = transitions.value_counts().unstack(fill_value=0)
transition_probs = transition_counts.div(transition_counts.sum(axis=1), axis=0)

In [None]:
from sklearn.decomposition import PCA

pca = PCA(n_components=2)
X_2d = pca.fit_transform(X)

samples_2d = pd.DataFrame(X_2d, columns=["x", "y"])
samples_2d["label"] = labels.values


In [None]:
centroids_2d = samples_2d.groupby("label")[["x", "y"]].mean()


In [None]:
label_colors

In [None]:
from matplotlib.patches import FancyArrowPatch
from matplotlib import pyplot as plt

fig, ax = plt.subplots(figsize=(10, 8))

# Plot the samples
sns.scatterplot(
    data=samples_2d,
    x="x", y="y",
    hue="label",
    alpha=0.5,
    s=100,
    ax=ax,
    palette=label_colors
)

# Draw arrows between centroids (use arrow patches)
for src_label in transition_probs.index:
    for tgt_label in transition_probs.columns:
        prob = transition_probs.loc[src_label, tgt_label]
        if prob > 0.05:  # threshold to avoid clutter
            src = centroids_2d.loc[src_label]
            tgt = centroids_2d.loc[tgt_label]

            arrow = FancyArrowPatch(
                (src["x"], src["y"]),
                (tgt["x"], tgt["y"]),
                arrowstyle='-|>',
                mutation_scale=40,  # arrow size
                linewidth=5,
                color='green',
                alpha=0.2 + prob/2  # use prob as transparency
            )
            ax.add_patch(arrow)

# Annotate centroids
for label, coord in centroids_2d.iterrows():
    ax.text(coord["x"], coord["y"], label, fontsize=10, weight="bold", ha='center', va='center')

ax.set_title("Sample embedding + label transition graph")
ax.set_xlabel("PC1")
ax.set_ylabel("PC2")
ax.legend(loc="best", title="Label", fontsize=9)
plt.tight_layout()
plt.show()


In [None]:
transition_probs

In [None]:
A = transition_probs.to_numpy()

In [None]:
# Create directed graph
G = nx.from_numpy_array(A, create_using=nx.DiGraph)

# Extract weights for setting edge widths
weights = [A[i][j] for i, j in G.edges()]

# Optional: scale widths for better visibility
max_width = 5
min_width = 0.5
normalized_weights = [
    min_width + (w / max(weights)) * (max_width - min_width)
    for w in weights
]

# Draw graph
pos = nx.spring_layout(G)  # or try nx.circular_layout(G)
nx.draw(
    G, pos, with_labels=True, node_color=transition_probs.columns.map(label_colors),
    arrows=True, width=normalized_weights,
    edge_color='gray'
)

plt.title("Weighted Directed Graph (Edge Width = Weight)")
plt.show()

In [None]:
from pydiffmap import diffusion_map as dm

mydmap = dm.DiffusionMap.from_sklearn(n_evecs=3, alpha=0.5, epsilon='bgh')
X_dm = mydmap.fit_transform(X)

# X_dm[:, 0] is diffusion pseudotime-like coordinate


In [None]:
X_dm[:, 0]

In [None]:
import networkx as nx

# Build graph
G = nx.Graph()
for i, neighbors in enumerate(indices[:, 1:]):
    for j in neighbors:
        G.add_edge(i, j)

# Get graph-based layout (like aligning samples by the transition graph)
pos = nx.spring_layout(G, seed=42)  # dict: {sample_index: (x, y)}

# Create embedding aligned to the graph
graph_coords = np.array([pos[i] for i in range(len(X))])


In [None]:
label_color_values

In [None]:
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
from sklearn.neighbors import NearestNeighbors


# Create graph
G = nx.Graph()
for i, neighbors in enumerate(indices[:, 1:]):  # skip self
    for j in neighbors:
        G.add_edge(i, j)

# Get layout based on graph structure
pos = nx.spring_layout(G, seed=137, k=0.01)



# Plot the graph
plt.figure(figsize=(10, 8))
nx.draw_networkx_nodes(G, pos, 
                       node_color=label_color_values, 
                       node_size=40, alpha=0.8)
nx.draw_networkx_edges(G, pos, alpha=0.2, width=0.5)
plt.title("KNN Graph of Samples")
plt.axis("off")
plt.tight_layout()
plt.show()


In [None]:
from scipy.cluster.hierarchy import fcluster
threshold = 0.2

# Get column clusters by cutting at this height
col_clusters = fcluster(g.dendrogram_col.linkage, t=threshold, criterion='distance')

In [None]:
from sklearn.metrics import normalized_mutual_info_score

nmi = normalized_mutual_info_score(labels, col_clusters)
print(f"Normalized Mutual Information: {nmi:.3f}")


In [None]:
from scipy.cluster.hierarchy import fcluster

# Get the linkage matrix used for column clustering
linkage = g.dendrogram_col.linkage

# Set a number of clusters you want, e.g., 4
num_clusters = 4

# Extract flat cluster assignments
col_clusters = fcluster(linkage, num_clusters, criterion='maxclust')


In [None]:
# Verify shape, non-negativity, and column sums
assert D.shape == (n_features, 3), f"Expected D shape {(n_features, 3)}, got {D.shape}"
assert Lambda.shape == (3, n_samples), f"Expected Lambda shape {(3, n_samples)}, got {Lambda.shape}"
assert np.all(D >= 0), "D contains negative values"
assert np.all(Lambda >= 0), "Lambda contains negative values"

d_col_sums = D.sum(axis=0)
lambda_col_sums = Lambda.sum(axis=0)
np.testing.assert_allclose(d_col_sums, 1.0, atol=1e-4, err_msg="D columns do not sum to 1")
np.testing.assert_allclose(lambda_col_sums, 1.0, atol=1e-4, err_msg="Lambda columns do not sum to 1")

print("Julia notebook analog steps completed successfully!")
