In [16]:
import os
import numpy as np
import torch
from torch import nn
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_geometric.loader import DataLoader

import matplotlib.pyplot as plt
from skimage.io import imread
from skimage.transform import resize
from skimage.measure import label, regionprops
from scipy.ndimage import median_filter
from skimage.filters import gaussian
from skimage import exposure
from scipy.stats import pearsonr

from scipy.optimize import linear_sum_assignment

from pathlib import Path
desktop_path = Path.home() / "Desktop"


from torch.utils.data import Dataset, DataLoader

import torch
from torch import nn
from torch_geometric.nn import GCNConv

from config import CONFIG
import random

from tifffile import imread
from skimage.transform import resize
from scipy.ndimage import median_filter
from skimage.filters import gaussian
from skimage import exposure
from skimage.measure import label, regionprops

In [18]:
# Dataset class and loader
class GraphSequenceDataset(Dataset):
    def __init__(self, sequences):
        self.sequences = sequences
    def __len__(self):
        return len(self.sequences)
    def __getitem__(self, idx):
        return self.sequences[idx]





def find_closest_node(props, manual_xy):
    # props: regionprops for the frame
    # manual_xy: (x, y) in resized frame coordinates
    dists = [np.linalg.norm(np.array(p.centroid[::-1]) - np.array(manual_xy)) for p in props]
    return np.argmin(dists)

# denoise helper
def denoise_gaussian_median(frame, median_size=3, gaussian_sigma=1.0):
    medianed = median_filter(frame, size=median_size)
    smoothed = gaussian(medianed, sigma=gaussian_sigma)
    return smoothed

# loading function
lost_nodes = {}  # node_id: {'coords': (x, y), 'last_seen': frame_idx}

def build_graph_from_frame(
    frame,
    intensity_thresh=CONFIG["intensity_thresh"],  # Threshold for intensity to consider as signal, custom to Cellnet dataset
    max_nodes=CONFIG["max_nodes"],
    source_idx=None,
    prev_coords=None,
    prev_ids=None,
    next_id_start=0,
    lost_nodes=None,
    frame_idx=None,
    lost_ttl=CONFIG["lost_ttl"]  # how many frames to keep lost nodes
):
    # Threshold to create binary mask
    binary_mask = frame > intensity_thresh

    # Label connected regions (potential signal sources)
    labeled = label(binary_mask)

    # Extract region properties
    props = regionprops(labeled, intensity_image=frame)

    node_features = []
    coords = []
    intensities = []

    for i, p in enumerate(props[:max_nodes]):
        y, x = p.centroid
        intensity = p.mean_intensity

        # Extract additional biological/morphological properties
        area = p.area
        eccentricity = p.eccentricity
        solidity = p.solidity
        perimeter = p.perimeter

        # Build node feature vector
        node_feat = [
            (x / frame.shape[1]) * 2 - 1,
            (y / frame.shape[0]) * 2 - 1,
            intensity,
            area / (frame.shape[0] * frame.shape[1]),
            eccentricity,
            solidity,
            perimeter / (frame.shape[0] + frame.shape[1])
        ]
        # Add source indicator
        if source_idx is not None and i == source_idx:
            node_feat.append(1.0)
        else:
            node_feat.append(0.0)
        node_features.append(node_feat)
        coords.append((x, y))
        intensities.append(intensity)

    # --- Node ID assignment (after collecting all nodes) ---
   
    # Node tracking threshold
    node_tracking_dist_thresh = CONFIG["node_tracking_dist_thresh"]
    spatial_edge_thresh = CONFIG["spatial_edge_thresh"]
   
    ids = []
    used_lost_ids = set()
    if prev_coords is not None and prev_ids is not None and len(prev_coords) > 0 and len(coords) > 0:
        cost_matrix = np.zeros((len(coords), len(prev_coords)))
        for i, (x, y) in enumerate(coords):
            for j, (px, py) in enumerate(prev_coords):
                cost_matrix[i, j] = np.linalg.norm(np.array([x, y]) - np.array([px, py]))
        row_ind, col_ind = linear_sum_assignment(cost_matrix)
        assigned_prev = set()
        for i in range(len(coords)):
            if i in row_ind:
                j = col_ind[np.where(row_ind == i)[0][0]]
                if cost_matrix[i, j] < node_tracking_dist_thresh:  # 8 is a heuristic for "maximum allowed movement" for tracking
                    ids.append(prev_ids[j])
                    assigned_prev.add(j)
                else:
                    ids.append(next_id_start)
                    next_id_start += 1
            else:
                ids.append(next_id_start)
                next_id_start += 1

        # --- LOST NODE MANAGEMENT ---
        if lost_nodes is not None and frame_idx is not None:
            for j, prev_id in enumerate(prev_ids):
                if j not in assigned_prev:
                    lost_nodes[prev_id] = {'coords': prev_coords[j], 'last_seen': frame_idx - 1}

            # Remove lost nodes that have been lost for too long
            to_remove = [nid for nid, info in lost_nodes.items() if frame_idx - info['last_seen'] > lost_ttl]
            for nid in to_remove:
                del lost_nodes[nid]

            # Try to match new nodes to lost nodes
            for i, (x, y) in enumerate(coords):
                if ids[i] >= next_id_start - max_nodes:  # only for new nodes
                    for lost_id, info in lost_nodes.items():
                        dist = np.linalg.norm(np.array([x, y]) - np.array(info['coords']))
                        if dist < node_tracking_dist_thresh and lost_id not in used_lost_ids:
                            ids[i] = lost_id
                            used_lost_ids.add(lost_id)
                            lost_nodes[lost_id]['last_seen'] = frame_idx
                            break

    else:
        # First frame or no previous nodes
        ids = [next_id_start + k for k in range(len(coords))]
        next_id_start += len(coords)

    edge_index = []
    edge_attr = []

    # Connect nodes within spatial threshold
    for i, (xi, yi) in enumerate(coords):
        for j, (xj, yj) in enumerate(coords):
            if i != j and (xi - xj)**2 + (yi - yj)**2 < spatial_edge_thresh:  # threshold in pixels (try 16²=256 for 16 pixels, or 20²=400 for 20 pixels)
                edge_index.append([i, j])

                # --- Edge Features ---
                dist = np.sqrt((xi - xj)**2 + (yi - yj)**2) / np.sqrt(frame.shape[0]**2 + frame.shape[1]**2)
                delta_intensity = abs(intensities[i] - intensities[j])
                angle = np.arctan2(yj - yi, xj - xi) / np.pi

                edge_attr.append([dist, delta_intensity, angle])

    # Skip empty graphs
    if not node_features:
        return None, None, None, next_id_start

    # Convert to PyTorch Geometric Data object
    x = torch.tensor(node_features, dtype=torch.float)
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attr, dtype=torch.float)

    if edge_attr.ndim == 1:
        edge_attr = edge_attr.unsqueeze(1)

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    data.node_ids = torch.tensor(ids, dtype=torch.long)
    data.y = torch.tensor(intensities, dtype=torch.float).unsqueeze(-1)

    return data, coords, ids, next_id_start

def load_tif_as_graph_sequences(path, window=5, frame_skip=1, denoise_sigma=1.0, manual_source=None):
    """
    Loads a video as a list of sequences of graphs.
    Each sequence is a list of graphs of length 'window'.
    Tracks node identities across frames for temporal GNNs.
    """
    stack = imread(path)
    stack_denoised = np.array([denoise_gaussian_median(f, gaussian_sigma=denoise_sigma) for f in stack])
    stack_resized = np.array([resize(f, CONFIG["resize_shape"]) for f in stack_denoised])
    norm_stack = (stack_resized - stack_resized.min()) / (stack_resized.max() - stack_resized.min())
    norm_stack = norm_stack * 2 - 1

    sequences = []
    for i in range(0, len(norm_stack) - window, frame_skip):  # -window instead of -window+1
        graph_seq = []
        prev_coords, prev_ids, next_id_start = None, None, 0
        lost_nodes = {}
        frame_indices = []
        for j in range(window + 1):  # build window+1 graphs to allow shifting
            idx = i + j
            temp_props = regionprops(label(norm_stack[idx] > CONFIG["intensity_thresh"]), intensity_image=norm_stack[idx])
            if not temp_props:
                print(f"Frame {idx} skipped (no active nodes)")
                break
            if manual_source and idx == manual_source['frame']:
                source_idx = find_closest_node(temp_props, manual_source['xy'])
            else:
                source_idx = None
            g, coords, ids, next_id_start = build_graph_from_frame(
                norm_stack[idx],
                source_idx=source_idx,
                prev_coords=prev_coords,
                prev_ids=prev_ids,
                next_id_start=next_id_start,
                lost_nodes=lost_nodes,
                frame_idx=idx,
                lost_ttl=CONFIG["lost_ttl"]
            )
            if g is None:
                break
            g.frame_idx = idx
            g.source_file = path
            prev_coords, prev_ids = coords, ids
            graph_seq.append(g)
            frame_indices.append(idx)
        if len(graph_seq) == window + 1:
            # Shift targets: for each t, set g.y = next_g.y for matching node_ids
            for t in range(window):
                curr_g = graph_seq[t]
                next_g = graph_seq[t + 1]
                # Map node ids to indices
                curr_ids = curr_g.node_ids.cpu().numpy()
                next_ids = next_g.node_ids.cpu().numpy()
                next_id_to_idx = {nid: i for i, nid in enumerate(next_ids)}
                # For each node in curr_g, if it exists in next_g, set target to next_g.y, else set to nan
                new_y = []
                for nid in curr_ids:
                    if nid in next_id_to_idx:
                        new_y.append(next_g.y[next_id_to_idx[nid]])
                    else:
                        new_y.append(torch.tensor([float('nan')], dtype=torch.float))
                curr_g.y = torch.stack(new_y)
            # Only keep the first window graphs (targets are next-frame)
            sequences.append(graph_seq[:window])
    return sequences
file_paths = [
    'C:/Users/Platypus/Documents/CellNet/Real_Time_CS_Experiment-1093.tif',
    'C:/Users/Platypus/Documents/CellNet/Flow prior to chemical stimulation_Figure6C.tif', #✔ Clicked on Frame 77: (1.16, 2.96); ✔ Clicked on Frame 77: (15.57, 10.86)
    'C:/Users/Platypus/Documents/CellNet/Figure8.tif',
    'C:/Users/Platypus/Documents/CellNet/5uM_per_litre_Figure6_ChemicalStimulation.tif'
]
test_video_path = 'C:/Users/Platypus/Documents/CellNet/Cell Knocked_Figure7.tif'



#sources where signal starts
manual_sources = {
    'Real_Time_CS_Experiment-1093.tif': {'frame': 5, 'xy': (3.79, 3.66)}, #altenatively "Clicked on Frame 5: (24.44, 10.58)"
    'Flow prior to chemical stimulation_Figure6C.tif': { 'frame': 77, 'xy': (1.16, 2.96)}, #alt ✔ Clicked on Frame 77: (15.57, 10.86)
    'Figure8.tif': {'frame': 68, 'xy': (43.83, 0.89)}, 
    '5uM_per_litre_Figure6_ChemicalStimulation.tif':{'frame': 37, 'xy': (10.03, 5.46)},# alt Frame 76: (2.82, 3.38);  
    'Cell Knocked_Figure7.tif': {'frame': 0, 'xy': (54.50, 9.34)} #Clicked on Frame 0: (54.50, 9.34);alt ✔ Clicked on Frame 0: (53.66, 3.52)
    }
test_sequences = load_tif_as_graph_sequences(test_video_path, window=5, manual_source=manual_sources.get(os.path.basename(test_video_path)))
test_dataset = GraphSequenceDataset(test_sequences)
test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=lambda x: x[0])


class GraphSequenceDataset(Dataset):
    def __init__(self, sequences):
        """
        sequences: list of list of PyG Data objects (each inner list is a sequence of graphs)
        """
        self.sequences = sequences

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        # Returns a list of Data objects (the sequence)
        return self.sequences[idx]

def graph_sequence_collate(batch):
    # batch: list of sequences (each is a list of Data objects)
    # Output: list of length window, each element is a list of Data objects for that time step
    # For batch_size=1, this just returns [sequence]
    # For batch_size>1, this stacks sequences by time step
    batch_size = len(batch)
    seq_len = len(batch[0])
    collated = []
    for t in range(seq_len):
        collated.append([batch[b][t] for b in range(batch_size)])
    return collated



In [19]:
def naive_baseline(graph_seq):
    # Predict each node’s intensity as it was in the previous frame
    pred = []
    target = []
    for t in range(1, len(graph_seq)):
        prev = graph_seq[t-1]
        curr = graph_seq[t]

        # Match node_ids between frames
        id_to_idx_prev = {nid.item(): i for i, nid in enumerate(prev.node_ids)}
        id_to_idx_curr = {nid.item(): i for i, nid in enumerate(curr.node_ids)}

        common_ids = set(id_to_idx_prev.keys()).intersection(id_to_idx_curr.keys())
        if not common_ids:
            continue

        pred_t = []
        target_t = []
        for nid in common_ids:
            pred_t.append(prev.y[id_to_idx_prev[nid]].item())
            target_t.append(curr.y[id_to_idx_curr[nid]].item())

        pred.append(pred_t)
        target.append(target_t)

    pred_flat = np.concatenate(pred)
    target_flat = np.concatenate(target)
    loss = np.mean(np.abs(pred_flat - target_flat))
    corr, _ = pearsonr(pred_flat, target_flat)
    return loss, corr


In [20]:
class LSTMIntensityModel(nn.Module):
    def __init__(self, input_dim=1, hidden_dim=64, num_layers=1):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.out = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        # x: [batch, seq_len, 1]
        out, _ = self.lstm(x)
        return self.out  # [batch, seq_len, 1]


In [21]:
def prepare_lstm_data(graph_seq):
    # Align nodes by ID across frames
    node_id_lists = [g.node_ids.cpu().numpy() for g in graph_seq]
    common_ids = set(node_id_lists[0])
    for ids in node_id_lists[1:]:
        common_ids &= set(ids)
    common_ids = sorted(list(common_ids))
    if not common_ids:
        return None, None

    seq_data = []
    for g in graph_seq:
        id_to_idx = {nid.item(): i for i, nid in enumerate(g.node_ids)}
        intensities = [g.y[id_to_idx[nid]].item() for nid in common_ids]
        seq_data.append(intensities)

    seq_data = torch.tensor(seq_data).unsqueeze(-1).float()  # [seq_len, num_nodes, 1]
    seq_data = seq_data.permute(1, 0, 2)  # [num_nodes, seq_len, 1]
    return seq_data[:, :-1, :], seq_data[:, 1:, :]  # X, Y


In [22]:
def train_lstm_model(graph_seqs, epochs=20):
    model = LSTMIntensityModel()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_fn = nn.L1Loss()

    for epoch in range(epochs):
        losses = []
        for graph_seq in graph_seqs:
            model.train()
            out = prepare_lstm_data(graph_seq)
            if out is None:
                continue
            x, y = out
            pred = model(x)
            loss = loss_fn(pred, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
        print(f"[LSTM] Epoch {epoch+1}: MAE = {np.mean(losses):.4f}")
    return model


In [23]:
def evaluate_lstm_model(model, graph_seqs):
    model.eval()
    preds, trues = [], []
    with torch.no_grad():
        for graph_seq in graph_seqs:
            out = prepare_lstm_data(graph_seq)
            if out is None:
                continue
            x, y = out
            pred = model(x)
            preds.append(pred.flatten().numpy())
            trues.append(y.flatten().numpy())
    pred_flat = np.concatenate(preds)
    true_flat = np.concatenate(trues)
    loss = np.mean(np.abs(pred_flat - true_flat))
    corr, _ = pearsonr(pred_flat, true_flat)
    return loss, corr


In [27]:
# Naive baseline



all_sequences = []
for path in file_paths:
    if os.path.exists(path):
        video_name = os.path.basename(path)
        manual_source = manual_sources.get(video_name)
        seqs = load_tif_as_graph_sequences(path, window=5, manual_source=manual_source)
        all_sequences.extend(seqs)
print(f"Total loaded sequences: {len(all_sequences)}")

train_split = int(0.7 * len(all_sequences))
val_split = int(0.15 * len(all_sequences))

train_sequences = all_sequences[:train_split]
val_sequences = all_sequences[train_split:train_split + val_split]
test_sequences = all_sequences[train_split + val_split:]

train_dataset = GraphSequenceDataset(train_sequences)
test_dataset = GraphSequenceDataset(test_sequences)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=lambda x: x[0])
test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=lambda x: x[0])

loss_naive, corr_naive = naive_baseline(test_loader.dataset[0])
print(f"Naive Baseline - MAE: {loss_naive:.4f} | Pearson: {corr_naive:.4f}")

# LSTM baseline
lstm_model = train_lstm_model(train_loader.dataset[:20])  # train on small subset
loss_lstm, corr_lstm = evaluate_lstm_model(lstm_model, test_loader.dataset[:20])
print(f"LSTM Baseline - MAE: {loss_lstm:.4f} | Pearson: {corr_lstm:.4f}")


Total loaded sequences: 4105
Naive Baseline - MAE: nan | Pearson: nan


AttributeError: 'Linear' object has no attribute 'size'