In [3]:
import torch
import torch.nn as nn

class DeeperCNN(nn.Module):
    def __init__(self):
        super(DeeperCNN, self).__init__()

        self.fc1 = nn.Linear(5, 128)
        self.fc2 = nn.Linear(128, 512)
        self.fc3 = nn.Linear(512, 30 * 64 * 64)

        self.conv1 = nn.Conv2d(30, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(32, 30, kernel_size=3, padding=1)

        self.upsample = nn.Upsample(size=(250, 250), mode='bilinear', align_corners=True)

    def forward(self, x, capture_features=False):
        features = []

        # — Global MLP projection to (batch,30*64*64)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = x.view(-1,30,64,64)

        # — Capture the *true* time‐frames (coarse)
        if capture_features:
            features.append(x.clone().detach())

        # — Then your convolutions (abstract feature‐maps)
        x = F.relu(self.conv1(x))
        if capture_features: features.append(x.clone().detach())

        x = F.relu(self.conv2(x))
        if capture_features: features.append(x.clone().detach())

        x = F.relu(self.conv3(x))
        if capture_features: features.append(x.clone().detach())

        x = F.relu(self.conv4(x))
        if capture_features: features.append(x.clone().detach())

        x = F.relu(self.conv5(x))
        if capture_features: features.append(x.clone().detach())

        # — Upsample
        x = self.upsample(x)

        if capture_features:
            return x, features
        else:
            return x



In [None]:
# === Imports ===
import torch
import torch.nn.functional as F        # ← add this line
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, FloatSlider, IntSlider
from IPython.display import display

# === Model Import ===

# === Configuration ===
MODEL_PATH = r"C:\Users\Ali\Desktop\798 Project\visualization_samples_30frames_CNN_DEEP\best_model_30frames_200Epochs.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === Load Model ===
model = DeeperCNN().to(DEVICE)  
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# === Metrics Function ===
def compute_metrics(final_frame):
    solid_fraction = final_frame.mean()
    gx, gy = np.gradient(final_frame)
    grad_mag = np.sqrt(gx**2 + gy**2)
    branchiness = grad_mag.sum()
    return solid_fraction, branchiness

# === Inference + Plotting Function ===
def predict_and_plot(ΔT0, c, N, θ, r0):
    inp = torch.tensor([[ΔT0, c, N, θ, r0]], dtype=torch.float32).to(DEVICE)
    with torch.no_grad():
        pred = model(inp)                # [1, 30, H, W]
        final_frame = pred[0, -1].cpu().numpy()

    sf, br = compute_metrics(final_frame)

    plt.figure(figsize=(5, 5))
    plt.imshow(final_frame, cmap='plasma', origin='lower', vmin=0, vmax=1)
    plt.title(
        f"ΔT₀={ΔT0:.2f}, c={c:.3f}, N={N}, θ={int(θ)}°, r₀={r0:.3f}\n"
        f"Solid Fraction = {sf:.3f}, Branchiness = {br:.1f}"
    )
    plt.axis('off')
    plt.tight_layout()
    plt.show()

# === Interactive Sliders ===
interact(
    predict_and_plot,
    ΔT0=FloatSlider(min=-1, max=0, step=0.01, value=-0.3, description='ΔT₀'),
    c=FloatSlider(min=0.01, max=0.1, step=0.005, value=0.03, description='c'),
    N=IntSlider(min=1, max=10, step=1, value=5, description='N'),
    θ=FloatSlider(min=0, max=90, step=1, value=45, description='θ'),
    r0=FloatSlider(min=0.05, max=0.2, step=0.005, value=0.2, description='r₀')
)

#   param_grid = {
 #       "dT0": [-0.2, -0.4, -0.6, -0.8],
  #      "c": [0.005, 0.02, 0.05],
   ##     "N": [4, 6, 8],
    #    "theta_deg": [0, 15, 30, 45],
    #   "seed_radius": [0.08, 0.1, 0.15]
    #}"


interactive(children=(FloatSlider(value=-0.3, description='ΔT₀', max=0.0, min=-1.0, step=0.01), FloatSlider(va…

<function __main__.predict_and_plot(ΔT0, c, N, θ, r0)>