In [None]:
import torch
import torchvision.models as models
import torch.nn as nn
import numpy as np

# --- Configuration ---
PTH_FILE = 'efficientnet_b0_full.pth'
ONNX_FILE = 'efficientnet_b0.onnx'
# Assuming standard EfficientNetB0 input size and 3 color channels
INPUT_SHAPE = (1, 3, 224, 224) 
# NOTE: Replace 'NUM_CLASSES' with the actual number of classes 
# your model was trained on (e.g., 5 or 1000).
NUM_CLASSES = 2

# 1. Load the PyTorch Model Architecture
# Initialize the standard EfficientNetB0 model from torchvision
model = models.efficientnet_b0(weights=None)

# Adjust the final layer to match your training configuration
# This is crucial if you fine-tuned the model (as the attached notebook suggests)
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, NUM_CLASSES) 

# 2. Load the Trained Weights
try:
    # Load state_dict from the .pth file
    state_dict = torch.load(PTH_FILE, map_location=torch.device('cpu'))
    # If your .pth file saves the model state_dict directly:
    model.load_state_dict(state_dict)
except RuntimeError as e:
    # Some .pth files save the full model or include extra info (like an optimizer state)
    # If the direct load fails, you might need to extract the state_dict key:
    if 'state_dict' in state_dict:
        model.load_state_dict(state_dict['state_dict'])
    else:
        print(f"Could not load weights: {e}. Check the exact structure of your {PTH_FILE}.")
        raise

# Set model to evaluation mode
model.eval()

# 3. Create a Dummy Input
dummy_input = torch.randn(INPUT_SHAPE, requires_grad=True)

# 4. Export to ONNX
torch.onnx.export(
    model,                      # The PyTorch model
    dummy_input,                # Model input
    ONNX_FILE,                  # Output file name
    export_params=True,         # Export weights
    opset_version=13,           # ONNX OpSet version (13 is generally safe)
    do_constant_folding=True,
    input_names=['input'],      # Name the input layer
    output_names=['output'],    # Name the output layer
    dynamic_axes={
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    }
)

print(f"âœ… PyTorch model successfully exported to {ONNX_FILE}")