In [61]:

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
from scipy.stats import pearsonr
from tifffile import imread
from skimage.transform import resize
from scipy.ndimage import median_filter
from skimage.filters import gaussian
import os
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error

import random


from skimage.metrics import structural_similarity as ssim
import torch.nn.functional as F
from sklearn.model_selection import train_test_split




def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# Config
CONFIG = {
    "resize_shape": (64, 64),
    "window": 10,
    "future_steps": 1,
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "save_vis_dir": os.path.join(os.path.expanduser("~"), "Desktop", "LSTM_Predictions")
}
os.makedirs(CONFIG["save_vis_dir"], exist_ok=True)

In [62]:
# Preprocessing
def denoise_gaussian_median(frame, median_size=3, gaussian_sigma=1.0):
    return gaussian(median_filter(frame, size=median_size), sigma=gaussian_sigma)

def load_tif_sequences(path, window=10, future_steps=1):
    stack = imread(path)
    stack_denoised = np.array([denoise_gaussian_median(f) for f in stack])
    stack_resized = np.array([resize(f, CONFIG["resize_shape"]) for f in stack_denoised])
    norm_stack = (stack_resized - stack_resized.min()) / (stack_resized.max() - stack_resized.min())
    norm_stack = norm_stack * 2 - 1

    sequences = []
    for i in range(0, len(norm_stack) - window - future_steps + 1):
        seq = norm_stack[i:i+window]
        target = norm_stack[i+window:i+window+future_steps]
        sequences.append((seq, target))
    return sequences

In [63]:
# Dataset
class PatchSequenceDataset(Dataset):
    def __init__(self, sequences, shape=(64, 64)):
        self.sequences = sequences
        self.shape = shape
        self.pos_enc = self._generate_positional_encodings(shape)

    def _generate_positional_encodings(self, shape):
        H, W = shape
        grid_y, grid_x = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij')
        pos = torch.stack((grid_y, grid_x), dim=-1).float()
        pos = pos / torch.tensor([H - 1, W - 1])
        return pos

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq, target = self.sequences[idx]
        T = seq.shape[0]
        F = target.shape[0]
        H, W = self.shape
        pos_enc = self.pos_enc.permute(2, 0, 1).unsqueeze(0).repeat(T, 1, 1, 1)
        seq = torch.tensor(seq.reshape(T, H, W), dtype=torch.float32).unsqueeze(1)
        seq_with_pos = torch.cat([seq, pos_enc], dim=1)
        return seq_with_pos, torch.tensor(target.reshape(F, H, W), dtype=torch.float32)

In [64]:
# Model
class LSTM_baseline(nn.Module):
    def __init__(self, patch_size=4, image_size=(64, 64), hidden_dim=256, lstm_layers=2, dropout=0.1, future_steps=1):
        super().__init__()
        self.patch_size = patch_size
        self.image_size = image_size
        self.future_steps = future_steps
        self.num_patches = (image_size[0] // patch_size) * (image_size[1] // patch_size)
        self.input_dim = patch_size * patch_size + 2 * patch_size * patch_size

        self.encoder = nn.Sequential(
            nn.Linear(self.input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64)
        )

        self.lstm = nn.LSTM(
            input_size=64,
            hidden_size=hidden_dim,
            num_layers=lstm_layers,
            batch_first=True,
            dropout=dropout,
            bidirectional=False
        )

        self.attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=4, batch_first=True)

        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(),
            nn.Linear(64, self.input_dim * future_steps)
        )

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x):
        B, T, C, H, W = x.size()
        P = self.patch_size

        x = x.view(B, T, C, H, W)
        patches = x.unfold(3, P, P).unfold(4, P, P)
        patches = patches.permute(0, 1, 3, 4, 2, 5, 6).contiguous()
        patches = patches.view(B, T, -1, C * P * P)

        outputs = []
        for p in range(patches.size(2)):
            patch_seq = patches[:, :, p, :]
            encoded = self.encoder(patch_seq.view(-1, self.input_dim)).view(B, T, -1)
            lstm_out, _ = self.lstm(encoded)
            attn_out, _ = self.attn(lstm_out, lstm_out, lstm_out)
            decoded = self.decoder(attn_out[:, -1, :])
            decoded = decoded.view(B, self.future_steps, self.input_dim)
            outputs.append(decoded)

        outputs = torch.stack(outputs, dim=2)
        outputs = outputs.view(B, self.future_steps, H // P, W // P, C, P, P)
        outputs = outputs[:, :, :, :, 0, :, :]
        outputs = outputs.permute(0, 1, 2, 4, 3, 5).contiguous().view(B, self.future_steps, H, W)
        return outputs

# Utility functions
def compute_metrics(preds, targets):
    preds = np.array(preds)
    targets = np.array(targets)
    mae = mean_absolute_error(targets.flatten(), preds.flatten())
    corr, _ = pearsonr(targets.flatten(), preds.flatten())
    return mae, corr

def visualize_preds(preds, targets, shape=(64, 64), save_dir=None, prefix="LSTM_baseline"):
    preds = np.array(preds)
    targets = np.array(targets)
    num_samples = min(3, len(preds))
    F = preds.shape[1]

    for i in range(num_samples):
        for t in range(F):
            try:
                pred = preds[i, t]
                target = targets[i, t]
                diff = np.abs(pred - target)

                fig, axes = plt.subplots(1, 3, figsize=(12, 4))
                axes[0].imshow(pred, cmap='viridis')
                axes[0].set_title("Predicted Frame")
                axes[1].imshow(target, cmap='viridis')
                axes[1].set_title("Ground Truth")
                axes[2].imshow(diff, cmap='hot')
                axes[2].set_title("Abs Error")
                for ax in axes:
                    ax.axis('off')
                plt.tight_layout()

                if save_dir:
                    filename = os.path.join(save_dir, f"{prefix}_pred_{i}_t{t}.png")
                    plt.savefig(filename)
                    print(f"✅ Saved: {filename}")
                else:
                    plt.show()

                plt.close()
            except Exception as e:
                print(f"❌ Error saving frame {i}, time {t}: {e}")


In [65]:
# Training and Evaluation

def train_lstm_baseline(model, loader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = loss_fn(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate_lstm_baseline(model, loader, loss_fn, device, return_preds=False):
    model.eval()
    total_loss = 0
    preds, targets = [], []
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = loss_fn(out, y)
            total_loss += loss.item()
            if return_preds:
                preds.extend(out.cpu().numpy())
                targets.extend(y.cpu().numpy())
    if return_preds:
        return total_loss / len(loader), preds, targets
    return total_loss / len(loader)

# Utility functions

def compute_metrics(preds, targets):
    preds = np.array(preds)
    targets = np.array(targets)
    mae = mean_absolute_error(targets.flatten(), preds.flatten())
    corr, _ = pearsonr(targets.flatten(), preds.flatten())
    return mae, corr

def visualize_preds(preds, targets, shape=(64, 64), save_dir=None, prefix="LSTM_baseline"):
    preds = np.array(preds)
    targets = np.array(targets)
    num_samples = min(3, len(preds))
    F = preds.shape[1]

    for i in range(num_samples):
        for t in range(F):
            try:
                pred = preds[i, t]
                target = targets[i, t]
                diff = np.abs(pred - target)

                fig, axes = plt.subplots(1, 3, figsize=(12, 4))
                axes[0].imshow(pred, cmap='viridis')
                axes[0].set_title("Predicted Frame")
                axes[1].imshow(target, cmap='viridis')
                axes[1].set_title("Ground Truth")
                axes[2].imshow(diff, cmap='hot')
                axes[2].set_title("Abs Error")
                for ax in axes:
                    ax.axis('off')
                plt.tight_layout()

                if save_dir:
                    filename = os.path.join(save_dir, f"{prefix}_pred_{i}_t{t}.png")
                    plt.savefig(filename)
                    print(f"✅ Saved: {filename}")
                else:
                    plt.show()

                plt.close()
            except Exception as e:
                print(f"❌ Error saving frame {i}, time {t}: {e}")



In [66]:
# Main

if __name__ == "__main__":
    file_paths = [
        "C:/Users/Platypus/Documents/CellNet/Real_Time_CS_Experiment-1093.tif",
        "C:/Users/Platypus/Documents/CellNet/Flow prior to chemical stimulation_Figure6C.tif",
        "C:/Users/Platypus/Documents/CellNet/Figure8.tif",
        "C:/Users/Platypus/Documents/CellNet/5uM_per_litre_Figure6_ChemicalStimulation.tif",
        "C:/Users/Platypus/Documents/CellNet/Cell Knocked_Figure7.tif"
    ]

    all_seqs = []
    for path in file_paths:
        all_seqs.extend(load_tif_sequences(path, window=CONFIG["window"]))

    train_val, test = train_test_split(all_seqs, test_size=0.15, random_state=42)
    train_seqs, val_seqs = train_test_split(train_val, test_size=0.2, random_state=42)

    train_loader = DataLoader(PatchSequenceDataset(train_seqs, shape=CONFIG["resize_shape"]), batch_size=8, shuffle=True)
    val_loader = DataLoader(PatchSequenceDataset(val_seqs, shape=CONFIG["resize_shape"]), batch_size=8)
    test_loader = DataLoader(PatchSequenceDataset(test, shape=CONFIG["resize_shape"]), batch_size=8)

    model = LSTM_baseline(image_size=CONFIG["resize_shape"]).to(CONFIG["device"])
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.L1Loss()

    best_val = float('inf')
    patience = 5
    counter = 0

    for epoch in range(50):
        train_loss = train_lstm_baseline(model, train_loader, optimizer, loss_fn, CONFIG["device"])
        val_loss = evaluate_lstm_baseline(model, val_loader, loss_fn, CONFIG["device"])
        print(f"Epoch {epoch+1} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

        if val_loss < best_val:
            best_val = val_loss
            torch.save(model.state_dict(), "best_lstm_patch.pt")
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                print("Early stopping triggered.")
                break

    model.load_state_dict(torch.load("best_lstm_patch.pt"))
    test_loss, preds, targets = evaluate_lstm_baseline(model, test_loader, loss_fn, CONFIG["device"], return_preds=True)

    mae, corr = compute_metrics(preds, targets)

    print(f"\n✅ Final LSTM_baseline Patch Model:")
    print(f"  - L1 Loss: {test_loss:.4f}")
    print(f"  - MAE:     {mae:.4f}")
    print(f"  - Corr:    {corr:.4f}")

    pred_var = np.var(np.stack(preds), axis=0).mean()
    print(f"🔍 Mean variance across predicted pixels: {pred_var:.6f}")

    for i in range(3):
        try:
            visualize_preds(preds, targets, shape=CONFIG["resize_shape"], save_dir=CONFIG["save_vis_dir"])
        except Exception as e:
            print(f"❌ Error saving frame {i}: {e}")


Epoch 1 | Train Loss: 0.0494 | Val Loss: 0.0277
Epoch 2 | Train Loss: 0.0241 | Val Loss: 0.0176
Epoch 3 | Train Loss: 0.0186 | Val Loss: 0.0329
Epoch 4 | Train Loss: 0.0174 | Val Loss: 0.0160
Epoch 5 | Train Loss: 0.0149 | Val Loss: 0.0147
Epoch 6 | Train Loss: 0.0136 | Val Loss: 0.0174
Epoch 7 | Train Loss: 0.0140 | Val Loss: 0.0310
Epoch 8 | Train Loss: 0.0154 | Val Loss: 0.0105
Epoch 9 | Train Loss: 0.0118 | Val Loss: 0.0121
Epoch 10 | Train Loss: 0.0100 | Val Loss: 0.0081
Epoch 11 | Train Loss: 0.0121 | Val Loss: 0.0140
Epoch 12 | Train Loss: 0.0098 | Val Loss: 0.0074
Epoch 13 | Train Loss: 0.0095 | Val Loss: 0.0092
Epoch 14 | Train Loss: 0.0086 | Val Loss: 0.0071
Epoch 15 | Train Loss: 0.0087 | Val Loss: 0.0088
Epoch 16 | Train Loss: 0.0085 | Val Loss: 0.0065
Epoch 17 | Train Loss: 0.0082 | Val Loss: 0.0074
Epoch 18 | Train Loss: 0.0079 | Val Loss: 0.0060
Epoch 19 | Train Loss: 0.0100 | Val Loss: 0.0078
Epoch 20 | Train Loss: 0.0074 | Val Loss: 0.0094
Epoch 21 | Train Loss: 0.0074

  model.load_state_dict(torch.load("best_lstm_patch.pt"))



✅ Final LSTM_baseline Patch Model:
  - L1 Loss: 0.0059
  - MAE:     0.0059
  - Corr:    0.9840
🔍 Mean variance across predicted pixels: 0.003277
✅ Saved: C:\Users\Platypus\Desktop\LSTM_Predictions\LSTM_baseline_pred_0_t0.png
✅ Saved: C:\Users\Platypus\Desktop\LSTM_Predictions\LSTM_baseline_pred_1_t0.png
✅ Saved: C:\Users\Platypus\Desktop\LSTM_Predictions\LSTM_baseline_pred_2_t0.png
✅ Saved: C:\Users\Platypus\Desktop\LSTM_Predictions\LSTM_baseline_pred_0_t0.png
✅ Saved: C:\Users\Platypus\Desktop\LSTM_Predictions\LSTM_baseline_pred_1_t0.png
✅ Saved: C:\Users\Platypus\Desktop\LSTM_Predictions\LSTM_baseline_pred_2_t0.png
✅ Saved: C:\Users\Platypus\Desktop\LSTM_Predictions\LSTM_baseline_pred_0_t0.png
✅ Saved: C:\Users\Platypus\Desktop\LSTM_Predictions\LSTM_baseline_pred_1_t0.png
✅ Saved: C:\Users\Platypus\Desktop\LSTM_Predictions\LSTM_baseline_pred_2_t0.png
