In [12]:
import torch
from torch import nn
import time
import numpy as np
import onnx
import onnxruntime as ort
import os

In [13]:
# ================================
# 1. Core Spatial Fusion Module (Single Timestep)
# ================================
class FlagWindNet(nn.Module):
    def __init__(self, flag_h= 32, flag_w = 32, num_wind_points=8, hidden_dim=128):
        super().__init__()

        # Grid dimensions (Assumed 32x32 = 1024)
        self.grid_h = flag_h
        self.grid_w = flag_w

        # 1. CNN Flag Encoder (Takes 32x32x3 image structure)
        # Input: (Batch, 3, 32, 32) -> Output: Flattened Vector
        self.flag_encoder = nn.Sequential(
            # Conv Block 1: 3 -> 16
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2), # Output: 16 x 16 x 16

            # Conv Block 2: 16 -> 32
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2), # Output: 32 x 8 x 8

            # Conv Block 3: 32 -> 64
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2), # Output: 64 x 4 x 4

            # Flatten final feature map
            nn.Flatten(),
            # Project to latent dim (64 * 4 * 4 = 1024)
            nn.Linear(64 * 4 * 4, hidden_dim)
        )

        # 2. Wind Encoder (MLP)
        self.wind_flat_dim = num_wind_points * 3
        self.wind_encoder = nn.Sequential(
            nn.Linear(self.wind_flat_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, hidden_dim // 4)
        )

        # LSTM input size = Flag Latent + Wind Latent
        self.lstm_input_dim = hidden_dim + (hidden_dim // 4)

        # 3. Temporal Processing
        self.lstm = nn.LSTM(
            input_size=self.lstm_input_dim,
            hidden_size=hidden_dim,
            num_layers=2,
            batch_first=True
        )

        # ==========================================
        # 4. DECODER (4x4 -> 32x32)
        # ==========================================

        # Step A: Project LSTM Hidden State back to Spatial Feature Map Size
        # We need to get back to (64 channels, 4 height, 4 width)
        self.decoder_projection = nn.Linear(hidden_dim, 64 * 4 * 4)

        # Step B: Transpose Convolutions to Upsample
        self.decoder_cnn = nn.Sequential(
            # Unflatten happen manually in forward()

            # Block 1: 4x4 -> 8x8
            # Input: (B, 64, 4, 4)
            nn.ConvTranspose2d(in_channels=64, out_channels=32,
                               kernel_size=4, stride=2, padding=1),
            nn.ReLU(),

            # Block 2: 8x8 -> 16x16
            # Input: (B, 32, 8, 8)
            nn.ConvTranspose2d(in_channels=32, out_channels=16,
                               kernel_size=4, stride=2, padding=1),
            nn.ReLU(),

            # Block 3: 16x16 -> 32x32
            # Input: (B, 16, 16, 16)
            # Output channels = 3 (dx, dy, dz)
            nn.ConvTranspose2d(in_channels=16, out_channels=3,
                               kernel_size=4, stride=2, padding=1)

            # No final activation (Regression output can be pos or neg)
        )

    def forward(self, x_flag_seq, x_wind_seq):
        # x_flag_seq: (B, L, 1024, 3)
        # x_wind_seq: (B, L, 8, 3)

        B, L, N, D = x_flag_seq.shape

        # --- 1. PREPARE INPUTS ---
        # Reshape to (B*L, 3, 32, 32)
        x_flag_img = x_flag_seq.view(B * L, self.grid_h, self.grid_w, 3).permute(0, 3, 1, 2)
        x_wind_flat = x_wind_seq.view(B * L, -1)

        # --- 2. ENCODE ---
        # Flag: CNN -> Flatten -> Linear
        # The encoder now outputs the latent vector directly
        flag_latent = self.flag_encoder(x_flag_img)    # (B*L, hidden_dim)

        # Wind: MLP
        wind_latent = self.wind_encoder(x_wind_flat)   # (B*L, hidden_dim//4)

        # --- 3. LSTM ---
        # Combine Latents
        combined = torch.cat([flag_latent, wind_latent], dim=1) # (B*L, H + H/4)

        # Reshape for LSTM: (B, L, Input_Size)
        lstm_in = combined.view(B, L, -1)

        # Run LSTM
        lstm_out, _ = self.lstm(lstm_in) # Output: (B, L, hidden_dim)

        # --- 4. DECODE ---
        # Flatten time dim again: (B*L, hidden_dim)
        lstm_out_flat = lstm_out.reshape(B * L, -1)

        # Project back to feature map size: (B*L, 64*4*4)
        decoder_input = self.decoder_projection(lstm_out_flat)

        # Unflatten to 4x4 spatial map: (B*L, 64, 4, 4)
        decoder_input_map = decoder_input.view(B * L, 64, 4, 4)

        # Run Transpose Convs: Output (B*L, 3, 32, 32)
        spatial_output = self.decoder_cnn(decoder_input_map)

        # --- 5. FORMAT OUTPUT ---
        # (B*L, 3, 32, 32) -> (B*L, 32, 32, 3) -> (B, L, 1024, 3)
        output_seq = spatial_output.permute(0, 2, 3, 1).reshape(B, L, N, D)

        return output_seq

# # ================================
# # 2. Sequential Model
# # ================================
# class SequentialFlagWindNet(nn.Module):
#     def __init__(self, num_flag_points=1024,
#                  num_wind_points=8, hidden_dim=128):
#         super().__init__()

#         # Dimensions
#         self.flag_flat_dim = num_flag_points * 3
#         self.wind_flat_dim = num_wind_points * 3

#         # 1. Input Compression
#         # FIXED: Added commas between layers
#         self.flag_encoder = nn.Sequential(
#             nn.Linear(self.flag_flat_dim, hidden_dim * 2),
#             nn.ReLU(),
#             nn.Linear(hidden_dim * 2, hidden_dim),
#             nn.ReLU(),
#             nn.Linear(hidden_dim, hidden_dim // 2),
#             nn.ReLU(),
#             nn.Linear(hidden_dim // 2, hidden_dim // 4)
#         )

#         # FIXED: Input to 2nd layer matches output of 1st layer (hidden_dim // 4)
#         self.wind_encoder = nn.Sequential(
#             nn.Linear(self.wind_flat_dim, hidden_dim // 4),
#             nn.ReLU(),
#             nn.Linear(hidden_dim // 4, hidden_dim // 8)
#         )

#         # LSTM input size = (H/4) + (H/8) = 3H/8
#         self.lstm_input_dim = (hidden_dim // 4) + (hidden_dim // 8)

#         # 2. Temporal Processing
#         self.lstm = nn.LSTM(
#             input_size=self.lstm_input_dim,
#             hidden_size=hidden_dim,
#             num_layers=2,
#             batch_first=True
#         )

#         # 3. Decoder
#         # FIXED: The decoder must map the LSTM Output (hidden_dim)
#         # back to the full flag size (1024 * 3).
#         # We cannot reuse 'FlagWindSpatialFusionNet' here because that net
#         # expects raw 3D coordinates, but we have a compressed LSTM vector.
#         self.decoder = nn.Sequential(
#             nn.Linear(hidden_dim, hidden_dim * 2),
#             nn.ReLU(),
#             nn.Linear(hidden_dim * 2, self.flag_flat_dim) # Output: 3072
#         )

#     def forward(self, x_flag_seq, x_wind_seq):
#         # x_flag_seq: (B, L, 1024, 3)
#         # x_wind_seq: (B, L, 8, 3)
#         B, L, N, D = x_flag_seq.shape

#         # Flatten
#         flag_flat = x_flag_seq.reshape(B*L, -1)
#         wind_flat = x_wind_seq.reshape(B*L, -1)

#         # Encode
#         flag_enc = self.flag_encoder(flag_flat) # (B*L, H/4)
#         wind_enc = self.wind_encoder(wind_flat) # (B*L, H/8)

#         # Combine
#         combined = torch.cat([flag_enc, wind_enc], dim=1) # (B*L, 3H/8)

#         # Reshape for LSTM
#         lstm_in = combined.reshape(B, L, -1)

#         # LSTM Pass
#         lstm_out, _ = self.lstm(lstm_in) # (B, L, H)

#         # Decode from LSTM output
#         lstm_out_flat = lstm_out.reshape(B*L, -1)

#         # FIXED: Decoder uses the LSTM output
#         disp_flat = self.decoder(lstm_out_flat)

#         # Reshape back to sequence
#         return disp_flat.reshape(B, L, N, D)

In [14]:
# # ================================
# # 2. Model
# # ================================
# class FlagWindNet(nn.Module):
#     def __init__(self, hidden_size=512):
#         super().__init__()

#         # Encode per-flag-point (3 → hidden_size)
#         self.flag_encoder = nn.Sequential(
#             nn.Linear(3, hidden_size),
#             nn.ReLU(),
#             nn.Linear(hidden_size, hidden_size),
#             nn.ReLU()
#         )

#         # Encode per-wind-point (3 → hidden_size)
#         self.wind_encoder = nn.Sequential(
#             nn.Linear(3, hidden_size),
#             nn.ReLU(),
#             nn.Linear(hidden_size, hidden_size),
#             nn.ReLU()
#         )

#         # Combine wind into a global embedding (mean pooling)
#         self.wind_pool = nn.AdaptiveAvgPool1d(1)

#         # Final decoder: per-flag-point + wind → next (x,y,z)
#         self.decoder = nn.Sequential(
#             nn.Linear(hidden_size * 2, hidden_size),
#             nn.ReLU(),
#             nn.Linear(hidden_size, 3)   # predict next xyz
#         )

#     def forward(self, flag, wind):
#         """
#         flag: (B, 1024, 3)
#         wind: (B, 8, 3)
#         """
#         B, N, _ = flag.shape

#         # Encode flag points → (B, N, hidden)
#         flag_feat = self.flag_encoder(flag)  # applies per point

#         # Encode wind points → (B, 8, hidden)
#         wind_feat = self.wind_encoder(wind)

#         # Pool wind to global embedding → (B, hidden)
#         wind_global = wind_feat.mean(dim=1)

#         # Broadcast wind_global to match each flag point → (B, N, hidden)
#         wind_broadcast = wind_global.unsqueeze(1).expand(-1, N, -1)

#         # Concatenate flag and wind features
#         fused = torch.cat([flag_feat, wind_broadcast], dim=-1)  # (B, N, 2*hidden)

#         # Decode per-flag-point → (B, N, 3)
#         out = self.decoder(fused)

#         return out

In [15]:
# # ---- Load the full model object ----
# with torch.serialization.safe_globals([FlagWindNet]):
#     model = torch.load("flagwind_model.pth", map_location="cpu", weights_only=False)

# model.eval()

In [16]:
# ================================
# 2. LOAD AND EXPORT
# ================================

# --- Configuration ---
# !! SET THIS: Path to your trained model
MODEL_PATH = "./model/flagwind_model_weights.pth" 
EXPORT_NAME = "./model/flagwind_model.onnx" # Output file name
NUM_VERTICES = 1024 # Must match your flag's vertex count

# --- Load Model ---
print(f"Loading model from {MODEL_PATH}...")
if not os.path.exists(MODEL_PATH):
    print(f"!! ERROR: Model file not found at: {MODEL_PATH}")
    print("Please make sure the file is in the same directory or provide the full path.")
else:
    # HERE IS THE FIX: Instantiate with default parameters
    model = FlagWindNet() 
    
    # Load the weights
    model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
    
    # IMPORTANT: Set model to evaluation mode
    model.eval() 
    print("Model loaded successfully.")

    # # --- Create Dummy Inputs ---
    # # We need to provide example inputs so ONNX can trace the model's operations.
    # # The batch size (B) should be 1.
    # dummy_flag_input = torch.randn(1, NUM_VERTICES, 3)
    # dummy_wind_input = torch.randn(1, 8, 3) # Based on your model's forward pass

    # # --- Define Input/Output Names ---
    # # These names will be used in your C# script to feed data to the model.
    # input_names = ["flag_input", "wind_input"]
    # output_names = ["displacement_output"]

    # # --- Export ---
    # print(f"Exporting model to {EXPORT_NAME}...")
    # torch.onnx.export(
    #     model,
    #     (dummy_flag_input, dummy_wind_input), # Dummy inputs tuple
    #     EXPORT_NAME,
    #     input_names=input_names,
    #     output_names=output_names,
    #     opset_version=12, # A good, stable version
    #     verbose=False
    # )
    # print("="*30)
    # print(f"✅ Successfully exported model to {EXPORT_NAME}")
    # print(f"  Input 1 ({input_names[0]}): shape (1, {NUM_VERTICES}, 3)")
    # print(f"  Input 2 ({input_names[1]}): shape (1, 8, 3)")
    # print(f"  Output ({output_names[0]}): shape (1, {NUM_VERTICES}, 3)")
    # print("="*30)
    
    
    
    # --- 1. Define Correct 4D Dummy Inputs ---
    # Shape: (Batch=1, Seq_Len=5, Points=1024, Channels=3)
    dummy_flag_input = torch.randn(1, 5, 1024, 3) 

    # Shape: (Batch=1, Seq_Len=5, Wind_Points=8, Channels=3)
    dummy_wind_input = torch.randn(1, 5, 8, 3)

    # --- 2. Run the Export ---
    EXPORT_NAME = "flag_wind_net.onnx"
    input_names = ["flag_input", "wind_input"]
    output_names = ["displacement_output"]

    torch.onnx.export(
        model,
        (dummy_flag_input, dummy_wind_input),
        EXPORT_NAME,
        input_names=input_names,
        output_names=output_names,
        opset_version=12,
        # IMPORTANT: Allow Batch (0) and Sequence (1) dimensions to change size
        dynamic_axes={
            "flag_input": {0: "batch_size", 1: "sequence_length"},
            "wind_input": {0: "batch_size", 1: "sequence_length"},
            "displacement_output": {0: "batch_size", 1: "sequence_length"}
        },
        verbose=False
    )

    print(f"✅ Successfully exported model to {EXPORT_NAME}")

Loading model from ./model/flagwind_model_weights.pth...
Model loaded successfully.
✅ Successfully exported model to flag_wind_net.onnx
