In [71]:
import torch
from torch import nn
from torch.utils.data import DataLoader

from nn_train import *
from main_nn_module import *
from torch_dataset import *
from utils.split import *

In [72]:
device = "mps"


In [73]:
def calculate_model_size():
    # 创建模型实例
    nn_model = MainNnModule()

    # 计算总参数量
    total_params = sum(p.numel() for p in nn_model.parameters() if p.requires_grad)

    # 转换为千参数单位
    params_in_k = total_params / 1000

    # 创建模拟输入（单通道，10x16输入）
    dummy_input = torch.randn(1, 1, 10, 16)

    # 测试前向传播
    output = nn_model(dummy_input)

    return total_params, params_in_k, output.shape


In [74]:
model = MainNnModule().to(device)
optimizer = optim.AdamW(model.parameters())
loss_fn = nn.CrossEntropyLoss()

dataset = JQDataset("./datas/dataset.csv")
train_set, test_set = stratified_split(dataset, test_size=0.2)

train_dataloader = DataLoader(train_set, batch_size=8, shuffle=True)
test_dataloader = DataLoader(test_set, batch_size=8, shuffle=True)


In [75]:
epochs = 64
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer, device)
    test(test_dataloader, model, loss_fn, device)
print("Done!")

Epoch 1
-------------------------------
loss: 1.208029  [    8/   47]
Test Error: 
 Accuracy: 33.3%, Avg loss: 1.110210 

Epoch 2
-------------------------------
loss: 1.122810  [    8/   47]
Test Error: 
 Accuracy: 33.3%, Avg loss: 1.113658 

Epoch 3
-------------------------------
loss: 1.119207  [    8/   47]
Test Error: 
 Accuracy: 33.3%, Avg loss: 1.134094 

Epoch 4
-------------------------------
loss: 1.106864  [    8/   47]
Test Error: 
 Accuracy: 33.3%, Avg loss: 1.096844 

Epoch 5
-------------------------------
loss: 1.064164  [    8/   47]
Test Error: 
 Accuracy: 33.3%, Avg loss: 1.113546 

Epoch 6
-------------------------------
loss: 1.171954  [    8/   47]
Test Error: 
 Accuracy: 33.3%, Avg loss: 1.106164 

Epoch 7
-------------------------------
loss: 1.041249  [    8/   47]
Test Error: 
 Accuracy: 33.3%, Avg loss: 1.108452 

Epoch 8
-------------------------------
loss: 1.088234  [    8/   47]
Test Error: 
 Accuracy: 33.3%, Avg loss: 1.105810 

Epoch 9
----------------

In [76]:
total_params, params_in_k, output_shape = calculate_model_size()

print(f"模型总参数量: {total_params}")
print(f"以k为单位的参数量: {params_in_k:.2f}k")
print(f"输出形状: {output_shape}")


模型总参数量: 1031
以k为单位的参数量: 1.03k
输出形状: torch.Size([1, 3])


In [77]:
torch.save(model,'./model.pt')