In [10]:
import os
import sys
from torch.utils.data import DataLoader
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from dataloader import iris_dataloader


# 初始化神经网络模型

class NN(nn.Module):
    # 将输入层、隐藏层、输出层的维度输入进神经网络模型中
    def __init__(self, in_dim, hidden_dim1, hidden_dim2, out_dim) -> None:
        super().__init__()
        self.layer1 = nn.Linear(in_dim, hidden_dim1)
        self.layer2 = nn.Linear(hidden_dim1, hidden_dim2)
        self.layer3 = nn.Linear(hidden_dim2, out_dim)

    #定义前向传播
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

    # 定义计算环境


device = torch.device("cude:0" if torch.cuda.is_available() else "cpu")

# 训练集，验证集和测试集

custom_dataset = iris_dataloader("D:\\Pycharm\\PyCharm 2024.3.1.1\\Pytorch实战\\iris.txt")

# 划分数据集
train_size = int(len(custom_dataset) * 0.7)
val_size = int(len(custom_dataset) * 0.2)
# or test_size = int(len(custom_dataset)) - train_size - val_size
test_size = int(len(custom_dataset) * 0.1)

# random_split 按比例的随机切分，两个参数，一个是需要切分的数据集，一个是划分数据集的比例
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(custom_dataset,
                                                                         [train_size, val_size, test_size])

# shuffle=True作用：在batch抽取一定量数据集出来之后，将数据集进行打乱
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

print("训练集的大小", train_size, "验证集的大小", val_size, "测试集的大小", test_size)


# 定义一个推理函数，来计算并返回准确率

def infer(self, model, dataset, device):
    model.eval()
    acc_num = 0
    # 仅验证当前模型的性能，并不改变模型的参数
    with torch.no_grad():
        for data in dataset:
            datas, label = data
            # 该模型的返回结果（三种鸢尾花的可能性）
            outputs = model(data.to(device))
            #此时第零维为batch维度，即训练数据的数量，第一维才是结果维度，即鸢尾花的可能性
            # 该函数返回一个元组，其中包含两个元素，第一个元素是最大值（每个样本的最大分数），第二个元素是最大值所在的索引（每个样本预测的类别索引），因此[1]代表我们只取索引部分，即预测的类别标签
            predict_y = torch.max(outputs, dim=1)[1]
            # 比较当前模型的预测结果(predicct_y)与真实结果（label.to(device)）
            # 由于每一次都是一批量数据加入该函数中，因此我们需要sum().item()来获取这一批量数据中预测正确的个数，而+=是对每一批量加起来的全部数据的预测个数，而item()则是让我们取到该结果的数值，即数量
            acc_num += torch.eq(predict_y, label.to(device)).sum().item()

    acc = acc_num / len(dataset)
    return acc


def main(self, lr=0.005, epochs=20):
    model = NN(4, 12, 6, 3).to(device)
    loss_f = nn.CrossEntropyLoss()

    # model.parameters()将会调用模型中的所有参数，if语句判断是否为可迭代的参数 如果p.requires_grad为True，则进入该列表中
    pg = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.Adam(pg, lr=lr)

    # 权重文件存储路径，getcwd()将会获取目前文件夹的路径
    save_path = os.path.join(os.getcwd(), "results/weights")
    if os.path.exists(save_path) is False:
        os.makedirs(save_path)

    # 开始训练
    for epoch in range(epochs):
        model.train()
        acc_num = torch.zeros(1).to(device)
        sample_num = 0

        train_bar = tqdm(train_loader, file=sys.stdout, ncols=100)
        for datas in train_bar:
            data, label = datas
            # 移除标签张量中大小为1的最后一个维度，(-1)表示指定要挤压（移除）的维度位置
            label = label.squeeze(-1)
            sample_num += data.shape[0]

            # 防止以前的梯度，对当前产生一些影响
            optimizer.zero_grad()
            outputs = model(data.to(device))
            pred_class = torch.max(outputs, dim=1)[1]
            acc_num = torch.eq(pred_class, label.to(device)).sum()

            loss = loss_f(outputs, label.to(device))
            loss.backward()
            optimizer.step()

            train_acc = acc_num / sample_num
            train_bar.desc = "train rpoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)

        val_acc = infer(model, val_loader, device)
        print("train epoch[{}/{}] loss:{:.3f} train_acc{:.3f}".format(epoch + 1, epochs, loss, train_acc))
        torch.save(model.state_dict(), os.path.join(save_path, "nn.pth"))

        # 每次数据集迭代之后，要对初始化的指标清零
        train_acc = 0.
        val_acc = 0.
    print("Finished Training!")

    test_acc = infer(model, test_loader, device)
    print("test_acc", test_acc)


if __name__ == "__main__":
    main()


TypeError: Could not convert ['1 5.1 3.5 1.4 0.2 "setosa"2 4.9 3 1.4 0.2 "setosa"3 4.7 3.2 1.3 0.2 "setosa"4 4.6 3.1 1.5 0.2 "setosa"5 5 3.6 1.4 0.2 "setosa"6 5.4 3.9 1.7 0.4 "setosa"7 4.6 3.4 1.4 0.3 "setosa"8 5 3.4 1.5 0.2 "setosa"9 4.4 2.9 1.4 0.2 "setosa"10 4.9 3.1 1.5 0.1 "setosa"11 5.4 3.7 1.5 0.2 "setosa"12 4.8 3.4 1.6 0.2 "setosa"13 4.8 3 1.4 0.1 "setosa"14 4.3 3 1.1 0.1 "setosa"15 5.8 4 1.2 0.2 "setosa"16 5.7 4.4 1.5 0.4 "setosa"17 5.4 3.9 1.3 0.4 "setosa"18 5.1 3.5 1.4 0.3 "setosa"19 5.7 3.8 1.7 0.3 "setosa"20 5.1 3.8 1.5 0.3 "setosa"21 5.4 3.4 1.7 0.2 "setosa"22 5.1 3.7 1.5 0.4 "setosa"23 4.6 3.6 1 0.2 "setosa"24 5.1 3.3 1.7 0.5 "setosa"25 4.8 3.4 1.9 0.2 "setosa"26 5 3 1.6 0.2 "setosa"27 5 3.4 1.6 0.4 "setosa"28 5.2 3.5 1.5 0.2 "setosa"29 5.2 3.4 1.4 0.2 "setosa"30 4.7 3.2 1.6 0.2 "setosa"31 4.8 3.1 1.6 0.2 "setosa"32 5.4 3.4 1.5 0.4 "setosa"33 5.2 4.1 1.5 0.1 "setosa"34 5.5 4.2 1.4 0.2 "setosa"35 4.9 3.1 1.5 0.2 "setosa"36 5 3.2 1.2 0.2 "setosa"37 5.5 3.5 1.3 0.2 "setosa"38 4.9 3.6 1.4 0.1 "setosa"39 4.4 3 1.3 0.2 "setosa"40 5.1 3.4 1.5 0.2 "setosa"41 5 3.5 1.3 0.3 "setosa"42 4.5 2.3 1.3 0.3 "setosa"43 4.4 3.2 1.3 0.2 "setosa"44 5 3.5 1.6 0.6 "setosa"45 5.1 3.8 1.9 0.4 "setosa"46 4.8 3 1.4 0.3 "setosa"47 5.1 3.8 1.6 0.2 "setosa"48 4.6 3.2 1.4 0.2 "setosa"49 5.3 3.7 1.5 0.2 "setosa"50 5 3.3 1.4 0.2 "setosa"51 7 3.2 4.7 1.4 "versicolor"52 6.4 3.2 4.5 1.5 "versicolor"53 6.9 3.1 4.9 1.5 "versicolor"54 5.5 2.3 4 1.3 "versicolor"55 6.5 2.8 4.6 1.5 "versicolor"56 5.7 2.8 4.5 1.3 "versicolor"57 6.3 3.3 4.7 1.6 "versicolor"58 4.9 2.4 3.3 1 "versicolor"59 6.6 2.9 4.6 1.3 "versicolor"60 5.2 2.7 3.9 1.4 "versicolor"61 5 2 3.5 1 "versicolor"62 5.9 3 4.2 1.5 "versicolor"63 6 2.2 4 1 "versicolor"64 6.1 2.9 4.7 1.4 "versicolor"65 5.6 2.9 3.6 1.3 "versicolor"66 6.7 3.1 4.4 1.4 "versicolor"67 5.6 3 4.5 1.5 "versicolor"68 5.8 2.7 4.1 1 "versicolor"69 6.2 2.2 4.5 1.5 "versicolor"70 5.6 2.5 3.9 1.1 "versicolor"71 5.9 3.2 4.8 1.8 "versicolor"72 6.1 2.8 4 1.3 "versicolor"73 6.3 2.5 4.9 1.5 "versicolor"74 6.1 2.8 4.7 1.2 "versicolor"75 6.4 2.9 4.3 1.3 "versicolor"76 6.6 3 4.4 1.4 "versicolor"77 6.8 2.8 4.8 1.4 "versicolor"78 6.7 3 5 1.7 "versicolor"79 6 2.9 4.5 1.5 "versicolor"80 5.7 2.6 3.5 1 "versicolor"81 5.5 2.4 3.8 1.1 "versicolor"82 5.5 2.4 3.7 1 "versicolor"83 5.8 2.7 3.9 1.2 "versicolor"84 6 2.7 5.1 1.6 "versicolor"85 5.4 3 4.5 1.5 "versicolor"86 6 3.4 4.5 1.6 "versicolor"87 6.7 3.1 4.7 1.5 "versicolor"88 6.3 2.3 4.4 1.3 "versicolor"89 5.6 3 4.1 1.3 "versicolor"90 5.5 2.5 4 1.3 "versicolor"91 5.5 2.6 4.4 1.2 "versicolor"92 6.1 3 4.6 1.4 "versicolor"93 5.8 2.6 4 1.2 "versicolor"94 5 2.3 3.3 1 "versicolor"95 5.6 2.7 4.2 1.3 "versicolor"96 5.7 3 4.2 1.2 "versicolor"97 5.7 2.9 4.2 1.3 "versicolor"98 6.2 2.9 4.3 1.3 "versicolor"99 5.1 2.5 3 1.1 "versicolor"100 5.7 2.8 4.1 1.3 "versicolor"101 6.3 3.3 6 2.5 "virginica"102 5.8 2.7 5.1 1.9 "virginica"103 7.1 3 5.9 2.1 "virginica"104 6.3 2.9 5.6 1.8 "virginica"105 6.5 3 5.8 2.2 "virginica"106 7.6 3 6.6 2.1 "virginica"107 4.9 2.5 4.5 1.7 "virginica"108 7.3 2.9 6.3 1.8 "virginica"109 6.7 2.5 5.8 1.8 "virginica"110 7.2 3.6 6.1 2.5 "virginica"111 6.5 3.2 5.1 2 "virginica"112 6.4 2.7 5.3 1.9 "virginica"113 6.8 3 5.5 2.1 "virginica"114 5.7 2.5 5 2 "virginica"115 5.8 2.8 5.1 2.4 "virginica"116 6.4 3.2 5.3 2.3 "virginica"117 6.5 3 5.5 1.8 "virginica"118 7.7 3.8 6.7 2.2 "virginica"119 7.7 2.6 6.9 2.3 "virginica"120 6 2.2 5 1.5 "virginica"121 6.9 3.2 5.7 2.3 "virginica"122 5.6 2.8 4.9 2 "virginica"123 7.7 2.8 6.7 2 "virginica"124 6.3 2.7 4.9 1.8 "virginica"125 6.7 3.3 5.7 2.1 "virginica"126 7.2 3.2 6 1.8 "virginica"127 6.2 2.8 4.8 1.8 "virginica"128 6.1 3 4.9 1.8 "virginica"129 6.4 2.8 5.6 2.1 "virginica"130 7.2 3 5.8 1.6 "virginica"131 7.4 2.8 6.1 1.9 "virginica"132 7.9 3.8 6.4 2 "virginica"133 6.4 2.8 5.6 2.2 "virginica"134 6.3 2.8 5.1 1.5 "virginica"135 6.1 2.6 5.6 1.4 "virginica"136 7.7 3 6.1 2.3 "virginica"137 6.3 3.4 5.6 2.4 "virginica"138 6.4 3.1 5.5 1.8 "virginica"139 6 3 4.8 1.8 "virginica"140 6.9 3.1 5.4 2.1 "virginica"141 6.7 3.1 5.6 2.4 "virginica"142 6.9 3.1 5.1 2.3 "virginica"143 5.8 2.7 5.1 1.9 "virginica"144 6.8 3.2 5.9 2.3 "virginica"145 6.7 3.3 5.7 2.5 "virginica"146 6.7 3 5.2 2.3 "virginica"147 6.3 2.5 5 1.9 "virginica"148 6.5 3 5.2 2 "virginica"149 6.2 3.4 5.4 2.3 "virginica"150 5.9 3 5.1 1.8 "virginica"'] to numeric