In [7]:
import pandas
import data_process
import tools
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm


In [8]:
class RNN(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, hidden_size, device):
        super(RNN, self).__init__()
        self.device = device
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.rnn = nn.RNN(embedding_dim, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, 2)  # 假设是二分类任务
    
    def forward(self, x):
        x = self.embedding(x)  # 获取词嵌入
        out, _ = self.rnn(x)  # RNN 前向传播
        out = out[:, -1, :]  # 取最后一个时间步的输出
        out = self.fc(out)  # 输出层
        return out

包装数据

In [9]:
data  = pandas.read_csv('./motionClassify.csv')
vocab = data_process.gen_vocab(data)
data_train  =  data_process.gen_dataset(data[:40000],vocab)
data_test = data_process.gen_dataset(data[40000:],vocab)
Batch_size = 64
train_iter = torch.utils.data.DataLoader(data_train,Batch_size,shuffle=True)
test_iter = torch.utils.data.DataLoader(data_test,Batch_size,shuffle=True)

In [10]:
lr=0.05
criterion = torch.nn.CrossEntropyLoss()
device = torch.device('cpu' if not torch.cuda.is_available() else 'cuda:0')
net1 = RNN(num_embeddings=len(vocab),embedding_dim=256,hidden_size=256,device=device)
optimizer1 = torch.optim.SGD(net1.parameters(),lr)


In [11]:
tools.train(net1,train_iter,device,optimizer1,criterion)

 16%|█▌        | 100/625 [02:48<15:01,  1.72s/it]

batch100,loss = 0.681061327457428


 32%|███▏      | 200/625 [14:15<1:01:03,  8.62s/it]

batch200,loss = 0.6917207837104797


 48%|████▊     | 300/625 [20:43<09:09,  1.69s/it]  

batch300,loss = 0.6934199929237366


 64%|██████▍   | 400/625 [26:59<05:44,  1.53s/it]  

batch400,loss = 0.6930481195449829


 80%|████████  | 500/625 [30:45<03:32,  1.70s/it]

batch500,loss = 0.6930347681045532


 96%|█████████▌| 600/625 [35:53<00:38,  1.55s/it]

batch600,loss = 0.6938478946685791


100%|██████████| 625/625 [36:31<00:00,  3.51s/it]


In [12]:
tools.test(net1,test_iter,device)

100%|██████████| 157/157 [00:12<00:00, 12.75it/s]

accuracy = 0.49959999322891235



