In [64]:
import torch
import torch.nn as nn
import pandas as pd
import torch.nn.functional as F

import numpy as np

import data_process as dp

from pointnet_model import get_model

In [65]:
# 读取训练数据
train_data = pd.read_csv('../dataset/new_data.csv')

# 查看最后一列
print(train_data.iloc[:, -1].value_counts())

1    5382
8    4842
2    4210
3    3996
4    3540
9    3536
5    3420
6    3270
7    3253
0    3225
Name: 0, dtype: int64


In [66]:
#随机抽样30%的数据作为测试集
test_data = train_data.sample(frac=0.2)
train_data = train_data.drop(test_data.index)

print(train_data.shape, test_data.shape)

(30939, 64) (7735, 64)


In [67]:
X, y = dp.data_to_points_cloud(train_data)
Xval, yval = dp.data_to_points_cloud(test_data)

print(X.shape, y.shape, Xval.shape, yval.shape)

ngpu= 1
# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

X = X.to(device)
y = y.to(device)
Xval = Xval.to(device)
yval = yval.to(device)


torch.Size([30939, 3, 21]) torch.Size([30939, 10]) torch.Size([7735, 3, 21]) torch.Size([7735, 10])


In [68]:
model = get_model(num_classes=10, global_feat=True, feature_transform=False, channel=3)

model.cuda(device=device)

# 定义损失函数和优化器
loss = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),lr = 0.00004)


In [69]:
import visualization as vs

num_epochs = 15
num_samples = X.shape[0]
batch_size = 32



for epoch in range(num_epochs):
    train_loss = 0.0
    for i in range(0, num_samples, batch_size):
        input = X[i:i+batch_size]
        label = y[i:i+batch_size]
 
        # 前向传播
        model.train()
        output, trans, trans_feat = model(input)

        l = loss(output, label)
        
        # 反向传播和优化
        optimizer.zero_grad()
        l.backward()
        optimizer.step()

        train_loss += l.item()

    train_loss /= num_samples
    with torch.no_grad():
        model.eval()
        y_pred, _, _ = model(Xval)
        val_loss = loss(y_pred, yval).item()


    # 打印每个 epoch 的损失
    print(f'Epoch {epoch+1}/{num_epochs}, train_Loss: {train_loss}, val_Loss: {val_loss}')
    
    if val_loss < 0.0001:
        break
    


    # 保存模型
torch.save(model.state_dict(), '../model/model.pth')

Epoch 1/15, train_Loss: 0.001559455398365529, val_Loss: 0.016350211575627327
Epoch 2/15, train_Loss: 0.0005244472388621451, val_Loss: 0.006980341859161854
Epoch 3/15, train_Loss: 0.0002874046826259589, val_Loss: 0.004425222519785166
Epoch 4/15, train_Loss: 0.00019122346951713385, val_Loss: 0.0034321697894483805
Epoch 5/15, train_Loss: 0.00014101149283791547, val_Loss: 0.002921453444287181
Epoch 6/15, train_Loss: 0.00011093991424315285, val_Loss: 0.0026759773027151823
Epoch 7/15, train_Loss: 9.180742476376066e-05, val_Loss: 0.0021394421346485615
Epoch 8/15, train_Loss: 7.267963175638555e-05, val_Loss: 0.0018935413099825382
Epoch 9/15, train_Loss: 6.73047179232176e-05, val_Loss: 0.0020153019577264786
Epoch 10/15, train_Loss: 5.669809154291311e-05, val_Loss: 0.0018949421355500817
Epoch 11/15, train_Loss: 4.363173060569787e-05, val_Loss: 0.0016323844902217388
Epoch 12/15, train_Loss: 4.5574795905201765e-05, val_Loss: 0.0020605160389095545
Epoch 13/15, train_Loss: 4.294910059504344e-05, val

In [70]:
with torch.no_grad():
    model.eval()
    pre_val, _, _ = model(Xval)

head = 10


print(pre_val.shape)
print(pre_val[:head])
print(yval[:head])

torch.Size([7735, 10])
tensor([[1.6068e-04, 1.1693e-04, 2.4951e-04, 4.5942e-05, 3.5200e-04, 5.0592e-04,
         4.9363e-04, 9.9744e-01, 3.9095e-04, 2.4023e-04],
        [2.1474e-04, 1.8425e-04, 2.4904e-04, 4.7587e-04, 1.0751e-03, 9.9551e-01,
         7.7911e-04, 5.6548e-04, 8.1226e-04, 1.3761e-04],
        [1.1942e-04, 9.9778e-01, 3.1622e-04, 9.1208e-05, 6.5179e-05, 4.7252e-05,
         1.7632e-04, 2.6944e-04, 5.4860e-04, 5.8389e-04],
        [1.4974e-04, 1.0408e-04, 2.0337e-04, 1.6371e-04, 7.3019e-04, 9.9525e-01,
         1.6189e-03, 9.3804e-04, 7.8248e-04, 5.8457e-05],
        [6.4716e-04, 9.9208e-01, 1.0478e-03, 3.7852e-04, 4.0733e-04, 2.0984e-04,
         1.3566e-03, 7.6592e-04, 1.7157e-03, 1.3864e-03],
        [9.5761e-05, 3.4759e-04, 2.2387e-04, 7.0074e-05, 1.0803e-04, 1.5795e-04,
         1.4418e-04, 2.8645e-04, 9.9839e-01, 1.7222e-04],
        [5.2977e-05, 9.6094e-04, 9.9583e-01, 8.3404e-04, 3.7202e-04, 3.0782e-04,
         1.7379e-04, 8.8922e-04, 2.7889e-04, 2.9557e-04],
    