In [None]:
#!/usr/bin/env python3
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

# -------------------------------
# 1. Load the trained model
# -------------------------------
class BulkAE(nn.Module):
    def __init__(self, n_genes, latent_dim=320):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(n_genes, 4096), nn.ReLU(),
            nn.Linear(4096, 1024), nn.ReLU(),
            nn.Linear(1024, latent_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 1024), nn.ReLU(),
            nn.Linear(1024, 4096), nn.ReLU(),
            nn.Linear(4096, n_genes)
        )

    def forward(self, x):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat, z


In [None]:
#!/usr/bin/env python3
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

# -------------------------------
# 1. Load the trained model
# -------------------------------
class BulkAE(nn.Module):
    def __init__(self, n_genes, latent_dim=320):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(n_genes, 4096), nn.ReLU(),
            nn.Linear(4096, 1024), nn.ReLU(),
            nn.Linear(1024, latent_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 1024), nn.ReLU(),
            nn.Linear(1024, 4096), nn.ReLU(),
            nn.Linear(4096, n_genes)
        )

    def forward(self, x):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat, z


# -------------------------------
# 2. Load Test Data
# -------------------------------
TEST_PARQUET = "./data/archs4/processed_short_proteins/test_expr_logtpm_short.parquet"
test_df = pd.read_parquet(TEST_PARQUET)
X_test = test_df.T.astype(np.float32).values
n_samples, n_genes = X_test.shape

print(f"Test matrix: {X_test.shape}")

# metadata
meta = pd.read_csv("./data/archs4/processed_short_proteins/test_metadata_short.csv")
print("Metadata columns:", meta.columns.tolist())

# match sample order
meta = meta.set_index("geo_accession").loc[test_df.columns].reset_index()


# -------------------------------
# 3. Load Trained Weights
# -------------------------------
model = BulkAE(n_genes, latent_dim=320)
model.load_state_dict(torch.load("autoencoder_weights.pt", map_location="cpu"))
model.eval()

# -------------------------------
# 4. Extract Latent Embeddings
# -------------------------------
with torch.no_grad():
    X_tensor = torch.from_numpy(X_test)
    _, Z = model(X_tensor)    # shape: [n_samples, latent_dim]
    Z = Z.numpy()

print("Latent:", Z.shape)


# -------------------------------
# 5. Run t-SNE
# -------------------------------
print("Running TSNE on raw counts...")
tsne_raw = TSNE(n_components=2, perplexity=30, learning_rate="auto").fit_transform(X_test)

print("Running TSNE on latent...")
tsne_latent = TSNE(n_components=2, perplexity=30, learning_rate="auto").fit_transform(Z)


# -------------------------------
# 6. Plotting
# -------------------------------
labels = meta["tcga_label"].astype("category")
colors = labels.cat.codes.values
label_names = labels.cat.categories


plt.figure(figsize=(12,5))

# ----- RAW -----
plt.subplot(1,2,1)
plt.scatter(tsne_raw[:,0], tsne_raw[:,1], c=colors, cmap="tab20", s=4)
plt.title("t-SNE of Raw log(TPM+1) Expression")
plt.xlabel("tSNE1"); plt.ylabel("tSNE2")

# ----- LATENT -----
plt.subplot(1,2,2)
plt.scatter(tsne_latent[:,0], tsne_latent[:,1], c=colors, cmap="tab20", s=4)
plt.title("t-SNE of Autoencoder Latent Space")
plt.xlabel("tSNE1"); plt.ylabel("tSNE2")

plt.tight_layout()
plt.show()


Test matrix: (9446, 19357)
Metadata columns: ['geo_accession', 'characteristics_ch1', 'source_name_ch1', 'text', 'tcga_label']


RuntimeError: Error(s) in loading state_dict for BulkAE:
	Unexpected key(s) in state_dict: "encoder.4.weight", "encoder.4.bias", "decoder.4.weight", "decoder.4.bias". 
	size mismatch for encoder.0.weight: copying a param with shape torch.Size([4096, 19357]) from checkpoint, the shape in current model is torch.Size([1024, 19357]).
	size mismatch for encoder.0.bias: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for encoder.2.weight: copying a param with shape torch.Size([1024, 4096]) from checkpoint, the shape in current model is torch.Size([320, 1024]).
	size mismatch for encoder.2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for decoder.2.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([19357, 1024]).
	size mismatch for decoder.2.bias: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([19357]).