In [20]:
import torch
from torch import nn
import time
import numpy as np
import onnx
import onnxruntime as ort
import os

In [21]:
# ================================
# 2. Model
# ================================
class FlagWindNet(nn.Module):
    def __init__(self, hidden_size=512):
        super().__init__()

        # Encode per-flag-point (3 → hidden_size)
        self.flag_encoder = nn.Sequential(
            nn.Linear(3, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU()
        )

        # Encode per-wind-point (3 → hidden_size)
        self.wind_encoder = nn.Sequential(
            nn.Linear(3, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU()
        )

        # Combine wind into a global embedding (mean pooling)
        self.wind_pool = nn.AdaptiveAvgPool1d(1)

        # Final decoder: per-flag-point + wind → next (x,y,z)
        self.decoder = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 3)   # predict next xyz
        )

    def forward(self, flag, wind):
        """
        flag: (B, 1024, 3)
        wind: (B, 8, 3)
        """
        B, N, _ = flag.shape

        # Encode flag points → (B, N, hidden)
        flag_feat = self.flag_encoder(flag)  # applies per point

        # Encode wind points → (B, 8, hidden)
        wind_feat = self.wind_encoder(wind)

        # Pool wind to global embedding → (B, hidden)
        wind_global = wind_feat.mean(dim=1)

        # Broadcast wind_global to match each flag point → (B, N, hidden)
        wind_broadcast = wind_global.unsqueeze(1).expand(-1, N, -1)

        # Concatenate flag and wind features
        fused = torch.cat([flag_feat, wind_broadcast], dim=-1)  # (B, N, 2*hidden)

        # Decode per-flag-point → (B, N, 3)
        out = self.decoder(fused)

        return out

In [22]:
# # ---- Load the full model object ----
# with torch.serialization.safe_globals([FlagWindNet]):
#     model = torch.load("flagwind_model.pth", map_location="cpu", weights_only=False)

# model.eval()

In [23]:
# ================================
# 2. LOAD AND EXPORT
# ================================

# --- Configuration ---
# !! SET THIS: Path to your trained model
MODEL_PATH = "./model/flagwind_model_weights.pth" 
EXPORT_NAME = "./model/flagwind_model.onnx" # Output file name
NUM_VERTICES = 1024 # Must match your flag's vertex count

# --- Load Model ---
print(f"Loading model from {MODEL_PATH}...")
if not os.path.exists(MODEL_PATH):
    print(f"!! ERROR: Model file not found at: {MODEL_PATH}")
    print("Please make sure the file is in the same directory or provide the full path.")
else:
    # HERE IS THE FIX: Instantiate with default parameters
    model = FlagWindNet() 
    
    # Load the weights
    model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
    
    # IMPORTANT: Set model to evaluation mode
    model.eval() 
    print("Model loaded successfully.")

    # --- Create Dummy Inputs ---
    # We need to provide example inputs so ONNX can trace the model's operations.
    # The batch size (B) should be 1.
    dummy_flag_input = torch.randn(1, NUM_VERTICES, 3)
    dummy_wind_input = torch.randn(1, 8, 3) # Based on your model's forward pass

    # --- Define Input/Output Names ---
    # These names will be used in your C# script to feed data to the model.
    input_names = ["flag_input", "wind_input"]
    output_names = ["displacement_output"]

    # --- Export ---
    print(f"Exporting model to {EXPORT_NAME}...")
    torch.onnx.export(
        model,
        (dummy_flag_input, dummy_wind_input), # Dummy inputs tuple
        EXPORT_NAME,
        input_names=input_names,
        output_names=output_names,
        opset_version=12, # A good, stable version
        verbose=False
    )
    print("="*30)
    print(f"✅ Successfully exported model to {EXPORT_NAME}")
    print(f"  Input 1 ({input_names[0]}): shape (1, {NUM_VERTICES}, 3)")
    print(f"  Input 2 ({input_names[1]}): shape (1, 8, 3)")
    print(f"  Output ({output_names[0]}): shape (1, {NUM_VERTICES}, 3)")
    print("="*30)

Loading model from ./model/flagwind_model_weights.pth...
Model loaded successfully.
Exporting model to ./model/flagwind_model.onnx...
✅ Successfully exported model to ./model/flagwind_model.onnx
  Input 1 (flag_input): shape (1, 1024, 3)
  Input 2 (wind_input): shape (1, 8, 3)
  Output (displacement_output): shape (1, 1024, 3)
