In [19]:

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
from scipy.stats import pearsonr
from tifffile import imread
from skimage.transform import resize
from scipy.ndimage import median_filter
from skimage.filters import gaussian
import os
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error

import random


from skimage.metrics import structural_similarity as ssim
import torch.nn.functional as F
from sklearn.model_selection import train_test_split




def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# Config
CONFIG = {
    "resize_shape": (64, 64),
    "window": 10,
    "future_steps": 5,
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "save_vis_dir": os.path.join(os.path.expanduser("~"), "Desktop", "LSTM_Predictions")
}
os.makedirs(CONFIG["save_vis_dir"], exist_ok=True)

In [20]:
# Preprocessing
def denoise_gaussian_median(frame, median_size=3, gaussian_sigma=1.0):
    return gaussian(median_filter(frame, size=median_size), sigma=gaussian_sigma)

def load_tif_sequences(path, window=10, future_steps=1):
    stack = imread(path)
    stack_denoised = np.array([denoise_gaussian_median(f) for f in stack])
    stack_resized = np.array([resize(f, CONFIG["resize_shape"]) for f in stack_denoised])
    norm_stack = (stack_resized - stack_resized.min()) / (stack_resized.max() - stack_resized.min())
    norm_stack = norm_stack * 2 - 1

    sequences = []
    for i in range(0, len(norm_stack) - window - future_steps + 1):
        seq = norm_stack[i:i+window]
        target = norm_stack[i+window:i+window+future_steps]
        sequences.append((seq, target))
    return sequences

In [21]:
test_file = "C:/Users/Platypus/Documents/CellNet/Flow prior to chemical stimulation_Figure6C.tif"
test_stack = imread(test_file)
half_len = len(test_stack) // 2
test_stack = test_stack[:half_len]

# Save the original load_tif_sequences, but modify it to accept a stack
def load_tif_sequences_from_stack(stack, window=10, future_steps=5):
    stack_denoised = np.array([denoise_gaussian_median(f) for f in stack])
    stack_resized = np.array([resize(f, CONFIG["resize_shape"]) for f in stack_denoised])
    norm_stack = (stack_resized - stack_resized.min()) / (stack_resized.max() - stack_resized.min())
    norm_stack = norm_stack * 2 - 1

    sequences = []
    for i in range(0, len(norm_stack) - window - future_steps + 1):
        seq = norm_stack[i:i+window]
        target = norm_stack[i+window:i+window+future_steps]
        sequences.append((seq, target))
    return sequences

# Load test sequences from first half of specific file
test_seqs = load_tif_sequences_from_stack(test_stack, window=CONFIG["window"])

# Load training sequences from the other files
train_files = [
    "C:/Users/Platypus/Documents/CellNet/Real_Time_CS_Experiment-1093.tif",
    "C:/Users/Platypus/Documents/CellNet/Figure8.tif",
    "C:/Users/Platypus/Documents/CellNet/5uM_per_litre_Figure6_ChemicalStimulation.tif",
    "C:/Users/Platypus/Documents/CellNet/Cell Knocked_Figure7.tif"
]

train_val_seqs = []
for path in train_files:
    train_val_seqs.extend(load_tif_sequences(path, window=CONFIG["window"]))

train_seqs, val_seqs = train_test_split(train_val_seqs, test_size=0.2, random_state=42)


In [22]:
# Dataset
class PatchSequenceDataset(Dataset):
    def __init__(self, sequences, shape=(64, 64)):
        self.sequences = sequences
        self.shape = shape
        self.pos_enc = self._generate_positional_encodings(shape)

    def _generate_positional_encodings(self, shape):
        H, W = shape
        grid_y, grid_x = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij')
        pos = torch.stack((grid_y, grid_x), dim=-1).float()
        pos = pos / torch.tensor([H - 1, W - 1])
        return pos

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq, target = self.sequences[idx]
        T = seq.shape[0]
        F = target.shape[0]
        H, W = self.shape
        pos_enc = self.pos_enc.permute(2, 0, 1).unsqueeze(0).repeat(T, 1, 1, 1)
        seq = torch.tensor(seq.reshape(T, H, W), dtype=torch.float32).unsqueeze(1)
        seq_with_pos = torch.cat([seq, pos_enc], dim=1)
        return seq_with_pos, torch.tensor(target.reshape(F, H, W), dtype=torch.float32)

In [23]:
# Model
class LSTM_baseline(nn.Module):
    def __init__(self, patch_size=4, image_size=(64, 64), hidden_dim=256, lstm_layers=2, dropout=0.1, future_steps=5):
        super().__init__()
        self.patch_size = patch_size
        self.image_size = image_size
        self.future_steps = future_steps
        self.num_patches = (image_size[0] // patch_size) * (image_size[1] // patch_size)
        self.input_dim = patch_size * patch_size + 2 * patch_size * patch_size

        self.encoder = nn.Sequential(
            nn.Linear(self.input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64)
        )

        self.lstm = nn.LSTM(
            input_size=64,
            hidden_size=hidden_dim,
            num_layers=lstm_layers,
            batch_first=True,
            dropout=dropout,
            bidirectional=False
        )

        self.attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=4, batch_first=True)

        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(),
            nn.Linear(64, self.input_dim * future_steps)
        )

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x):
        B, T, C, H, W = x.size()
        P = self.patch_size

        x = x.view(B, T, C, H, W)
        patches = x.unfold(3, P, P).unfold(4, P, P)
        patches = patches.permute(0, 1, 3, 4, 2, 5, 6).contiguous()
        patches = patches.view(B, T, -1, C * P * P)

        outputs = []
        for p in range(patches.size(2)):
            patch_seq = patches[:, :, p, :]
            encoded = self.encoder(patch_seq.view(-1, self.input_dim)).view(B, T, -1)
            lstm_out, _ = self.lstm(encoded)
            attn_out, _ = self.attn(lstm_out, lstm_out, lstm_out)
            decoded = self.decoder(attn_out[:, -1, :])
            decoded = decoded.view(B, self.future_steps, self.input_dim)
            outputs.append(decoded)

        outputs = torch.stack(outputs, dim=2)
        outputs = outputs.view(B, self.future_steps, H // P, W // P, C, P, P)
        outputs = outputs[:, :, :, :, 0, :, :]
        outputs = outputs.permute(0, 1, 2, 4, 3, 5).contiguous().view(B, self.future_steps, H, W)
        return outputs

# Utility functions
def compute_metrics(preds, targets):
    preds = np.array(preds)
    targets = np.array(targets)
    mae = mean_absolute_error(targets.flatten(), preds.flatten())
    corr, _ = pearsonr(targets.flatten(), preds.flatten())
    return mae, corr

def visualize_preds(preds, targets, shape=(64, 64), save_dir=None, prefix="LSTM_baseline"):
    preds = np.array(preds)
    targets = np.array(targets)
    num_samples = min(3, len(preds))
    F = preds.shape[1]

    for i in range(num_samples):
        for t in range(F):
            try:
                pred = preds[i, t]
                target = targets[i, t]
                diff = np.abs(pred - target)

                fig, axes = plt.subplots(1, 3, figsize=(12, 4))
                axes[0].imshow(pred, cmap='viridis')
                axes[0].set_title("Predicted Frame")
                axes[1].imshow(target, cmap='viridis')
                axes[1].set_title("Ground Truth")
                axes[2].imshow(diff, cmap='hot')
                axes[2].set_title("Abs Error")
                for ax in axes:
                    ax.axis('off')
                plt.tight_layout()

                if save_dir:
                    filename = os.path.join(save_dir, f"{prefix}_pred_{i}_t{t}.png")
                    plt.savefig(filename)
                    print(f"✅ Saved: {filename}")
                else:
                    plt.show()

                plt.close()
            except Exception as e:
                print(f"❌ Error saving frame {i}, time {t}: {e}")


In [24]:
# Training and Evaluation

def train_lstm_baseline(model, loader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = loss_fn(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate_lstm_baseline(model, loader, loss_fn, device, return_preds=False):
    model.eval()
    total_loss = 0
    preds, targets = [], []
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = loss_fn(out, y)
            total_loss += loss.item()
            if return_preds:
                preds.extend(out.cpu().numpy())
                targets.extend(y.cpu().numpy())
    if return_preds:
        return total_loss / len(loader), preds, targets
    return total_loss / len(loader)

# Utility functions

def compute_metrics(preds, targets):
    preds = np.array(preds)
    targets = np.array(targets)
    mae = mean_absolute_error(targets.flatten(), preds.flatten())
    corr, _ = pearsonr(targets.flatten(), preds.flatten())
    return mae, corr

def visualize_preds(preds, targets, shape=(64, 64), save_dir=None, prefix="LSTM_baseline"):
    preds = np.array(preds)
    targets = np.array(targets)
    num_samples = min(3, len(preds))
    F = preds.shape[1]

    for i in range(num_samples):
        for t in range(F):
            try:
                pred = preds[i, t]
                target = targets[i, t]
                diff = np.abs(pred - target)

                fig, axes = plt.subplots(1, 3, figsize=(12, 4))
                axes[0].imshow(pred, cmap='viridis')
                axes[0].set_title("Predicted Frame")
                axes[1].imshow(target, cmap='viridis')
                axes[1].set_title("Ground Truth")
                axes[2].imshow(diff, cmap='hot')
                axes[2].set_title("Abs Error")
                for ax in axes:
                    ax.axis('off')
                plt.tight_layout()

                if save_dir:
                    filename = os.path.join(save_dir, f"{prefix}_pred_{i}_t{t}.png")
                    plt.savefig(filename)
                    print(f"✅ Saved: {filename}")
                else:
                    plt.show()

                plt.close()
            except Exception as e:
                print(f"❌ Error saving frame {i}, time {t}: {e}")



In [25]:
# Main

if __name__ == "__main__":
    test_file = "C:/Users/Platypus/Documents/CellNet/Flow prior to chemical stimulation_Figure6C.tif"
    test_stack = imread(test_file)
    half_len = len(test_stack) // 2
    test_stack = test_stack[:half_len]
    test_seqs = load_tif_sequences_from_stack(test_stack, window=CONFIG["window"])

    # Files for training and validation (excluding the test file)
    train_files = [
        "C:/Users/Platypus/Documents/CellNet/Real_Time_CS_Experiment-1093.tif",
        "C:/Users/Platypus/Documents/CellNet/Figure8.tif",
        "C:/Users/Platypus/Documents/CellNet/5uM_per_litre_Figure6_ChemicalStimulation.tif",
        "C:/Users/Platypus/Documents/CellNet/Cell Knocked_Figure7.tif"
    ]

    # Load train+val sequences using original function
    train_val_seqs = []
    for path in train_files:
        train_val_seqs.extend(load_tif_sequences(path, window=CONFIG["window"]))

    # Split train and val
    train_seqs, val_seqs = train_test_split(train_val_seqs, test_size=0.2, random_state=42)

    # Create loaders
    train_loader = DataLoader(PatchSequenceDataset(train_seqs, shape=CONFIG["resize_shape"]), batch_size=8, shuffle=True)
    val_loader = DataLoader(PatchSequenceDataset(val_seqs, shape=CONFIG["resize_shape"]), batch_size=8)
    test_loader = DataLoader(PatchSequenceDataset(test_seqs, shape=CONFIG["resize_shape"]), batch_size=8)



    model = LSTM_baseline(image_size=CONFIG["resize_shape"]).to(CONFIG["device"])
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.L1Loss()

    best_val = float('inf')
    patience = 5
    counter = 0

    for epoch in range(50):
        train_loss = train_lstm_baseline(model, train_loader, optimizer, loss_fn, CONFIG["device"])
        val_loss = evaluate_lstm_baseline(model, val_loader, loss_fn, CONFIG["device"])
        print(f"Epoch {epoch+1} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

        if val_loss < best_val:
            best_val = val_loss
            torch.save(model.state_dict(), "best_lstm_patch.pt")
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                print("Early stopping triggered.")
                break

    model.load_state_dict(torch.load("best_lstm_patch.pt"))
    test_loss, preds, targets = evaluate_lstm_baseline(model, test_loader, loss_fn, CONFIG["device"], return_preds=True)

    mae, corr = compute_metrics(preds, targets)

    print(f"\n✅ Final LSTM_baseline Patch Model:")
    print(f"  - L1 Loss: {test_loss:.4f}")
    print(f"  - MAE:     {mae:.4f}")
    print(f"  - Corr:    {corr:.4f}")

    pred_var = np.var(np.stack(preds), axis=0).mean()
    print(f"🔍 Mean variance across predicted pixels: {pred_var:.6f}")

    for i in range(3):
        try:
            visualize_preds(preds, targets, shape=CONFIG["resize_shape"], save_dir=CONFIG["save_vis_dir"])
        except Exception as e:
            print(f"❌ Error saving frame {i}: {e}")


  return F.l1_loss(input, target, reduction=self.reduction)
  return F.l1_loss(input, target, reduction=self.reduction)
  return F.l1_loss(input, target, reduction=self.reduction)


Epoch 1 | Train Loss: 0.0822 | Val Loss: 0.0199
Epoch 2 | Train Loss: 0.0232 | Val Loss: 0.0270
Epoch 3 | Train Loss: 0.0213 | Val Loss: 0.0221
Epoch 4 | Train Loss: 0.0198 | Val Loss: 0.0198
Epoch 5 | Train Loss: 0.0167 | Val Loss: 0.0217
Epoch 6 | Train Loss: 0.0178 | Val Loss: 0.0139
Epoch 7 | Train Loss: 0.0138 | Val Loss: 0.0158
Epoch 8 | Train Loss: 0.0129 | Val Loss: 0.0118
Epoch 9 | Train Loss: 0.0123 | Val Loss: 0.0123
Epoch 10 | Train Loss: 0.0131 | Val Loss: 0.0139
Epoch 11 | Train Loss: 0.0106 | Val Loss: 0.0092
Epoch 12 | Train Loss: 0.0111 | Val Loss: 0.0114
Epoch 13 | Train Loss: 0.0107 | Val Loss: 0.0096
Epoch 14 | Train Loss: 0.0099 | Val Loss: 0.0098
Epoch 15 | Train Loss: 0.0102 | Val Loss: 0.0096
Epoch 16 | Train Loss: 0.0092 | Val Loss: 0.0081
Epoch 17 | Train Loss: 0.0086 | Val Loss: 0.0087
Epoch 18 | Train Loss: 0.0091 | Val Loss: 0.0103
Epoch 19 | Train Loss: 0.0087 | Val Loss: 0.0107
Epoch 20 | Train Loss: 0.0088 | Val Loss: 0.0109
Epoch 21 | Train Loss: 0.0085

  model.load_state_dict(torch.load("best_lstm_patch.pt"))



✅ Final LSTM_baseline Patch Model:
  - L1 Loss: 0.0076
  - MAE:     0.0076
  - Corr:    0.8750
🔍 Mean variance across predicted pixels: 0.000597
✅ Saved: C:\Users\Platypus\Desktop\LSTM_Predictions\LSTM_baseline_pred_0_t0.png
✅ Saved: C:\Users\Platypus\Desktop\LSTM_Predictions\LSTM_baseline_pred_0_t1.png
✅ Saved: C:\Users\Platypus\Desktop\LSTM_Predictions\LSTM_baseline_pred_0_t2.png
✅ Saved: C:\Users\Platypus\Desktop\LSTM_Predictions\LSTM_baseline_pred_0_t3.png
✅ Saved: C:\Users\Platypus\Desktop\LSTM_Predictions\LSTM_baseline_pred_0_t4.png
✅ Saved: C:\Users\Platypus\Desktop\LSTM_Predictions\LSTM_baseline_pred_1_t0.png
✅ Saved: C:\Users\Platypus\Desktop\LSTM_Predictions\LSTM_baseline_pred_1_t1.png
✅ Saved: C:\Users\Platypus\Desktop\LSTM_Predictions\LSTM_baseline_pred_1_t2.png
✅ Saved: C:\Users\Platypus\Desktop\LSTM_Predictions\LSTM_baseline_pred_1_t3.png
✅ Saved: C:\Users\Platypus\Desktop\LSTM_Predictions\LSTM_baseline_pred_1_t4.png
✅ Saved: C:\Users\Platypus\Desktop\LSTM_Predictions\LS

In [26]:
import os
import matplotlib.pyplot as plt
import numpy as np

def stitch_lstm_predictions(image_paths, save_path, figsize=(16, 4)):
    """
    Safely stitches available images horizontally. Skips missing ones.
    """
    imgs = []
    valid_paths = []

    for p in image_paths:
        if os.path.exists(p):
            imgs.append(plt.imread(p))
            valid_paths.append(p)
        else:
            print(f"⚠️ Missing file: {p}")

    if not imgs:
        print("❌ No valid images to stitch.")
        return

    fig, axes = plt.subplots(1, len(imgs), figsize=figsize)
    if len(imgs) == 1:
        axes = [axes]  # Make it iterable if only one subplot

    for ax, img, path in zip(axes, imgs, valid_paths):
        ax.imshow(img)
        ax.axis("off")
        title = os.path.basename(path).split("_t")[-1].replace(".png", "")
        ax.set_title(f"t+{title}", fontsize=10)

    plt.tight_layout()
    fig.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"✅ Saved stitched image to: {save_path}")


In [27]:
import os

existing_files = sorted([
    f for f in os.listdir(CONFIG["save_vis_dir"]) if f.startswith("LSTM_baseline_pred_0")
])
print("✅ Existing LSTM frame files:\n", existing_files)


✅ Existing LSTM frame files:
 ['LSTM_baseline_pred_0_t0.png', 'LSTM_baseline_pred_0_t1.png', 'LSTM_baseline_pred_0_t2.png', 'LSTM_baseline_pred_0_t3.png', 'LSTM_baseline_pred_0_t4.png']


In [28]:
prefix = "LSTM_baseline_pred_0_t"
image_paths = [os.path.join(CONFIG["save_vis_dir"], f"{prefix}{t}.png") for t in range(5)]
save_path = os.path.join(CONFIG["save_vis_dir"], "stitched_LSTM_pred_0.png")
stitch_lstm_predictions(image_paths, save_path)


✅ Saved stitched image to: C:\Users\Platypus\Desktop\LSTM_Predictions\stitched_LSTM_pred_0.png


In [29]:
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import gridspec
import numpy as np
import os
import torch

# Matplotlib visual style (clean, white background, consistent colorbars)
mpl.rcParams.update({
    'figure.facecolor': 'white',
    'axes.facecolor': 'white',
    'axes.edgecolor': 'black',
    'axes.labelcolor': 'black',
    'xtick.color': 'black',
    'ytick.color': 'black',
    'text.color': 'black',
    'savefig.facecolor': 'white',
    'savefig.edgecolor': 'white'
})

# Normalize image for display [0, 1] range
def normalize_for_display(img):
    img_min = np.min(img)
    img_max = np.max(img)
    if img_max - img_min < 1e-6:
        return np.zeros_like(img)
    return (img - img_min) / (img_max - img_min)

# Run after model is trained and test_ds is available
model.eval()
with torch.no_grad():
    for i in range(3):  # For 3 test samples
        x, y = test_loader.dataset[i]
        x = x.unsqueeze(0).to(CONFIG["device"])
        pred = model(x).cpu().numpy()[0]   # shape: (5, H, W)
        true = y.numpy()                   # shape: (5, H, W)

        preds_vis = [normalize_for_display(pred[t]) for t in range(5)]
        trues_vis = [normalize_for_display(true[t]) for t in range(5)]
        errors_vis = [normalize_for_display(np.abs(pred[t] - true[t])) for t in range(5)]

        fig = plt.figure(figsize=(12, 15))
        spec = gridspec.GridSpec(nrows=5, ncols=6, width_ratios=[1, 0.05, 1, 0.05, 1, 0.05], wspace=0.4, hspace=0.3)

        ims = []
        for t in range(5):
            # Predicted
            ax_pred = fig.add_subplot(spec[t, 0])
            im_pred = ax_pred.imshow(preds_vis[t], cmap='inferno', vmin=0, vmax=1)
            ax_pred.set_title(f"Predicted t+{t+1}")
            ax_pred.axis('off')
            if t == 0: ims.append(im_pred)

            # Ground Truth
            ax_gt = fig.add_subplot(spec[t, 2])
            im_gt = ax_gt.imshow(trues_vis[t], cmap='inferno', vmin=0, vmax=1)
            ax_gt.set_title(f"Ground Truth t+{t+1}")
            ax_gt.axis('off')
            if t == 0: ims.append(im_gt)

            # Absolute Error
            ax_err = fig.add_subplot(spec[t, 4])
            im_err = ax_err.imshow(errors_vis[t], cmap='hot', vmin=0, vmax=1)
            ax_err.set_title("Abs Error")
            ax_err.axis('off')
            if t == 0: ims.append(im_err)

        # Shared colorbars
        cbar_pred = fig.add_subplot(spec[:, 1])
        fig.colorbar(ims[0], cax=cbar_pred)
        cbar_pred.set_title("Pred", fontsize=10)

        cbar_gt = fig.add_subplot(spec[:, 3])
        fig.colorbar(ims[1], cax=cbar_gt)
        cbar_gt.set_title("Truth", fontsize=10)

        cbar_err = fig.add_subplot(spec[:, 5])
        fig.colorbar(ims[2], cax=cbar_err)
        cbar_err.set_title("Error", fontsize=10)

        # Save
        save_path = os.path.join(CONFIG["save_vis_dir"], f"stacked_sample{i}_with_colorbars.png")
        plt.savefig(save_path, bbox_inches='tight', pad_inches=0.2, dpi=300)
        plt.close()

        print(f"✅ Saved: {save_path}")


✅ Saved: C:\Users\Platypus\Desktop\LSTM_Predictions\stacked_sample0_with_colorbars.png
✅ Saved: C:\Users\Platypus\Desktop\LSTM_Predictions\stacked_sample1_with_colorbars.png
✅ Saved: C:\Users\Platypus\Desktop\LSTM_Predictions\stacked_sample2_with_colorbars.png


In [None]:
# Imports
import os
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tifffile import imread
from torch.utils.data import Dataset, DataLoader
from skimage.transform import resize
from scipy.ndimage import median_filter
from skimage.filters import gaussian
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error
from scipy.stats import pearsonr

# Config
CONFIG = {
    "resize_shape": (64, 64),
    "window": 10,
    "future_steps": 5,
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "save_vis_dir": os.path.expanduser("~/Desktop/LSTM_Predictions")
}
os.makedirs(CONFIG["save_vis_dir"], exist_ok=True)

# Preprocessing
def denoise_gaussian_median(frame, median_size=3, gaussian_sigma=1.0):
    return gaussian(median_filter(frame, size=median_size), sigma=gaussian_sigma)

def load_tif_sequences_from_stack(stack, window=10, future_steps=1):
    stack_denoised = np.array([denoise_gaussian_median(f) for f in stack])
    stack_resized = np.array([resize(f, CONFIG["resize_shape"]) for f in stack_denoised])
    norm_stack = (stack_resized - stack_resized.min()) / (stack_resized.max() - stack_resized.min())
    norm_stack = norm_stack * 2 - 1

    sequences = []
    for i in range(0, len(norm_stack) - window - future_steps + 1):
        seq = norm_stack[i:i+window]
        target = norm_stack[i+window:i+window+future_steps]
        sequences.append((seq, target))
    return sequences

# Dataset
class PatchSequenceDataset(Dataset):
    def __init__(self, sequences, shape=(64, 64)):
        self.sequences = sequences
        self.shape = shape
        self.pos_enc = self._generate_positional_encodings(shape)

    def _generate_positional_encodings(self, shape):
        H, W = shape
        grid_y, grid_x = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij')
        pos = torch.stack((grid_y, grid_x), dim=-1).float()
        pos = pos / torch.tensor([H - 1, W - 1])
        return pos

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq, target = self.sequences[idx]
        T = seq.shape[0]
        F = target.shape[0]
        H, W = self.shape
        pos_enc = self.pos_enc.permute(2, 0, 1).unsqueeze(0).repeat(T, 1, 1, 1)
        seq = torch.tensor(seq.reshape(T, H, W), dtype=torch.float32).unsqueeze(1)
        seq_with_pos = torch.cat([seq, pos_enc], dim=1)
        return seq_with_pos, torch.tensor(target.reshape(F, H, W), dtype=torch.float32)

# Model
class LSTM_baseline(nn.Module):
    def __init__(self, patch_size=4, image_size=(64, 64), hidden_dim=256, lstm_layers=2, dropout=0.1, future_steps=5):
        super().__init__()
        self.patch_size = patch_size
        self.image_size = image_size
        self.future_steps = future_steps
        self.num_patches = (image_size[0] // patch_size) * (image_size[1] // patch_size)
        self.input_dim = patch_size * patch_size + 2 * patch_size * patch_size

        self.encoder = nn.Sequential(
            nn.Linear(self.input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64)
        )

        self.lstm = nn.LSTM(
            input_size=64,
            hidden_size=hidden_dim,
            num_layers=lstm_layers,
            batch_first=True,
            dropout=dropout,
            bidirectional=False
        )

        self.attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=4, batch_first=True)

        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(),
            nn.Linear(64, self.input_dim * future_steps)
        )

    def forward(self, x):
        B, T, C, H, W = x.size()
        P = self.patch_size

        x = x.view(B, T, C, H, W)
        patches = x.unfold(3, P, P).unfold(4, P, P)
        patches = patches.permute(0, 1, 3, 4, 2, 5, 6).contiguous()
        patches = patches.view(B, T, -1, C * P * P)

        outputs = []
        for p in range(patches.size(2)):
            patch_seq = patches[:, :, p, :]
            encoded = self.encoder(patch_seq.view(-1, self.input_dim)).view(B, T, -1)
            lstm_out, _ = self.lstm(encoded)
            attn_out, _ = self.attn(lstm_out, lstm_out, lstm_out)
            decoded = self.decoder(attn_out[:, -1, :])
            decoded = decoded.view(B, self.future_steps, self.input_dim)
            outputs.append(decoded)

        outputs = torch.stack(outputs, dim=2)
        outputs = outputs.view(B, self.future_steps, H // P, W // P, C, P, P)
        outputs = outputs[:, :, :, :, 0, :, :]
        outputs = outputs.permute(0, 1, 2, 4, 3, 5).contiguous().view(B, self.future_steps, H, W)
        return outputs

# Evaluation
def evaluate_lstm_baseline(model, loader, loss_fn, device, return_preds=False):
    model.eval()
    total_loss = 0
    preds, targets = [], []
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = loss_fn(out, y)
            total_loss += loss.item()
            if return_preds:
                preds.extend(out.cpu().numpy())
                targets.extend(y.cpu().numpy())
    if return_preds:
        return total_loss / len(loader), preds, targets
    return total_loss / len(loader)




  model.load_state_dict(torch.load("best_lstm_patch.pt"))


✅ Saved: C:\Users\Platypus/Desktop/LSTM_Predictions\sample0_t0_triple.png
✅ Saved: C:\Users\Platypus/Desktop/LSTM_Predictions\sample0_t1_triple.png
✅ Saved: C:\Users\Platypus/Desktop/LSTM_Predictions\sample0_t2_triple.png
✅ Saved: C:\Users\Platypus/Desktop/LSTM_Predictions\sample0_t3_triple.png
✅ Saved: C:\Users\Platypus/Desktop/LSTM_Predictions\sample0_t4_triple.png


In [9]:
# Fix: normalize predictions from [-1, 1] to [0, 1]
def normalize_for_display(img):
    img = np.clip((img + 1) / 2, 0, 1)  # Safe normalization
    return img

def plot_lstm_frame_with_colorbars(pred, true, save_path, t_idx=0):
    pred = normalize_for_display(pred)
    true = normalize_for_display(true)
    error = np.abs(pred - true)

    fig, axes = plt.subplots(1, 3, figsize=(14, 4))
    vmin, vmax = 0, 1

    im0 = axes[0].imshow(pred, cmap="inferno", vmin=vmin, vmax=vmax)
    axes[0].set_title(f"Predicted t+{t_idx+1}")
    axes[0].axis("off")
    plt.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04)

    im1 = axes[1].imshow(true, cmap="inferno", vmin=vmin, vmax=vmax)
    axes[1].set_title(f"Ground Truth t+{t_idx+1}")
    axes[1].axis("off")
    plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)

    im2 = axes[2].imshow(error, cmap="hot", vmin=0, vmax=1)
    axes[2].set_title("Abs Error")
    axes[2].axis("off")
    plt.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04)

    plt.tight_layout()
    fig.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"✅ Saved: {save_path}")


In [10]:
sample_idx = 0
for t in range(5):
    pred_frame = preds[sample_idx][t]
    true_frame = targets[sample_idx][t]
    out_path = os.path.join(CONFIG["save_vis_dir"], f"sample{sample_idx}_t{t}_triple.png")
    plot_lstm_frame_with_colorbars(pred_frame, true_frame, out_path, t_idx=t)


✅ Saved: C:\Users\Platypus/Desktop/LSTM_Predictions\sample0_t0_triple.png
✅ Saved: C:\Users\Platypus/Desktop/LSTM_Predictions\sample0_t1_triple.png
✅ Saved: C:\Users\Platypus/Desktop/LSTM_Predictions\sample0_t2_triple.png
✅ Saved: C:\Users\Platypus/Desktop/LSTM_Predictions\sample0_t3_triple.png
✅ Saved: C:\Users\Platypus/Desktop/LSTM_Predictions\sample0_t4_triple.png
