In [1]:
#Cell 1: Imports

%matplotlib notebook

import os
import numpy as np
import torch
from torch import nn
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_geometric.loader import DataLoader

import matplotlib.pyplot as plt
from skimage.io import imread
from skimage.transform import resize
from skimage.measure import label, regionprops
from scipy.ndimage import median_filter
from skimage.filters import gaussian
from skimage import exposure
from scipy.stats import pearsonr

from scipy.optimize import linear_sum_assignment

from pathlib import Path
desktop_path = Path.home() / "Desktop"


from torch.utils.data import Dataset, DataLoader

import torch
from torch import nn
from torch_geometric.nn import GCNConv

from config import CONFIG
import random


seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [2]:
import cv2
import numpy as np

# Initialize background subtractor globally (once)
backSub = cv2.createBackgroundSubtractorKNN(history=500, dist2Threshold=400.0, detectShadows=True)

def apply_background_subtraction(frame):
    # Convert to uint8 if needed
    if frame.dtype != np.uint8:
        frame = (255 * (frame - frame.min()) / (frame.max() - frame.min())).astype(np.uint8)
    fg_mask = backSub.apply(frame)
    # Optionally, threshold to get binary mask
    _, fg_mask = cv2.threshold(fg_mask, 30, 255, cv2.THRESH_BINARY)
    return fg_mask

In [3]:
lost_nodes = {}  # node_id: {'coords': (x, y), 'last_seen': frame_idx}

def build_graph_from_frame(
    frame,
    intensity_thresh=CONFIG["intensity_thresh"],  # Threshold for intensity to consider as signal, custom to Cellnet dataset
    max_nodes=CONFIG["max_nodes"],
    source_idx=None,
    prev_coords=None,
    prev_ids=None,
    next_id_start=0,
    lost_nodes=None,
    frame_idx=None,
    lost_ttl=CONFIG["lost_ttl"]  # how many frames to keep lost nodes
):
    # Threshold to create binary mask
    binary_mask = frame > intensity_thresh

    # Label connected regions (potential signal sources)
    labeled = label(binary_mask)

    # Extract region properties
    props = regionprops(labeled, intensity_image=frame)

    node_features = []
    coords = []
    intensities = []

    for i, p in enumerate(props[:max_nodes]):
        y, x = p.centroid
        intensity = p.mean_intensity
        area = p.area
        eccentricity = p.eccentricity
        solidity = p.solidity
        perimeter = p.perimeter
        bbox = p.bbox  # (min_row, min_col, max_row, max_col)
        bbox_width = bbox[3] - bbox[1]
        bbox_height = bbox[2] - bbox[0]
        aspect_ratio = bbox_width / bbox_height if bbox_height > 0 else 0
        convex_area = p.convex_area
        major_axis_length = p.major_axis_length
        minor_axis_length = p.minor_axis_length

        # Build node feature vector
        node_feat = [
        (x / frame.shape[1]) * 2 - 1,
        (y / frame.shape[0]) * 2 - 1,
        intensity,
        area / (frame.shape[0] * frame.shape[1]),
        eccentricity,
        solidity,
        perimeter / (frame.shape[0] + frame.shape[1]),
        bbox_width / frame.shape[1],
        bbox_height / frame.shape[0],
        aspect_ratio,
        convex_area / (frame.shape[0] * frame.shape[1]),
        major_axis_length / frame.shape[1],
        minor_axis_length / frame.shape[0],
       # source indicator (0 or 1),        
        ]
        # Add source indicator
        if source_idx is not None and i == source_idx:
            node_feat.append(1.0)
        else:
            node_feat.append(0.0)
        node_features.append(node_feat)
        coords.append((x, y))
        intensities.append(intensity)

    # --- Node ID assignment (after collecting all nodes) ---
   
    # Node tracking threshold
    node_tracking_dist_thresh = CONFIG["node_tracking_dist_thresh"]
    spatial_edge_thresh = CONFIG["spatial_edge_thresh"]
   
    ids = []
    used_lost_ids = set()
    if prev_coords is not None and prev_ids is not None and len(prev_coords) > 0 and len(coords) > 0:
        cost_matrix = np.zeros((len(coords), len(prev_coords)))
        for i, (x, y) in enumerate(coords):
            for j, (px, py) in enumerate(prev_coords):
                cost_matrix[i, j] = np.linalg.norm(np.array([x, y]) - np.array([px, py]))
        row_ind, col_ind = linear_sum_assignment(cost_matrix)
        assigned_prev = set()
        for i in range(len(coords)):
            if i in row_ind:
                j = col_ind[np.where(row_ind == i)[0][0]]
                if cost_matrix[i, j] < node_tracking_dist_thresh:  # 8 is a heuristic for "maximum allowed movement" for tracking
                    ids.append(prev_ids[j])
                    assigned_prev.add(j)
                else:
                    ids.append(next_id_start)
                    next_id_start += 1
            else:
                ids.append(next_id_start)
                next_id_start += 1

        # --- LOST NODE MANAGEMENT ---
        if lost_nodes is not None and frame_idx is not None:
            for j, prev_id in enumerate(prev_ids):
                if j not in assigned_prev:
                    lost_nodes[prev_id] = {'coords': prev_coords[j], 'last_seen': frame_idx - 1}

            # Remove lost nodes that have been lost for too long
            to_remove = [nid for nid, info in lost_nodes.items() if frame_idx - info['last_seen'] > lost_ttl]
            for nid in to_remove:
                del lost_nodes[nid]

            # Try to match new nodes to lost nodes
            for i, (x, y) in enumerate(coords):
                if ids[i] >= next_id_start - max_nodes:  # only for new nodes
                    for lost_id, info in lost_nodes.items():
                        dist = np.linalg.norm(np.array([x, y]) - np.array(info['coords']))
                        if dist < node_tracking_dist_thresh and lost_id not in used_lost_ids:
                            ids[i] = lost_id
                            used_lost_ids.add(lost_id)
                            lost_nodes[lost_id]['last_seen'] = frame_idx
                            break

    else:
        # First frame or no previous nodes
        ids = [next_id_start + k for k in range(len(coords))]
        next_id_start += len(coords)

    edge_index = []
    edge_attr = []

    # Connect nodes within spatial threshold
    for i, (xi, yi) in enumerate(coords):
        for j, (xj, yj) in enumerate(coords):
            if i != j and (xi - xj)**2 + (yi - yj)**2 < spatial_edge_thresh:  # threshold in pixels (try 16²=256 for 16 pixels, or 20²=400 for 20 pixels)
                edge_index.append([i, j])

                # --- Edge Features ---
                dist = np.sqrt((xi - xj)**2 + (yi - yj)**2) / np.sqrt(frame.shape[0]**2 + frame.shape[1]**2)
                delta_intensity = abs(intensities[i] - intensities[j])
                angle = np.arctan2(yj - yi, xj - xi) / np.pi

                edge_attr.append([dist, delta_intensity, angle])

    # Skip empty graphs
    if not node_features:
        return None, None, None, next_id_start

    # Convert to PyTorch Geometric Data object
    x = torch.tensor(node_features, dtype=torch.float)
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attr, dtype=torch.float)

    if edge_attr.ndim == 1:
        edge_attr = edge_attr.unsqueeze(1)

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    data.node_ids = torch.tensor(ids, dtype=torch.long)
    data.y = torch.tensor(intensities, dtype=torch.float).unsqueeze(-1)

    return data, coords, ids, next_id_start

In [4]:
#Cell 3: Denoising Functions
def denoise_gaussian_median(frame, median_size=3, gaussian_sigma=1.0):
    medianed = median_filter(frame, size=median_size)
    smoothed = gaussian(medianed, sigma=gaussian_sigma)
    return smoothed

def apply_clahe(frame, clip_limit=0.01):
    return exposure.equalize_adapthist(frame, clip_limit=clip_limit)  

In [5]:
#Cell 4: Load Graph Sequence


def find_closest_node(props, manual_xy):
    # props: regionprops for the frame
    # manual_xy: (x, y) in resized frame coordinates
    dists = [np.linalg.norm(np.array(p.centroid[::-1]) - np.array(manual_xy)) for p in props]
    return np.argmin(dists)


def load_tif_as_graph_sequences(path, window=5, frame_skip=1, denoise_sigma=1.0, manual_source=None):
    """
    Loads a video as a list of sequences of graphs.
    Each sequence is a list of graphs of length 'window'.
    Tracks node identities across frames for temporal GNNs.
    """
    stack = imread(path)
    stack_denoised = np.array([denoise_gaussian_median(f, gaussian_sigma=denoise_sigma) for f in stack])
    stack_resized = np.array([resize(f, CONFIG["resize_shape"]) for f in stack_denoised])
    norm_stack = (stack_resized - stack_resized.min()) / (stack_resized.max() - stack_resized.min())
    norm_stack = norm_stack * 2 - 1

    sequences = []
    for i in range(0, len(norm_stack) - window, frame_skip):  # -window instead of -window+1
        graph_seq = []
        prev_coords, prev_ids, next_id_start = None, None, 0
        lost_nodes = {}
        frame_indices = []
        for j in range(window + 1):  # build window+1 graphs to allow shifting
            idx = i + j
            frame = norm_stack[idx]
            fg_mask = apply_background_subtraction(frame)
            temp_props = regionprops(label(fg_mask > 0), intensity_image=frame)
            if not temp_props:
                print(f"Frame {idx} skipped (no active nodes)")
                break
            if manual_source and idx == manual_source['frame']:
                source_idx = find_closest_node(temp_props, manual_source['xy'])
            else:
                source_idx = None
            g, coords, ids, next_id_start = build_graph_from_frame(
                norm_stack[idx],
                source_idx=source_idx,
                prev_coords=prev_coords,
                prev_ids=prev_ids,
                next_id_start=next_id_start,
                lost_nodes=lost_nodes,
                frame_idx=idx,
                lost_ttl=CONFIG["lost_ttl"]
            )
            if g is None:
                break
            g.frame_idx = idx
            g.source_file = path
            prev_coords, prev_ids = coords, ids
            graph_seq.append(g)
            frame_indices.append(idx)
        if len(graph_seq) == window + 1:
            # Shift targets: for each t, set g.y = next_g.y for matching node_ids
            for t in range(window):
                curr_g = graph_seq[t]
                next_g = graph_seq[t + 1]
                # Map node ids to indices
                curr_ids = curr_g.node_ids.cpu().numpy()
                next_ids = next_g.node_ids.cpu().numpy()
                next_id_to_idx = {nid: i for i, nid in enumerate(next_ids)}
                # For each node in curr_g, if it exists in next_g, set target to next_g.y, else set to nan
                new_y = []
                for nid in curr_ids:
                    if nid in next_id_to_idx:
                        new_y.append(next_g.y[next_id_to_idx[nid]])
                    else:
                        new_y.append(torch.tensor([float('nan')], dtype=torch.float))
                curr_g.y = torch.stack(new_y)
            # Only keep the first window graphs (targets are next-frame)
            sequences.append(graph_seq[:window])
    return sequences

In [6]:
# Cell 5: GCRN Definition


class GCRNCell(nn.Module):
    def __init__(self, node_feat_dim, hidden_dim=128, num_layers=2):
        super().__init__()
        self.num_layers = num_layers
        self.gcns = nn.ModuleList([GCNConv(node_feat_dim if i == 0 else hidden_dim, hidden_dim) for i in range(num_layers)])
        self.grus = nn.ModuleList([nn.GRUCell(hidden_dim, hidden_dim) for _ in range(num_layers)])


    def forward(self, x, edge_index, h):
        for i in range(self.num_layers):
            x_gcn = self.gcns[i](x, edge_index)
            h = self.grus[i](x_gcn, h)
            x = h  # Feed hidden state to next layer
        return h

class GCRN(nn.Module):
    def __init__(self, node_feat_dim, hidden_dim, out_dim=1, num_layers=2):
        super().__init__()
        self.cell = GCRNCell(node_feat_dim, hidden_dim, num_layers=num_layers)
        self.out_proj = nn.Linear(hidden_dim, out_dim)

    def forward(self, graph_seq):
        h_dict = {}
        outputs = []
        for g in graph_seq:
            node_ids = g.node_ids.cpu().numpy()
            num_nodes = g.x.shape[0]
            device = g.x.device
            h = torch.zeros(num_nodes, self.cell.gcns[0].out_channels, device=device)
            for i, nid in enumerate(node_ids):
                if nid in h_dict:
                    h[i] = h_dict[nid].to(device)
            h_new = self.cell(g.x, g.edge_index, h)
            for i, nid in enumerate(node_ids):
                h_dict[nid] = h_new[i].detach().cpu()
            outputs.append(self.out_proj(h_new))
        return outputs

# node_feat_dim = 14  # Number of node features (from build_graph_from_frame)
# hidden_dim = 128   # Hidden dimension for GCRN
# out_dim = 1        # Output dimension (e.g., regression target per node)

# #When creating the model, specify num_layers (e.g., num_layers=2 or 3)
# model = GCRN(node_feat_dim=node_feat_dim, hidden_dim=hidden_dim, out_dim=out_dim, num_layers=2).to(device)

In [7]:
def align_by_node_id(preds, graph_seq):
    # preds: list of [num_nodes, out_dim] tensors
    # graph_seq: list of Data objects
    node_id_lists = [g.node_ids.cpu().numpy() for g in graph_seq]
    common_ids = set(node_id_lists[0])
    for ids in node_id_lists[1:]:
        common_ids &= set(ids)
    if not common_ids:
        return None, None  # No common nodes to compare
    common_ids = sorted(list(common_ids))
    pred_aligned = []
    target_aligned = []
    for t, g in enumerate(graph_seq):
        id_to_idx = {nid: i for i, nid in enumerate(g.node_ids.cpu().numpy())}
        idxs = [id_to_idx[nid] for nid in common_ids]
        pred_aligned.append(preds[t][idxs])
        target_aligned.append(g.y[idxs])
    pred_tensor = torch.stack(pred_aligned, dim=0).squeeze(-1)
    target_tensor = torch.stack(target_aligned, dim=0).to(pred_tensor.device).squeeze(-1)
    return pred_tensor, target_tensor

def train_gnn(model, loader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0
    for graph_seq in loader:
        graph_seq = [g.to(device) for g in graph_seq]
        optimizer.zero_grad()
        preds = model(graph_seq)
        pred_tensor, target_tensor = align_by_node_id(preds, graph_seq)
        if pred_tensor is None:
            continue  # skip if no common nodes
        mask = ~torch.isnan(target_tensor)
        if mask.sum() == 0:
            continue
        # Apply mask to both pred_tensor and target_tensor
        loss = loss_fn(pred_tensor[mask], target_tensor[mask])
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / max(1, len(loader))

def evaluate_gnn(model, loader, loss_fn, device, return_preds=False):
    model.eval()
    total_loss = 0
    all_preds = []
    all_targets = []
    with torch.no_grad():
        for graph_seq in loader:
            graph_seq = [g.to(device) for g in graph_seq]
            preds = model(graph_seq)
            pred_tensor, target_tensor = align_by_node_id(preds, graph_seq)
            if pred_tensor is None:
                continue
            mask = ~torch.isnan(target_tensor)
            if mask.sum() == 0:
                continue
            # Apply mask to both pred_tensor and target_tensor
            loss = loss_fn(pred_tensor[mask], target_tensor[mask])
            total_loss += loss.item()
            if return_preds:
                all_preds.append(pred_tensor[mask].cpu().numpy())
                all_targets.append(target_tensor[mask].cpu().numpy())
    if return_preds:
        return total_loss / max(1, len(loader)), all_preds, all_targets
    return total_loss / max(1, len(loader))

In [8]:
# Cell 7: Dataset Preparation



class GraphSequenceDataset(Dataset):
    def __init__(self, sequences):
        """
        sequences: list of list of PyG Data objects (each inner list is a sequence of graphs)
        """
        self.sequences = sequences

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        # Returns a list of Data objects (the sequence)
        return self.sequences[idx]

def graph_sequence_collate(batch):
    # batch: list of sequences (each is a list of Data objects)
    # Output: list of length window, each element is a list of Data objects for that time step
    # For batch_size=1, this just returns [sequence]
    # For batch_size>1, this stacks sequences by time step
    batch_size = len(batch)
    seq_len = len(batch[0])
    collated = []
    for t in range(seq_len):
        collated.append([batch[b][t] for b in range(batch_size)])
    return collated

file_paths = [
    'C:/Users/Platypus/Documents/CellNet/Real_Time_CS_Experiment-1093.tif',
    'C:/Users/Platypus/Documents/CellNet/Flow prior to chemical stimulation_Figure6C.tif', #✔ Clicked on Frame 77: (1.16, 2.96); ✔ Clicked on Frame 77: (15.57, 10.86)
    'C:/Users/Platypus/Documents/CellNet/Figure8.tif',
    'C:/Users/Platypus/Documents/CellNet/5uM_per_litre_Figure6_ChemicalStimulation.tif'
]
test_video_path = 'C:/Users/Platypus/Documents/CellNet/Cell Knocked_Figure7.tif'



#sources where signal starts
manual_sources = {
    'Real_Time_CS_Experiment-1093.tif': {'frame': 5, 'xy': (3.79, 3.66)}, #altenatively "Clicked on Frame 5: (24.44, 10.58)"
    'Flow prior to chemical stimulation_Figure6C.tif': { 'frame': 77, 'xy': (1.16, 2.96)}, #alt ✔ Clicked on Frame 77: (15.57, 10.86)
    'Figure8.tif': {'frame': 68, 'xy': (43.83, 0.89)}, 
    '5uM_per_litre_Figure6_ChemicalStimulation.tif':{'frame': 37, 'xy': (10.03, 5.46)},# alt Frame 76: (2.82, 3.38);  
    'Cell Knocked_Figure7.tif': {'frame': 0, 'xy': (54.50, 9.34)} #Clicked on Frame 0: (54.50, 9.34);alt ✔ Clicked on Frame 0: (53.66, 3.52)
    }


# Load all sequences of graphs from specified files
all_sequences = []
for path in file_paths:
    if os.path.exists(path):
        video_name = os.path.basename(path)
        manual_source = manual_sources.get(video_name)
        seqs = load_tif_as_graph_sequences(path, window=5, manual_source=manual_source)
        all_sequences.extend(seqs)
print(f"Total loaded sequences: {len(all_sequences)}")

# Split into train, validation, and test sets
train_split = int(0.7 * len(all_sequences))
val_split = int(0.15 * len(all_sequences))

train_sequences = all_sequences[:train_split]
val_sequences = all_sequences[train_split:train_split + val_split]
test_sequences = all_sequences[train_split + val_split:]

# Usage:
train_dataset = GraphSequenceDataset(train_sequences)
val_dataset = GraphSequenceDataset(val_sequences)
test_dataset = GraphSequenceDataset(test_sequences)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=lambda x: x[0])
val_loader = DataLoader(val_dataset, batch_size=1, collate_fn=lambda x: x[0])
test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=lambda x: x[0])


Frame 6 skipped (no active nodes)
Frame 6 skipped (no active nodes)
Frame 24 skipped (no active nodes)
Frame 24 skipped (no active nodes)
Frame 24 skipped (no active nodes)
Frame 24 skipped (no active nodes)
Frame 24 skipped (no active nodes)
Frame 24 skipped (no active nodes)
Frame 40 skipped (no active nodes)
Frame 61 skipped (no active nodes)
Frame 61 skipped (no active nodes)
Frame 61 skipped (no active nodes)
Frame 61 skipped (no active nodes)
Frame 61 skipped (no active nodes)
Frame 61 skipped (no active nodes)
Frame 71 skipped (no active nodes)
Frame 71 skipped (no active nodes)
Frame 71 skipped (no active nodes)
Frame 71 skipped (no active nodes)
Frame 71 skipped (no active nodes)
Frame 71 skipped (no active nodes)
Frame 74 skipped (no active nodes)
Frame 74 skipped (no active nodes)
Frame 74 skipped (no active nodes)
Frame 79 skipped (no active nodes)
Frame 79 skipped (no active nodes)
Frame 79 skipped (no active nodes)
Frame 79 skipped (no active nodes)
Frame 79 skipped (no a

In [9]:

stack = imread(file_paths[0])
stack_denoised = np.array([denoise_gaussian_median(f, gaussian_sigma=1.0) for f in stack])
stack_resized = np.array([resize(f, CONFIG["resize_shape"]) for f in stack_denoised])
norm_stack = (stack_resized - stack_resized.min()) / (stack_resized.max() - stack_resized.min())
norm_stack = norm_stack * 2 - 1
idx = 24  # Change this to the frame index you want to visualize

frame = norm_stack[idx]
fg_mask = apply_background_subtraction(frame)
plt.figure()
plt.subplot(1,2,1)
plt.title("Original")
plt.imshow(frame, cmap='gray')
plt.subplot(1,2,2)
plt.title("Foreground Mask")
plt.imshow(fg_mask, cmap='gray')
plt.show()

<IPython.core.display.Javascript object>

In [10]:
# Cell 8: Model Setup and Training

# Set up device, model, optimizer, and loss function
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Adjust these dimensions as needed for your data:
node_feat_dim = 14  # Number of node features (from build_graph_from_frame)
hidden_dim = 128   # Hidden dimension for GCRN
out_dim = 1        # Output dimension (e.g., regression target per node)

model = GCRN(node_feat_dim=node_feat_dim, hidden_dim=hidden_dim, out_dim=out_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
loss_fn = nn.L1Loss()

best_val_loss = float('inf')
patience = 5
patience_counter = 0

import csv

train_losses = []
val_losses = []

for epoch in range(100):    
    loss = train_gnn(model, train_loader, optimizer, loss_fn, device)
    val_loss = evaluate_gnn(model, val_loader, loss_fn, device)
    train_losses.append(loss)
    val_losses.append(val_loss)
    print(f"Epoch {epoch+1} - Train Loss: {loss:.4f} | Val Loss: {val_loss:.4f}")

    print(f"Epoch {epoch+1} - Train Loss: {loss:.4f} | Val Loss: {val_loss:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save(model.state_dict(), "best_gcrn.pt")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break
# Save logs to CSV
with open("train_log.csv", "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["epoch", "train_loss", "val_loss"])
    for i, (tr, vl) in enumerate(zip(train_losses, val_losses)):
        writer.writerow([i+1, tr, vl])
        
model.load_state_dict(torch.load("best_gcrn.pt"))

IndexError: index 0 is out of bounds for dimension 0 with size 0

In [None]:
#Cell 9: Final Evaluation

# Load test sequences (list of sequences, each is a list of Data objects)
test_sequences = load_tif_as_graph_sequences(test_video_path, window=5, manual_source=manual_sources.get(os.path.basename(test_video_path)))
test_dataset = GraphSequenceDataset(test_sequences)
test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=lambda x: x[0])

test_loss, all_preds, all_targets = evaluate_gnn(model, test_loader, loss_fn, device, return_preds=True)

# Flatten predictions and targets for correlation/statistics
pred_flat = np.concatenate([p.flatten() for p in all_preds])
true_flat = np.concatenate([t.flatten() for t in all_targets])
corr, _ = pearsonr(pred_flat, true_flat)  # Calculate Pearson correlation coefficient
print(f"Test Loss: {test_loss:.4f} | Pearson Correlation: {corr:.4f}")

In [None]:
#Cell 10: Visualization - Scatter and Histogram

plt.figure(figsize=(6, 6))
plt.scatter(true_flat, pred_flat, alpha=0.5, edgecolor='k')
plt.plot([true_flat.min(), true_flat.max()], [true_flat.min(), true_flat.max()], 'r--')
plt.xlabel('True Intensity')
plt.ylabel('Predicted Intensity')
plt.title('Scatter: Predicted vs True Node Intensities')
plt.grid(True)
plt.tight_layout()
plt.savefig(desktop_path / "scatter_pred_vs_true.png", bbox_inches='tight', dpi=300)
plt.close()

plt.figure(figsize=(6, 4))
plt.hist(pred_flat - true_flat, bins=30, color='gray', edgecolor='black')
plt.axvline(0, color='red', linestyle='--')
plt.xlabel('Prediction Error')
plt.ylabel('Node Count')
plt.title('Histogram of Prediction Errors')
plt.tight_layout()
plt.savefig(desktop_path / "hist_prediction_error.png", bbox_inches='tight', dpi=300)
plt.close()


In [None]:
# Cell 11: Line Plot Visualization

model.eval()
with torch.no_grad():
    for graph_seq in test_loader:
        # graph_seq is a list of Data objects (the sequence)
        graph_seq = [g.to(device) for g in graph_seq]
        preds = model(graph_seq)  # list of [num_nodes, out_dim], one per time step
        preds = [p.squeeze(-1).cpu().numpy() for p in preds]
        trues = [g.y.squeeze(-1).cpu().numpy() for g in graph_seq]
        seq_len = len(preds)
        for t in range(seq_len):
            plt.figure(figsize=(8, 4))
            plt.plot(preds[t], 'ro-', label='Predicted')
            plt.plot(trues[t], 'bo-', label='Ground Truth')
            plt.title(f'Node Intensities: Prediction vs Ground Truth (t={t})')
            plt.xlabel('Node Index')
            plt.ylabel('Intensity')
            plt.legend()
            plt.savefig(desktop_path / f"intensities_pred_vs_gt_t{t}.png", bbox_inches='tight', dpi=300)
            plt.close()
        break  # Remove this if you want to plot for all sequences in the test set

In [None]:
# Cell 13: Overlay Prediction on Input Image

graph_seq = test_loader.dataset[0]  # First sequence in test set
with torch.no_grad():
    preds = model([g.to(device) for g in graph_seq])
    preds = [p.squeeze(-1).cpu().numpy() for p in preds]

for t, g in enumerate(graph_seq):
    img_path = g.source_file
    frame_idx = g.frame_idx
    stack = imread(img_path)
    frame = stack[frame_idx]
    frame_resized = resize(frame, (64, 64))
    coords = g.x.cpu().numpy()[:, :2]
    x_img = ((coords[:, 0] + 1) / 2) * frame_resized.shape[1]
    y_img = ((coords[:, 1] + 1) / 2) * frame_resized.shape[0]
    plt.figure(figsize=(8, 8))
    plt.imshow(frame_resized, cmap='gray')
    plt.scatter(x_img, y_img, c=preds[t], cmap='hot', s=80, edgecolor='k')
    plt.title(f'Predicted Signal Propagation on Image (t={t})')
    plt.colorbar(label='Predicted Intensity')
    plt.axis('off')
    plt.savefig(desktop_path / f"pred_on_img_t{t}.png", bbox_inches='tight', dpi=300)
    plt.close()