In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import json
import os

# 避免TensorFlow導入問題，先用純NumPy實作
class FashionMNISTNet(nn.Module):
    def __init__(self):
        super(FashionMNISTNet, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, 10)
        
    def forward(self, x):
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

def extract_weights_only():
    """只提取權重，避免TensorFlow導入問題"""
    print("開始提取PyTorch權重...")
    
    # 載入PyTorch模型
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    pytorch_model = FashionMNISTNet()
    
    try:
        pytorch_model.load_state_dict(torch.load('fashion_mnist_pytorch.pth', map_location=device))
        print("✅ PyTorch模型載入成功")
    except FileNotFoundError:
        print("❌ 找不到 fashion_mnist_pytorch.pth，請先執行訓練")
        return False
    
    pytorch_model.eval()
    
    # 提取權重
    weights_dict = {}
    layer_idx = 0
    
    for name, param in pytorch_model.named_parameters():
        if 'fc' in name:
            if 'weight' in name:
                # 轉置權重矩陣
                weights_dict[f'layer_{layer_idx+1}_weights'] = param.detach().cpu().numpy().T
            elif 'bias' in name:
                weights_dict[f'layer_{layer_idx+1}_bias'] = param.detach().cpu().numpy()
                layer_idx += 1
    
    print(f"提取了 {len(weights_dict)} 個權重參數")
    
    # 確保model資料夾存在
    os.makedirs('model', exist_ok=True)
    
    # 儲存權重
    np.savez('model/fashion_mnist.npz', **weights_dict)
    
    # 手動建立架構JSON
    model_config = {
        "name": "sequential",
        "layers": [
            {
                "class_name": "Flatten",
                "config": {
                    "name": "flatten",
                    "trainable": True,
                    "dtype": "float32",
                    "data_format": "channels_last"
                }
            },
            {
                "class_name": "Dense",
                "config": {
                    "name": "dense",
                    "trainable": True,
                    "dtype": "float32",
                    "units": 256,
                    "activation": "relu",
                    "use_bias": True
                }
            },
            {
                "class_name": "Dense",
                "config": {
                    "name": "dense_1",
                    "trainable": True,
                    "dtype": "float32",
                    "units": 128,
                    "activation": "relu",
                    "use_bias": True
                }
            },
            {
                "class_name": "Dense",
                "config": {
                    "name": "dense_2",
                    "trainable": True,
                    "dtype": "float32",
                    "units": 64,
                    "activation": "relu",
                    "use_bias": True
                }
            },
            {
                "class_name": "Dense",
                "config": {
                    "name": "dense_3",
                    "trainable": True,
                    "dtype": "float32",
                    "units": 10,
                    "activation": "softmax",
                    "use_bias": True
                }
            }
        ]
    }
    
    # 儲存架構
    with open('model/fashion_mnist.json', 'w') as f:
        json.dump(model_config, f, indent=2)
    
    print("✅ 模型檔案已儲存")
    print("- 架構檔案: model/fashion_mnist.json")
    print("- 權重檔案: model/fashion_mnist.npz")
    
    return True

if __name__ == "__main__":
    success = extract_weights_only()
    if success:
        print("轉換完成！")


開始提取PyTorch權重...


  pytorch_model.load_state_dict(torch.load('fashion_mnist_pytorch.pth', map_location=device))


✅ PyTorch模型載入成功
提取了 8 個權重參數
✅ 模型檔案已儲存
- 架構檔案: model/fashion_mnist.json
- 權重檔案: model/fashion_mnist.npz
轉換完成！
