In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch_geometric.loader import DataLoader

from torch.utils.data import Dataset
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
from sklearn.metrics import mean_absolute_error
import matplotlib.pyplot as plt
from skimage.io import imread
from skimage.transform import resize
from skimage.measure import label, regionprops
from scipy.ndimage import median_filter
from skimage.filters import gaussian
from skimage import exposure
from sklearn.model_selection import train_test_split

# Set config
CONFIG = {
    "resize_shape": (512, 512),
    "intensity_thresh": 0.01,
    "lost_ttl": 3,
    "max_nodes": 100,
    "batch_size": 1,
    "lr": 1e-3,
    "num_epochs": 100,
    "patience": 10,
    "save_dir": os.path.expanduser("~/Desktop/GNN_Predictions")
}

os.makedirs(CONFIG["save_dir"], exist_ok=True)

# Set seeds for reproducibility
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
def denoise_gaussian_median(frame, median_size=3, gaussian_sigma=1.0):
    medianed = median_filter(frame, size=median_size)
    smoothed = gaussian(medianed, sigma=gaussian_sigma)
    return smoothed

def apply_clahe(frame, clip_limit=0.01):
    return exposure.equalize_adapthist(frame, clip_limit=clip_limit)


In [3]:
def build_graph_from_frame(frame, intensity_thresh=0.01, max_nodes=100, source_idx=None):
    binary_mask = frame > intensity_thresh
    labeled = label(binary_mask)
    props = regionprops(labeled, intensity_image=frame)

    node_features = []
    node_ids = []
    for i, p in enumerate(props[:max_nodes]):
        x, y = p.centroid[::-1]  # x = width, y = height
        intensity = p.mean_intensity
        node_features.append([x, y, intensity])
        node_ids.append(i)

    if len(node_features) == 0:
        return None
    elif len(node_features) == 1:
        return None  # GCNConv needs at least 2 nodes

    # More than 1 node
    idx = torch.arange(len(node_features))
    comb = torch.combinations(idx, r=2)
    if comb.size(0) == 0:
        return None  # No edges

    edge_index = comb.T
    edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)

    # Final validation
    if not (edge_index.ndim == 2 and edge_index.shape[0] == 2 and edge_index.shape[1] > 0):
        return None

    x = torch.tensor(node_features, dtype=torch.float)  # Features: [x, y, intensity]
    y = torch.tensor([f[2] for f in node_features], dtype=torch.float).view(-1, 1)  # Use intensity as label
    positions = torch.tensor([[f[0], f[1]] for f in node_features], dtype=torch.float)

    g = Data(x=x, edge_index=edge_index, y=y, pos=positions)
    g.node_ids = torch.tensor(node_ids, dtype=torch.long)
    return g

from scipy.spatial.distance import cdist



def build_temporal_graph_sequence(path):
    max_nodes = CONFIG["max_nodes"]
    intensity_thresh = CONFIG["intensity_thresh"]

    stack = imread(path)
    stack_denoised = np.array([denoise_gaussian_median(f, gaussian_sigma=1.0) for f in stack])
    stack_resized = np.array([resize(f, CONFIG["resize_shape"]) for f in stack_denoised])
    norm_stack = (stack_resized - stack_resized.min()) / (stack_resized.max() - stack_resized.min())
    norm_stack = norm_stack * 2 - 1

    graphs = []
    node_positions_per_frame = []

    for frame in norm_stack:
        binary_mask = frame > intensity_thresh
        labeled = label(binary_mask)
        props = regionprops(labeled, intensity_image=frame)

        node_features = []
        node_positions = []

        for p in props[:max_nodes]:
            x, y = p.centroid[::-1]
            intensity = p.mean_intensity
            node_features.append([x, y, intensity])
            node_positions.append([x, y])

        if len(node_features) > 1:
            graphs.append(np.array(node_features))
            node_positions_per_frame.append(np.array(node_positions))

    aligned_graphs = []

    for t in range(len(graphs) - 1):
        current = graphs[t]
        next_frame = graphs[t + 1]
        pos_current = node_positions_per_frame[t]
        pos_next = node_positions_per_frame[t + 1]

        dist_matrix = cdist(pos_current, pos_next)
        match = dist_matrix.argmin(axis=1)

        matched_intensities = []
        for i, j in enumerate(match):
            if j < len(next_frame):
                matched_intensities.append(next_frame[j][2])
            else:
                matched_intensities.append(0.0)

        x = torch.tensor(current, dtype=torch.float)
        y = torch.tensor(matched_intensities, dtype=torch.float).view(-1, 1)

        idx = torch.arange(x.shape[0])
        comb = torch.combinations(idx, r=2)
        edge_index = comb.T
        edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)

        pos = torch.tensor(pos_current, dtype=torch.float)
        g = Data(x=x, edge_index=edge_index, y=y, pos=pos)
        aligned_graphs.append(g)

    return aligned_graphs



In [5]:
class SimpleGraphDataset(Dataset):
    def __init__(self, graphs):
        # Filter out invalid graphs
        self.graphs = [
            g for g in graphs
            if hasattr(g, "edge_index")
            and g.edge_index is not None
            and g.edge_index.ndim == 2
            and g.edge_index.shape[0] == 2
            and g.edge_index.shape[1] > 0
        ]
        print(f"Using {len(self.graphs)} valid graphs after filtering.")

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx):
        return self.graphs[idx]  # Return the full Data object


In [None]:
#Simple GCN Model


import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class StrongerGNN(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.conv1 = GCNConv(input_dim, 64)
        self.conv2 = GCNConv(64, 64)
        self.conv3 = GCNConv(64, 32)
        self.dropout = nn.Dropout(0.3)
        self.lin = nn.Linear(32, 1)

    def forward(self, x, edge_index, batch=None):
        x = F.relu(self.conv1(x, edge_index))
        x = self.dropout(x)
        x = F.relu(self.conv2(x, edge_index))
        x = self.dropout(x)
        x = F.relu(self.conv3(x, edge_index))
        return self.lin(x)


In [None]:
def visualize_gnn_predictions(preds, targets, positions, shape=(128, 128), num_examples=5):
    os.makedirs(CONFIG["save_dir"], exist_ok=True)

    for i in range(min(num_examples, len(positions))):  # Use only valid batches
        pred_img = np.zeros(shape)
        target_img = np.zeros(shape)
        count_img = np.zeros(shape)

        try:
            for (x, y), pred, true in zip(positions[i], preds[i], targets[i]):
                x = int(np.clip(x, 0, shape[1] - 1))  # width
                y = int(np.clip(y, 0, shape[0] - 1))  # height
                pred_img[y, x] += pred
                target_img[y, x] += true
                count_img[y, x] += 1
        except Exception as e:
            print(f" Skipped batch {i} due to error: {e}")
            continue

        # Avoid division by zero
        count_img[count_img == 0] = 1
        pred_img /= count_img
        target_img /= count_img
        diff_img = np.abs(pred_img - target_img)

        fig, axs = plt.subplots(1, 3, figsize=(12, 4))
        axs[0].imshow(target_img, cmap="viridis")
        axs[0].set_title("True Activation at t+1")
        axs[1].imshow(pred_img, cmap="viridis")
        axs[1].set_title("Predicted Activation at t+1")
        axs[2].imshow(diff_img, cmap="hot")
        axs[2].set_title("Prediction Error")
        axs[1].scatter(
        [int(p[0]) for p in positions[i]],
        [int(p[1]) for p in positions[i]],
        color='white', s=10, alpha=0.5
        )
        for ax in axs:
            ax.axis("off")

        plt.tight_layout()
        plt.savefig(os.path.join(CONFIG["save_dir"], f"GNN_Predictions_{i}.png"))
        plt.close()


In [None]:



def train(model, loader, optimizer, loss_fn_main, loss_fn_alt):
    model.train()
    total_main = 0
    total_alt = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss_main = loss_fn_main(out, data.y)
        loss_alt = loss_fn_alt(out, data.y)
        loss_main.backward()
        optimizer.step()
        total_main += loss_main.item()
        total_alt += loss_alt.item()
    return total_main / len(loader), total_alt / len(loader)

def evaluate(model, loader, loss_fn_main, loss_fn_alt, return_preds=False):
    model.eval()
    total_main = 0
    total_alt = 0
    preds = []
    targets = []
    positions_all = []

    with torch.no_grad():
        for i, data in enumerate(loader):
            data = data.to(device)
            out = model(data.x, data.edge_index, data.batch)  # [num_nodes, 1]
            y = data.y.view(-1, 1)                            # [num_nodes, 1]

            loss_main = loss_fn_main(out, y)
            loss_alt = loss_fn_alt(out, y)

            total_main += loss_main.item()
            total_alt += loss_alt.item()

            if return_preds:
                preds.append(out.cpu().numpy())
                targets.append(y.cpu().numpy())

                # Safely collect node positions if available
                if hasattr(data, 'pos') and data.pos is not None:
                    positions_all.append(data.pos.cpu().numpy())
                else:
                    print(f" Skipping positions for batch {i} (data.pos is None)")

    if return_preds:
        preds = [p.squeeze() for p in preds]
        targets = [t.squeeze() for t in targets]
        all_preds_flat = np.concatenate(preds)
        all_targets_flat = np.concatenate(targets)
        mae = mean_absolute_error(all_targets_flat, all_preds_flat)
        corr = np.corrcoef(all_targets_flat, all_preds_flat)[0, 1]

        return total_main / len(loader), total_alt / len(loader), mae, corr, preds, targets, positions_all

    else:
        return total_main / len(loader), total_alt / len(loader)





In [9]:
if __name__ == "__main__":
    file_paths = [
        "C:/Users/Platypus/Documents/CellNet/Real_Time_CS_Experiment-1093.tif",
        "C:/Users/Platypus/Documents/CellNet/Flow prior to chemical stimulation_Figure6C.tif",
        "C:/Users/Platypus/Documents/CellNet/Real_Time_CS_Experiment-1093.tif",
        "C:/Users/Platypus/Documents/CellNet/Flow prior to chemical stimulation_Figure6C.tif", #✔ Clicked on Frame 77: (1.16, 2.96); ✔ Clicked on Frame 77: (15.57, 10.86)
        "C:/Users/Platypus/Documents/CellNet/Figure8.tif",
        "C:/Users/Platypus/Documents/CellNet/5uM_per_litre_Figure6_ChemicalStimulation.tif",
        "C:/Users/Platypus/Documents/CellNet/Cell Knocked_Figure7.tif"]

    all_graphs = []
    for path in file_paths:
       # gs = [g for g in build_temporal_graph_sequence(path) if g is not None]
        gs = build_temporal_graph_sequence(path)

        all_graphs.extend(gs)
    
    print(f"Loaded {len(all_graphs)} valid graphs across {len(file_paths)} videos.")


    train_val, test = train_test_split(all_graphs, test_size=0.15, random_state=42)
    train_graphs, val_graphs = train_test_split(train_val, test_size=0.176, random_state=42)

    train_loader = DataLoader(SimpleGraphDataset(train_graphs), batch_size=1, shuffle=True)
    val_loader = DataLoader(SimpleGraphDataset(val_graphs), batch_size=1)
    test_loader = DataLoader(SimpleGraphDataset(test), batch_size=1)

    model = StrongerGNN(input_dim=train_graphs[0].x.shape[1]).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"])
    loss_fn_main = nn.L1Loss()
    loss_fn_alt = nn.SmoothL1Loss()  # Huber


    best_val_loss = float("inf")
    patience_counter = 0

    for epoch in range(CONFIG["num_epochs"]):
        train_main, train_alt = train(model, train_loader, optimizer, loss_fn_main, loss_fn_alt)
        val_main, val_alt = evaluate(model, val_loader, loss_fn_main, loss_fn_alt)

        print(f"Epoch {epoch+1:03} | Train L1: {train_main:.4f} | Huber: {train_alt:.4f} || Val L1: {val_main:.4f} | Huber: {val_alt:.4f}")

        if val_main < best_val_loss:
            best_val_loss = val_main
            torch.save(model.state_dict(), "best_baseline_gnn.pt")
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= CONFIG["patience"]:
                print("Early stopping.")
                break


    model.load_state_dict(torch.load("best_baseline_gnn.pt"))
    test_main, test_alt, test_mae, test_corr, preds, targets, positions = evaluate(
    model, test_loader, loss_fn_main, loss_fn_alt, return_preds=True
)

visualize_gnn_predictions(preds, targets, positions, shape=CONFIG["resize_shape"])


print("\n✅ Final GNN Test Metrics:")
print(f"  - Test Loss (L1):       {test_main:.4f}")
print(f"  - Test Loss (Huber):    {test_alt:.4f}")
print(f"  - MAE:                  {test_mae:.4f}")
print(f"  - Pearson Correlation:  {test_corr:.4f}")



  
    # Love u, 
    # cmok
  

Loaded 1849 valid graphs across 7 videos.
Using 1294 valid graphs after filtering.
Using 277 valid graphs after filtering.
Using 278 valid graphs after filtering.
Epoch 001 | Train L1: 0.9070 | Huber: 0.6121 || Val L1: 0.1219 | Huber: 0.0176
Epoch 002 | Train L1: 0.1534 | Huber: 0.0237 || Val L1: 0.1175 | Huber: 0.0122
Epoch 003 | Train L1: 0.1274 | Huber: 0.0166 || Val L1: 0.1177 | Huber: 0.0163
Epoch 004 | Train L1: 0.1160 | Huber: 0.0151 || Val L1: 0.1118 | Huber: 0.0134
Epoch 005 | Train L1: 0.1106 | Huber: 0.0141 || Val L1: 0.1128 | Huber: 0.0119
Epoch 006 | Train L1: 0.1084 | Huber: 0.0136 || Val L1: 0.1093 | Huber: 0.0137
Epoch 007 | Train L1: 0.1062 | Huber: 0.0135 || Val L1: 0.1084 | Huber: 0.0130
Epoch 008 | Train L1: 0.1057 | Huber: 0.0134 || Val L1: 0.1140 | Huber: 0.0163
Epoch 009 | Train L1: 0.1045 | Huber: 0.0131 || Val L1: 0.1106 | Huber: 0.0153
Epoch 010 | Train L1: 0.1042 | Huber: 0.0131 || Val L1: 0.1107 | Huber: 0.0150
Epoch 011 | Train L1: 0.1036 | Huber: 0.0127 ||

  model.load_state_dict(torch.load("best_baseline_gnn.pt"))



✅ Final GNN Test Metrics:
  - Test Loss (L1):       0.1087
  - Test Loss (Huber):    0.0116
  - MAE:                  0.1033
  - Pearson Correlation:  0.4105
