In [40]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchaudio
import sys

import matplotlib.pyplot as plt
import IPython.display as ipd

from tqdm import tqdm

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [5]:
from torchaudio.datasets import SPEECHCOMMANDS
import os


class SubsetSC(SPEECHCOMMANDS):
    def __init__(self, subset: str = None):
        super().__init__("./", download=True)

        def load_list(filename):
            filepath = os.path.join(self._path, filename)
            with open(filepath) as fileobj:
                return [os.path.normpath(os.path.join(self._path, line.strip())) for line in fileobj]

        if subset == "validation":
            self._walker = load_list("validation_list.txt")
        elif subset == "testing":
            self._walker = load_list("testing_list.txt")
        elif subset == "training":
            excludes = load_list("validation_list.txt") + load_list("testing_list.txt")
            excludes = set(excludes)
            self._walker = [w for w in self._walker if w not in excludes]
            # 筛选代表数字的音频文件
        digits = ['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine']
        self._walker = [w for w in self._walker if any(digit in w for digit in digits)]
        self.labels = []
    def collect_labels(self):
        # 收集所有唯一的标签
        for _, _, label, _, _ in self:
            if label not in self.labels:
                self.labels.append(label)
        self.labels.sort()

# # Create training and testing split of the data. We do not use validation in this tutorial.
# train_set = SubsetSC("training")
# test_set = SubsetSC("testing")

# waveform, sample_rate, label, speaker_id, utterance_number = train_set[0]

In [6]:
import torch
from torch.utils.data import DataLoader
import torchaudio.transforms as T
from torch.utils.data.dataset import random_split

sample_rate=16000
new_sample_rate = 8000
transform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=new_sample_rate)

from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    # 初始化特征和标签列表
    inputs, labels = [], []

    # 定义一个临时列表来存储所有波形的长度
    lengths = []

    # 提取波形和标签
    for waveform, _, label, _, _ in batch:
        lengths.append(waveform.size(1))
        inputs.append(waveform.squeeze(0)) 
        labels.append(label)

    # 找到最大的波形长度
    max_len = max(lengths)

    # Pad输入波形到相同的长度
    inputs_padded = pad_sequence(inputs, batch_first=True, padding_value=0)

    # 将标签转换为张量
    label_to_index = {label: index for index, label in enumerate(sorted(set(labels)))}
    labels_indices = torch.tensor([label_to_index[label] for label in labels])

    return inputs_padded, labels_indices

# all_data = SubsetSC(subset=None)  # 加载所有数据
# total_size = len(all_data)
# train_size = int(total_size * 0.8)
# test_size = total_size - train_size
# train_dataset, test_dataset = random_split(all_data, [train_size, test_size])

# train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
# test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)
# train_dataset, test_dataset = random_split(all_data, [train_size, test_size])

# 使用SubsetSC类和collate_fn来创建DataLoader
if device == "cuda":
    num_workers = 16
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False
train_set = SubsetSC(subset='training')
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, collate_fn=collate_fn, num_workers=num_workers, pin_memory=pin_memory)
val_set = SubsetSC(subset='validation')
val_loader = DataLoader(val_set, batch_size=64, shuffle=True, collate_fn=collate_fn, num_workers=num_workers, pin_memory=pin_memory)
test_set = SubsetSC(subset='testing')
test_loader = DataLoader(test_set, batch_size=64, shuffle=False, collate_fn=collate_fn, num_workers=num_workers, pin_memory=pin_memory)

train_set.collect_labels()
print("Training set labels:", train_set.labels)

Training set labels: ['eight', 'five', 'four', 'nine', 'one', 'seven', 'six', 'three', 'two', 'zero']


In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Perceptron(nn.Module):
    def __init__(self, input_size, num_classes):
        super(Perceptron, self).__init__()
        self.fc = nn.Linear(input_size, num_classes)
    
    def forward(self, x):
        out = self.fc(x)
        return out


In [26]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out


In [30]:
class M5(nn.Module):
    def __init__(self, n_input=1, n_output=35, stride=16, n_channel=32):
        super().__init__()
        self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=80, stride=stride)
        self.bn1 = nn.BatchNorm1d(n_channel)
        self.pool1 = nn.MaxPool1d(4)
        self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
        self.bn2 = nn.BatchNorm1d(n_channel)
        self.pool2 = nn.MaxPool1d(4)
        self.conv3 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=3)
        self.bn3 = nn.BatchNorm1d(2 * n_channel)
        self.pool3 = nn.MaxPool1d(4)
        self.conv4 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.bn4 = nn.BatchNorm1d(2 * n_channel)
        self.pool4 = nn.MaxPool1d(4)
        self.fc1 = nn.Linear(2 * n_channel, n_output)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(self.bn1(x))
        x = self.pool1(x)
        x = self.conv2(x)
        x = F.relu(self.bn2(x))
        x = self.pool2(x)
        x = self.conv3(x)
        x = F.relu(self.bn3(x))
        x = self.pool3(x)
        x = self.conv4(x)
        x = F.relu(self.bn4(x))
        x = self.pool4(x)
        x = F.avg_pool1d(x, x.shape[-1])
        x = x.permute(0, 2, 1)
        x = self.fc1(x)
        return F.log_softmax(x, dim=2)

In [43]:
# 参数设置
input_size =  16000
hidden_size = 64
num_classes = 10 
learning_rate = 0.01
num_epochs = 10
best_accuracy = 0.0
# 模型、损失函数和优化器
model = Perceptron(input_size, num_classes)
# model = MLP(input_size, hidden_size, num_classes)
model.to(device)

print(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,weight_decay=0.0001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
log_interval = 20
n_epoch = 2

pbar_update = 1 / (len(train_loader) + len(test_loader))
losses = []

# The transform needs to live on the same device as the model and the data.
transform = transform.to(device)
with tqdm(total=num_epochs) as pbar:     
    for epoch in range(1, num_epochs + 1):
        # 训练过程...
        model.train()  # 确保模型处于训练模式
        for batch_idx, (features, labels) in enumerate(train_loader):
            k=len(features)
            labels=labels.to(device)
            features = features.view(features.size(0), -1)
            features=features.to(device)
            outputs = model(features)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if batch_idx % log_interval == 0:
                print(f"Train Epoch: {epoch} [{batch_idx * k}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}")

            # update progress bar
            pbar.update(pbar_update)
            # record loss
            losses.append(loss.item())

        # 验证过程
        model.eval()  # 设置模型为评估模式
        correct = 0
        total = 0
        with torch.no_grad():
            for features, labels in val_loader:
                features=features.to(device)
                labels=labels.to(device)
                features = features.view(features.size(0), -1)
                outputs = model(features)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        accuracy = correct / total
        print(f'Epoch [{epoch+1}/{num_epochs}], Accuracy: {accuracy:.4f}')
        
        # 检查是否为最佳模型，并保存
        if accuracy > best_accuracy:
            print(f"Found better model at epoch {epoch+1} with accuracy {accuracy:.4f}. Saving model...")
            best_accuracy = accuracy
            torch.save(model.state_dict(), 'best_model_accuracy.pt')  # 保存最佳模型的权重
        scheduler.step()


cuda


  0%|          | 0.0018115942028985507/10 [00:00<20:53, 125.36s/it]



  0%|          | 0.038043478260869575/10 [00:04<21:06, 127.13s/it] 



  1%|          | 0.07427536231884056/10 [00:09<21:16, 128.62s/it] 



  1%|          | 0.11050724637681146/10 [00:14<22:16, 135.13s/it]



  1%|▏         | 0.1467391304347825/10 [00:18<21:34, 131.38s/it] 



  2%|▏         | 0.1829710144927537/10 [00:23<21:33, 131.71s/it] 



  2%|▏         | 0.21920289855072486/10 [00:28<20:28, 125.57s/it]



  3%|▎         | 0.255434782608696/10 [00:33<21:23, 131.71s/it]  



  3%|▎         | 0.2916666666666672/10 [00:37<20:48, 128.63s/it] 



  3%|▎         | 0.32789855072463836/10 [00:42<21:07, 131.05s/it]



  4%|▎         | 0.36413043478260954/10 [00:47<20:39, 128.60s/it]



  4%|▍         | 0.4003623188405807/10 [00:51<20:31, 128.30s/it] 



  4%|▍         | 0.4365942028985519/10 [00:56<20:16, 127.16s/it] 



  5%|▍         | 0.47282608695652306/10 [01:01<20:11, 127.15s/it]



  5%|▌         | 0.5090579710144942/10 [01:05<21:10, 133.82s/it] 



  5%|▌         | 0.5452898550724654/10 [01:10<19:53, 126.23s/it]



  6%|▌         | 0.5815217391304366/10 [01:15<20:16, 129.18s/it]



  6%|▌         | 0.6177536231884078/10 [01:20<20:38, 131.97s/it]



  7%|▋         | 0.6539855072463789/10 [01:24<21:08, 135.68s/it]



  7%|▋         | 0.6902173913043501/10 [01:29<20:27, 131.90s/it]



  7%|▋         | 0.7264492753623213/10 [01:34<19:58, 129.21s/it]



  8%|▊         | 0.7626811594202925/10 [01:39<19:50, 128.88s/it]



  8%|▊         | 0.7989130434782636/10 [01:43<19:26, 126.73s/it]



  8%|▊         | 0.8351449275362348/10 [01:48<19:34, 128.19s/it]



  9%|▊         | 0.871376811594206/10 [01:53<19:11, 126.14s/it] 



  9%|▉         | 0.8822463768115973/10 [01:54<18:41, 122.97s/it]

Epoch [2/10], Accuracy: 0.1312
Found better model at epoch 2 with accuracy 0.1312. Saving model...


  9%|▉         | 0.8840579710144959/10 [02:07<5:40:46, 2242.99s/it]



  9%|▉         | 0.9202898550724671/10 [02:12<19:15, 127.23s/it]   



 10%|▉         | 0.9565217391304383/10 [02:16<19:08, 127.00s/it]



 10%|▉         | 0.9927536231884094/10 [02:21<20:04, 133.67s/it]



 10%|█         | 1.0289855072463787/10 [02:26<19:34, 130.90s/it]



 11%|█         | 1.0652173913043477/10 [02:30<19:41, 132.20s/it]



 11%|█         | 1.1014492753623166/10 [02:35<19:45, 133.27s/it]



 11%|█▏        | 1.1376811594202856/10 [02:40<18:55, 128.13s/it]



 12%|█▏        | 1.1739130434782545/10 [02:45<18:24, 125.17s/it]



 12%|█▏        | 1.2101449275362235/10 [02:49<19:57, 136.24s/it]



 12%|█▏        | 1.2463768115941924/10 [02:54<19:09, 131.34s/it]



 13%|█▎        | 1.2826086956521614/10 [02:59<18:57, 130.45s/it]



 13%|█▎        | 1.3188405797101304/10 [03:03<17:22, 120.12s/it]



 14%|█▎        | 1.3550724637680993/10 [03:08<18:49, 130.62s/it]



 14%|█▍        | 1.3913043478260683/10 [03:13<18:13, 127.04s/it]



 14%|█▍        | 1.4275362318840372/10 [03:18<18:05, 126.62s/it]



 15%|█▍        | 1.4637681159420062/10 [03:22<18:07, 127.41s/it]



 15%|█▍        | 1.4999999999999751/10 [03:27<18:12, 128.58s/it]



 15%|█▌        | 1.536231884057944/10 [03:32<18:30, 131.24s/it] 



 16%|█▌        | 1.572463768115913/10 [03:36<18:16, 130.09s/it] 



 16%|█▌        | 1.608695652173882/10 [03:41<17:39, 126.22s/it] 



 16%|█▋        | 1.644927536231851/10 [03:46<20:48, 149.37s/it] 



 17%|█▋        | 1.68115942028982/10 [03:51<18:13, 131.46s/it]  



 17%|█▋        | 1.7173913043477889/10 [03:55<17:29, 126.74s/it]



 18%|█▊        | 1.7536231884057578/10 [04:00<18:13, 132.56s/it]



 18%|█▊        | 1.7644927536231485/10 [04:01<16:44, 122.03s/it]

Epoch [3/10], Accuracy: 0.1389
Found better model at epoch 3 with accuracy 0.1389. Saving model...


 18%|█▊        | 1.766304347826047/10 [04:15<5:13:44, 2286.26s/it]



 18%|█▊        | 1.802536231884016/10 [04:19<18:06, 132.53s/it]    



 18%|█▊        | 1.8387681159419849/10 [04:24<17:00, 125.06s/it]



 19%|█▊        | 1.8749999999999538/10 [04:29<18:02, 133.28s/it]



 19%|█▉        | 1.9112318840579228/10 [04:33<16:57, 125.84s/it]



 19%|█▉        | 1.9474637681158917/10 [04:38<17:14, 128.42s/it]



 20%|█▉        | 1.9836956521738607/10 [04:43<17:42, 132.57s/it]



 20%|██        | 2.0199275362318296/10 [04:47<17:45, 133.56s/it]



 21%|██        | 2.0561594202897986/10 [04:52<16:38, 125.66s/it]



 21%|██        | 2.0923913043477675/10 [04:57<16:27, 124.94s/it]



 21%|██▏       | 2.1286231884057365/10 [05:01<17:09, 130.84s/it]



 22%|██▏       | 2.1648550724637055/10 [05:06<16:21, 125.27s/it]



 22%|██▏       | 2.2010869565216744/10 [05:10<16:21, 125.90s/it]



 22%|██▏       | 2.2373188405796434/10 [05:15<16:45, 129.52s/it]



 23%|██▎       | 2.2735507246376123/10 [05:20<16:25, 127.49s/it]



 23%|██▎       | 2.3097826086955813/10 [05:24<16:42, 130.35s/it]



 23%|██▎       | 2.3460144927535502/10 [05:29<16:42, 130.99s/it]



 24%|██▍       | 2.382246376811519/10 [05:34<16:32, 130.24s/it] 



 24%|██▍       | 2.418478260869488/10 [05:39<15:55, 125.96s/it] 



 25%|██▍       | 2.454710144927457/10 [05:43<15:51, 126.14s/it] 



 25%|██▍       | 2.490942028985426/10 [05:48<15:59, 127.82s/it] 



 25%|██▌       | 2.527173913043395/10 [05:53<16:22, 131.54s/it] 



 26%|██▌       | 2.563405797101364/10 [05:57<15:36, 125.89s/it] 



 26%|██▌       | 2.599637681159333/10 [06:02<16:15, 131.84s/it] 



 26%|██▋       | 2.635869565217302/10 [06:07<15:58, 130.12s/it] 



 26%|██▋       | 2.6467391304346926/10 [06:08<15:16, 124.70s/it]

Epoch [4/10], Accuracy: 0.1318


 26%|██▋       | 2.648550724637591/10 [06:21<4:39:29, 2281.08s/it]



 27%|██▋       | 2.68478260869556/10 [06:26<16:46, 137.54s/it]     



 27%|██▋       | 2.721014492753529/10 [06:31<16:28, 135.76s/it] 



 28%|██▊       | 2.757246376811498/10 [06:36<15:44, 130.47s/it] 



 28%|██▊       | 2.793478260869467/10 [06:40<15:34, 129.63s/it] 



 28%|██▊       | 2.8297101449274358/10 [06:45<16:18, 136.45s/it]



 29%|██▊       | 2.8659420289854047/10 [06:50<15:22, 129.25s/it]



 29%|██▉       | 2.9021739130433737/10 [06:54<15:47, 133.54s/it]



 29%|██▉       | 2.9384057971013426/10 [06:59<15:04, 128.05s/it]



 30%|██▉       | 2.9746376811593116/10 [07:04<15:42, 134.16s/it]



 30%|███       | 3.0108695652172806/10 [07:08<15:02, 129.15s/it]



 30%|███       | 3.0471014492752495/10 [07:13<14:10, 122.25s/it]



 31%|███       | 3.0833333333332185/10 [07:18<14:39, 127.18s/it]



 31%|███       | 3.1195652173911874/10 [07:23<14:27, 126.06s/it]



 32%|███▏      | 3.1557971014491564/10 [07:27<15:21, 134.58s/it]



 32%|███▏      | 3.1920289855071253/10 [07:32<15:13, 134.23s/it]



 32%|███▏      | 3.2282608695650943/10 [07:37<14:24, 127.67s/it]



 33%|███▎      | 3.2644927536230632/10 [07:41<14:42, 131.03s/it]



 33%|███▎      | 3.300724637681032/10 [07:46<14:35, 130.62s/it] 



 33%|███▎      | 3.336956521739001/10 [07:51<14:50, 133.62s/it] 



 34%|███▎      | 3.37318840579697/10 [07:56<14:20, 129.79s/it]  



 34%|███▍      | 3.409420289854939/10 [08:00<14:22, 130.88s/it] 



 34%|███▍      | 3.445652173912908/10 [08:05<13:45, 126.02s/it] 



 35%|███▍      | 3.481884057970877/10 [08:10<14:15, 131.22s/it] 



 35%|███▌      | 3.518115942028846/10 [08:14<14:02, 130.05s/it] 



 35%|███▌      | 3.5289855072462366/10 [08:16<14:49, 137.44s/it]

Epoch [5/10], Accuracy: 0.1219


 35%|███▌      | 3.530797101449135/10 [08:29<4:04:04, 2263.67s/it]



 36%|███▌      | 3.567028985507104/10 [08:34<14:09, 132.06s/it]    



 36%|███▌      | 3.603260869565073/10 [08:38<13:21, 125.25s/it] 



 36%|███▋      | 3.639492753623042/10 [08:43<13:33, 127.94s/it] 



 37%|███▋      | 3.675724637681011/10 [08:48<13:34, 128.74s/it] 



 37%|███▋      | 3.71195652173898/10 [08:52<14:33, 138.93s/it]  



 37%|███▋      | 3.748188405796949/10 [08:57<13:07, 125.93s/it] 



 38%|███▊      | 3.7844202898549177/10 [09:02<13:51, 133.78s/it]



 38%|███▊      | 3.8206521739128867/10 [09:07<14:27, 140.35s/it]



 39%|███▊      | 3.8568840579708557/10 [09:11<13:07, 128.22s/it]



 39%|███▉      | 3.8931159420288246/10 [09:16<13:28, 132.46s/it]



 39%|███▉      | 3.9293478260867936/10 [09:20<12:47, 126.51s/it]



 40%|███▉      | 3.9655797101447625/10 [09:25<12:42, 126.29s/it]



 40%|████      | 4.0018115942027315/10 [09:30<13:00, 130.05s/it]



 40%|████      | 4.0380434782607/10 [09:35<12:21, 124.32s/it]   



 41%|████      | 4.074275362318669/10 [09:39<12:34, 127.33s/it] 



 41%|████      | 4.110507246376638/10 [09:44<12:44, 129.82s/it] 



 41%|████▏     | 4.146739130434607/10 [09:49<12:59, 133.20s/it] 



 42%|████▏     | 4.182971014492576/10 [09:53<12:23, 127.90s/it] 



 42%|████▏     | 4.219202898550545/10 [09:58<12:22, 128.40s/it] 



 43%|████▎     | 4.255434782608514/10 [10:03<13:15, 138.47s/it] 



 43%|████▎     | 4.291666666666483/10 [10:07<12:23, 130.28s/it] 



 43%|████▎     | 4.327898550724452/10 [10:12<12:27, 131.76s/it] 



 44%|████▎     | 4.364130434782421/10 [10:17<12:08, 129.34s/it] 



 44%|████▍     | 4.40036231884039/10 [10:22<12:19, 132.07s/it]  



 44%|████▍     | 4.411231884057781/10 [10:23<11:21, 121.88s/it]

Epoch [6/10], Accuracy: 0.1309


 44%|████▍     | 4.413043478260679/10 [10:36<3:29:25, 2249.15s/it]



 44%|████▍     | 4.449275362318648/10 [10:41<12:01, 130.00s/it]   



 45%|████▍     | 4.485507246376617/10 [10:45<11:43, 127.53s/it] 



 45%|████▌     | 4.521739130434586/10 [10:50<12:06, 132.62s/it] 



 46%|████▌     | 4.557971014492555/10 [10:55<11:22, 125.39s/it] 



 46%|████▌     | 4.594202898550524/10 [10:59<11:24, 126.65s/it] 



 46%|████▋     | 4.630434782608493/10 [11:04<11:05, 123.92s/it] 



 47%|████▋     | 4.666666666666462/10 [11:09<11:18, 127.15s/it] 



 47%|████▋     | 4.702898550724431/10 [11:13<11:47, 133.59s/it] 



 47%|████▋     | 4.7391304347824/10 [11:18<11:38, 132.68s/it]   



 48%|████▊     | 4.775362318840369/10 [11:23<11:17, 129.62s/it] 



 48%|████▊     | 4.811594202898338/10 [11:28<10:59, 127.03s/it] 



 48%|████▊     | 4.847826086956307/10 [11:32<10:57, 127.53s/it] 



 49%|████▉     | 4.8840579710142755/10 [11:37<11:39, 136.71s/it]



 49%|████▉     | 4.9202898550722445/10 [11:42<11:22, 134.38s/it]



 50%|████▉     | 4.956521739130213/10 [11:46<10:43, 127.62s/it] 



 50%|████▉     | 4.992753623188182/10 [11:51<10:58, 131.57s/it] 



 50%|█████     | 5.028985507246151/10 [11:56<10:52, 131.17s/it] 



 51%|█████     | 5.06521739130412/10 [12:01<10:33, 128.28s/it]  



 51%|█████     | 5.101449275362089/10 [12:05<10:22, 127.10s/it] 



 51%|█████▏    | 5.137681159420058/10 [12:10<11:10, 137.98s/it] 



 52%|█████▏    | 5.173913043478027/10 [12:15<10:39, 132.41s/it] 



 52%|█████▏    | 5.210144927535996/10 [12:20<10:03, 125.96s/it] 



 52%|█████▏    | 5.246376811593965/10 [12:24<10:09, 128.11s/it] 



 53%|█████▎    | 5.282608695651934/10 [12:29<09:51, 125.47s/it] 



 53%|█████▎    | 5.293478260869325/10 [12:30<09:40, 123.44s/it] 

Epoch [7/10], Accuracy: 0.1268


 53%|█████▎    | 5.295289855072223/10 [12:44<2:58:18, 2273.93s/it]



 53%|█████▎    | 5.331521739130192/10 [12:48<09:58, 128.25s/it]   



 54%|█████▎    | 5.367753623188161/10 [12:53<09:53, 128.21s/it] 



 54%|█████▍    | 5.40398550724613/10 [12:57<09:45, 127.30s/it]  



 54%|█████▍    | 5.440217391304099/10 [13:02<10:36, 139.50s/it] 



 55%|█████▍    | 5.476449275362068/10 [13:07<09:47, 129.93s/it] 



 55%|█████▌    | 5.512681159420037/10 [13:12<09:32, 127.50s/it] 



 55%|█████▌    | 5.548913043478006/10 [13:16<09:34, 129.11s/it] 



 56%|█████▌    | 5.585144927535975/10 [13:21<09:26, 128.34s/it] 



 56%|█████▌    | 5.621376811593944/10 [13:26<09:19, 127.86s/it] 



 57%|█████▋    | 5.657608695651913/10 [13:31<09:30, 131.31s/it] 



 57%|█████▋    | 5.693840579709882/10 [13:35<08:39, 120.70s/it] 



 57%|█████▋    | 5.730072463767851/10 [13:40<09:22, 131.76s/it] 



 58%|█████▊    | 5.76630434782582/10 [13:45<09:32, 135.29s/it]  



 58%|█████▊    | 5.8025362318837885/10 [13:49<08:42, 124.50s/it]



 58%|█████▊    | 5.8387681159417575/10 [13:54<08:58, 129.52s/it]



 59%|█████▊    | 5.874999999999726/10 [13:59<08:40, 126.19s/it] 



 59%|█████▉    | 5.911231884057695/10 [14:03<08:54, 130.75s/it] 



 59%|█████▉    | 5.947463768115664/10 [14:08<08:43, 129.12s/it] 



 60%|█████▉    | 5.983695652173633/10 [14:13<08:20, 124.55s/it] 



 60%|██████    | 6.019927536231602/10 [14:17<08:39, 130.61s/it] 



 61%|██████    | 6.056159420289571/10 [14:22<08:12, 125.00s/it] 



 61%|██████    | 6.09239130434754/10 [14:27<08:13, 126.24s/it]  



 61%|██████▏   | 6.128623188405509/10 [14:31<08:18, 128.71s/it] 



 62%|██████▏   | 6.164855072463478/10 [14:36<08:02, 125.76s/it] 



 62%|██████▏   | 6.175724637680869/10 [14:37<07:31, 118.00s/it] 

Epoch [8/10], Accuracy: 0.1134


 62%|██████▏   | 6.177536231883767/10 [14:50<2:23:23, 2250.66s/it]



 62%|██████▏   | 6.213768115941736/10 [14:55<08:26, 133.83s/it]   



 62%|██████▏   | 6.249999999999705/10 [15:00<08:04, 129.19s/it] 



 63%|██████▎   | 6.286231884057674/10 [15:04<07:57, 128.55s/it] 



 63%|██████▎   | 6.322463768115643/10 [15:09<07:58, 130.22s/it] 



 64%|██████▎   | 6.358695652173612/10 [15:14<08:11, 134.91s/it] 



 64%|██████▍   | 6.394927536231581/10 [15:19<07:41, 127.91s/it] 



 64%|██████▍   | 6.43115942028955/10 [15:24<07:42, 129.57s/it]  



 65%|██████▍   | 6.467391304347519/10 [15:28<07:42, 130.99s/it] 



 65%|██████▌   | 6.503623188405488/10 [15:33<07:19, 125.84s/it] 



 65%|██████▌   | 6.539855072463457/10 [15:38<07:33, 131.09s/it] 



 66%|██████▌   | 6.576086956521426/10 [15:42<07:14, 126.89s/it] 



 66%|██████▌   | 6.612318840579395/10 [15:47<06:50, 121.22s/it] 



 66%|██████▋   | 6.648550724637364/10 [15:51<07:02, 126.19s/it] 



 67%|██████▋   | 6.684782608695333/10 [15:56<06:55, 125.35s/it] 



 67%|██████▋   | 6.7210144927533015/10 [16:01<06:59, 128.00s/it]



 68%|██████▊   | 6.7572463768112705/10 [16:05<07:03, 130.66s/it]



 68%|██████▊   | 6.7934782608692394/10 [16:10<06:56, 129.78s/it]



 68%|██████▊   | 6.829710144927208/10 [16:15<06:39, 126.02s/it] 



 69%|██████▊   | 6.865942028985177/10 [16:20<06:47, 130.07s/it] 



 69%|██████▉   | 6.902173913043146/10 [16:24<06:28, 125.37s/it] 



 69%|██████▉   | 6.938405797101115/10 [16:29<06:56, 135.88s/it] 



 70%|██████▉   | 6.974637681159084/10 [16:34<06:17, 124.84s/it] 



 70%|███████   | 7.010869565217053/10 [16:38<06:41, 134.19s/it] 



 70%|███████   | 7.047101449275022/10 [16:43<06:20, 128.82s/it] 



 71%|███████   | 7.057971014492413/10 [16:44<06:11, 126.25s/it] 

Epoch [9/10], Accuracy: 0.1257


 71%|███████   | 7.059782608695311/10 [16:57<1:50:16, 2250.26s/it]



 71%|███████   | 7.09601449275328/10 [17:02<06:35, 136.07s/it]    



 71%|███████▏  | 7.132246376811249/10 [17:07<06:01, 126.18s/it] 



 72%|███████▏  | 7.168478260869218/10 [17:12<06:05, 128.92s/it] 



 72%|███████▏  | 7.204710144927187/10 [17:16<06:01, 129.38s/it] 



 72%|███████▏  | 7.240942028985156/10 [17:21<05:59, 130.24s/it] 



 73%|███████▎  | 7.277173913043125/10 [17:26<06:03, 133.58s/it] 



 73%|███████▎  | 7.313405797101094/10 [17:30<05:46, 128.85s/it] 



 73%|███████▎  | 7.349637681159063/10 [17:35<05:42, 129.28s/it] 



 74%|███████▍  | 7.385869565217032/10 [17:40<05:57, 136.87s/it] 



 74%|███████▍  | 7.422101449275001/10 [17:45<05:34, 129.80s/it] 



 75%|███████▍  | 7.45833333333297/10 [17:49<05:47, 136.80s/it]  



 75%|███████▍  | 7.494565217390939/10 [17:54<05:14, 125.40s/it] 



 75%|███████▌  | 7.530797101448908/10 [17:59<05:21, 130.24s/it] 



 76%|███████▌  | 7.567028985506877/10 [18:03<05:17, 130.63s/it] 



 76%|███████▌  | 7.603260869564846/10 [18:08<05:22, 134.45s/it] 



 76%|███████▋  | 7.6394927536228145/10 [18:13<05:06, 129.67s/it]



 77%|███████▋  | 7.6757246376807835/10 [18:17<04:53, 126.37s/it]



 77%|███████▋  | 7.7119565217387525/10 [18:22<05:04, 133.02s/it]



 77%|███████▋  | 7.748188405796721/10 [18:27<04:47, 127.90s/it] 



 78%|███████▊  | 7.78442028985469/10 [18:32<04:53, 132.35s/it]  



 78%|███████▊  | 7.820652173912659/10 [18:36<04:34, 126.05s/it] 



 79%|███████▊  | 7.856884057970628/10 [18:41<04:39, 130.33s/it] 



 79%|███████▉  | 7.893115942028597/10 [18:46<04:26, 126.69s/it] 



 79%|███████▉  | 7.929347826086566/10 [18:50<04:24, 127.50s/it] 



 79%|███████▉  | 7.940217391303957/10 [18:52<04:15, 124.06s/it] 

Epoch [10/10], Accuracy: 0.1296


 79%|███████▉  | 7.942028985506855/10 [19:05<1:20:07, 2336.03s/it]



 80%|███████▉  | 7.978260869564824/10 [19:10<04:38, 137.61s/it]   



 80%|████████  | 8.014492753622793/10 [19:15<04:19, 130.86s/it] 



 81%|████████  | 8.050724637680762/10 [19:19<04:15, 130.92s/it]



 81%|████████  | 8.086956521738731/10 [19:24<04:05, 128.54s/it]



 81%|████████  | 8.1231884057967/10 [19:29<04:00, 128.27s/it]  



 82%|████████▏ | 8.159420289854669/10 [19:33<04:08, 134.96s/it]



 82%|████████▏ | 8.195652173912638/10 [19:38<04:00, 133.26s/it]



 82%|████████▏ | 8.231884057970607/10 [19:43<03:50, 130.45s/it]



 83%|████████▎ | 8.268115942028576/10 [19:48<03:41, 128.16s/it]



 83%|████████▎ | 8.304347826086545/10 [19:53<03:48, 135.04s/it]



 83%|████████▎ | 8.340579710144514/10 [19:57<03:29, 126.12s/it]



 84%|████████▍ | 8.376811594202483/10 [20:02<03:25, 126.35s/it]



 84%|████████▍ | 8.413043478260452/10 [20:06<03:15, 123.50s/it]



 84%|████████▍ | 8.44927536231842/10 [20:11<03:22, 130.54s/it] 



 85%|████████▍ | 8.48550724637639/10 [20:16<03:14, 128.14s/it] 



 85%|████████▌ | 8.521739130434359/10 [20:21<03:23, 137.65s/it]



 86%|████████▌ | 8.557971014492328/10 [20:25<03:05, 128.77s/it]



 86%|████████▌ | 8.594202898550297/10 [20:30<02:56, 125.82s/it]



 86%|████████▋ | 8.630434782608265/10 [20:35<03:06, 136.10s/it]



 87%|████████▋ | 8.666666666666234/10 [20:40<02:55, 131.76s/it]



 87%|████████▋ | 8.702898550724203/10 [20:44<02:46, 128.01s/it]



 87%|████████▋ | 8.739130434782172/10 [20:49<02:39, 126.51s/it]



 88%|████████▊ | 8.775362318840141/10 [20:54<02:47, 136.90s/it]



 88%|████████▊ | 8.81159420289811/10 [20:58<02:28, 124.60s/it] 



 88%|████████▊ | 8.822463768115501/10 [21:13<02:49, 144.32s/it]

Epoch [11/10], Accuracy: 0.1238





In [45]:
# 测试模型
model.load_state_dict(torch.load('best_model_accuracy.pt'))

model.eval()  # 设置模型为评估模式
model.to(device)
with torch.no_grad():
    correct = 0
    total = 0
    for features, labels in test_loader:
        features=features.to(device)
        labels=labels.to(device)
        features = features.view(features.size(0), -1)
        outputs = model(features)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    print(f'Accuracy of the model on the test set: {100 * correct / total} %')


Accuracy of the model on the test set: 9.544679814950085 %
