In [19]:

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
from scipy.stats import pearsonr
from tifffile import imread
from skimage.transform import resize
from scipy.ndimage import median_filter
from skimage.filters import gaussian
import os
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error

import random
from sklearn.metrics import mean_absolute_error
from scipy.stats import pearsonr
from skimage.metrics import structural_similarity as ssim
import torch.nn.functional as F

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# ----------------------------
# CONFIG (if not imported from global scope)
# ----------------------------
CONFIG = {
    "resize_shape": (128, 128),
    "intensity_thresh": 0.1,
    "lost_ttl": 3,
    "save_vis_dir": os.path.join(os.path.expanduser("~"), "Desktop", "LSTM_Predictions")
}

# Create visualization directory
os.makedirs(CONFIG["save_vis_dir"], exist_ok=True)


In [20]:
# ----------------------------
# Preprocessing
# ----------------------------
def denoise_gaussian_median(frame, median_size=3, gaussian_sigma=1.0):
    medianed = median_filter(frame, size=median_size)
    smoothed = gaussian(medianed, sigma=gaussian_sigma)
    return smoothed

# ----------------------------
# Sequence Loader (no graph tracking)
# ----------------------------
def load_tif_node_sequences(path, window=5, frame_skip=1, denoise_sigma=1.0):
    stack = imread(path)
    stack_denoised = np.array([denoise_gaussian_median(f, gaussian_sigma=denoise_sigma) for f in stack])
    stack_resized = np.array([resize(f, CONFIG["resize_shape"]) for f in stack_denoised])
    norm_stack = (stack_resized - stack_resized.min()) / (stack_resized.max() - stack_resized.min())
    norm_stack = norm_stack * 2 - 1

    # Flatten each frame to a 1D vector (simulate node features as pixels for LSTM)
    h, w = CONFIG["resize_shape"]
    flattened_stack = norm_stack.reshape(len(norm_stack), -1)  # [T, H*W]

    sequences = []
    for i in range(0, len(flattened_stack) - window, frame_skip):
        seq = flattened_stack[i:i+window]       # [window, H*W]
        target = flattened_stack[i+window]      # [H*W]
        sequences.append((seq, target))
    return sequences

In [21]:
# ----------------------------
# LSTM Model
# ----------------------------
class CalciumLSTMRegressor(nn.Module):
    def __init__(self, input_size, hidden_size=64, num_layers=2, dropout=0.5):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=False
        )

        self.batch_norm = nn.BatchNorm1d(hidden_size)

        self.fc1 = nn.Linear(hidden_size, 64)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.fc2 = nn.Linear(64, 32)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.output = nn.Linear(32, input_size)  # predict next-frame full vector
        self._initialize_weights()

    def _initialize_weights(self):
        for name, param in self.named_parameters():
            if 'weight' in name:
                if 'lstm' in name and len(param.shape) >= 2:
                    nn.init.orthogonal_(param)
                elif len(param.shape) >= 2:
                    nn.init.xavier_uniform_(param)
                else:
                    nn.init.normal_(param, mean=0.0, std=0.1)
            elif 'bias' in name:
                nn.init.zeros_(param)

    def forward(self, x):
        batch_size = x.size(0)
        h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=x.device)
        c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=x.device)
        lstm_out, _ = self.lstm(x, (h0, c0))
        out = lstm_out[:, -1, :]  # last timestep
        out = self.batch_norm(out)
        out = self.relu1(self.fc1(out))
        out = self.dropout1(out)
        out = self.relu2(self.fc2(out))
        out = self.dropout2(out)
        return self.output(out)

In [22]:
# ----------------------------
# Dataset Wrapper
# ----------------------------
class NodeSequenceDataset(Dataset):
    def __init__(self, graph_sequences):
        self.sequences = []
        for graph_seq in graph_sequences:
            node_id_sets = [set(g.node_ids.cpu().numpy()) for g in graph_seq]
            common_ids = set.intersection(*node_id_sets)
            for nid in common_ids:
                seq_feats = []
                for g in graph_seq[:-1]:  # all but last frame
                    idx = (g.node_ids == nid).nonzero(as_tuple=True)[0].item()
                    seq_feats.append(g.x[idx].unsqueeze(0))
                seq_feats = torch.cat(seq_feats, dim=0)
                target_g = graph_seq[-1]
                target_idx = (target_g.node_ids == nid).nonzero(as_tuple=True)[0].item()
                target = target_g.y[target_idx]  # predict next-frame intensity
                self.sequences.append((seq_feats, target))

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        x, y = self.sequences[idx]
        return x, y

class FrameSequenceDataset(Dataset):
    def __init__(self, sequences):
        self.data = sequences

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        seq, target = self.data[idx]
        return torch.tensor(seq, dtype=torch.float32), torch.tensor(target, dtype=torch.float32)


In [36]:

# ----------------------------
# Train / Eval Functions
# ----------------------------
def train_lstm(model, loader, optimizer, loss_fn_main, loss_fn_alt, device):
    model.train()
    total_main = 0
    total_alt = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss_main = loss_fn_main(out.squeeze(), y.squeeze())
        loss_alt = loss_fn_alt(out.squeeze(), y.squeeze())
        loss_main.backward()
        optimizer.step()
        total_main += loss_main.item()
        total_alt += loss_alt.item()
    return total_main / len(loader), total_alt / len(loader)

# def evaluate_lstm(model, loader, loss_fn_main, loss_fn_alt, device):
#     model.eval()
#     total_main = 0
#     total_alt = 0
#     with torch.no_grad():
#         for x, y in loader:
#             x, y = x.to(device), y.to(device)
#             out = model(x)
#             loss_main = loss_fn_main(out.squeeze(), y.squeeze())
#             loss_alt = loss_fn_alt(out.squeeze(), y.squeeze())
#             total_main += loss_main.item()
#             total_alt += loss_alt.item()
#     return total_main / len(loader), total_alt / len(loader)

def evaluate_lstm(model, loader, loss_fn_main, loss_fn_alt, device, return_preds=False):
    model.eval()
    total_main = 0
    total_alt = 0
    all_preds = []
    all_targets = []
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss_main = loss_fn_main(out.squeeze(), y.squeeze())
            loss_alt = loss_fn_alt(out.squeeze(), y.squeeze())
            total_main += loss_main.item()
            total_alt += loss_alt.item()
            if return_preds:
                all_preds.append(out.squeeze().cpu().numpy())
                all_targets.append(y.squeeze().cpu().numpy())
    if return_preds:
        return total_main / len(loader), total_alt / len(loader), all_preds, all_targets
    else:
        return total_main / len(loader), total_alt / len(loader)





In [37]:
from sklearn.metrics import mean_absolute_error
import matplotlib.pyplot as plt
import os

def visualize_lstm_predictions(preds, targets, shape=(128, 128), idx=0, save_dir=None):
    """
    Visualize and optionally save one predicted vs ground truth frame from LSTM output.
    """
    pred_frame = preds[idx].reshape(shape)
    target_frame = targets[idx].reshape(shape)
    diff_frame = np.abs(pred_frame - target_frame)

    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
    axs[0].imshow(pred_frame, cmap='viridis')
    axs[0].set_title("Predicted")
    axs[1].imshow(target_frame, cmap='viridis')
    axs[1].set_title("Ground Truth")
    axs[2].imshow(diff_frame, cmap='hot')
    axs[2].set_title("Difference")

    for ax in axs:
        ax.axis('off')

    plt.tight_layout()

    # Save if save_dir is given
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
        save_path = os.path.join(save_dir, f"prediction_{idx:02d}.png")
        plt.savefig(save_path, bbox_inches='tight')
        print(f"✅ Saved visualization to: {save_path}")

    plt.close(fig)  # Close to avoid displaying during batch runs


In [42]:
if __name__ == "__main__":
    from sklearn.model_selection import train_test_split

    file_paths = [
        'C:/Users/Platypus/Documents/CellNet/Real_Time_CS_Experiment-1093.tif',
        'C:/Users/Platypus/Documents/CellNet/Flow prior to chemical stimulation_Figure6C.tif'
    ]

    all_sequences = []
    for path in file_paths:
        seqs = load_tif_node_sequences(path, window=5)
        all_sequences.extend(seqs)

    train_val, test_seq = train_test_split(all_sequences, test_size=0.15, random_state=42)
    train_seq, val_seq = train_test_split(train_val, test_size=0.176, random_state=42)

    train_ds = FrameSequenceDataset(train_seq)
    val_ds = FrameSequenceDataset(val_seq)
    test_ds = FrameSequenceDataset(test_seq)

    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=32)
    test_loader = DataLoader(test_ds, batch_size=32)

    input_size = CONFIG["resize_shape"][0] * CONFIG["resize_shape"][1]
    model = CalciumLSTMRegressor(input_size=input_size).to("cuda" if torch.cuda.is_available() else "cpu")
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.L1Loss()
    loss_fn_alt = nn.SmoothL1Loss()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    best_val_loss = float('inf')
    patience = 5
    patience_counter = 0
    train_losses = []
    val_losses = []
   
    # Example: main loss = L1Loss, alt loss = HuberLoss
    loss_fn_main = torch.nn.L1Loss()
    loss_fn_alt = torch.nn.SmoothL1Loss()

    train_main_loss, train_alt_loss = train_lstm(model, train_loader, optimizer, loss_fn_main, loss_fn_alt, device)
    val_main_loss, val_alt_loss = evaluate_lstm(model, val_loader, loss_fn_main, loss_fn_alt, device)


    for epoch in range(100):
        train_main_loss, train_alt_loss = train_lstm(model, train_loader, optimizer, loss_fn_main, loss_fn_alt, device)
        val_main_loss, val_alt_loss = evaluate_lstm(model, val_loader, loss_fn_main, loss_fn_alt, device)
        
        train_losses.append(train_main_loss)
        val_losses.append(val_main_loss)
        
        print(f"Epoch {epoch+1} - Train Main Loss: {train_main_loss:.4f} | Val Main Loss: {val_main_loss:.4f}")

        if val_main_loss < best_val_loss:
            best_val_loss = val_main_loss
            patience_counter = 0
            torch.save(model.state_dict(), "best_lstm.pt")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break


    model.load_state_dict(torch.load("best_lstm.pt"))
    test_main_loss, test_alt_loss, preds, targets =evaluate_lstm(model, test_loader, loss_fn_main, loss_fn_alt, device, return_preds=True)

    # Flatten batched preds/targets into list of 1D arrays
    flat_preds = [p for batch in preds for p in batch]
    flat_targets = [t for batch in targets for t in batch]

    for i in range(3):  # save first 3 examples
        visualize_lstm_predictions(
            flat_preds, 
            flat_targets, 
            shape=CONFIG["resize_shape"], 
            idx=i, 
            save_dir=CONFIG["save_vis_dir"]
    )


    # print(f"\n✅ Final Test Loss: {test_loss:.4f}")
    print(f"\n✅ Final Test Losses:")
    print(f"   Main Loss (L1):       {test_main_loss:.4f}")
    print(f"   Alt Loss  (Huber):    {test_alt_loss:.4f}")


Epoch 1 - Train Main Loss: 0.4002 | Val Main Loss: 0.0151
Epoch 2 - Train Main Loss: 0.2191 | Val Main Loss: 0.1337
Epoch 3 - Train Main Loss: 0.2084 | Val Main Loss: 0.1153
Epoch 4 - Train Main Loss: 0.1903 | Val Main Loss: 0.0651
Epoch 5 - Train Main Loss: 0.1854 | Val Main Loss: 0.0585
Epoch 6 - Train Main Loss: 0.1774 | Val Main Loss: 0.0613
Early stopping triggered.


  model.load_state_dict(torch.load("best_lstm.pt"))


✅ Saved visualization to: C:\Users\Platypus\Desktop\LSTM_Predictions\prediction_00.png
✅ Saved visualization to: C:\Users\Platypus\Desktop\LSTM_Predictions\prediction_01.png
✅ Saved visualization to: C:\Users\Platypus\Desktop\LSTM_Predictions\prediction_02.png

✅ Final Test Losses:
   Main Loss (L1):       0.0152
   Alt Loss  (Huber):    0.0005


In [45]:
from sklearn.metrics import mean_absolute_error
from scipy.stats import pearsonr

def compute_lstm_metrics(preds, targets):
    pred_flat = np.concatenate(preds).flatten()
    target_flat = np.concatenate(targets).flatten()

    mae = mean_absolute_error(target_flat, pred_flat)
    pearson_corr, _ = pearsonr(pred_flat, target_flat)

    return mae, pearson_corr

# After model.load_state_dict(...)
test_main_loss, test_alt_loss, preds, targets= evaluate_lstm(model, test_loader, loss_fn_main, loss_fn_alt, device, return_preds=True)
mae, corr = compute_lstm_metrics(preds, targets)

print(f"\n✅ Final LSTM Test Metrics:")
print(f"  - Test Loss (L1): {test_loss:.4f}")
print(f"  - MAE: {mae:.4f}")
print(f"  - Pearson Correlation: {corr:.4f}")



✅ Final LSTM Test Metrics:
  - Test Loss (L1): 0.0093
  - MAE: 0.0148
  - Pearson Correlation: 0.5723
