In [1]:
from torch import nn
from PIL import Image
import torch
import os
import torchvision.models as models
from torchvision.transforms import transforms
from torch.utils.data import Dataset, DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
transform = transforms.Compose([
    transforms.Resize((400, 400)),  # 将图像大小统一调整到400x400
    transforms.ToTensor()
])
class ImgSeq(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.data = []  # 用于存储完整图像路径
        self.labels = []  # 用于存储图像对应的标签
        self.folder_path=[]
        # 路径操作 遍历图片序列文件夹，收集数据和标签
        for index, sub_path_name in enumerate(['injFail', 'injSuccess']):
            dir_path = os.path.join(root_dir, sub_path_name)  # 构造文件夹路径
            for folder_name in os.listdir(dir_path):  # 遍历文件夹中的所有文件
                path = os.path.join(dir_path, folder_name)  # 构造序列文件夹的完整路径
                if os.path.isdir(path):  # 检查是否为目录
                    self.folder_path.append(path)
                    self.data.append([os.path.join(path, img) for img in sorted(os.listdir(path))])
                    # sorted保证序列的顺序性，防止序列图片被打乱
                    self.labels.append(index)  # 保存对应的标签（0为失败，1为成功）
    def __len__(self):# 定义类的长度返回方法
        return len(self.data)
    def __getitem__(self, idx): # 定义getitem方法 允许类使用索引操作
        img_seq = self.data[idx] # 取对应索引的图片序列
        images = [Image.open(img).convert('L') for img in img_seq]  # PIL库的图片操作 转化为灰度图
        images = [self.transform(img) for img in images]  # 对每个图像应用预处理变换
        images = torch.stack(images)  # 图像堆叠成新张量
        label = self.labels[idx]
        return images, label  # 返回处理后的图像堆叠和标签
    pass
### 所有网络模块的定义
# 定义CNN模块，用于特征提取
class SqCNN(nn.Module):
    def __init__(self):
        super(SqCNN, self).__init__()
        # 加载预训练的模型单通道输入ResNet-50预训练模型
        resnet = models.resnet18(pretrained=True)
        resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2), padding=(2, 2), bias=False)
        # 迁移学习，我们需要去除最后一个全连接层
        # 重用ResNet的特征提取能力，而用自己的分类层替换掉原有的分类层，以便适应新的任务或类别数。
        self.features = nn.Sequential(*list(resnet.children())[:-1])
        # 添加Dropout层，用于减少过拟合
        self.dropout = nn.Dropout(p=0.3)
    def forward(self, x):
        # x的维度是[batch_size, seq_len, channels, height, width]
        # 正确的维度应该是：4 8 1 400 400
        batch_size, seq_len, c, h, w = x.size()
        # 合并批次和序列长度，使其适合CNN输入
        x = x.view(batch_size * seq_len, c, h, w)
        x = self.features(x)
        # 重新解放维度，还原批次和序列长度
        x = x.view(batch_size, seq_len, -1)
        x = self.dropout(x)
        return x
# 定义LSTM与注意力机制模块
class LSTMWithAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_classes):
        super(LSTMWithAttention, self).__init__()
        # 定义LSTM层
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        # 添加Dropout层
        self.dropout = nn.Dropout(p=0.3)
        # 定义注意力层，用于为每个时间步的隐藏状态分配权重
        self.attention_layer = nn.Linear(hidden_dim, 1)
        # 定义全连接层，用于分类
        self.fc = nn.Linear(hidden_dim, num_classes)
    def forward(self, x):
        # 通过LSTM处理输入
        lstm_out, _ = self.lstm(x)
        # 应用Dropout
        lstm_out = self.dropout(lstm_out)
        # 生成注意力权重
        attention_weights = torch.softmax(self.attention_layer(lstm_out), dim=1)
        # 计算加权的上下文向量
        context_vector = torch.sum(lstm_out * attention_weights, dim=1)
        # 通过全连接层得到最终输出
        out = self.fc(context_vector)
        return out
# 整体模型结构 串联CNN和LSTM两个模型
class CNNLSTM(nn.Module):
    def __init__(self):
        super(CNNLSTM, self).__init__()
        # 初始化CNN特征提取器
        self.cnn = SqCNN()
        # 初始化LSTM与注意力机制模块
        self.lstm_attention = LSTMWithAttention(input_dim=512, hidden_dim=128, num_layers=1, num_classes=2)
    def forward(self, x):
        # 从CNN获取特征
        cnn_features = self.cnn(x)
        # 使用特征通过LSTM和注意力机制模块得到输出
        output = self.lstm_attention(cnn_features)
        return output
# 加载模型函数
def load_model(model_path):
    model = CNNLSTM()
    model.load_state_dict(torch.load(model_path))  # 加载模型参数
    model.eval()  # 模型设为评估模式
    return model
# 预测函数，并计算准确率
def predict(model, data_loader):
    model.eval()  # 模型设为评估模式
    correct = 0
    total = 0
    results = []
    model = model.to(device)
    with torch.no_grad():  # 禁用梯度计算
        for images, labels in data_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)  # 前向传播
            _, predicted = outputs.max(1)  # 获取预测结果
            results.extend(predicted.cpu().numpy())  # 将结果存入列表
            # 累计正确的预测数量
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()  # 累加正确预测数
            pass
        pass
    accuracy = correct / total
    return results, accuracy


cuda


In [2]:
# 这是对外接口，改变参数就可以测试了
model = load_model("model.pth") # 模型加载接口
folder_name='injection-dataset_test' # 测试用例父路径接口



In [3]:
# 加载测试数据集
test_dataset = ImgSeq(root_dir=folder_name, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=0)
# 进行预测并计算准确率
predictions, accuracy = predict(model, test_loader)
# 输出预测结果
print(f'以下是预测结果')
for i, prediction in enumerate(predictions):
    print(f'{test_dataset.folder_path[i]} : {"Success" if prediction == 1 else "Fail"}')
print(f'Accuracy: {accuracy * 100:.2f}%')

以下是预测结果
injection-dataset_test\injFail\MyVideo_33 : Fail
injection-dataset_test\injFail\MyVideo_34 : Fail
injection-dataset_test\injFail\MyVideo_35 : Fail
injection-dataset_test\injFail\MyVideo_36 : Fail
injection-dataset_test\injFail\MyVideo_37 : Fail
injection-dataset_test\injFail\MyVideo_38 : Fail
injection-dataset_test\injFail\MyVideo_39 : Fail
injection-dataset_test\injFail\MyVideo_40 : Fail
injection-dataset_test\injSuccess\MyVideo_33 : Success
injection-dataset_test\injSuccess\MyVideo_34 : Success
injection-dataset_test\injSuccess\MyVideo_35 : Success
injection-dataset_test\injSuccess\MyVideo_36 : Success
injection-dataset_test\injSuccess\MyVideo_37 : Success
injection-dataset_test\injSuccess\MyVideo_38 : Success
injection-dataset_test\injSuccess\MyVideo_39 : Success
injection-dataset_test\injSuccess\MyVideo_40 : Success
Accuracy: 100.00%
