**CNN MODEL and Training**

DATA SET LOADER: Same one as the one before, this one loads the training set: Please change the directory to where the splitted files are

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os

# === Configuration ===
DATASET_DIR = rDATASET_DIR = r"C:\Users\Ali\Desktop\798 Project\Splitting 30 Frames"
 # Adjust if needed

class DendriteDataset(Dataset):
    def __init__(self, split="train"):
        assert split in ["train", "val", "test"], "Split must be 'train', 'val', or 'test'."

        self.X = np.load(os.path.join(DATASET_DIR, f"X_{split}.npy"))
        self.Y = np.load(os.path.join(DATASET_DIR, f"Y_{split}.npy"))

        self.X = torch.tensor(self.X, dtype=torch.float32)
        self.Y = torch.tensor(self.Y, dtype=torch.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x = self.X[idx]           # Shape: (5,)
        y = self.Y[idx]           # Shape: (41, 250, 250)
        return x, y

# === Usage example ===
if __name__ == "__main__":
    # Create datasets
    train_dataset = DendriteDataset(split="train")
    val_dataset = DendriteDataset(split="val")
    test_dataset = DendriteDataset(split="test")

    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

    # Quick check
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        print(f"Input batch shape: {inputs.shape}")   # Should be (batch_size, 5)
        print(f"Target batch shape: {targets.shape}") # Should be (batch_size, 41, 250, 250)
        break


Input batch shape: torch.Size([8, 5])
Target batch shape: torch.Size([8, 30, 250, 250])


**CNN Class:** 5 layers

In [2]:
import torch
import torch.nn as nn

class DeeperCNN(nn.Module):
    def __init__(self):
        super(DeeperCNN, self).__init__()

        self.fc1 = nn.Linear(5, 128)
        self.fc2 = nn.Linear(128, 512)
        self.fc3 = nn.Linear(512, 30 * 64 * 64)

        self.conv1 = nn.Conv2d(30, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(32, 30, kernel_size=3, padding=1)

        self.upsample = nn.Upsample(size=(250, 250), mode='bilinear', align_corners=True)

    def forward(self, x, capture_features=False):
        features = []

        # — Global MLP projection to (batch,30*64*64)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = x.view(-1,30,64,64)

        # — Capture the *true* time‐frames (coarse)
        if capture_features:
            features.append(x.clone().detach())

        # — Then your convolutions (abstract feature‐maps)
        x = F.relu(self.conv1(x))
        if capture_features: features.append(x.clone().detach())

        x = F.relu(self.conv2(x))
        if capture_features: features.append(x.clone().detach())

        x = F.relu(self.conv3(x))
        if capture_features: features.append(x.clone().detach())

        x = F.relu(self.conv4(x))
        if capture_features: features.append(x.clone().detach())

        x = F.relu(self.conv5(x))
        if capture_features: features.append(x.clone().detach())

        # — Upsample
        x = self.upsample(x)

        if capture_features:
            return x, features
        else:
            return x



Data Training: 200 epochs and chooses the one with the lowest MSE

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import os

# === Configuration ===
DATASET_DIR = r"C:\Users\Ali\Desktop\798 Project\Splitting 30 Frames"
BATCH_SIZE = 8
LEARNING_RATE = 1e-3
EPOCHS = 200  # 🔥 Increased to 200 epochs
SAVE_MODEL_PATH = "best_model_30frames.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === Load datasets ===
train_dataset = DendriteDataset(split="train")
val_dataset = DendriteDataset(split="val")

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# === Initialize model, loss, optimizer ===
model = DeeperCNN().to(DEVICE)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 🔥 Add learning rate scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
# === Training loop ===
best_val_loss = float('inf')

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0.0

    for inputs, targets in train_loader:
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(inputs)

        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * inputs.size(0)

    train_loss /= len(train_loader.dataset)

    # === Validation ===
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item() * inputs.size(0)

    val_loss /= len(val_loader.dataset)

    print(f"Epoch [{epoch+1}/{EPOCHS}] - Train Loss: {train_loss:.6f} - Val Loss: {val_loss:.6f}")

    # === Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), SAVE_MODEL_PATH)
        print(f"✅ New best model saved at epoch {epoch+1}!")

    # 🔥 Step the scheduler after every epoch
    scheduler.step()

print("✅ Training complete!")


**EVALUATION!**: Calculates MSE on evaluation and plots/saves them


In [14]:
# evaluation_cnn.py
import torch
import torch.nn as nn
import torch.nn.functional as F  # ✅ Add this line
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os


# === Configurations ===
DATASET_DIR = r"C:\Users\Ali\Desktop\798 Project\Splitting 30 Frames"
BATCH_SIZE = 8
MODEL_PATH = r"C:\Users\Ali\Desktop\798 Project\visualization_samples_30frames_CNN_DEEP\best_model_30frames_200Epochs.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAVE_PLOTS_DIR = "evaluation_plots"

# === Load test dataset ===
test_dataset = DendriteDataset(split="test")
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# === Load model ===
model = DeeperCNN().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# === Define loss function ===
criterion = nn.MSELoss()

# === Evaluation ===
test_loss = 0.0

os.makedirs(SAVE_PLOTS_DIR, exist_ok=True)

with torch.no_grad():
    for batch_idx, (inputs, targets) in enumerate(test_loader):
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        outputs = model(inputs)

        loss = criterion(outputs, targets)
        test_loss += loss.item() * inputs.size(0)

        # Save a few sample plots
        if batch_idx < 3:  # Save first 3 batches (can adjust)
            for i in range(min(inputs.size(0), 3)):  # Save 3 samples per batch
                pred = outputs[i, -1, :, :].cpu().numpy()
                true = targets[i, -1, :, :].cpu().numpy()

                fig, axes = plt.subplots(1, 2, figsize=(10, 5))
                axes[0].imshow(true, cmap="plasma", origin="lower")
                axes[0].set_title("Ground Truth")
                axes[1].imshow(pred, cmap="plasma", origin="lower")
                axes[1].set_title("Prediction")
                plt.suptitle(f"Sample {batch_idx * BATCH_SIZE + i}")
                plt.savefig(os.path.join(SAVE_PLOTS_DIR, f"sample_{batch_idx * BATCH_SIZE + i}.png"))
                plt.close()

test_loss /= len(test_loader.dataset)
print(f"✅ Final Test MSE Loss: {test_loss:.6f}")


✅ Final Test MSE Loss: 0.017701


**VISUALIZATION:**

In [15]:
# visualize_predictions_30frames.py

import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os

# === Configurations ===
DATASET_DIR = r"C:\Users\Ali\Desktop\798 Project\Splitting 30 Frames"
BATCH_SIZE = 1  # Plot one sample at a time
MODEL_PATH = r"C:\Users\Ali\Desktop\798 Project\visualization_samples_30frames_CNN_DEEP\best_model_30frames_200Epochs.pth"  # Update if you saved with a different name
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAVE_DIR = "visualization_samples_30frames_deepCNN"

# === Prepare directories ===
os.makedirs(SAVE_DIR, exist_ok=True)

# === Load dataset and model ===
test_dataset = DendriteDataset(split="test")
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

model = DeeperCNN().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# === Visualize few samples ===
frames_to_plot = [0, 5, 10, 20, 29]  # 0 to 29 now (not 39 anymore)

with torch.no_grad():
    for batch_idx, (inputs, targets) in enumerate(test_loader):
        if batch_idx >= 10:  # Plot 3 samples only
            break

        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        outputs = model(inputs)

        input_params = inputs.cpu().numpy()[0]
        pred = outputs.squeeze(0).cpu().numpy()   # (30, 250, 250)
        true = targets.squeeze(0).cpu().numpy()    # (30, 250, 250)

        # Plot selected frames
        fig, axs = plt.subplots(len(frames_to_plot), 2, figsize=(8, 2.5 * len(frames_to_plot)))

        for idx, t in enumerate(frames_to_plot):
            axs[idx, 0].imshow(true[t], cmap="plasma", origin="lower", vmin=0, vmax=1)
            axs[idx, 0].set_title(f"Ground Truth (t={t})")
            axs[idx, 0].axis('off')

            axs[idx, 1].imshow(pred[t], cmap="plasma", origin="lower", vmin=0, vmax=1)
            axs[idx, 1].set_title(f"Prediction (t={t})")
            axs[idx, 1].axis('off')

        plt.suptitle(f"Sample {batch_idx} | Input: {input_params}", fontsize=12)
        plt.tight_layout()
        plt.savefig(os.path.join(SAVE_DIR, f"sample_{batch_idx}.png"))
        plt.close()

print(f"✅ Visualization saved in {SAVE_DIR}/")


✅ Visualization saved in visualization_samples_30frames_deepCNN/


VIDEOS!

In [None]:

import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import imageio
from tqdm import tqdm
from torch.utils.data import DataLoader

from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas

# === Config ===
SAVE_DIR = "growth_videos"
os.makedirs(SAVE_DIR, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Update the MODEL_PATH to your actual model path
MODEL_PATH = r"C:\Users\Ali\Desktop\798 Project\visualization_samples_30frames_CNN_DEEP\best_model_30frames_200Epochs.pth"

# === Load model and dataset ===
model = DeeperCNN().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

test_dataset = DendriteDataset(split="test")
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# === Loop through a few samples ===
for idx, (inputs, targets) in enumerate(tqdm(test_loader)):
    if idx >= 5:
        break

    inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
    with torch.no_grad():
        outputs = model(inputs)

    true = targets.squeeze(0).cpu().numpy()
    pred = outputs.squeeze(0).cpu().numpy()

    frames = []
    for t in range(true.shape[0]):
        fig, axs = plt.subplots(1, 2, figsize=(6, 3))
        canvas = FigureCanvas(fig)

        axs[0].imshow(true[t], cmap='plasma', origin='lower', vmin=0, vmax=1)
        axs[0].set_title(f"Ground Truth t={t}")
        axs[0].axis('off')

        axs[1].imshow(pred[t], cmap='plasma', origin='lower', vmin=0, vmax=1)
        axs[1].set_title(f"Prediction t={t}")
        axs[1].axis('off')

        canvas.draw()
        # Use buffer_rgba to get the image data
        w, h = canvas.get_width_height()
        buf = canvas.buffer_rgba()
        img_arr = np.frombuffer(buf, dtype=np.uint8).reshape(h, w, 4)
        # Drop alpha channel
        rgb = img_arr[:, :, :3]
        frames.append(rgb)
        plt.close(fig)

    gif_path = os.path.join(SAVE_DIR, f"sample_{idx}_fixed.gif")
    imageio.mimsave(gif_path, frames, fps=5)
    print(f"🎥 Saved: {gif_path}")
