In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import sys

sys.path.append(os.getcwd())

from training.policy_network.model import ChessPolicyNet as Model, INPUT_CHANNELS

MODEL_PATH = "models/policy_network/BetaChess.pt"
ONNX_PATH = "models/policy_network/BetaChess.onnx"


def export():
    try:
        import onnx
    except ImportError:
        import subprocess
        subprocess.check_call([sys.executable, "-m", "pip", "install", "onnx"])
        import onnx

    print(f"Inicjalizacja modelu ResNet (BetaChess)...")
    
    model = Model() 
    
    print(f"Wczytywanie wag z {MODEL_PATH}...")
    try:
        model.load_state_dict(torch.load(MODEL_PATH, map_location='cpu'))
    except Exception as e:
        print(f"Błąd ładowania wag: {e}")
        return

    model.eval()

    print(f"Tworzenie dummy_input o wymiarach [1, {INPUT_CHANNELS}, 8, 8]...")
    dummy_input = torch.randn(1, INPUT_CHANNELS, 8, 8, requires_grad=True)

    print(f"Eksportowanie do {ONNX_PATH}...")
    torch.onnx.export(model, 
        dummy_input, 
        ONNX_PATH, 
        export_params=True,        
        opset_version=11,          
        do_constant_folding=True,
        input_names=['input'], 
        output_names=['output'],
        dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
    )
    
    print(f"SUKCES! Wyeksportowano model do {ONNX_PATH}")

if __name__ == "__main__":
    export()

Inicjalizacja modelu ResNet (BetaChess)...
Wczytywanie wag z models/policy_network/BetaChess.pt...
Tworzenie dummy_input o wymiarach [1, 69, 8, 8]...
Eksportowanie do models/policy_network/BetaChess.onnx...


  model.load_state_dict(torch.load(MODEL_PATH, map_location='cpu'))


SUKCES! Wyeksportowano model do models/policy_network/BetaChess.onnx
