# Combined Trainer

In [None]:
import pkg_resources

def is_package_installed(package_name):
    try:
        return pkg_resources.get_distribution(package_name)
        # return True
    except pkg_resources.DistributionNotFound:
        return False

# Example usage
print(is_package_installed('fluid_simulation'))  # True if numpy is installed, False otherwise

In [None]:
from torch.optim.lr_scheduler import StepLR

## Setup

In [None]:
def train_model(model, dataloader, num_epochs=10, learning_rate=0.001, device=None, model_type="gnn"):
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    # Add LR scheduler
    scheduler = StepLR(optimizer, step_size=4, gamma=0.92)  # Reduce LR by factor of 0.1 every 5 epochs
    
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    print(f'Starting training: {next(model.parameters()).device}')
    
    for epoch in range(num_epochs):
        model.train()  # Set model to training mode
        running_loss = 0.0
        i = 0
        for batch in dataloader:
            if model_type == "cnn":
                inputs, targets = batch
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(targets, outputs)
            elif model_type == "gnn":
                batch = batch.to(device)
                out = model(batch.x, batch.edge_index, batch.edge_attr, batch.edge_distance).squeeze()
                # Assume targets are binary (0 or 1)
                loss = criterion(batch.y, out)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            # if i % 20 == 0:
            #     print(f"Batch {i} loss: {loss.item()}")
            #     #print(batch)
            # i += 1
        
        # Step the scheduler
        scheduler.step()
        
        epoch_loss = running_loss / len(dataloader)
        # if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, LR: {scheduler.get_last_lr()[0]:.6f}")
    
    print('Finished training')
    return model

In [None]:
import os
import sys
# sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))

In [None]:
import pandas as pd
import torch
import torch.nn
import os

In [None]:
from fluid_simulation.gnn_torch_only import GridGNNWithAngles
from fluid_simulation.utils import create_grid_graph_with_angles

In [None]:
width, height = 10, 10  # Grid dimensions
data = create_grid_graph_with_angles(width, height)
num_nodes = data['num_nodes']
edge_index = data['edge_index']
edge_attr = data['edge_attr']
edge_distance = data['edge_distance']

print(edge_index.shape)

# Node features: Let's use random features for demonstration
x = torch.randn(num_nodes, 8)  # 16 input features per node

# Initialize the GNN model
model = GridGNNWithAngles(
    in_channels=8,
    hidden_channels=32,
    out_channels=8,  # For example, binary classification
    num_layers=3,
    use_angle=True,
    use_target_node_feat=True
)
print(model)

print(sum(p.numel() for p in model.parameters() if p.requires_grad))

# Forward pass
out = model(x, edge_index, edge_attr, edge_distance)
print(out.shape)  # Should be [num_nodes, out_channels]

print('--smaller net--')
# Initialize the GNN model
x = torch.randn(num_nodes, 8)  # 16 input features per node
model = GridGNNWithAngles(
    in_channels=8,
    hidden_channels=64,
    out_channels=8,  # For example, binary classification
    num_layers=1,
    use_angle=True,
    use_target_node_feat=False
)
print(model)

print(sum(p.numel() for p in model.parameters() if p.requires_grad))

# Forward pass
out = model(x, edge_index, edge_attr, edge_distance)
print(out.shape)  # Should be [num_nodes, out_channels]

In [None]:
csv_file = '../../../data/combined_data_with_deltas.csv'
# Load your data into a DataFrame
timestep_n_rows = 51_200 # actually this is 2 timesteps
n_steps = 40
df = pd.read_csv(csv_file, nrows=timestep_n_rows * n_steps)
df2 = pd.read_csv(csv_file, nrows=timestep_n_rows * n_steps, skiprows=timestep_n_rows * 125)
# df3 = pd.read_csv(csv_file, nrows=timestep_n_rows * n_steps, skiprows=timestep_n_rows * 125 * 2)

In [None]:
df2.columns = df.columns
# df3.columns = df.columns
print(df.shape)
print(df.shape[0] // 51_200)
df.simulation_id.unique()
# df2.simulation_id.value_counts()
df = pd.concat([df, df2])#, df3])

In [None]:
# Get all float columns
float_columns = df.select_dtypes(include=['float64', 'float32']).columns

# Round only the float columns to 10 decimal places
df[float_columns] = df[float_columns].round(decimals=10)

In [None]:
df.head()

In [None]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

In [None]:
from fluid_simulation.utils import prepare_data

In [None]:
from fluid_simulation.datasets_v2 import GridDatasetGNN, GridDatasetCNN 
from fluid_simulation.models_v2 import CNN, GridGNNWithAngles

In [None]:
import pandas as pd

def calculate_deltas(df):
    # Get all column names
    columns = df.columns
    
    # Find columns with '_next' suffix
    next_columns = [col for col in columns if col.endswith('_next')]
    
    # For each '_next' column, find its counterpart and calculate delta
    for next_col in next_columns:
        base_col = next_col.replace('_next', '')
        
        # Check if the base column exists
        if base_col in columns:
            delta_col = f'delta_{base_col}'
            df[delta_col] = df[next_col] - df[base_col]
    
    return df

# Example usage:
# Assuming you have a DataFrame named 'df'
# df = calculate_deltas(df)

def prepare_data(df, target_pattern="_next", input_pattern_filter="_next", input_pattern_filter_2=None):
    metadata_cols = ['simulation_id', 'timestep', 'row', 'col', 'iter', "time", "pressure", "pressure_next"]
    input_cols = [col for col in df.columns if col not in metadata_cols and not input_pattern_filter in col] + ["border"]
    if input_pattern_filter_2:
        input_cols = [col for col in input_cols if not input_pattern_filter_2 in col]
    target_cols = [col for col in df.columns if target_pattern in col and col.replace(target_pattern, '') not in metadata_cols]
    
    print(f"Input columns: {input_cols}")
    print(f"Target columns: {target_cols}")

    row_max = df.row.max()
    col_max = df.col.max()
    df.loc[:, "border"] = 0.0
    df.loc[df["row"].isin([0, row_max]) | df["col"].isin([0, col_max]), "border"] = 1.0
    
    return df, input_cols, target_cols

In [None]:
use_deltas = True

if use_deltas:
    df = calculate_deltas(df)
    df, input_cols, target_cols = prepare_data(df, target_pattern="delta_", input_pattern_filter_2="delta")
    target_cols
else:
    df, input_cols, target_cols = prepare_data(df, target_pattern="_next", input_pattern_filter_2="delta")
    target_cols

In [None]:
target_cols = [c for c in target_cols if not 'is_fluid' in c]
target_cols

In [None]:
df.head()

## GNN

In [None]:
import torch
import torch.nn as nn
from torch_geometric.data import Data, Dataset, DataLoader
from torch_geometric.nn import MessagePassing
import torch.nn.functional as F
import pandas as pd
import math
import numpy as np
from sklearn.model_selection import train_test_split

### Assert GNN Works

In [None]:
# precomputed_edge_index, precomputed_edge_attr = create_efficient_grid_graph_with_direction_onehot(160, 160)

# Instantiate the dataset
dataset_gnn = GridDatasetGNN(
    df=df, feature_cols=input_cols, target_cols=target_cols, height=160,width=160)

"""
First Layer:
Node B aggregates information from nodes A and C.
Second Layer:
Node B now indirectly aggregates information from nodes A, C, and D (since node C aggregates from B and D in the first layer).
"""
gnn_channels = 16
# HIDDEN CHANNELS ACTUALLY GETS USED FOR BOTH X_i and X_j so it's times 2
model_gnn = GridGNNWithAngles(in_channels=len(input_cols), hidden_channels=gnn_channels, 
                              out_channels=len(target_cols), num_layers=2, use_angle=True, use_target_node_feat=False)

model_gnn = model_gnn.to(device)

# Create DataLoader
from torch_geometric.loader import DataLoader

batch_size = 12
loader_gnn = DataLoader(dataset_gnn, batch_size=batch_size, shuffle=True)

# DataBatch(x=[307200, 6], edge_index=[2, 2434608], edge_attr=[2434608], y=[307200, 5], batch=[307200], ptr=[13])

# Example: Iterate through the DataLoader
for batch in loader_gnn:
    batch = batch.to(device)
    # batch.x: [batch_size * num_nodes, in_features]
    # batch.edge_index: [2, batch_size * 4 * num_nodes]
    # batch.y: [batch_size * num_nodes, target_features]
    print(batch)
    output = model_gnn(batch.x, batch.edge_index, batch.edge_attr, batch.edge_distance)
    # Compute loss, backpropagate, etc.
    print(output.shape)
    break  # Remove this to iterate through the entire dataset

In [None]:
model_gnn

### Model Training

#### IMPORTANT: The number of rows from each simulation ID must be the same

In [None]:
n_epochs = 10

In [None]:
model

In [None]:
sum(p.numel() for p in model_gnn.parameters() if p.requires_grad)

In [None]:
print(device)
gnn_trained = train_model(
    model_gnn, loader_gnn, num_epochs=n_epochs, learning_rate=0.001, device=device, model_type="gnn"
)

In [None]:
torch.save(model_gnn, f'gnn_10-epoch-{gnn_channels}-channels_next.pt')

## CNN

### Assert CNN Works

In [None]:
cnn_dataset = GridDatasetCNN(df, input_cols, target_cols, row_max=None, col_max=None)

In [None]:
# Instantiate the dataset
model_cnn = CNN(in_channels=len(input_cols), hidden_channels=32, out_channels=len(target_cols), num_layers=3)
model_cnn = model_cnn.to(device)

print(next(model_cnn.parameters()).device)
print(sum(p.numel() for p in model_cnn.parameters() if p.requires_grad))

In [None]:
from torch.utils.data import Dataset, DataLoader

batch_size = 8
loader_cnn = DataLoader(cnn_dataset, batch_size=batch_size, shuffle=True)

# Example: Iterate through the DataLoader
for inputs, targets in loader_cnn:
    inputs = inputs.to(device)
    targets = targets.to(device)
    # batch.x: [batch_size * num_nodes, in_features]
    # batch.edge_index: [2, batch_size * 4 * num_nodes]
    # batch.y: [batch_size * num_nodes, target_features]

    print(inputs.shape)
    output = model_cnn(inputs)
    # Compute loss, backpropagate, etc.
    print(output.shape)
    break  # Remove this to iterate through the entire dataset

### Model Training

In [None]:
n_epochs = 15
cnn_trained = train_model(
    model_cnn, loader_cnn, num_epochs=n_epochs, learning_rate=0.001, device=device, model_type="cnn"
)

In [None]:
cnn_dataset[0][0].shape

In [None]:
df[(df.timestep == 0) & (df.col == 0) & (df.row > 80)]

## Simple Predictions

In [None]:
print(next(model_cnn.parameters()).device)
print(next(model_gnn.parameters()).device)

cnn_model = model_cnn
gnn_model = model_gnn

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from torch_geometric.data import Data as GeoData, Batch
from torch_geometric.nn import NNConv
import matplotlib.pyplot as plt

# Set device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")



# -------------------------------
# Initialize Models
# -------------------------------

# Define feature channels
input_cols = ['u', 'v', 'density', 'is_fluid', 'border']
target_cols = ['u_next', 'v_next', 'density_next']

cnn_input_channels = len(input_cols)  # 6
cnn_output_channels = len(target_cols)  # 5

gnn_input_channels = cnn_input_channels
gnn_output_channels = cnn_output_channels
edge_attr_dim = 4  # As per one-hot encoding


# Load pre-trained weights if available
# cnn_model.load_state_dict(torch.load('path_to_cnn_model.pth'))
# gnn_model.load_state_dict(torch.load('path_to_gnn_model.pth'))

cnn_model.eval()
gnn_model.eval()

# -------------------------------
# Prepare Dummy DataFrame (Replace with Actual Data)
# -------------------------------

H, W = 160, 160  # Grid size
num_nodes = H * W

# Example DataFrame structure
data_dict = {
    'simulation_id': np.repeat([0], num_nodes),
    'timestep': np.repeat([0], num_nodes),
    'row': np.tile(np.arange(H), W),
    'col': np.repeat(np.arange(W), H),
    'u': np.zeros(num_nodes),
    'v': np.zeros(num_nodes),
    'density': np.ones(num_nodes) * 0.1,
    #'pressure': np.zeros(num_nodes),
    'is_fluid': np.zeros(num_nodes),
    'border': np.zeros(num_nodes),
    'u_next': np.zeros(num_nodes),
    'v_next': np.zeros(num_nodes),
    'density_next': np.ones(num_nodes) * 0.1,
    #'pressure_next': np.zeros(num_nodes),
    'is_fluid_next': np.ones(num_nodes)
}

data_dict['border'] = ((data_dict['row'] == 0) | (data_dict['col'] == 0) | 
                       (data_dict['row'] == H-1) | (data_dict['col'] == W-1)).astype(int)

data_dict['is_fluid'] = ((data_dict['row'] != 0) | (data_dict['col'] != 0) | 
                       (data_dict['row'] != H-1) | (data_dict['col'] != W-1)).astype(int)

# Set device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# -------------------------------
# Initialize Models
# -------------------------------

input_cols = ['u', 'v', 'density', 'is_fluid', 'border']
target_cols = ['u_next', 'v_next', 'density_next']

cnn_input_channels = len(input_cols)      # 6
cnn_output_channels = len(target_cols)    # 5

gnn_input_channels = cnn_input_channels
gnn_output_channels = cnn_output_channels
edge_attr_dim = 4  # As per one-hot encoding


# Load pre-trained weights if available
# cnn_model.load_state_dict(torch.load('path_to_cnn_model.pth'))
# gnn_model.load_state_dict(torch.load('path_to_gnn_model.pth'))

cnn_model.eval()
gnn_model.eval()

# -------------------------------
# Prepare Dummy DataFrame (Replace with Actual Data)
# -------------------------------

df_simulation = pd.DataFrame(data_dict)

# -------------------------------
# Initialize Data (Without Dataset Class)
# -------------------------------

# Create edge_index and edge_attr
data = create_grid_graph_with_angles(H, W)
num_nodes = data['num_nodes']
edge_index = data['edge_index']
edge_attr = data['edge_attr']
edge_distance = data['edge_distance']

# Extract node features and targets
features = df_simulation[input_cols].values.astype(np.float32).flatten()
features = torch.tensor(features, dtype=torch.float).view(1, len(input_cols), H, W).to(device)  # [1, C, H, W]

# Modify 'u' velocity in specific region
center_row = H // 2
row_start = max(center_row - 3, 0)
row_end = min(center_row + 3, H)
cols_to_modify = [0, 1]

features_np = features.cpu().numpy().copy()
features_np[0, input_cols.index('u'), row_start:row_end, cols_to_modify] = 5.0
features_np[0, input_cols.index('density'), row_start:row_end, cols_to_modify] = 0.8
features = torch.tensor(features_np).to(device)

# Initialize time
current_time = torch.tensor([[0.0]]).float().to(device)  # [1, 1]

density_index = target_cols.index('density_next')
u_index = target_cols.index('u_next')

# -------------------------------
# Define Prediction Functions
# -------------------------------

def create_border_mask(H, W, device):
    mask = torch.zeros(1, H, W, device=device)
    mask[:, 0, :] = 1  # Top border
    mask[:, -1, :] = 1  # Bottom border
    mask[:, :, 0] = 1  # Left border
    mask[:, :, -1] = 1  # Right border
    return mask

def predict_with_cnn(cnn, features):
    with torch.no_grad():
        cnn_output = cnn(features)  # [1, C', H, W]
        
        # Create correct border mask
        border_mask = create_border_mask(H, W, device)
        
        # Append border information to CNN output
        is_fluid = torch.ones((1, 1, H, W)).to(device)
        if use_deltas:
            is_fluid = is_fluid * 0
        cnn_output_with_border = torch.cat([cnn_output, is_fluid, border_mask.unsqueeze(1)], dim=1)  # [1, C'+1, H, W]
    
    return cnn_output_with_border

def predict_with_gnn(gnn, features, edge_index, edge_attr, edge_distance, device, H, W):
    with torch.no_grad():
        # Flatten features for GNN
        x = features[0].permute(1, 2, 0).reshape(-1, features.shape[1]).to(device)  # [num_nodes, C]
        
        # Create GeoData object
        edge_index = edge_index.to(device)
        edge_attr = edge_attr.to(device)
        data = GeoData(x=x, edge_index=edge_index, edge_attr=edge_attr, edge_distance=edge_distance)
        batch = Batch.from_data_list([data]).to(device)  # Batch size of 1
        # GNN Prediction
        gnn_output = gnn(batch.x, batch.edge_index, batch.edge_attr, batch.edge_distance)  # [num_nodes, C']
        
        # Reshape to grid format
        gnn_output_grid = gnn_output.reshape(H, W, -1).permute(2, 0, 1).unsqueeze(0)  # [1, C', H, W]
        border_mask = create_border_mask(H, W, device)
        
        # Append border information to CNN output
        is_fluid = torch.ones((1, 1, H, W)).to(device)
        if use_deltas:
            is_fluid = is_fluid * 0
        gnn_output_with_border = torch.cat([gnn_output_grid, is_fluid, border_mask.unsqueeze(1)], dim=1)  # [1, C'+1, H, W]
        gnn_output_with_border = torch.clip(gnn_output_with_border, -20, 20)
        # gnn_output_with_border[:, 2, :, :] = torch.clip(gnn_output_with_border[:, 2, :, :], -1, 3)
    return gnn_output_with_border

In [None]:
import time
# -------------------------------
# Prediction Loop
# -------------------------------

# Simulation loop
num_steps = 50
cnn_predictions = []
gnn_predictions = []
debug = False

start = time.time()

cnn_features = features
gnn_features = features

for step in range(num_steps):
    if step % 5 == 0:
        print(f"Step {step+1}/{num_steps}")
    
    # CNN Prediction
    cnn_output = predict_with_cnn(cnn_model, cnn_features)
    
    # GNN Prediction
    gnn_output = predict_with_gnn(gnn_model, gnn_features, edge_index, edge_attr, edge_distance, device, H, W)
    
    # Update Time
    new_time = current_time + 1.0
    
    # Debugging: Print shapes and time
    if debug:
        print(f"  Current Features Shape: {features.shape}")
        print(f"  CNN Output Shape: {cnn_output.shape}")
        print(f"  GNN Output Shape: {gnn_output.shape}")
        print(f"  Current Time: {current_time.item()}, Next Time: {new_time.item()}")
        print(features.mean())
    
    # Append to predictions
    if not use_deltas:
        cnn_predictions.append(cnn_output.cpu().numpy())
        gnn_predictions.append(gnn_output.cpu().numpy())
    
    # Update current features (using CNN output for this example, but you can choose CNN or GNN)
    cnn_features = cnn_output
    gnn_features = gnn_output
    current_time = new_time
    
    # Modify 'u' velocity in specific region to maintain U velocity at 5
    cnn_features_np = cnn_features.cpu().numpy().copy()  # [1, C', H, W]
    cnn_u_index = 0  # Assuming 'u' is the first channel in the output
    cnn_features_np[0, u_index, row_start:row_end, cols_to_modify] = 5.0
    # Convert back to tensor
    if use_deltas:
        cnn_features = torch.tensor(cnn_features_np) + cnn_features.detach().cpu()
        cnn_predictions.append(cnn_features.cpu().numpy())
        cnn_features = cnn_features.to(device)
    else:
        cnn_features = torch.tensor(cnn_features_np).to(device)
    
    # Modify 'u' velocity in specific region to maintain U velocity at 5
    gnn_features_np = gnn_features.cpu().numpy().copy()  # [1, C', H, W]
    gnn_u_index = 0  # Assuming 'u' is the first channel in the output
    gnn_features_np[0, u_index, row_start:row_end, cols_to_modify] = 5.0
    # Convert back to tensor
    
    if use_deltas:
        gnn_features = torch.tensor(gnn_features_np) + gnn_features.detach().cpu()
        gnn_predictions.append(gnn_features.cpu().numpy())
        gnn_features = gnn_features.to(device)
    else:
        gnn_features = torch.tensor(gnn_features_np).to(device)

# Convert predictions to numpy arrays
cnn_predictions = np.array(cnn_predictions)
gnn_predictions = np.array(gnn_predictions)

print(f"Simulation complete. {time.time() - start} s")
print(f"CNN Predictions Shape: {cnn_predictions.shape}")
print(f"GNN Predictions Shape: {gnn_predictions.shape}")

# Example of accessing predictions
print("\nExample of accessing predictions:")
print("CNN prediction for step 2, channel 2:")
print(cnn_predictions[1, 0, 2, :5, :5])  # Show 5x5 grid of channel 2 at step 5
print("\nGNN prediction for step 2, channel 2:")
print(gnn_predictions[1, 0, 2, :5, :5])  # Show 5x5 grid of channel 2 at step 5

In [None]:
gnn_predictions[5, 0, :, 20, 20]

In [None]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np

# ... [previous simulation code remains the same] ...

# -------------------------------
# Visualization Functions
# -------------------------------

def create_static_plots(predictions, model_name, num_steps, density_index):
    plt.figure(figsize=(15, 3 * num_steps))
    for step in range(num_steps):
        density = predictions[step][0, density_index, :, :]
        plt.subplot(num_steps, 1, step + 1)
        plt.imshow(density, cmap='viridis', aspect='auto')
        plt.title(f"{model_name} Step {step+1}: Density")
        plt.colorbar()
    plt.tight_layout()
    plt.suptitle(f"{model_name} Predictions: Density over Time", fontsize=16)
    plt.subplots_adjust(top=0.95)
    plt.show()

def create_animation(predictions, model_name, index, index_label):
    fig, ax = plt.subplots()
    ims = []
    for step in range(len(predictions)):
        im = ax.imshow(predictions[step][0, index, :, :], animated=True, cmap='viridis')
        if step == 0:
            ax.imshow(predictions[step][0, index, :, :], cmap='viridis')  # show an initial one first
        ims.append([im])

    ani = animation.ArtistAnimation(fig, ims, interval=100, blit=True, repeat_delay=1000)
    plt.colorbar(im)
    plt.title(f"{model_name} {index_label} Evolution")
    
    # Save the animation
    ani.save(f'./gifs/{model_name}_{index_label}_evolution.gif', writer='pillow')
    
    plt.show()
# -------------------------------
# Visualization
# -------------------------------
density_index = target_cols.index('density_next')
u_index = target_cols.index('u_next')

# Static Plots
# create_static_plots(cnn_predictions, "CNN", num_steps, density_index)
# create_static_plots(gnn_predictions, "GNN", num_steps, density_index)

# Animations
create_animation(cnn_predictions, "CNN", density_index, 'density')
create_animation(gnn_predictions, "GNN", density_index, 'density')
create_animation(cnn_predictions, "CNN", u_index, 'u')
create_animation(gnn_predictions, "GNN", u_index, 'u')

# # Comparison Plot (Static)
# plt.figure(figsize=(20, 4 * num_steps))
# for step in range(3):
#     # CNN
#     plt.subplot(num_steps, 2, 2*step + 1)
#     plt.imshow(cnn_predictions[step][0, density_index, :, :], cmap='viridis', aspect='auto')
#     plt.title(f"CNN Step {step+1}: Density")
#     plt.colorbar()
    
#     # GNN
#     plt.subplot(num_steps, 2, 2*step + 2)
#     plt.imshow(gnn_predictions[step][0, density_index, :, :], cmap='viridis', aspect='auto')
#     plt.title(f"GNN Step {step+1}: Density")
#     plt.colorbar()

# plt.tight_layout()
# plt.suptitle("Comparison: CNN vs GNN Predictions of Density over Time", fontsize=16)
# plt.subplots_adjust(top=0.95)
# plt.show()

# Print statistics
print("\nStatistics:")
for step in range(2):
    cnn_density = cnn_predictions[step][0, density_index, :, :]
    gnn_density = gnn_predictions[step][0, density_index, :, :]
    print(f"\nStep {step+1}:")
    print(f"  CNN - Min: {cnn_density.min():.4f}, Max: {cnn_density.max():.4f}, Mean: {cnn_density.mean():.4f}")
    print(f"  GNN - Min: {gnn_density.min():.4f}, Max: {gnn_density.max():.4f}, Mean: {gnn_density.mean():.4f}")

In [None]:
cnn_predictions[0][0, density_index, :, :]