In [None]:
from wassnmf.validation import *
from wassnmf.wassdil import *
import torch

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd 
spot_data = pd.read_csv("../data/DLBCL_spotdata.csv", index_col=0)
X = spot_data.values
coord = np.linspace(-12, 12, X.shape[0])

In [None]:
from scipy.spatial.distance import pdist, squareform

# Compute pairwise Euclidean distances between rows
D = pdist(spot_data.values, metric='euclidean')  # or 'cosine', 'correlation', etc.
D_square = squareform(D)  # Convert to square matrix

# Wrap in a DataFrame for labels
D_df = pd.DataFrame(D_square, index=spot_data.index, columns=spot_data.index)


In [None]:
# imports 
import seaborn as sns
import matplotlib.pyplot as plt
# Plot heatmap
plt.figure(figsize=(8, 8))
sns.heatmap(D_df, annot=True, fmt=".2f", cmap="coolwarm", square=True, cbar_kws={"shrink": 0.8})
plt.title("Euclidian Distance Between Spots")
plt.tight_layout()
plt.show()

In [None]:

# D_square is your pairwise distance matrix (squareform(pdist(...)))
# Set your epsilon (kernel width) — try median or mean of distances as a start
eps = np.median(D_square)**2  # or manually: eps = 1.0

K = np.exp(-D_square**2 / eps)



In [None]:
sns.heatmap(K, cmap="viridis", square=True, cbar_kws={"shrink": 0.8},
            xticklabels=spot_data.index, yticklabels=spot_data.index)
plt.title('Kernel matrix for the WNMF');

In [None]:
model = WassersteinDiL(dtype=torch.float32)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# x and k to torch
x = torch.tensor(X, dtype=torch.float32, device=device)
k = torch.tensor(K, dtype=torch.float32, device=device)

In [None]:
x.shape

In [None]:
D, Lambda = model.fit(x.T, k, 3, verbose=True, device=device)

In [None]:
from wassnmf.functional import *

In [None]:
nmf_model = WassersteinNMF(n_components=3, device='cuda', verbose=True,
                           n_iter=10)


In [None]:
X, K, coord, cost_matrix = generate_data(scenario)

In [None]:
X = X / X.sum()
coords = torch.randn(20, 2)  # Example: 2D coordinates for each row of X
K = torch.cdist(coords, coords) ** 2
K = torch.exp(-K / 0.1) # Gibbs Kernel


In [None]:
X_cuda = torch.tensor(X, device=device, dtype=torch.float32)#.cuda()
K_cuda = torch.tensor(K, device=device, dtype=torch.float32)#.cuda()

In [None]:
X_cuda

In [None]:
x.shape

In [None]:
D_torch, Lambda_torch = wasserstein_nmf(x, k, 3, n_iter=1)

In [None]:
D_torch

In [None]:
Lambda_torch

In [None]:
# Example data
m = 20  # Number of rows in X
n = 20 # Number of columns in X
k = 3   # Number of components
X, K, _, _ = generate_data(scenario)
X = torch.tensor(X, dtype=torch.float32)

In [None]:
K

In [None]:

# X = torch.rand(m, n)
# X = X / X.sum(dim=0, keepdim=True) # ensure columns of X sum to 1

# Create a cost matrix (example: squared Euclidean distance)
# coords = torch.randn(m, 2)  # Example: 2D coordinates for each row of X
# K = torch.cdist(coords, coords) ** 2
# K = torch.exp(-K / 0.1) # Gibbs Kernel
K = torch.tensor(K, dtype=torch.float32)
 
# Move data to GPU
X_cuda = X.cuda()
K_cuda = K.cuda()

# Run GPU version
D_torch, Lambda_torch = wasserstein_nmf_gpu(X_cuda, K_cuda, k, n_iter=10)

# Verify reconstruction (should be close to the original X)
X_reconstructed = torch.matmul(D_torch, Lambda_torch)
print("\nReconstruction Error (GPU):", torch.norm(X.cpu() - X_reconstructed)) #compare with X on the CPU

In [None]:
import seaborn as sns

In [None]:
X_rec = D_torch @ Lambda_torch

In [None]:
import matplotlib.pyplot as plt

sns.heatmap(X_rec.detach().numpy())

plt.show()

sns.heatmap(D_torch.detach().numpy())
plt.show()

sns.heatmap(Lambda_torch.detach().numpy())
plt.show() 

In [None]:
sns.heatmap(X)

In [None]:



sns.heatmap(K.detach().numpy(), cmap='viridis')


In [None]:
from sklearn.decomposition import NMF

In [None]:
model = NMF(n_components=3, init='random', random_state=0)

In [None]:
W = model.fit_transform(X)
H = model.components_

In [None]:
X_nmf = np.dot(W, H)

In [None]:
sns.heatmap(X_nmf)
plt.show()

sns.heatmap(W)
plt.show()  

sns.heatmap(H)
plt.show()  

In [None]:
import wsingular

In [None]:
# Example data
m = 20  # Number of rows in X
n = 20 # Number of columns in X
k = 3   # Number of components
X, K, _, _ = generate_data(scenario)
X = torch.tensor(X, dtype=torch.float32)

In [None]:
C, D = wsingular.sinkhorn_singular_vectors(
    X,
    eps=5e-2,
    dtype=X.dtype,
    device=device,
    n_iter=100,
    progress_bar=True,
)

In [None]:
C, D = C.cpu(), D.cpu()

In [None]:
C

In [None]:
# Display the SSV.
fig, axes = plt.subplots(1, 3, figsize=(10, 5))
fig.suptitle('Sinkhorn Singular Vectors')

axes[0].set_title('The data.')
axes[0].imshow(X)
axes[0].set_xticks(range(0, m, 5))
axes[0].set_yticks(range(0, m, 5))

axes[1].set_title('Distance between samples.')
axes[1].imshow(D)
axes[1].set_xticks(range(0, m, 5))
axes[1].set_yticks(range(0, m, 5))

axes[2].set_title('Distance between features.')
axes[2].imshow(C)
axes[2].set_xticks(range(0, n, 5))
axes[2].set_yticks(range(0, n, 5))

plt.show()

In [None]:
C.sum()

In [None]:

K = torch.exp(-D / 0.01)  # Gibbs Kernel, eps
# K = D
 
# Move data to GPU
X_cuda = X.cuda()
K_cuda = K.cuda()

# Run GPU version
D_torch, Lambda_torch = wasserstein_nmf_gpu(X_cuda, K_cuda, k, n_iter=10)

# Verify reconstruction (should be close to the original X)
X_reconstructed_C = torch.matmul(D_torch, Lambda_torch)
print("\nReconstruction Error (GPU):", torch.norm(X.cpu() - X_reconstructed_C)) #compare with X on the CPU

In [None]:
import matplotlib.pyplot as plt

sns.heatmap(X_reconstructed_C.detach().numpy())

plt.show()

sns.heatmap(D_torch.detach().numpy())
plt.show()

sns.heatmap(Lambda_torch.detach().numpy())
plt.show() 

sns.heatmap(X.detach().numpy())
plt.show() 

In [None]:
D_np =  D_torch.detach().numpy()
dominant_component = np.argmax(np.var(D_np, axis=0))

In [None]:
pseudotime = D_np[:, dominant_component]

In [None]:
sample_order

In [None]:
pseudotime

In [None]:
sample_order = np.argsort(pseudotime)
sorted_pseudotime = np.sort(pseudotime)

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(range(len(sorted_pseudotime)), sorted_pseudotime, marker="o", linestyle="-")
plt.xlabel("Sample Rank")
plt.ylabel("Pseudotime Value")
plt.title("Pseudotime Estimation")
plt.grid()
plt.show()


In [None]:
sns.heatmap(X.T[np.argsort(pca_pt.T)[0]])
plt.show() 

In [None]:
from sklearn.decomposition import PCA

In [None]:
# Apply PCA to reduce dimensionality to 2D
pca = PCA(n_components=2)
projected_data = pca.fit_transform(D_np[sample_order])

# Get feature (column) loadings as arrows
feature_arrows = pca.components_.T  # Each feature contributes to the 2D space

# Plot data points
# plt.figure(figsize=(8, 6))
plt.scatter(projected_data[:, 0], projected_data[:, 1], alpha=0.6, label="Data Points",
            # cmap='vlag', 
            c=sample_order)

# Plot feature arrows
origin = np.zeros((2, 3))  # Origin for arrows
plt.quiver(
    origin[0], origin[1], feature_arrows[:, 0], feature_arrows[:, 1], 
    angles='xy', scale_units='xy', scale=1, color='r', width=0.005, label="Feature Directions"
)

# Labels and styling
plt.xlabel("PC1")
plt.ylabel("PC2")
plt.title("Sparse Multidimensional Data in 2D with Feature Arrows")
plt.legend()
plt.grid(True)
plt.show()

In [None]:
pca = PCA(n_components=1)
pca_pt = pca.fit_transform(X.T)

In [None]:
plt.plot(pca_pt)

In [None]:
from mpl_toolkits import mplot3d

In [None]:
fig = plt.figure(figsize = (10, 7))
ax = plt.axes(projection ="3d")

# 3d scatter plot
ax.scatter3D(D_np[:, 0], D_np[:, 1], D_np[:, 2], 
             c=pca_pt, cmap='Blues')


In [None]:
X = X_cuda.cpu().numpy()

In [None]:
X

In [None]:
np.argsort(pca_pt.T)[0]


In [None]:

# Every sample as a 2d image
fig, axes = plt.subplots(4, 5, figsize=(10, 10))
for i in range(20):
    axes[i // 5, i % 5].imshow(X[np.argsort(pca_pt.T)[0]][
        i].reshape(4, 5))
    axes[i // 5, i % 5].axis('off')
    axes[i // 5, i % 5].set_title(f"Sample {i}")
plt.show()

# save as a gif
import imageio

images = []
for i in range(20):
    plt.imshow(X[i].reshape(4, 5))
    plt.axis('off')
    plt.title(f"Sample {i}")
    plt.savefig(f"sample_{i}.png")
    images.append(imageio.imread(f"sample_{i}.png"))
    plt.close()

imageio.mimsave('samples.gif', images)

