# Lab 3 â€” CNNs for Microscopy Image Analysis
**Models:** U-Net (semantic segmentation) + StarDist (instance segmentation)

**Objective:** 
- Test two very different CNNs on microscopy images.
- Observe strengths, weaknesses, and visualize results.


In [None]:
# 1) Imports and setup
import os
from pathlib import Path
import json
import zipfile
import requests
from io import BytesIO

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image

import cv2

import torch
import torchvision.transforms as T
import segmentation_models_pytorch as smp

from stardist.models import StarDist2D
from csbdeep.utils import normalize
from skimage import measure

# Paths
ROOT = Path.cwd()
DATA_DIR = ROOT / "data"
RAW_DIR = DATA_DIR / "LD4_images"
OUTPUTS = ROOT / "outputs"
IMGS_OUT = OUTPUTS / "images"
METRICS_DIR = OUTPUTS / "metrics"
MODELS_DIR = ROOT / "models"

for p in [DATA_DIR, RAW_DIR, OUTPUTS, IMGS_OUT, METRICS_DIR, MODELS_DIR]:
    p.mkdir(parents=True, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)


SyntaxError: invalid syntax (3565984541.py, line 2)

In [None]:
# 2) Download LD4 images (Dropbox)
dropbox_url = "https://www.dropbox.com/s/qscq5qa5v5nbwxi/LD4_images.zip?dl=1"

if not any(RAW_DIR.iterdir()):
    print("Downloading LD4 images...")
    r = requests.get(dropbox_url, stream=True)
    r.raise_for_status()
    z = zipfile.ZipFile(BytesIO(r.content))
    z.extractall(RAW_DIR)
    print("Extracted to", RAW_DIR)
else:
    print("LD4 images already exist.")


In [None]:
# 3) Load images
img_paths = sorted([p for p in RAW_DIR.rglob("*") if p.suffix.lower() in [".png",".jpg",".tif"]])
print(f"Found {len(img_paths)} images")

def load_image(path):
    im = Image.open(path).convert("RGB")
    return np.array(im)

def show_image(im, title=None):
    plt.figure(figsize=(5,5))
    plt.imshow(im)
    if title: plt.title(title)
    plt.axis('off')
    plt.show()

# Show sample
if img_paths:
    show_image(load_image(img_paths[0]), "Sample Image")


In [None]:
# 4) Preprocessing functions
def rescale_uint8(img):
    imin, imax = img.min(), img.max()
    if imax==imin: return np.zeros_like(img, dtype=np.uint8)
    out = ((img - imin)/(imax-imin)*255).astype(np.uint8)
    return out


In [None]:
# 5) Build U-Net
ENCODER = "resnet34"
ENCODER_WEIGHTS = "imagenet"
NUM_CLASSES = 1

def build_unet():
    model = smp.Unet(
        encoder_name=ENCODER,
        encoder_weights=ENCODER_WEIGHTS,
        in_channels=3,
        classes=NUM_CLASSES,
        activation=None
    )
    return model

unet = build_unet().to(DEVICE)
unet.eval()
preprocess_input = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
print("U-Net ready")


In [None]:
# U-Net prediction
def unet_predict_image(image_cv, model, thresh=0.5):
    x = preprocess_input(image_cv.astype('float32'))
    x = np.transpose(x,(2,0,1))[None,...]
    x = torch.from_numpy(x).float().to(DEVICE)
    with torch.no_grad():
        logits = model(x)
        prob = torch.sigmoid(logits).cpu().numpy()[0,0]
    mask = prob > thresh
    return mask, prob


In [None]:
# 6) Load StarDist
try:
    stardist_model = StarDist2D.from_pretrained("2D_versatile_fluorescent")
    print("StarDist loaded")
except Exception:
    stardist_model = StarDist2D(None, name="stardist_demo", basedir=str(MODELS_DIR), n_rays=32, grid=(1,1))
    print("StarDist demo model created (untrained)")


In [None]:
# StarDist prediction
def stardist_predict_image(image_cv, model):
    gray = cv2.cvtColor(image_cv, cv2.COLOR_RGB2GRAY)
    img = normalize(gray.astype(np.float32), 1, 99.8)
    labels, _ = model.predict_instances(img)
    return labels


In [None]:
# 7) Overlay helper
def overlay_mask_on_image(image_cv, mask, color=(255,0,0), alpha=0.4):
    overlay = image_cv.copy()
    color_arr = np.array(color,dtype=np.uint8).reshape(1,1,3)
    overlay[mask>0] = (overlay[mask>0].astype(int)*(1-alpha) + color_arr*alpha).astype(np.uint8)
    return overlay


In [None]:
# 8) Run inference on all images
results_summary = []

for p in img_paths:
    im = load_image(p)
    im = rescale_uint8(im)
    
    # U-Net
    u_mask, _ = unet_predict_image(im, unet)
    u_overlay = overlay_mask_on_image(im, u_mask, color=(255,0,0))
    
    # StarDist
    sd_labels = stardist_predict_image(im, stardist_model)
    sd_overlay = overlay_mask_on_image(im, sd_labels>0, color=(0,255,0))
    
    # Save overlays
    base_name = p.stem
    Image.fromarray(u_overlay).save(IMGS_OUT / f"{base_name}_unet_overlay.png")
    Image.fromarray(sd_overlay).save(IMGS_OUT / f"{base_name}_stardist_overlay.png")
    
    results_summary.append({
        "image": p.name,
        "unet_pixels": int(u_mask.sum()),
        "stardist_instances": int(sd_labels.max())
    })

pd.DataFrame(results_summary).to_csv(METRICS_DIR / "inference_summary.csv", index=False)
print("Inference complete. Results saved to outputs folder.")


In [None]:
# 9) Show example results
for i in range(min(3,len(img_paths))):
    im = load_image(img_paths[i])
    u_overlay = np.array(Image.open(IMGS_OUT / f"{img_paths[i].stem}_unet_overlay.png"))
    sd_overlay = np.array(Image.open(IMGS_OUT / f"{img_paths[i].stem}_stardist_overlay.png"))
    
    plt.figure(figsize=(12,4))
    plt.subplot(1,3,1)
    plt.imshow(im); plt.title("Original"); plt.axis('off')
    plt.subplot(1,3,2)
    plt.imshow(u_overlay); plt.title("U-Net Overlay"); plt.axis('off')
    plt.subplot(1,3,3)
    plt.imshow(sd_overlay); plt.title("StarDist Overlay"); plt.axis('off')
    plt.show()


In [None]:
# 10) Robustness tests (optional)
def add_noise(img, sigma=20):
    noise = np.random.normal(0,sigma,img.shape).astype(np.float32)
    return np.clip(img.astype(np.float32)+noise,0,255).astype(np.uint8)

def add_blur(img, ksize=5):
    return cv2.GaussianBlur(img,(ksize,ksize),0)

def change_brightness(img,factor=0.6):
    return np.clip(img.astype(np.float32)*factor,0,255).astype(np.uint8)

robust_summary = []

for p in img_paths[:3]:  # only first 3 for speed
    im = load_image(p)
    im = rescale_uint8(im)
    variants = {"noise":add_noise(im),"blur":add_blur(im),"dark":change_brightness(im,0.6)}
    
    for vname,vimg in variants.items():
        u_mask,_ = unet_predict_image(vimg,unet)
        sd_labels = stardist_predict_image(vimg,stardist_model)
        robust_summary.append({
            "image":p.name,
            "variant":vname,
            "unet_pixels":int(u_mask.sum()),
            "stardist_instances":int(sd_labels.max())
        })

pd.DataFrame(robust_summary).to_csv(METRICS_DIR / "robustness_summary.csv", index=False)
print("Robustness summary saved.")


# 11) Optional Questions

**1. Difficulties and solutions:**  
- Installing stardist & tensorflow versions: solved by matching compatible versions.  
- Image size mismatch for U-Net: solved by resizing before prediction.  

**2. Best CNN architectures for microscopy:**  
- **U-Net**: Excellent for semantic segmentation due to skip connections and localization.  
- **StarDist**: Excellent for instance segmentation of star-shaped cells; handles overlapping cells.  

**3. Own question:**  
- *How do these networks behave on noisy images?*  
- **Answer:** U-Net is robust to small noise but may merge/split regions; StarDist fails to detect faint or blurred objects.
