# 01 â€” Explore Dataset

In this notebook we:

- Load the configured Visium dataset (default: DLPFC-151673)  
- Inspect AnnData structure (obs, var, obsm)  
- Visualize the tissue and spot layout  
- Plot a few random genes  
- Inspect expression distributions  
- Get an intuitive feel for spatial patterns  

This is an exploratory notebook, not for final results.

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import scanpy as sc
import squidpy as sq

# Dask Future Warning fix:
import dask
dask.config.set({"dataframe.query-planning": True})

import warnings
# Silence the specific pandas/squidpy categorical warning
warnings.filterwarnings("ignore", category=FutureWarning, module="pandas")
# Silence the specialized AnnData implicit modification warnings
warnings.filterwarnings("ignore", message="Transforming to str index")

import os
from pathlib import Path

# Calibrate project root
while not (Path.cwd() / 'data').exists() and Path.cwd().parent != Path.cwd():
    os.chdir('..')

from src.data.data_loader import SpatialDataset
from src.visualization.plots import SpatialPlotter

# Session management
from src.utils.session import SessionManager
session = SessionManager.get_or_create_session(profile='default')
session.log("Starting notebook 01: Data exploration", notebook="01_explore")

def save_plot_to_session(filename):
    """Save current matplotlib figure to session plots directory."""
    path = session.get_plot_path(filename)
    plt.savefig(path, dpi=150, bbox_inches='tight')
    session.log(f"Saved {filename}", notebook="01_explore")
    return path

## Load dataset

In [None]:
dataset_path = session.config.get("dataset_path", "data/DLPFC-151673")
dataset = SpatialDataset(dataset_path)
dataset.load()

adata = dataset.adata
print(f"Loaded dataset from: {dataset_path}")
adata

## Inspect AnnData structure

We look at:
- `.obs` (spots metadata)
- `.var` (genes)
- `.obsm["spatial"]` (coordinates)

In [None]:
adata.obs.head()

In [None]:
adata.var.head()

In [None]:
adata.obsm["spatial"][:5]

## Plot tissue and spot layout

In [None]:
import json
from matplotlib.image import imread

# Load tissue image and scale factor for overlay
tissue_img_path = Path(dataset_path) / "spatial" / "tissue_hires_image.png"
scale_json_path = Path(dataset_path) / "spatial" / "scalefactors_json.json"

with open(scale_json_path) as f:
    scale_factors = json.load(f)
scale = scale_factors["tissue_hires_scalef"]
tissue_img = imread(str(tissue_img_path))

coords = adata.obsm["spatial"]
adata.obs["total_counts"] = np.array(adata.X.sum(axis=1)).flatten()
sample_name = Path(dataset_path).name

fig, ax = plt.subplots(figsize=(7, 6))
ax.imshow(tissue_img, alpha=0.4)
sc_plot = ax.scatter(
    coords[:, 0] * scale,
    coords[:, 1] * scale,
    c=adata.obs["total_counts"].values,
    cmap="magma",
    s=6,
    alpha=0.85,
    edgecolors="none",
)
ax.set_title(f"Tissue & Spots -- {sample_name}", fontsize=14, fontweight="bold")
ax.axis("off")
plt.colorbar(sc_plot, ax=ax, shrink=0.65, label="Total UMI counts")

plt.tight_layout()
save_plot_to_session("nb01_tissue_spots.png")
plt.show()
plt.close()

## Plot a few random genes

We pick a few random genes and visualize their spatial patterns.

In [None]:
plotter = SpatialPlotter(adata, dataset.filter_bank)

n_genes_to_show = 5
random_gene_ids = np.random.choice(adata.n_vars, size=n_genes_to_show, replace=False)

random_gene_ids

In [None]:
for gid in random_gene_ids:
    gene_name = adata.var.index[gid]
    print(f"\nGene id: {gid}, name: {gene_name}")
    plot_path = session.get_plot_path(f"nb01_gene_diag_{gid}.png")
    plotter.full_gene_diagnostic_plot(gid, save=True, path=plot_path)

## Expression distribution overview

We inspect:
- total counts per spot
- total counts per gene
- number of detected genes per spot

In [None]:
# Total counts per spot
spot_counts = np.array(adata.X.sum(axis=1)).flatten()

plt.figure(figsize=(6, 4))
plt.hist(spot_counts, bins=50)
plt.xlabel("Total counts per spot")
plt.ylabel("Frequency")
plt.title("Distribution of total counts per spot")
save_plot_to_session("nb01_dist_counts_per_spot.png")
plt.show()
plt.close()

In [None]:
# Total counts per gene
gene_counts = np.array(adata.X.sum(axis=0)).flatten()

plt.figure(figsize=(6, 4))
plt.hist(gene_counts, bins=50)
plt.xlabel("Total counts per gene")
plt.ylabel("Frequency")
plt.title("Distribution of total counts per gene")
save_plot_to_session("nb01_dist_counts_per_gene.png")
plt.show()
plt.close()

In [None]:
# Number of detected genes per spot
detected_genes_per_spot = (adata.X > 0).sum(axis=1).A1

plt.figure(figsize=(6, 4))
plt.hist(detected_genes_per_spot, bins=50)
plt.xlabel("Detected genes per spot")
plt.ylabel("Frequency")
plt.title("Distribution of detected genes per spot")
save_plot_to_session("nb01_dist_genes_per_spot.png")
plt.show()
plt.close()

## Optional: plot known marker genes (if available)

In [None]:
marker_name = "MOBP"  # change if not present
if marker_name in adata.var.index:
    gid = np.where(adata.var.index == marker_name)[0][0]
    print(f"Marker {marker_name} found at index {gid}")
    plot_path = session.get_plot_path(f"nb01_marker_{marker_name}.png")
    plotter.full_gene_diagnostic_plot(gid, save=True, path=plot_path)
else:
    print(f"Marker {marker_name} not found in this dataset.")