From fbab0824d49342a46580fee2c81aa29b0204cd8d Mon Sep 17 00:00:00 2001 From: Elaina <1463967532@qq.com> Date: Tue, 11 Nov 2025 12:53:31 +0000 Subject: [PATCH 1/6] =?UTF-8?q?=E7=A7=BB=E6=A4=8Dvisual=E5=8F=AF=E8=A7=86?= =?UTF-8?q?=E5=8C=96,=E6=9C=89numpy=E5=86=B2=E7=AA=81=EF=BC=8C=E6=89=93?= =?UTF-8?q?=E7=AE=97=E6=9B=B4=E6=8D=A2=E5=AE=B9=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/visual.py | 68 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/src/visual.py b/src/visual.py index e69de29..029b110 100644 --- a/src/visual.py +++ b/src/visual.py @@ -0,0 +1,68 @@ +# visual.py +import matplotlib.pyplot as plt +import os,torch +class SaveAndVisual: + """模型管理类,负责模型保存和训练可视化""" + + def __init__(self, model_dir='models', loss_img_path='loss_curve.png'): + self.model_dir = model_dir + self.loss_img_path = loss_img_path + self.epoch_losses = [] # 存储每个epoch的损失 + self.epoch_indices = [] # 存储epoch索引 + self._init_visualization() + self._init_model_dir() + + def _init_model_dir(self): + """初始化模型保存目录""" + os.makedirs(self.model_dir, exist_ok=True) + + def _init_visualization(self): + """初始化可视化环境""" + plt.ion() # 开启交互模式 + self.fig, self.ax = plt.subplots(figsize=(10, 6)) + self.ax.set_xlabel("Epoch") + self.ax.set_ylabel("Loss") + self.ax.set_title("Training Loss (per Epoch)") + self.line, = self.ax.plot([], [], label="Epoch Loss") + self.ax.legend() + + def loadModel(self, model:torch.nn.Module, optimizer, device): + """加载已保存的模型""" + model_path = os.path.join(self.model_dir, 'best_transformer.pth') + if os.path.exists(model_path): + checkpoint = torch.load(model_path, map_location=device, weights_only=True) + model.load_state_dict(checkpoint['model_state_dict']) + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + print(f"发现现有模型,其损失为: {checkpoint['loss']:.4f}") + return checkpoint['loss'] + return float('inf') + + def saveModel(self, model, optimizer, epoch, loss): + """保存模型检查点""" + model_path = os.path.join(self.model_dir, 'best_transformer.pth') + torch.save({ + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss, + }, model_path) + print(f" 保存最佳模型(损失:{loss:.4f})") + + def updateVisualization(self, epoch, loss): + """更新训练损失可视化""" + self.epoch_losses.append(loss) + self.epoch_indices.append(epoch) + + # 更新图像数据 + self.line.set_data(self.epoch_indices, self.epoch_losses) + self.ax.relim() # 重新计算坐标轴范围 + self.ax.autoscale_view() # 自动调整视图 + plt.draw() + plt.pause(0.01) + + def finalizeVisualization(self): + """训练结束后保存并显示最终图像""" + plt.ioff() # 关闭交互模式 + self.ax.set_title("Training Loss (Final)") + plt.savefig(self.loss_img_path) + plt.show() From 5f4d1b9dc74c84cf11ac636cd169aeb8e107b6a3 Mon Sep 17 00:00:00 2001 From: Elaina <1463967532@qq.com> Date: Tue, 11 Nov 2025 21:25:29 +0800 Subject: [PATCH 2/6] =?UTF-8?q?=E6=9B=B4=E6=94=B9=E5=BA=95=E5=B1=82?= =?UTF-8?q?=E9=95=9C=E5=83=8F=E4=B8=BA=E6=97=A0ros=20=E4=BE=9D=E8=B5=96?= =?UTF-8?q?=E7=9A=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .devcontainer/cpu/Dockerfile | 25 ++++---------------- .devcontainer/cpu/build.bash | 2 +- .devcontainer/gpu/Dockerfile | 46 ++++-------------------------------- 3 files changed, 9 insertions(+), 64 deletions(-) diff --git a/.devcontainer/cpu/Dockerfile b/.devcontainer/cpu/Dockerfile index f44b0e6..731b4c8 100644 --- a/.devcontainer/cpu/Dockerfile +++ b/.devcontainer/cpu/Dockerfile @@ -1,23 +1,6 @@ -#FROM elainasuki/ros:ros2-humble-full-v3 -FROM elainasuki/ros:ros2-humble-full-0614 - -ARG USERNAME=Elaina - -ARG USER_GID=$USER_UID - -ARG GROUP_NAME=wheel -RUN apt-get update \ - && apt-get install -y sudo vim nautilus -USER ${USERNAME} -#安装前置依赖 +FROM ubuntu:22.04 +RUN apt-get update\ + && apt-get install -y git python3 python3-pip RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu -# 安装yolo -RUN pip install imutils ultralytics openvino -#安装其他依赖 - -RUN pip install jupyter d2l==0.17.6 - -USER root -RUN echo "source /opt/ros/humble/setup.bash" >> /home/${USERNAME}/.bashrc -ENV PYTHONPATH="/home/Elaina/yolo:${PYTHONPATH}" +RUN pip install pandas openpyxl matplotlib \ No newline at end of file diff --git a/.devcontainer/cpu/build.bash b/.devcontainer/cpu/build.bash index 9416aa9..9d362a6 100755 --- a/.devcontainer/cpu/build.bash +++ b/.devcontainer/cpu/build.bash @@ -9,7 +9,7 @@ echo "脚本目录: $SCRIPT_DIR" echo "父目录: $PARENT_DIR" # 设置默认 tag -TAG="pytorch" +TAG="pytorch_cpu" # 从外部传入的 IMAGE_REPO(格式:ghcr.io/user/repo 或 docker.io/user/repo) IMAGE_REPO="elainasuki/other" diff --git a/.devcontainer/gpu/Dockerfile b/.devcontainer/gpu/Dockerfile index 9bfbd20..c8d96a8 100644 --- a/.devcontainer/gpu/Dockerfile +++ b/.devcontainer/gpu/Dockerfile @@ -1,44 +1,6 @@ -# 基础镜像:ROS 2 Humble(官方桌面版,基于Ubuntu 22.04) -FROM elainasuki/ros:ros2-humble-full-0614 +FROM pytorch/pytorch:2.9.0-cuda13.0-cudnn9-runtime -# 修复未定义变量 -ARG USERNAME=Elaina -ARG USER_UID=1000 -ARG USER_GID=$USER_UID +RUN pip3 install pandas openpyxl matplotlib -# 设置非交互模式+pip缓存目录(避免占用根目录空间) -ENV DEBIAN_FRONTEND=noninteractive -ENV PIP_CACHE_DIR=/tmp/pip-cache - -# 1. 精简基础依赖(仅保留必需项)+ 清理缓存 -RUN apt-get update && apt-get install -y --no-install-recommends \ - wget \ - ca-certificates \ - python3-pip \ - && apt-get clean \ - && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* - -# 2. 仅安装CUDA运行时(极致精简) -RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb \ - && dpkg -i cuda-keyring_1.1-1_all.deb \ - && apt-get update \ - && apt-get install -y --no-install-recommends \ - cuda-runtime-12-1 \ - libcudnn8=8.9.2.26-1+cuda12.1 \ - && apt-get clean \ - && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* \ - && rm -f cuda-keyring_1.1-1_all.deb - -# 3. 配置环境变量(修复未定义问题) -ENV PATH=/usr/local/cuda-12.1/bin:${PATH} -ENV LD_LIBRARY_PATH=/usr/local/cuda-12.1/lib64:${LD_LIBRARY_PATH:-/usr/local/lib} - -# 4. 优化PyTorch安装(节省空间) -RUN pip3 install --no-cache-dir --upgrade pip \ - && pip3 install --no-cache-dir --ignore-installed sympy \ - && pip3 install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 \ - && rm -rf $PIP_CACHE_DIR /tmp/* /var/tmp/* -#安装pandas 和openpyxl -RUN pip3 install pandas openpyxl -# 5. 配置ROS 2环境 -RUN echo "source /opt/ros/humble/setup.bash" >> ~/.bashrc \ No newline at end of file +RUN apt-get update &&\ + apt-get install -y git \ No newline at end of file From 897223ea11ee8ae1cd0b0285c9bcd807d8b39b1a Mon Sep 17 00:00:00 2001 From: Elaina <1463967532@qq.com> Date: Tue, 11 Nov 2025 14:06:51 +0000 Subject: [PATCH 3/6] =?UTF-8?q?=E6=9B=B4=E6=94=B9docker=20compose=20?= =?UTF-8?q?=E5=B9=B6=E6=B7=BB=E5=8A=A0gui=20=E6=B5=8B=E8=AF=95=E4=BE=8B?= =?UTF-8?q?=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .devcontainer/devcontainer.json | 2 +- .devcontainer/docker-compose.yml | 34 +++++++++----------------------- .vscode/settings.json | 3 +++ src/visual.py | 9 +++++++++ 4 files changed, 22 insertions(+), 26 deletions(-) create mode 100644 .vscode/settings.json diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index f137411..74ff08c 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -2,7 +2,7 @@ "name": "information-task-container", "dockerComposeFile": "docker-compose.yml", "service": "gpu-service", - "workspaceFolder": "/home/Elaina/pytorch", + "workspaceFolder": "/pytorch", "shutdownAction": "stopCompose", "customizations": { "vscode": { diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml index 7345a13..9159444 100644 --- a/.devcontainer/docker-compose.yml +++ b/.devcontainer/docker-compose.yml @@ -1,10 +1,7 @@ version: '3' services: pytorch-service: - # build: - # context: .. - # dockerfile: .devcontainer/Dockerfile - image: elainasuki/other:pytorch + image: elainasuki/other:pytorch_cpu container_name: information-cpu-container environment: - DISPLAY=${DISPLAY} @@ -14,21 +11,13 @@ services: - TERM=xterm-256color volumes: - /tmp/.X11-unix:/tmp/.X11-unix - - ./..:/home/Elaina/pytorch - - /dev:/dev - network_mode: host - pid: "host" # 添加 pid 命名空间共享 - ipc: "host" # 添加 ipc 命名空间共享 - privileged: true + - ./..:/pytorch/ + entrypoint: [ '/bin/bash' ] stdin_open: true tty: true - user: "Elaina" # runtime: "nvidia" - working_dir: "/home/Elaina/pytorch" # 指定默认工作目录 + working_dir: "/pytorch" # 指定默认工作目录 gpu-service: - # build: - # context: .. - # dockerfile: .devcontainer/Dockerfile image: elainasuki/other:pytorch_gpu container_name: information-task-container environment: @@ -39,14 +28,9 @@ services: - TERM=xterm-256color volumes: - /tmp/.X11-unix:/tmp/.X11-unix - - ./..:/home/Elaina/pytorch - - /dev:/dev - network_mode: host - pid: "host" # 添加 pid 命名空间共享 - ipc: "host" # 添加 ipc 命名空间共享 - privileged: true - stdin_open: true - tty: true - user: "Elaina" + - ./..:/pytorch runtime: "nvidia" - working_dir: "/home/Elaina/pytorch" # 指定默认工作目录 + tty: true + stdin_open: true + entrypoint: [ '/bin/bash' ] + working_dir: "/pytorch" # 指定默认工作目录 diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..8aef7b1 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.analysis.typeCheckingMode": "standard" +} \ No newline at end of file diff --git a/src/visual.py b/src/visual.py index 029b110..305f04a 100644 --- a/src/visual.py +++ b/src/visual.py @@ -66,3 +66,12 @@ def finalizeVisualization(self): self.ax.set_title("Training Loss (Final)") plt.savefig(self.loss_img_path) plt.show() +if __name__ == "__main__": + plt.figure(figsize=(6, 4)) + plt.plot([1, 2, 3, 4], [1, 4, 9, 16], 'r-', label='test') + plt.xlabel('X') + plt.ylabel('Y') + plt.title('GUI') + plt.legend() + plt.grid(True) + plt.show() # 弹出窗口显示测试图 \ No newline at end of file From c498bd13c2a06964449e97da51d447ae3ab3bea4 Mon Sep 17 00:00:00 2001 From: Elaina <1463967532@qq.com> Date: Tue, 11 Nov 2025 15:25:46 +0000 Subject: [PATCH 4/6] =?UTF-8?q?=E4=BF=AE=E6=94=B9gitignore=20=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0pytorch=E6=94=AF=E6=8C=81device=EF=BC=8C=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E5=A4=9A=E5=B1=82=E6=84=9F=E7=9F=A5=E6=9C=BA=E5=92=8C?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 4 ++++ data/.~task2.xlsx | Bin 165 -> 0 bytes src/load_data.py | 10 +++++----- src/model.py | 19 +++++++++++++++++++ src/train.py | 31 +++++++++++++++++++++++++++++++ src/visual.py | 7 +++++-- 6 files changed, 64 insertions(+), 7 deletions(-) delete mode 100644 data/.~task2.xlsx create mode 100644 src/train.py diff --git a/.gitignore b/.gitignore index e69de29..f34d590 100644 --- a/.gitignore +++ b/.gitignore @@ -0,0 +1,4 @@ +.~* +model/ +*.png +*.pyc \ No newline at end of file diff --git a/data/.~task2.xlsx b/data/.~task2.xlsx deleted file mode 100644 index 65e0b3ec5dfb4f507292e35df66c14f8d1e3ed11..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 165 gcmZQe%}LD6OH?2curZ`Eu;Q#;t diff --git a/src/load_data.py b/src/load_data.py index 674d557..5dc69d2 100644 --- a/src/load_data.py +++ b/src/load_data.py @@ -17,13 +17,13 @@ def LoadTask2Data(file_path)->tuple[torch.Tensor, torch.Tensor]: outputs_tensor = torch.tensor(outputs, dtype=torch.float32) return inputs_tensor, outputs_tensor class DataBatcher: - def __init__(self, file_path: str, val_ratio: float = 0.2, batch_size: int = 32, shuffle: bool = True): + def __init__(self, file_path: str, val_ratio: float = 0.2, batch_size: int = 32, shuffle: bool = True,device: torch.device = torch.device('cpu')): """ 数据批处理工具类,支持划分验证集和生成批次数据 """ self.batch_size = batch_size self.shuffle = shuffle - + self.device = device # 加载原始数据 self.inputs, self.outputs = LoadTask2Data(file_path) @@ -81,16 +81,16 @@ def _splitAndCreateBatches(self, val_ratio: float): def getTrainBatches(self): """获取训练集批次张量 (输入, 输出)""" - return self.train_inputs, self.train_outputs + return self.train_inputs.to(self.device), self.train_outputs.to(self.device) def getValBatches(self): """获取验证集批次张量 (输入, 输出)""" - return self.val_inputs, self.val_outputs + return self.val_inputs.to(self.device), self.val_outputs.to(self.device) if __name__ == "__main__": # 测试代码 - file_path = os.path.join("/home/Elaina/pytorch/data", "task2.xlsx") + file_path = os.path.join("/pytorch/data", "task2.xlsx") # 初始化批处理工具 batcher = DataBatcher(file_path, val_ratio=0.2, batch_size=16) diff --git a/src/model.py b/src/model.py index e69de29..011383b 100644 --- a/src/model.py +++ b/src/model.py @@ -0,0 +1,19 @@ +import torch.nn as nn +import os,torch +class MLPModel(nn.Module): + """简单的多层感知机模型""" + + def __init__(self, input_size, hidden_size, output_size, device=torch.device('cpu')): + super().__init__() + self.linear1 = nn.Linear(input_size, hidden_size) + self.relu = nn.ReLU() + self.linear2 = nn.Linear(hidden_size, output_size) + self.device=device + self.to(device) # 关键:初始化时就把模型移到目标设备 + def forward(self, x:torch.Tensor): + assert isinstance(x, torch.Tensor), "输入必须是torch.Tensor类型" + x = x.to(self.device) # 确保输入数据在正确的设备上 + x = self.linear1(x) + x = self.relu(x) + x = self.linear2(x) + return x \ No newline at end of file diff --git a/src/train.py b/src/train.py new file mode 100644 index 0000000..82f52b6 --- /dev/null +++ b/src/train.py @@ -0,0 +1,31 @@ +from visual import * +from model import * +from load_data import * +def train(): + device=torch.device('cuda' if torch.cuda.is_available() else'cpu') + model=MLPModel(input_size=8,hidden_size=24,output_size=2,device=device) + data_path='/pytorch/data/task2.xlsx' + model_path='/pytorch/model/' + batcher = DataBatcher(file_path=data_path, val_ratio=0.2, batch_size=16,device=device) + train_inputs, train_outputs = batcher.getTrainBatches() + val_inputs, val_outputs = batcher.getValBatches() + #平方损失 + loss_fn=torch.nn.MSELoss() + optimizer=torch.optim.Adam(model.parameters(),lr=0.001) + + visual=SaveAndVisual(model_dir=model_path, loss_img_path='loss_curve.png') + num_epoch=600 + for epoch in range(num_epoch): + model.train() + inputs=train_inputs + targets=train_outputs + optimizer.zero_grad() + outputs=model(inputs) + loss=loss_fn(outputs,targets) + loss.backward() + optimizer.step() + visual.updateVisualization(epoch,loss.item()) + visual.finalizeVisualization() + +if __name__ == "__main__": + train() \ No newline at end of file diff --git a/src/visual.py b/src/visual.py index 305f04a..c97eb98 100644 --- a/src/visual.py +++ b/src/visual.py @@ -11,7 +11,7 @@ def __init__(self, model_dir='models', loss_img_path='loss_curve.png'): self.epoch_indices = [] # 存储epoch索引 self._init_visualization() self._init_model_dir() - + self.loop_count=0 def _init_model_dir(self): """初始化模型保存目录""" os.makedirs(self.model_dir, exist_ok=True) @@ -52,7 +52,10 @@ def updateVisualization(self, epoch, loss): """更新训练损失可视化""" self.epoch_losses.append(loss) self.epoch_indices.append(epoch) - + self.loop_count+=1 + if(self.loop_count%10==0): + self.loop_count=0 + print(f" 训练损失(第{epoch+1}轮):{loss:.4f}") # 更新图像数据 self.line.set_data(self.epoch_indices, self.epoch_losses) self.ax.relim() # 重新计算坐标轴范围 From 7a7348af9803eb27e6f92a3278b5a4d0ed089feb Mon Sep 17 00:00:00 2001 From: Elaina <1463967532@qq.com> Date: Tue, 11 Nov 2025 15:39:29 +0000 Subject: [PATCH 5/6] =?UTF-8?q?fix=20=E4=BF=AE=E5=A4=8D=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E6=B2=A1=E6=9C=89=E4=BF=9D=E5=AD=98=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/train.py | 3 ++- src/visual.py | 10 ++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/train.py b/src/train.py index 82f52b6..d5eee0b 100644 --- a/src/train.py +++ b/src/train.py @@ -13,8 +13,9 @@ def train(): loss_fn=torch.nn.MSELoss() optimizer=torch.optim.Adam(model.parameters(),lr=0.001) - visual=SaveAndVisual(model_dir=model_path, loss_img_path='loss_curve.png') + visual=SaveAndVisual(model_dir=model_path, loss_img_path=model_path+'loss_curve.png') num_epoch=600 + visual.loadModel(model,optimizer,device) for epoch in range(num_epoch): model.train() inputs=train_inputs diff --git a/src/visual.py b/src/visual.py index c97eb98..3071260 100644 --- a/src/visual.py +++ b/src/visual.py @@ -25,15 +25,18 @@ def _init_visualization(self): self.ax.set_title("Training Loss (per Epoch)") self.line, = self.ax.plot([], [], label="Epoch Loss") self.ax.legend() - + self.min_loss=float('inf') def loadModel(self, model:torch.nn.Module, optimizer, device): """加载已保存的模型""" model_path = os.path.join(self.model_dir, 'best_transformer.pth') + self.model=model + self.optimizer=optimizer if os.path.exists(model_path): checkpoint = torch.load(model_path, map_location=device, weights_only=True) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) print(f"发现现有模型,其损失为: {checkpoint['loss']:.4f}") + self.min_loss=checkpoint['loss'] return checkpoint['loss'] return float('inf') @@ -46,7 +49,7 @@ def saveModel(self, model, optimizer, epoch, loss): 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, }, model_path) - print(f" 保存最佳模型(损失:{loss:.4f})") + # print(f" 保存最佳模型(损失:{loss:.4f})") def updateVisualization(self, epoch, loss): """更新训练损失可视化""" @@ -60,6 +63,9 @@ def updateVisualization(self, epoch, loss): self.line.set_data(self.epoch_indices, self.epoch_losses) self.ax.relim() # 重新计算坐标轴范围 self.ax.autoscale_view() # 自动调整视图 + if(loss Date: Tue, 11 Nov 2025 15:44:05 +0000 Subject: [PATCH 6/6] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=8E=A8=E7=90=86.mlp?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E5=8F=AF=E8=83=BD=E8=BF=87=E4=BA=8E=E7=AE=80?= =?UTF-8?q?=E5=8D=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/infer.py | 25 +++++++++++++++++++++++++ src/model.py | 2 +- 2 files changed, 26 insertions(+), 1 deletion(-) create mode 100644 src/infer.py diff --git a/src/infer.py b/src/infer.py new file mode 100644 index 0000000..2bc4aec --- /dev/null +++ b/src/infer.py @@ -0,0 +1,25 @@ +from visual import * +from model import * +from load_data import * +def infer(): + #初始化参数 + model_path='/pytorch/model/best_transformer.pth' + device=torch.device('cuda' if torch.cuda.is_available() else'cpu') + #加载模型 + model=MLPModel(input_size=8,hidden_size=24,output_size=2,device=device) + model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True)['model_state_dict']) + #加载数据 + data_path='/pytorch/data/task2.xlsx' + batcher = DataBatcher(file_path=data_path, val_ratio=0.2, batch_size=16,device=device) + val_inputs, val_outputs = batcher.getValBatches() + #进行推理 + output=model(val_inputs) + #取第1batch的结果比较 + print("真实值:", val_outputs[0]) + print("预测值:", output[0]) + # 计算均方误差 + loss_fn=torch.nn.MSELoss() + loss=loss_fn(output,val_outputs) + print(f"验证集均方误差: {loss.item():.4f}") +if __name__ == "__main__": + infer() \ No newline at end of file diff --git a/src/model.py b/src/model.py index 011383b..baf186a 100644 --- a/src/model.py +++ b/src/model.py @@ -10,7 +10,7 @@ def __init__(self, input_size, hidden_size, output_size, device=torch.device('cp self.linear2 = nn.Linear(hidden_size, output_size) self.device=device self.to(device) # 关键:初始化时就把模型移到目标设备 - def forward(self, x:torch.Tensor): + def forward(self, x:torch.Tensor)->torch.Tensor: assert isinstance(x, torch.Tensor), "输入必须是torch.Tensor类型" x = x.to(self.device) # 确保输入数据在正确的设备上 x = self.linear1(x)