In [None]:
# !pip install google-colab
!pip install opencv-python
!pip install -U albumentations
from google.colab import drive
drive.mount('/content/drive')


Collecting opencv-python
  Downloading opencv_python-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (19 kB)
Downloading opencv_python-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (67.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.0/67.0 MB[0m [31m19.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: opencv-python
Successfully installed opencv-python-4.12.0.88
Collecting albumentations
  Downloading albumentations-2.0.8-py3-none-any.whl.metadata (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.1/43.1 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
Collecting albucore==0.0.24 (from albumentations)
  Downloading albucore-0.0.24-py3-none-any.whl.metadata (5.3 kB)
Collecting opencv-python-headless>=4.9.0.80 (from albumentations)
  Downloading opencv_python_headless-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (19 kB)
Collecting stringzilla>=3.

In [None]:
import zipfile
import os

zip_path = "/content/drive/MyDrive/mri_data_png.zip"   # adjust path if different
extract_path = "/content/mri_data_png"

os.makedirs(extract_path, exist_ok=True)

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

print("Extracted to:", extract_path)


Extracted to: /content/mri_data_png


In [None]:
import torch
import torch.nn as nn

class Block(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, act="relu", use_dropout=False, norm="instance"):
        super(Block, self).__init__()

        conv = (
            nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode="reflect")
            if down else
            nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False)
        )

        # Choose normalization (default: InstanceNorm for small-batch stability)
        if norm == "instance":
            norm_layer = nn.InstanceNorm2d(out_channels, affine=True)
        elif norm == "batch":
            norm_layer = nn.BatchNorm2d(out_channels)
        elif norm is None:
            norm_layer = nn.Identity()
        else:
            raise ValueError(f"Unknown norm='{norm}'. Use 'instance', 'batch', or None.")

        act_layer = nn.ReLU(inplace=True) if act == "relu" else nn.LeakyReLU(0.2, inplace=True)

        self.conv = nn.Sequential(conv, norm_layer, act_layer)
        self.use_dropout = use_dropout
        self.dropout = nn.Dropout2d(0.5)  # keep Dropout (pix2pix uses it in some up blocks)

        self.down = down  # (kept for clarity; not used inside forward)

    def forward(self, x):
        x = self.conv(x)
        return self.dropout(x) if self.use_dropout else x


# your Block is the same one you just finalized
# class Block(...):  # as you defined above
#     ...

import torch
import torch.nn as nn

# uses your finalized Block (InstanceNorm + Dropout2d etc.)
class Generator(nn.Module):
    def __init__(self, in_ch=1, out_ch=6, features=64):
        super().__init__()
        # Encoder: 256->128->64->32->16->8->4->2  (initial_down + 6 downs)
        self.initial_down = nn.Sequential(
            nn.Conv2d(in_ch, features, 4, 2, 1, padding_mode="reflect"),  # 256 -> 128
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.down1 = Block(features,       features * 2, down=True, act="leaky")   # 128 -> 64
        self.down2 = Block(features * 2,   features * 4, down=True, act="leaky")   # 64  -> 32
        self.down3 = Block(features * 4,   features * 8, down=True, act="leaky")   # 32  -> 16
        self.down4 = Block(features * 8,   features * 8, down=True, act="leaky")   # 16  -> 8
        self.down5 = Block(features * 8,   features * 8, down=True, act="leaky")   # 8   -> 4
        self.down6 = Block(features * 8,   features * 8, down=True, act="leaky")   # 4   -> 2

        # Bottleneck: 2x2 -> 1x1
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features * 8, features * 8, 4, 2, 1),  # 2 -> 1
            nn.ReLU(inplace=True),
        )

        # Decoder: 1->2(cat d6)->4(cat d5)->8(cat d4)->16(cat d3)->32(cat d2)->64(cat d1)->128(cat d0)->256
        self.up1 = Block(features * 8,       features * 8, down=False, act="relu", use_dropout=True)        # 1 -> 2
        self.up2 = Block(features * 8 * 2,   features * 8, down=False, act="relu", use_dropout=True)        # 2 -> 4
        self.up3 = Block(features * 8 * 2,   features * 8, down=False, act="relu", use_dropout=True)        # 4 -> 8
        self.up4 = Block(features * 8 * 2,   features * 8, down=False, act="relu")                          # 8 -> 16
        self.up5 = Block(features * 8 * 2,   features * 4, down=False, act="relu")                          # 16 -> 32
        self.up6 = Block(features * 4 * 2,   features * 2, down=False, act="relu")                          # 32 -> 64
        self.up7 = Block(features * 2 * 2,   features,     down=False, act="relu")                          # 64 -> 128

        # Final: 128 -> 256, logits for 6 classes (NO Tanh)
        self.final_up = nn.ConvTranspose2d(features * 2, out_ch, kernel_size=4, stride=2, padding=1)

    def forward(self, x):
        d1 = self.initial_down(x)   # 128
        d2 = self.down1(d1)         # 64
        d3 = self.down2(d2)         # 32
        d4 = self.down3(d3)         # 16
        d5 = self.down4(d4)         # 8
        d6 = self.down5(d5)         # 4
        d7 = self.down6(d6)         # 2

        b  = self.bottleneck(d7)    # 1

        u1 = self.up1(b)                            # 2
        u2 = self.up2(torch.cat([u1, d7], 1))       # 4
        u3 = self.up3(torch.cat([u2, d6], 1))       # 8
        u4 = self.up4(torch.cat([u3, d5], 1))       # 16
        u5 = self.up5(torch.cat([u4, d4], 1))       # 32
        u6 = self.up6(torch.cat([u5, d3], 1))       # 64
        u7 = self.up7(torch.cat([u6, d2], 1))       # 128

        logits = self.final_up(torch.cat([u7, d1], 1))  # 256
        return logits


In [None]:
G = Generator(in_ch=1, out_ch=6, features=64)

# Dummy input: batch of 2 grayscale PNG slices (N=2, C=1, H=256, W=256)
x = torch.randn(2, 1, 256, 256)

# Forward pass
y = G(x)

print("Input shape :", x.shape)
print("Output shape:", y.shape)

Input shape : torch.Size([2, 1, 256, 256])
Output shape: torch.Size([2, 6, 256, 256])


In [None]:
import torch
import torch.nn as nn

class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=False, padding_mode="reflect"),
            nn.InstanceNorm2d(out_channels, affine=True),  # swapped from BatchNorm2d
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


import torch
import torch.nn as nn
import torch.nn.functional as F

class Discriminator(nn.Module):
    """
    PatchGAN discriminator for MRI->mask at 256x256.

    Expects either:
      - x: (N, in_ch_x, H, W)
      - y: (N, in_ch_y, H, W)   # one-hot or softmax probs
    OR
      - x: (N, in_ch_x, H, W)
      - y: (N, H, W)            # integer class indices in [0..in_ch_y-1]

    Returns: (N, 1, h, w) patch scores.
    """
    def __init__(self, in_ch_x=1, in_ch_y=6, features=(64, 128, 256, 512)):
        super().__init__()
        self.in_ch_x = in_ch_x
        self.in_ch_y = in_ch_y

        in_pair = in_ch_x + in_ch_y

        # First layer: no norm (pix2pix convention)
        self.initial = nn.Sequential(
            nn.Conv2d(in_pair, features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
            nn.LeakyReLU(0.2, inplace=True),
        )

        blocks = []
        in_c = features[0]
        for f in features[1:]:
            stride = 1 if f == features[-1] else 2
            blocks.append(nn.Sequential(
                nn.Conv2d(in_c, f, kernel_size=4, stride=stride, padding=1, bias=False, padding_mode="reflect"),
                nn.InstanceNorm2d(f, affine=True),
                nn.LeakyReLU(0.2, inplace=True),
            ))
            in_c = f

        # Final 1-channel conv -> patch score map
        blocks.append(nn.Conv2d(in_c, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))

        self.model = nn.Sequential(*blocks)

    def _ensure_mask_channels(self, y):
        """
        Convert (N,H,W) index mask to (N,C,H,W) one-hot if needed.
        If y already has shape (N,C,H,W), return as-is.
        """
        if y.dim() == 3:
            # y is class indices
            if not (y.dtype == torch.long or y.dtype == torch.int64):
                y = y.long()
            y = F.one_hot(y, num_classes=self.in_ch_y).permute(0, 3, 1, 2).float()
        elif y.dim() == 4:
            # y is (N,C,H,W) already
            # ensure channel count matches config
            if y.size(1) != self.in_ch_y:
                raise ValueError(f"Expected mask with {self.in_ch_y} channels, got {y.size(1)}.")
            # if it's probs that's fine; if it's one-hot int, cast to float
            if not y.is_floating_point():
                y = y.float()
        else:
            raise ValueError(f"Mask y must be (N,H,W) or (N,C,H,W), got shape {tuple(y.shape)}.")
        return y

    def forward(self, x, y):
        """
        x: (N, in_ch_x, H, W)
        y: (N, in_ch_y, H, W)  or  (N, H, W) indices
        """
        y = self._ensure_mask_channels(y)
        pair = torch.cat([x, y], dim=1)        # (N, in_ch_x+in_ch_y, H, W)
        h = self.initial(pair)
        out = self.model(h)
        return out
# quick sanity test for your shapes
if __name__ == "__main__":
    N, H, W = 2, 256, 256
    D = Discriminator(in_ch_x=1, in_ch_y=6)

    x = torch.randn(N, 1, H, W)               # grayscale
    y_idx = torch.randint(0, 6, (N, H, W))    # indices
    out_real = D(x, y_idx)
    print("D(x, y_idx) ->", out_real.shape)

    y_probs = torch.softmax(torch.randn(N, 6, H, W), dim=1)
    out_fake = D(x, y_probs)
    print("D(x, y_probs) ->", out_fake.shape)

D(x, y_idx) -> torch.Size([2, 1, 30, 30])
D(x, y_probs) -> torch.Size([2, 1, 30, 30])


In [None]:
import os, glob, random, cv2, numpy as np, torch
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2

# ---------- discover & split ----------
def collect_ids(root="mri_data", img_dir="images", mask_dir="masks"):
    imgs  = {os.path.splitext(os.path.basename(p))[0]: p
             for p in glob.glob(os.path.join(root, img_dir, "*.png"))}
    masks = {os.path.splitext(os.path.basename(p))[0]: p
             for p in glob.glob(os.path.join(root, mask_dir, "*.png"))}
    ids = sorted(list(set(imgs.keys()) & set(masks.keys())))
    if not ids:
        raise RuntimeError("No matching image/mask basenames found.")
    return ids, imgs, masks

def split_ids(ids, val_ratio=0.1, seed=42):
    random.Random(seed).shuffle(ids)
    n_val = max(1, int(len(ids) * val_ratio))
    return ids[n_val:], ids[:n_val]  # train_ids, val_ids

# ---------- dataset ----------
class SpinePNG(Dataset):
    def __init__(self, ids, img_map, mask_map):
        self.ids = ids
        self.img_map = img_map
        self.mask_map = mask_map
        self.tf = A.Compose([
            A.Resize(256, 256),
            A.HorizontalFlip(p=0.5),
            A.Normalize(mean=(0.5,), std=(0.5,)),  # assumes image scaled to [0,1]
            ToTensorV2()
        ])

    def __len__(self): return len(self.ids)

    def __getitem__(self, i):
        bid = self.ids[i]
        img = cv2.imread(self.img_map[bid], cv2.IMREAD_UNCHANGED)  # 8/16-bit ok
        msk = cv2.imread(self.mask_map[bid], cv2.IMREAD_UNCHANGED) # uint8 indices 0..5

        # ensure grayscale
        if img.ndim == 3:  # BGR -> gray
            img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

        # scale image to [0,1]
        if img.dtype == np.uint16:
            img = img.astype(np.float32) / 65535.0
        else:
            img = img.astype(np.float32) / 255.0

        # albumentations expects HxW
        aug = self.tf(image=img, mask=msk.astype(np.int64))
        x = aug["image"]            # (1,256,256) because ToTensorV2 adds C for gray
        if x.ndim == 2:             # safety: add channel if needed
            x = x.unsqueeze(0)
        y = aug["mask"].long()      # (256,256) class indices
        return x, y

# ---------- build loaders ----------
def make_loaders(root="mri_data", batch_size=4, val_ratio=0.1, num_workers=4, seed=42):
    ids, img_map, mask_map = collect_ids(root)
    train_ids, val_ids = split_ids(ids, val_ratio=val_ratio, seed=seed)
    train_ds = SpinePNG(train_ids, img_map, mask_map)
    val_ds   = SpinePNG(val_ids,   img_map, mask_map)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                              num_workers=num_workers, pin_memory=True)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False,
                              num_workers=max(1, num_workers//2), pin_memory=True)
    return train_loader, val_loader

train_loader, val_loader = make_loaders(root="/content/mri_data_png/data", batch_size=4, val_ratio=0.1)

for xb, yb in train_loader:
    print(xb.shape, yb.shape)  # expect (N,1,256,256) and (N,256,256)
    break
print(len(train_loader.dataset))  # should be ~231
print(len(val_loader.dataset))    # should be ~26



torch.Size([4, 1, 256, 256]) torch.Size([4, 256, 256])
3182
353


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

NUM_CLASSES = 20
bce = nn.BCEWithLogitsLoss()
ce  = nn.CrossEntropyLoss()

def dice_loss_from_logits(logits, target_idx, eps=1e-6):
    p = F.softmax(logits, dim=1)                                   # (N,20,H,W)
    t = F.one_hot(target_idx, NUM_CLASSES).permute(0,3,1,2).float()# (N,20,H,W)
    num = 2 * (p * t).sum((0,2,3))
    den = (p*p).sum((0,2,3)) + (t*t).sum((0,2,3)) + eps
    return 1 - (num/den).mean()

def train(gen, disc, train_loader, val_loader, device="cuda", epochs=20, lr=2e-4, lambda_seg=20.0,
          save_each_epoch=False):
    gen.to(device); disc.to(device)

    opt_g = torch.optim.Adam(gen.parameters(),  lr=lr, betas=(0.5, 0.999))
    opt_d = torch.optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))

    amp_enabled = device.startswith("cuda") and torch.cuda.is_available()
    scaler_g = torch.cuda.amp.GradScaler(enabled=amp_enabled)
    scaler_d = torch.cuda.amp.GradScaler(enabled=amp_enabled)

    for epoch in range(1, epochs+1):
        gen.train(); disc.train()
        pbar = tqdm(train_loader, desc=f"epoch {epoch}/{epochs}", ncols=100, leave=False)

        for x, y_idx in pbar:
            x = x.to(device, non_blocking=True)
            y_idx = y_idx.to(device, non_blocking=True)

            # ---- D step ----
            with torch.cuda.amp.autocast(enabled=amp_enabled):
                with torch.no_grad():
                    logits_fake = gen(x)
                    probs_fake  = F.softmax(logits_fake, dim=1)       # (N,20,H,W)
                y_real = F.one_hot(y_idx, NUM_CLASSES).permute(0,3,1,2).float()
                d_real = disc(x, y_real)
                d_fake = disc(x, probs_fake.detach())
                loss_d = bce(d_real, torch.ones_like(d_real)) + bce(d_fake, torch.zeros_like(d_fake))

            opt_d.zero_grad(set_to_none=True)
            scaler_d.scale(loss_d).backward()
            scaler_d.step(opt_d); scaler_d.update()

            # ---- G step ----
            with torch.cuda.amp.autocast(enabled=amp_enabled):
                logits = gen(x)
                probs  = F.softmax(logits, dim=1)
                d_fake_for_g = disc(x, probs)
                gan_loss = bce(d_fake_for_g, torch.ones_like(d_fake_for_g))
                seg_loss = ce(logits, y_idx) + dice_loss_from_logits(logits, y_idx)
                loss_g = gan_loss + lambda_seg * seg_loss

            opt_g.zero_grad(set_to_none=True)
            scaler_g.scale(loss_g).backward()
            scaler_g.step(opt_g); scaler_g.update()

            pbar.set_postfix(D=f"{loss_d.item():.3f}",
                              G=f"{loss_g.item():.3f}",
                              CE=f"{ce(logits, y_idx).item():.3f}")

        # ---- validation: CE + Dice ----
        gen.eval()
        ce_sum, dice_sum, n = 0.0, 0.0, 0
        with torch.inference_mode():
            for x, y_idx in val_loader:
                x = x.to(device); y_idx = y_idx.to(device)
                logits = gen(x)
                ce_sum   += ce(logits, y_idx).item()
                dice_sum += (1.0 - dice_loss_from_logits(logits, y_idx).item())
                n += 1

        avg_ce   = ce_sum / max(n, 1)
        avg_dice = dice_sum / max(n, 1)
        print(f"[epoch {epoch}] val CE: {avg_ce:.3f} | val Dice: {avg_dice:.3f}")

        # optional: save once per epoch
        if save_each_epoch:
            torch.save(gen.state_dict(), f"gen_epoch{epoch}.pth")

    # return AFTER all epochs
    return gen


In [None]:
gen  = Generator(in_ch=1, out_ch=NUM_CLASSES, features=64)
disc = Discriminator(in_ch_x=1, in_ch_y=NUM_CLASSES)

model = train(gen, disc, train_loader, val_loader, device="cpu", epochs=20, lr=2e-4, lambda_seg=20.0)


  scaler_g = torch.cuda.amp.GradScaler(enabled=amp_enabled)
  scaler_d = torch.cuda.amp.GradScaler(enabled=amp_enabled)
  with torch.cuda.amp.autocast(enabled=amp_enabled):
  with torch.cuda.amp.autocast(enabled=amp_enabled):


[epoch 1] val CE: 0.213 | val Dice: 0.468




[epoch 2] val CE: 0.181 | val Dice: 0.588




[epoch 3] val CE: 0.134 | val Dice: 0.668




[epoch 4] val CE: 0.199 | val Dice: 0.666


epoch 5/20:  64%|████████████▊       | 509/796 [10:43<05:58,  1.25s/it, CE=0.056, D=0.044, G=12.732]

In [None]:
import os, glob, csv
import numpy as np
import cv2
from collections import Counter

MASK_DIR = "/content/mri_data_png/data/masks"
NUM_CLASSES = 6            # <-- set your current class count

# --- aggregate stats ---
global_counts = Counter()
bad_files = []            # files containing labels >= NUM_CLASSES or < 0
rgb_like = []             # files that load as 3-channel (likely color masks)
per_file_stats = []       # (path, min, max, uniques_truncated)

mask_paths = sorted(glob.glob(os.path.join(MASK_DIR, "*.png")))
if not mask_paths:
    raise RuntimeError(f"No PNG masks found under {MASK_DIR}")

for p in mask_paths:
    m = cv2.imread(p, cv2.IMREAD_UNCHANGED)

    if m is None:
        print(f"[WARN] Could not read: {p}")
        continue

    # If mask is accidentally RGB, keep note; we’ll take single channel for auditing
    if m.ndim == 3:
        rgb_like.append(p)
        m = m[..., 0]  # take first channel just to inspect values

    # ensure integer type
    if not np.issubdtype(m.dtype, np.integer):
        # round if float; then cast
        m = np.rint(m).astype(np.int64)
    else:
        m = m.astype(np.int64)

    # collect unique values for this file
    u, c = np.unique(m, return_counts=True)
    global_counts.update(dict(zip(u.tolist(), c.tolist())))

    # min/max and check bounds
    min_v, max_v = int(u.min()), int(u.max())
    per_file_stats.append((p, min_v, max_v, u[:20].tolist()))  # only show first 20 unique values

    if (min_v < 0) or (max_v >= NUM_CLASSES):
        bad_files.append((p, min_v, max_v))

# --- print summary ---
total_pixels = sum(global_counts.values())
sorted_vals = sorted(global_counts.items(), key=lambda kv: kv[0])

print(f"Scanned {len(mask_paths)} mask files")
print(f"Total pixels counted: {total_pixels:,}")
print("\nGlobal label histogram (value: count, percent):")
for v, cnt in sorted_vals:
    pct = 100.0 * cnt / (total_pixels + 1e-9)
    print(f"  {int(v):>3}: {cnt:>10}  ({pct:5.2f}%)")

print("\nMin/Max over all files:")
all_vals = [v for v,_ in sorted_vals]
print(f"  global min: {min(all_vals)}")
print(f"  global max: {max(all_vals)}")

if rgb_like:
    print(f"\n[WARN] {len(rgb_like)} masks loaded as 3-channel (likely color). Examples:")
    for p in rgb_like[:5]:
        print("  ", p)
    print("→ Convert these to single-channel index masks (uint8 indices 0..NUM_CLASSES-1).")

if bad_files:
    print(f"\n[PROBLEM] {len(bad_files)} files contain labels outside [0..{NUM_CLASSES-1}]. Examples:")
    for p, mn, mx in bad_files[:10]:
        print(f"  {p}  (min={mn}, max={mx})")
else:
    print(f"\nAll files are within [0..{NUM_CLASSES-1}].")


Scanned 3535 mask files
Total pixels counted: 955,040,358

Global label histogram (value: count, percent):
    0:  892493471  (93.45%)
    1:    7766168  ( 0.81%)
    2:    8144924  ( 0.85%)
    3:    7945922  ( 0.83%)
    4:    7153831  ( 0.75%)
    5:    6286986  ( 0.66%)
    6:    4966257  ( 0.52%)
    7:    2323574  ( 0.24%)
    8:     576605  ( 0.06%)
    9:      68376  ( 0.01%)
   10:    8390497  ( 0.88%)
   11:    1413412  ( 0.15%)
   12:    1656320  ( 0.17%)
   13:    1663399  ( 0.17%)
   14:    1492928  ( 0.16%)
   15:    1184436  ( 0.12%)
   16:     877683  ( 0.09%)
   17:     480197  ( 0.05%)
   18:     127329  ( 0.01%)
   19:      28043  ( 0.00%)

Min/Max over all files:
  global min: 0
  global max: 19

[PROBLEM] 3285 files contain labels outside [0..5]. Examples:
  /content/mri_data_png/data/masks/100_01.png  (min=0, max=8)
  /content/mri_data_png/data/masks/100_02.png  (min=0, max=11)
  /content/mri_data_png/data/masks/100_03.png  (min=0, max=18)
  /content/mri_data_png/

In [None]:
# !pip install spatialdata-io
!wget https://s3.embl.de/spatialdata/spatialdata-sandbox/merfish.zip

--2025-09-05 07:22:14--  https://s3.embl.de/spatialdata/spatialdata-sandbox/merfish.zip
Resolving s3.embl.de (s3.embl.de)... 194.94.45.80
Connecting to s3.embl.de (s3.embl.de)|194.94.45.80|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 53533526 (51M) [application/zip]
Saving to: ‘merfish.zip’


2025-09-05 07:22:17 (19.4 MB/s) - ‘merfish.zip’ saved [53533526/53533526]



In [None]:
import zipfile
import os

zip_path = "/content/merfish.zip"   # adjust path if different
extract_path = "/content/merfish"

os.makedirs(extract_path, exist_ok=True)

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

print("Extracted to:", extract_path)

Extracted to: /content/merfish


In [None]:
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import spatialdata as sd           # read_zarr
import scanpy as sc                # transcriptomics core
import squidpy as sq               # spatial graph/stats
from shapely.geometry import shape # centroid fallback with GeoJSON dicts
import pyarrow.parquet as pq       # read shapes/points parquet

ZARR_DIR = Path("/content/merfish/data.zarr")  # <-- change to absolute path if needed
assert ZARR_DIR.exists(), f"{ZARR_DIR} not found"

# -------------------------
# 1) Load SpatialData Zarr
# -------------------------
sdata = sd.read_zarr(str(ZARR_DIR))

# Pick the main table -> AnnData (cells x genes)
# Most SpatialData stores name it "table"
if "table" in sdata.tables:
    adata = sdata.tables["table"]
else:
    # take the first table if not named "table"
    adata = next(iter(sdata.tables.values()))

print(adata)
print("obs columns:", list(adata.obs.columns)[:10])
print("var columns:", list(adata.var.columns)[:10])
print("obsm keys:", list(adata.obsm.keys()))

# ---------------------------------------------------
# 2) Get per-cell spatial coordinates for visualization
#    Priority:
#    (A) adata.obsm['spatial'] if present
#    (B) centroids from shapes/cells/shapes.parquet
#    (C) centroids from points if points contain cell_id
# ---------------------------------------------------
coords = None

# (A) common convention: 'spatial'
if "spatial" in adata.obsm and adata.obsm["spatial"] is not None:
    XY = np.asarray(adata.obsm["spatial"])
    coords = pd.DataFrame({"cell_id": adata.obs_names.astype(str),
                           "x": XY[:,0].astype(float),
                           "y": XY[:,1].astype(float)})
    print("Using coordinates from adata.obsm['spatial'].")

# (B) cell polygons -> centroids
if coords is None:
    shapes_cells = ZARR_DIR/"shapes"/"cells"/"shapes.parquet"
    if shapes_cells.exists():
        df_shapes = pq.read_table(shapes_cells).to_pandas()
        # heuristics to find an id column
        id_col = None
        for k in ["cell_id","id","label","name","_index","index"]:
            if k in df_shapes.columns:
                id_col = k; break
        if id_col is None:
            df_shapes = df_shapes.reset_index().rename(columns={"index":"cell_id"})
            id_col = "cell_id"

        # If geometry is present in WKB/GeoJSON-like dicts
        if "geometry" in df_shapes.columns:
            # geometry may be dict (GeoJSON) or bytes (WKB). Handle both:
            try:
                cent = []
                for g in df_shapes["geometry"]:
                    if isinstance(g, dict):
                        c = shape(g).centroid
                    elif isinstance(g, (bytes, bytearray, memoryview)):
                        from shapely import wkb
                        c = wkb.loads(bytes(g)).centroid
                    else:
                        # unknown geometry encoding; skip
                        c = None
                    cent.append((float(c.x), float(c.y)) if c is not None else (np.nan, np.nan))
                cent = np.array(cent)
                coords = pd.DataFrame({"cell_id": df_shapes[id_col].astype(str),
                                       "x": cent[:,0], "y": cent[:,1]})
                coords = coords.dropna()
                print("Using coordinates from shapes/cells centroids.")
            except Exception as e:
                print("Could not parse shapes geometry:", e)

# (C) points with cell_id -> mean per cell_id
if coords is None:
    points_dir = ZARR_DIR/"points"/"single_molecule"/"points.parquet"
    if points_dir.exists():
        pts = pq.read_table(points_dir).to_pandas()
        cols = {c.lower(): c for c in pts.columns}
        if {"cell_id","x","y"}.issubset(cols.keys()):
            cell_col, x_col, y_col = cols["cell_id"], cols["x"], cols["y"]
            coords = (pts.groupby(cell_col)[[x_col,y_col]]
                        .mean().reset_index()
                        .rename(columns={cell_col:"cell_id", x_col:"x", y_col:"y"}))
            print("Using coordinates from points (mean x,y per cell_id).")
        else:
            print("Points found but missing 'cell_id' -> cannot derive per-cell coords from points.")

# Save coords csv if available (useful for debugging or plotting elsewhere)
if coords is not None:
    coords.to_csv("coords.csv", index=False)
    print("coords.csv saved:", coords.shape)
else:
    print("WARNING: No per-cell coordinates found. Spatial plots will be skipped.")

# ---------------------------------------
# 3) Export counts.csv (cells x genes)
# ---------------------------------------
counts = pd.DataFrame(adata.X.toarray() if hasattr(adata.X, "toarray") else adata.X,
                      index=adata.obs_names, columns=adata.var_names)
counts.reset_index().rename(columns={"index":"cell_id"}).to_csv("counts.csv", index=False)
print("counts.csv saved:", counts.shape)

# -------------------------------------------------------
# 4) Minimal analysis: QC -> normalize/log -> PCA -> UMAP/Leiden
# -------------------------------------------------------
sc.pp.calculate_qc_metrics(adata, inplace=True)  # n_genes_by_counts, total_counts, etc.

# Simple filtering suggestions (optional)
# sc.pp.filter_cells(adata, min_genes=100)
# sc.pp.filter_genes(adata, min_cells=3)

sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, n_top_genes=2000, flavor="seurat_v3", subset=True)
sc.pp.scale(adata, max_value=10)
sc.tl.pca(adata, n_comps=50)
sc.pp.neighbors(adata, n_neighbors=15, n_pcs=30)
sc.tl.umap(adata)
sc.tl.leiden(adata, key_added="cluster")

# UMAP colored by cluster
sc.pl.umap(adata, color=["cluster"], legend_loc="on data", show=False)
plt.savefig("umap_clusters.png", dpi=150, bbox_inches="tight"); plt.close()
print("Saved: umap_clusters.png")

# -------------------------------------------------------
# 5) Spatial analysis (only if we have per-cell coords)
# -------------------------------------------------------
if coords is not None:
    # align order to adata.obs
    coords_indexed = coords.set_index("cell_id").reindex(adata.obs_names)
    # store as generic coordinates
    adata.obsm["spatial"] = coords_indexed[["x","y"]].values

    # Build spatial neighbor graph (generic 2D coordinates)
    sq.gr.spatial_neighbors(adata, coord_type="generic")
    # Neighborhood enrichment between clusters
    sq.gr.nhood_enrichment(adata, cluster_key="cluster")

    # Quick spatial scatter (matplotlib)
    fig, ax = plt.subplots()
    sc = ax.scatter(adata.obsm["spatial"][:,0], adata.obsm["spatial"][:,1],
                    c=pd.Categorical(adata.obs["cluster"]).codes, s=6)
    ax.set_title("Spatial clusters (scatter)")
    ax.set_aspect("equal")
    ax.invert_yaxis()  # image-like coords
    plt.savefig("spatial_clusters.png", dpi=200, bbox_inches="tight"); plt.close()
    print("Saved: spatial_clusters.png")

    # Neighborhood enrichment heatmap
    sq.pl.nhood_enrichment(adata, cluster_key="cluster", figsize=(5,4), show=False)
    plt.savefig("nhood_enrichment.png", dpi=150, bbox_inches="tight"); plt.close()
    print("Saved: nhood_enrichment.png")
else:
    print("Skipped spatial plots: no coords.")

# -------------------------------------------------------
# 6) Simple marker discovery per cluster (Δ-mean quick pass)
# -------------------------------------------------------
clusters = adata.obs["cluster"].astype(str)
X = (adata.X.toarray() if hasattr(adata.X, "toarray") else adata.X)
import numpy as np
markers = []
for c in sorted(clusters.unique()):
    mask = (clusters == c).values
    mu_in  = np.asarray(X[mask,:]).mean(axis=0)
    mu_out = np.asarray(X[~mask,:]).mean(axis=0)
    delta = mu_in - mu_out
    top_idx = np.argsort(delta)[::-1][:10]
    for rank, j in enumerate(top_idx, start=1):
        markers.append({"cluster": c, "rank": rank, "gene": adata.var_names[j], "delta_mean": float(delta[j])})
pd.DataFrame(markers).to_csv("cluster_markers_top10.csv", index=False)
print("Saved: cluster_markers_top10.csv")


  from skimage.io import imread
  from squidpy.im._io import _assert_dims_present, _infer_dimensions, _lazy_load_image
  from squidpy.read._utils import _load_image, _read_counts
  compressor, fill_value = _kwargs_compat(compressor, fill_value, kwargs)


AnnData object with n_obs × n_vars = 2389 × 268
    obs: 'cell_id', 'region'
    uns: 'spatialdata_attrs'
obs columns: ['cell_id', 'region']
var columns: []
obsm keys: []
Using coordinates from shapes/cells centroids.
coords.csv saved: (2389, 3)
counts.csv saved: (2389, 268)


IndexError: Positions outside range of features.

In [None]:
# (recommended) conda create -n sdata python=3.10 -y && conda activate sdata
!python -m pip install -U pip
!python -m pip install spatialdata scanpy squidpy zarr pyarrow shapely geopandas matplotlib


Collecting pip
  Downloading pip-25.2-py3-none-any.whl.metadata (4.7 kB)
Downloading pip-25.2-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m19.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
Successfully installed pip-25.2
Collecting squidpy
  Downloading squidpy-1.6.5-py3-none-any.whl.metadata (9.0 kB)
Collecting docrep>=0.3.1 (from squidpy)
  Downloading docrep-0.3.2.tar.gz (33 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting matplotlib-scalebar>=0.8.0 (from squidpy)
  Downloading matplotlib_scalebar-0.9.0-py3-none-any.whl.metadata (18 kB)
Collecting omnipath>=1.0.7 (from squidpy)
  Downloading omnipath-1.0.12-py3-none-any.whl.metadata (7.0 kB)
Collecting validators>=0.18.2 (from squidpy)
  Downloading validators-0.35.0-py3-none-any.whl.me