In [1]:
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torchvision.models import resnet50, ResNet50_Weights
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets import ImageFolder
from PIL import Image
from transformers import ViTModel

In [2]:
class NsfwDetector(nn.Module):
    def __init__(self, num_classes=2, pretrained=True):
        super(NsfwDetector, self).__init__()
        # 特征提取部分 - 使用CNN和Transformer混合特征
        # CNN部分 (使用ResNet50提取局部特征)
        # self.cnn_backbone = models.resnet50(pretrained=pretrained)
        self.cnn_backbone = resnet50(weights=ResNet50_Weights.DEFAULT)
        self.cnn_features = nn.Sequential(*list(self.cnn_backbone.children())[:-2])
        # Transformer部分 (使用Vision Transformer提取全局特征)
        self.transformer = ViTModel.from_pretrained('google/vit-base-patch16-224' if pretrained else None)
        # 特征融合
        self.cnn_feature_adapter = nn.Sequential(
            nn.Conv2d(2048, 768, kernel_size=1),
            nn.BatchNorm2d(768),
            nn.ReLU()
        )
        # 分类头
        self.classifier = nn.Sequential(
            nn.Linear(768 * 2, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        # CNN特征提取
        cnn_feats = self.cnn_features(x)
        cnn_feats = self.cnn_feature_adapter(cnn_feats)
        cnn_feats = torch.mean(cnn_feats, dim=[2,3])
        # Transformer特征提取
        transformer_outputs = self.transformer(x)
        trans_feats = transformer_outputs.last_hidden_state[:, 0, :]
        # 特征融合
        combined = torch.cat((cnn_feats, trans_feats), dim=1)
        # 分类
        output = self.classifier(combined)
        return output


In [3]:
class ImagePathDataset(Dataset):
    """自定义数据集类，用于处理图像路径"""
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert('RGB')

        if self.transform:
            img = self.transform(img)

        return img, img_path

In [4]:
class NsfwDetectorPipeline:
    def __init__(self, model_path=None, device='cuda' if torch.cuda.is_available() else 'cpu'):
        """
        黄图检测管道类
        参数:
            model_path: 预训练模型路径
            device: 运行设备
        """
        self.device = device
        self.model = NsfwDetector(num_classes=2).to(device)
        if model_path:
            self.model.load_state_dict(torch.load(model_path, map_location=device))
        self.model.eval()
        # 图像预处理
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def predict(self, image_path, threshold=0.5):
        """
        预测图像是否为黄图
        参数:
            image_path: 图像路径
            threshold: 判定阈值
        返回:
            dict: 包含预测结果和置信度
        """
        img = Image.open(image_path).convert('RGB')
        img_tensor = self.transform(img).unsqueeze(0).to(self.device)
        with torch.no_grad():
            outputs = self.model(img_tensor)
            probs = torch.softmax(outputs, dim=1)
            nsfw_prob = probs[0, 1].item()  # 色情类别的概率
        result = {
            'is_nsfw': nsfw_prob > threshold,
            'confidence': nsfw_prob,
            'class': 'nsfw' if nsfw_prob > threshold else 'normal'
        }
        return result

    def predict2(self, image : Image, threshold=0.5):
        img_tensor = self.transform(image).unsqueeze(0).to(self.device)
        with torch.no_grad():
            outputs = self.model(img_tensor)
            probs = torch.softmax(outputs, dim=1)
            nsfw_prob = probs[0, 1].item()  # 色情类别的概率
        result = {
            'is_nsfw': nsfw_prob > threshold,
            'confidence': nsfw_prob,
            'class': 'nsfw' if nsfw_prob > threshold else 'normal'
        }
        return result

    def predict_batch(self, image_paths, threshold=0.5, batch_size=16):
        """
        批量预测多张图像
        参数:
            image_paths: 图像路径列表
            threshold: 判定阈值
            batch_size: 批处理大小
        返回:
            list: 包含每个图像预测结果的字典列表
        """
        # 创建自定义数据集
        dataset = ImagePathDataset(image_paths, transform=self.transform)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)

        start = time.time()
        results = []
        with torch.no_grad():
            for batch_images, batch_paths in dataloader:
                batch_images = batch_images.to(self.device)
                outputs = self.model(batch_images)
                probs = torch.softmax(outputs, dim=1)
                nsfw_probs = probs[:, 1].cpu().numpy()

                for path, prob in zip(batch_paths, nsfw_probs):
                    result = {
                        'file_path': path,
                        'is_nsfw': prob > threshold,
                        'confidence': float(prob),
                        'class': 'nsfw' if prob > threshold else 'normal'
                    }
                    results.append(result)
        end = time.time()
        print(f"predict_batch time: {end - start}")

        return results

In [5]:
class NsfwDataset(Dataset):
    """自定义数据集类"""
    def __init__(self, root_dir, transform=None):
        self.dataset = ImageFolder(root_dir, transform=transform)
        self.classes = self.dataset.classes
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]

In [6]:
def train_model(train_dir, val_dir, model_save_path, epochs=10, batch_size=32):
    """
    训练黄图检测模型
    参数:
        train_dir: 训练集目录
        val_dir: 验证集目录
        model_save_path: 模型保存路径
        epochs: 训练轮数
        batch_size: 批次大小
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 数据预处理
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    # 创建数据集和数据加载器
    train_dataset = NsfwDataset(train_dir, transform=transform)
    val_dataset = NsfwDataset(val_dir, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    # 初始化模型
    model = NsfwDetector(num_classes=2).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

    # 训练循环
    best_val_acc = 0.0

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        # 验证
        model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = correct / total
        print(f'Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader):.4f}, Val Acc: {val_acc:.4f}')

        # 保存最佳模型
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), model_save_path)
            print(f'Best model saved with val acc: {val_acc:.4f}')

    print('Training complete.')

In [7]:
# train_model(
#     train_dir=r'D:\data\pics',
#     val_dir=r'D:\data\pics',
#     model_save_path='nsfw_detector.pth',
#     epochs=10,
#     batch_size=32
# )

In [8]:
detector = NsfwDetectorPipeline(model_path='nsfw_detector.pth', device='cuda')

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
image = Image.open(r'D:\data\pics\nsfw\p_4.png')

In [10]:
res = detector.predict2(image)
print(res)

{'is_nsfw': True, 'confidence': 0.687592089176178, 'class': 'nsfw'}


In [13]:
root_path = r'D:\data\picsall'
imgs = []
for img in os.listdir(root_path):
    imgs.append(os.path.join(root_path, img))
imgs = imgs * 10
res = detector.predict_batch(imgs)

predict_batch time: 12.32150912284851
