#  Identifying Cellular Neighborhoods with Graph Neural Networks in Spatial Transcriptomics Data

### Project Introduction
...

--- 
### References

**Data Sources and Platforms**
* [10X Genomics](https://www.10xgenomics.com/datasets/human-breast-cancer-block-a-section-1-1-standard-1-1-0): Invasive Ductal Carcinoma tissue.

**Libraries**
* [Pytorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/)

* [SquidPy Docs](https://squidpy.readthedocs.io/en/stable/)

* [ScanPy Docs](https://scanpy.readthedocs.io/en/stable/)

---
### Table of Contents

1. [Imports](#imports)
2. [Constants & Data](#constants)
3. [...](#distributions)

6. [Summary](#summary)

---
### 1. Imports <a class="anchor" id="imports"></a>

In [None]:
import os
from pathlib import Path

import matplotlib.pyplot as plt
import seaborn as sns

import scanpy as sc
import squidpy as sq

import pandas as pd
import numpy as np

import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.utils import negative_sampling, train_test_split_edges
from tqdm import tqdm


---
### 2. Constants & Data Import <a class="anchor" id="constants"></a>


The necessary data files can be accessed and downloaded via the [10X Genomics Datasets Site](https://www.10xgenomics.com/datasets/human-breast-cancer-block-a-section-1-1-standard-1-1-0). The following specific datasets are required:

* **Filtered feature barcode matrix**

    *V1_Breast_Cancer_Block_A_Section_1_filtered_feature_bc_matrix* 
    
    This is your core gene expression data that contains a matrix where the rows are genes and the columns are the unique barcodes for each spot on the tissue slide.

* **Spatial imaging data**

    *V1_Breast_Cancer_Block_A_Section_1_spatial.tar.gz*

    Compressed folder contains all the crucial spatial information including files like tissue_positions.csv, which has the exact (x, y) pixel coordinates for every spot/barcode. 

* **Tissue Image**

    *V1_Breast_Cancer_Block_A_Section_1_image.tif*

    High-resolution H&E stained image of the tissue slice


The notebook expects both tsv files to be placed in the `data/10x` folder.

In [None]:
ROOT = Path(os.getcwd()).parents[0]

DATA_PATH = os.path.join(ROOT, "data", "10x")
SPATIAL_DATA_PATH = os.path.join(DATA_PATH, "spatial")
H5_DATA_PATH = os.path.join(
    DATA_PATH, "V1_Breast_Cancer_Block_A_Section_1_filtered_feature_bc_matrix.h5"
)

---
### 3. Data Exploration <a class="anchor" id="constants"></a>

Load the data into an AnnData object

In [None]:
adata = sc.read_visium(path=DATA_PATH, count_file=H5_DATA_PATH)
adata.var_names_make_unique()  # Ensure gene names are unique
adata

#### Pre-process the data

In [None]:
sc.pp.filter_genes(
    adata, min_cells=10
)  # Filter out genes expressed in fewer than 10 spots
sc.pp.normalize_total(adata, inplace=True)  # Normalize counts per spot
sc.pp.log1p(adata)  # Log-transform the data
# Calculate quality control metrics, adding total_counts (total number of gene counts per spot)
sc.pp.calculate_qc_metrics(adata, inplace=True)

adata

#### Explore the data object

**adata.uns, adata.obsm**

Display the image of the slide from`adata.uns` and plot the scatter plot from the spatial spot coordinates from `adata.obsm` (colored by the total counts in `adata.obs` to indicate expression per spot).

In [None]:
print(adata.uns["spatial"]["V1_Breast_Cancer_Block_A_Section_1"].keys())
print(print(adata.uns["spatial"]["V1_Breast_Cancer_Block_A_Section_1"]["metadata"]))
_, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(
    adata.uns["spatial"]["V1_Breast_Cancer_Block_A_Section_1"]["images"]["hires"]
)
axes[0].set_title("High Resolution Image")

axes[1].set_title("Low Resolution Image")
axes[1].imshow(
    adata.uns["spatial"]["V1_Breast_Cancer_Block_A_Section_1"]["images"]["lowres"]
)

axes[2].scatter(
    x=adata.obsm["spatial"][:, 0],
    y=adata.obsm["spatial"][:, 1] * -1
    - min(adata.obsm["spatial"][:, 1] * -1),  # invert y-axis
    c=adata.obs["total_counts"],
    cmap="viridis",
    s=8,
)
axes[2].set_title("Spatial Plot of Total Counts")
axes[0].set_aspect("equal")
axes[1].set_aspect("equal")
axes[2].set_aspect("equal")


**adata.var**

Each row in adata.var corresponds to a single gene. The important columns in adata.var for this notebook include:

* `gene_ids`: The Ensembl ID for the gene.
* `feature_types`: The type of featureb entirely "Gene Expression" for this data.
* `genome`: The reference genome used.
* `n_cells`: A column added by scanpy after filtering, showing in how many spots (cells) each gene was detected.
* `total_counts`: A quality control metric showing the total number of counts for each gene across all spots.
* `mean_counts`: average expression across the data set
* `log1p_mean_counts`: log-transformed version of `mean_counts`

In [None]:
print(adata.var.shape)
adata.var.head(2)

**adata.var**

Each row in adata.obs corresponds to a observation, or spot on the tissue slide. The important columns in adata.obs for this notebook include:

* `in_tissue`: A binary column (0 or 1) indicating whether the spot is located over the actual tissue section (1) or in the background (0).
* `array_row` and `array_col`: These are the integer row and column coordinates of the spot on the physical grid of the Visium slide.
* `total_counts`: This is the total number of gene transcripts (UMIs) detected in that specific spot. It's a measure of the sequencing depth or "library size" for that spot. 
* `n_genes_by_counts`: The number of unique genes that were detected in that spot. A very low number might indicate a low-quality spot.
* `log1p_total_counts`: The log-transformed version of total_counts.

In [None]:
print(adata.obs.shape)
adata.obs.head(2)

Construct a graph using squidpy's 'spatial_neighbors' function which builds a graph by connecting each spot to its nearest neighbors.  The hue on the scatter plot corresponds to the total number of gene transcripts detected in each spot.

In [None]:
sq.gr.spatial_neighbors(adata, coord_type="grid", n_neighs=6)
for k, v in adata.obsp.items():
    print(f"{k}: {v.shape}")

_, axes = plt.subplots(figsize=(6, 6))
sq.pl.spatial_scatter(
    adata,
    library_id="spatial",  # Use the key for the image
    color="total_counts",  # Color spots by total gene counts as an example
    shape=None,
    connectivity_key="spatial_connectivities",  # Tell squidpy to draw the graph
    title="Spatial Graph Overlay on Tissue",
    ax=axes,
)

**Observations**: 
* This slide contains 3,798 spots and 19,690 genes sampled.
* Total counts seem to correlate with cell density. The areas with the highest total_counts (the bright yellow spots) visually align with the densest, most purple regions in the H&E image. Conversely, the blue/darker areas with lower counts correspond to less dense regions (like stroma or connective tissue).
* The clear spatial pattern of high-count and low-count regions demonstrates that the tissue is heterogeneous. This heterogeneity should be leveraged by our GNN to learn if and where there are different types of genes being expressed.

---
### 4. Data Pre-processing <a class="anchor" id="model"></a>

...

Critical data for the GNN will be 
* `adata.X`: a sparse matrix that is n (number of spots) x m (number of genes) in size that contains gene expression at each spot.
* `adata.obs['spatial_connectivities']`: the graph adjacency matrix

Visualize the gene expression matrix

In [None]:
top_ten_highest_gene_idx = np.argsort(adata.X.toarray().mean(axis=0))[::-1][:20]
bottom_ten_highest_gene_idx = np.argsort(adata.X.toarray().mean(axis=0))[::-1][-20:]

In [None]:
# random_sample_of_spots = np.random.randint(0, adata.X.toarray().shape[0], 40)

# sample_array = np.concat([
#     adata.X.toarray()[random_sample_of_spots, :][:, top_ten_highest_gene_idx],
#     adata.X.toarray()[random_sample_of_spots, :][:, bottom_ten_highest_gene_idx]
# ], axis=1)

# _, axes = plt.subplots(1, 2, figsize=(20, 10))

# sns.heatmap(
#     adata.X.toarray(),
#     cmap="RdBu_r",
#     cbar_kws={"label": "Expression Level"},
#     xticklabels=[],
#     yticklabels=[],
#     cbar=False,
#     square = False,
#     ax=axes[0]
#     )

# sns.heatmap(
#     sample_array,
#     cmap="RdBu_r",
#     cbar_kws={"label": "Expression Level"},
#     xticklabels=np.concatenate([
#         adata.var_names[top_ten_highest_gene_idx],
#         adata.var_names[bottom_ten_highest_gene_idx]
#         ]),
#     yticklabels=[],
#     cbar=False,
#     square = True,
#     ax=axes[1]
#     )

# axes[0].set_title("Gene Expression Matrix")
# axes[0].set_xlabel("Genes")
# axes[0].set_ylabel("Spots")

# axes[1].set_title("Highest and Lowest Expressed Genes")
# axes[1].set_xlabel("Genes")
# axes[1].set_ylabel(f"{len(random_sample_of_spots)} Random Spots");

Some genes have relatively high expression over all spots in the slide indicated by the vertical lines in the heatmap above while others aren't expressed at all.

Convert the spatial connectivites from it's native matrix to COOrdinate sparse matrix. This format has three key components:

1. .row: A list of row indices for every non-zero element.
2. .col: A list of column indices for every non-zero element.
3. .data: A list of the actual values at those (row, col) positions

In [None]:
adata.obsp["spatial_connectivities"]

In [None]:
adata.obsp["spatial_connectivities"].tocoo()

In [None]:
edge_index_coo = adata.obsp["spatial_connectivities"].tocoo()
edge_index = torch.tensor(
    np.vstack([edge_index_coo.row, edge_index_coo.col]), dtype=torch.long
)

edge_index

In this edge_index coordinate representation, we see each node connected to at most 6 other nodes and both directions counted (i.e., node_a ==> node_b and node_b ==> node_a).

In [None]:
edges_per_node = pd.DataFrame(edge_index.flatten()).value_counts().values

edge_per_node_counts, edge_per_node_values = np.histogram(
    edges_per_node, bins=np.linspace(0, 15, 16)
)
pd.DataFrame(
    data=np.array([edge_per_node_values[:-1], edge_per_node_counts]).T,
    columns=["Number of Edges", "Counts"],
).sort_values("Counts", ascending=False).iloc[:6]

PyTorch Geometric models require a specific 'Data' object that holds the graph structure. Convert the Snapy AnnData to a compatible structure - COO (adjacency matrix) format for edges and dense torch tensor for the nodes (gene expression matrix).

In [None]:
# Edges - we will use the spatial connectivities computed by Squidpy
edge_index_coo = adata.obsp["spatial_connectivities"].tocoo()
edge_index = torch.tensor(
    np.vstack([edge_index_coo.row, edge_index_coo.col]), dtype=torch.long
)

# Nodes - we will use the gene expression matrix as node features
x = torch.tensor(adata.X.toarray(), dtype=torch.float)

# Create the PyG Data object
data = Data(x=x, edge_index=edge_index)
print("--- PyTorch Geometric Data Object ---")
print(data)

# Split edges into training, validation, and test sets
data = train_test_split_edges(data, val_ratio=0.1, test_ratio=0.1)

---
### 5. Build the Model <a class="anchor" id="model"></a>

 We will build a Graph Autoencoder (GAE). This model has two parts:

  - An Encoder: This is our GCN. Its job is to compress the high-dimensional
    gene expression of each node into a low-dimensional embedding.

  - A Decoder: It tries to reconstruct the original graph structure (the edges)
    from the learned embeddings.

The architecture of this model is based on the paper [Variational Graph Auto-Encoders](https://arxiv.org/abs/1611.07308) as we're aiming to build a model that learn meaningful latent embeddings representing our gene expressions based on their relative locations within the sample.

The encoder is comprised of Graph Convolutioanl Networks from [Semi-Supervised Classification with Graph Convolutional Networks](https://arxiv.org/pdf/1609.02907). ([blog](https://tkipf.github.io/graph-convolutional-networks/))

In [None]:
class GCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCNEncoder, self).__init__()
        # First GCN layer: maps input features to an intermediate dimension
        self.conv1 = GCNConv(in_channels, out_channels * 2)
        # Second GCN layer: maps intermediate dimension to the final embedding dimension
        self.conv2 = GCNConv(out_channels * 2, out_channels)

    def forward(self, x, edge_index):
        # Apply GCN layers with ReLU activation
        x = F.relu(self.conv1(x, edge_index))

        # The final output is the node embedding
        x = self.conv2(x, edge_index)

        return x


class GAEModel(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GAEModel, self).__init__()
        # The Encoder is our GCN model defined above
        self.encoder = GCNEncoder(in_channels, out_channels)

    def encode(self, x, edge_index):
        # Pass data through the encoder to get latent embeddings (z)
        return self.encoder(x, edge_index)

    def decode(self, z, pos_edge_index, neg_edge_index):
        # For a given set of positive and negative edges, predict the likelihood of existence
        # using an inner product decoder.
        pos_logits = (z[pos_edge_index[0]] * z[pos_edge_index[1]]).sum(dim=1)
        neg_logits = (z[neg_edge_index[0]] * z[neg_edge_index[1]]).sum(dim=1)
        return pos_logits, neg_logits

    def recon_loss(self, z, pos_edge_index):
        # Sample negative edges (pairs of nodes that are not connected)
        neg_edge_index = negative_sampling(
            edge_index=pos_edge_index,
            num_nodes=z.size(0),
            num_neg_samples=pos_edge_index.size(1),  # Match number of positive edges
        )

        # Get predictions (logits) for positive and negative edges
        pos_logits, neg_logits = self.decode(z, pos_edge_index, neg_edge_index)

        # Create labels: 1s for positive edges, 0s for negative edges
        pos_labels = torch.ones_like(pos_logits)
        neg_labels = torch.zeros_like(neg_logits)

        # Concatenate and compute binary cross-entropy loss
        logits = torch.cat([pos_logits, neg_logits], dim=0)
        labels = torch.cat([pos_labels, neg_labels], dim=0)

        return F.binary_cross_entropy_with_logits(logits, labels)

Define model parameters

In [None]:
in_channels = data.num_features  # Number of genes
out_channels = 32  # Desired size of the embedding for each spot

In [None]:
model = GAEModel(in_channels, out_channels)

# move model, data to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
data = data.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)

Train the model

In [None]:
def train():
    """Train the model for one epoch.
    
    Returns:
        float: The reconstruction loss on the training edges.
    """
    model.train()
    optimizer.zero_grad()
    z = model.encode(data.x, data.train_pos_edge_index)
    # calculate the loss on the training edges
    loss = model.recon_loss(z, data.train_pos_edge_index)
    loss.backward()
    optimizer.step()
    return float(loss)


@torch.no_grad()
def test(pos_edge_index):
    """Evaluate the model on the validation edges.
    
    Args:
        pos_edge_index (torch.Tensor): The positive edge indices for validation.
    Returns:
        float: The reconstruction loss on the validation edges.
    """
    model.eval()
    z = model.encode(data.x, data.train_pos_edge_index)
    # calculate the loss on the validation edges
    loss = model.recon_loss(z, pos_edge_index)
    return float(loss)


NUM_EPOCHS = 200
# SAVE_INTERVAL = 1

pbar = tqdm(range(1, NUM_EPOCHS + 1))
saved_embeddings = []
train_losses = []
val_losses = []
for epoch in pbar:

    # if epoch % SAVE_INTERVAL == 0 or epoch == 1:
    model.eval()
    with torch.no_grad():
        # Run inference on the full graph to get embeddings
        full_z = model.encode(data.x, edge_index.to(device)).cpu().numpy()
        saved_embeddings.append({'epoch': epoch, 'embeddings': full_z})
            
    train_loss = train()
    val_loss = test(data.val_pos_edge_index)

    train_losses.append(train_loss)
    val_losses.append(val_loss)

    # Update the progress bar's description with the current loss
    pbar.set_description(
        f"Epoch {epoch:03d} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}"
    )

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Validation Loss")
plt.title("Training and Validation Loss Over Epochs")
plt.xlabel("Epoch")
plt.ylabel("Reconstruction Loss")
plt.legend()
plt.grid(True)
plt.show()

In [None]:
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))

# This function will be called for each frame of the animation
def update(frame):
    # Clear the previous plots
    ax1.clear()
    ax2.clear()

    # Get the data for the current frame
    epoch_data = saved_embeddings[frame]
    epoch = epoch_data['epoch']
    embeddings = epoch_data['embeddings']
    
    # --- Perform clustering for the current frame ---
    adata.obsm['GCN_temp_embeddings'] = embeddings
    sc.pp.neighbors(adata, use_rep="GCN_temp_embeddings")
    sc.tl.leiden(adata, resolution=0.5, key_added="GCN_temp_clusters", random_state=42)
    
    # --- Plot 1: Spatial Clustering ---
    sq.pl.spatial_scatter(
        adata,
        color="GCN_temp_clusters",
        shape=None,
        ax=ax1,
        title=None,
        legend_loc=None
    )
    ax1.set_title(f"Spatial Neighborhoods\nEpoch: {epoch}")
    ax1.invert_yaxis()

    # --- Plot 2: UMAP of Embeddings ---
    sc.tl.umap(adata, min_dist=0.5, random_state=42)
    sc.pl.umap(
        adata,
        color="GCN_temp_clusters",
        ax=ax2,
        show=False,
        title=None,
        legend_loc='on data'
    )
    ax2.set_title(f"UMAP of Embeddings\nEpoch: {epoch}")

    # Set a single title for the entire figure for that frame
    fig.suptitle(f"GCN Training Convergence at Epoch {epoch}", fontsize=16)


# Create the animation object
# The interval is the delay between frames in milliseconds
ani = FuncAnimation(fig, update, frames=len(saved_embeddings), interval=500, repeat=False)

# Display the animation in the notebook
# This may take a minute or two to render
HTML(ani.to_html5_video())

In [None]:
model.eval()
with torch.no_grad():
    output = model.encode(data.x, edge_index.to(device))
    final_embeddings = output.cpu().numpy()
    
# Add the learned embeddings back to our original AnnData object for easy use
adata.obsm["GCN_embeddings"] = final_embeddings

final_embeddings.shape

**Observations**:
- In simple terms, this GNN learns a spatially-aware latent representation of the gene expression. Its GCN layers are built on the principle that local neighborhoods matter, and its contrastive loss function trains it to make the embeddings of neighboring nodes similar while making the embeddings of non-neighboring nodes dissimilar

---
### 6. Clustering & Visualization <a class="anchor" id="model"></a>

##### Perform Clustering on the GCN Embeddings

Build the nearest-neighbor graph in the embedding space using the learned embeddings stored in `adata.obsm['GCN_embeddings']`

In [None]:
sc.pp.neighbors(adata, use_rep="GCN_embeddings", n_neighbors=15)

Run the Leiden clustering algorithm.  The Leiden algorithm is a powerful community detection algorithm used to find groups of nodes that are more densely connected to each other than to the rest of the graph.

In [None]:
sc.tl.leiden(adata, resolution=0.5, key_added="GCN_leiden_clusters")
adata.obs['GCN_leiden_clusters'].value_counts()

##### Visualize the Cellular Neighborhoods

Use squidpy's plotting function to color each spot on the tissue image according to its assigned cluster ID. This reveals the spatial organization of the cellular neighborhoods identified by the GNN. In the same pannel, use UMAP to visualize how the embeddings cluster in 2D space colored by their Leiden assignments.

Compare the naive model to the trained model.

In [None]:
_, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].scatter(
    x=adata.obsm["spatial"][:, 0],
    y=adata.obsm["spatial"][:, 1] * -1
    - min(adata.obsm["spatial"][:, 1] * -1),  # invert y-axis
    c=adata.obs["total_counts"],
    cmap="viridis",
    s=8,
)
axes[0].set_title("Spatial Plot of Total Counts")
axes[0].set_aspect("equal")
axes[1].set_aspect("equal")
axes[2].set_aspect("equal")

sq.pl.spatial_scatter(
    adata,
    library_id="spatial",
    color="GCN_leiden_clusters",
    shape=None,
    title="Cellular Neighborhoods Identified by GCN Clustering",
    ax=axes[1]
)
axes[1].legend([])
sc.tl.umap(adata, min_dist=0.5)
sc.pl.umap(
    adata,
    color="GCN_leiden_clusters",
    title="UMAP of GCN Embeddings",
    ax=axes[2]
)

Identify the genes that are most significantly up-regulated in each cluster using a t-test to find differentially expressed genes.

In [None]:
sc.tl.rank_genes_groups(
    adata, 
    groupby='GCN_leiden_clusters', 
    method='t-test', 
    key_added='marker_genes'
)

# Visualize top marker genes for each cluster
pd.DataFrame(adata.uns['marker_genes']['names']).head(5)

Create a dot plot of the top 4 marker genes for each cluster

In [None]:
sc.pl.rank_genes_groups_dotplot(
    adata, 
    n_genes=4, 
    key='marker_genes', 
    groupby='GCN_leiden_clusters'
)

In [None]:
clusters = adata.obs['GCN_leiden_clusters'].unique()
sq.pl.spatial_scatter(
    adata,
    library_id="spatial",
    color="GCN_leiden_clusters",
    shape=None,
    title="Cellular Neighborhoods Identified by GCN Clustering",
)
plt.gca().invert_yaxis()

_, axes = plt.subplots(1, len(clusters) , figsize=(30, 5), sharey=True)
for idx, cluster in enumerate(clusters):
        
    top_marker_gene_for_cluster_2 = adata.uns['marker_genes']['names'][cluster][0]

    axes[idx].set_title(f"Cluster {cluster} - Top Marker Gene: {top_marker_gene_for_cluster_2}")
    sq.pl.spatial_scatter(
        adata,
        library_id="spatial",
        color=top_marker_gene_for_cluster_2,
        shape=None,
        title=f"Cluster: {cluster}\n{top_marker_gene_for_cluster_2}",
        ax=axes[idx],
        colorbar = False,
    )
    axes[idx].set_ylabel("")
    axes[idx].set_xlabel("")
    axes[idx].set_aspect("equal")
    
plt.tight_layout()

**Observations**
* The GNN successfully segmented the tissue into distinct, spatially coherent clusters that correspond to real, localized biological niches each defined by a unique top gene.

* The embeddings are high-quality - well-separated in the latent space meaning the GCN learned distinct feature representations for each neighborhood.

* The clusters correspond to biological effects based on the fact that the GCN Clusters appear in similar locations to the features in the Total Counts plot as well as regions of localized gene expression. Key Takeaways:

  * **Cluster 0 and 8**: These clusters are defined by an adaptive immune response since CD74 is essential for antigen presentation and IGLC2 is an immunoglobulin. These pinpoint regions dense with B-cells and plasma cells possibly forming an organized structure to fight the tumor. 

  * **Cluster 1**: This cluster possibly represents an highly proliferative and aggressive cancer region given that MALAT1 is suspected to be associated with metastasis.


---
### 6. Summary

...

The primary findings include:
* ...

Next Steps:
* Investigate the effect of different model architectures on the results. How do deeper networks affect the embeddings?