<a href="https://colab.research.google.com/github/ClaretWheel1481/ImageGuard-dev/blob/main/ImageGuard.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 导入所需库

In [None]:
!pip install efficientnet_pytorch
import os
import requests
from concurrent.futures import ThreadPoolExecutor
from urllib.parse import urlparse
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from PIL import ImageFile, Image
from tqdm import tqdm
from torch.utils.data import DataLoader
from efficientnet_pytorch import EfficientNet

ImageFile.LOAD_TRUNCATED_IMAGES = True



# 多线程下载数据集图片

In [None]:
# 创建文件夹
def create_directory(path):
    if not os.path.exists(path):
        os.makedirs(path)

In [None]:
def download_image(url, save_dir):
    try:
        response = requests.get(url, timeout=10)
        response.raise_for_status()
        img_name = os.path.basename(urlparse(url).path)
        img_path = os.path.join(save_dir, img_name)
        with open(img_path, 'wb') as f:
            f.write(response.content)
        print(f"Downloaded {img_name}")
    except Exception as e:
        print(f"Failed to download {url}: {e}")

In [None]:
def download_dataset(urls_file, save_dir):
    create_directory(save_dir)
    with open(urls_file, 'r') as f:
        urls = f.read().splitlines()
    with ThreadPoolExecutor(max_workers=128) as executor:
        for url in urls:
            executor.submit(download_image, url, save_dir)

In [None]:
datasets_list = ['neutral', 'porn']
# datasets_list = ['neutral', 'porn', 'politics', 'violence']

for dataset in datasets_list:
    urls_file = f'drive/MyDrive/DeepLearning - Bachelor/数据集/data/train/{dataset}/urls_{dataset}.txt'
    save_dir = f'dataset/{dataset}'
    download_dataset(urls_file, save_dir)

Failed to download http://blogs.r.ftdata.co.uk/photo-diary/files/2013/09/White-House-Blog-pic.jpg: HTTPConnectionPool(host='blogs.r.ftdata.co.uk', port=80): Max retries exceeded with url: /photo-diary/files/2013/09/White-House-Blog-pic.jpg (Caused by NameResolutionError("<urllib3.connection.HTTPConnection object at 0x7db34e689b50>: Failed to resolve 'blogs.r.ftdata.co.uk' ([Errno -5] No address associated with hostname)"))
Failed to download http://blog.hostthetoast.com/wp-content/uploads/2014/08/Bacon_Basil_Jalapeno_Corn_Dip_4.jpg: 404 Client Error: Not Found for url: http://blog.hostthetoast.com/wp-content/uploads/2014/08/Bacon_Basil_Jalapeno_Corn_Dip_4.jpg
Failed to download http://blog.hostthetoast.com/wp-content/uploads/2014/08/Chicken_Waffle_Nuggets.jpg: 404 Client Error: Not Found for url: http://blog.hostthetoast.com/wp-content/uploads/2014/08/Chicken_Waffle_Nuggets.jpg
Failed to download http://asset-3.soupcdn.com/asset/13796/0386_31bd_520.jpeg: 410 Client Error: Gone for url:

# 加载数据集

In [None]:
# 自定义ImageFolder类以跳过无效图像
class CustomImageFolder(datasets.ImageFolder):
    def __getitem__(self, index):
        while True:
            try:
                sample, target = super(CustomImageFolder, self).__getitem__(index)
                if sample is not None:
                    return sample, target
            except Exception as e:
                print(f"Failed to load image {self.imgs[index][0]}: {e}")
                index = (index + 1) % len(self.imgs)

In [None]:
# 加载图片数据方法
def pil_loader(path):
    try:
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')
    except Exception as e:
        print(f"Failed to load image {path}: {e}")
        return None

In [None]:
# 数据预处理
# data_transforms = {
#     'train': transforms.Compose([
#         transforms.Resize(1099),
#         transforms.RandomResizedCrop(1099, scale=(0.8, 1.0), ratio=(3/4, 4/3)),
#         transforms.RandomHorizontalFlip(),
#         transforms.ToTensor(),
#     ]),
#     'val': transforms.Compose([
#         transforms.Resize(1099),
#         transforms.ToTensor(),
#     ]),
# }

# ImageNet标准归一化参数
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(256),
        transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ]),
}

In [None]:
# 加载数据集
train_dataset = CustomImageFolder('dataset', transform=data_transforms['train'], loader=pil_loader)
val_dataset = CustomImageFolder('drive/MyDrive/DeepLearning - Bachelor/数据集/data/validation', transform=data_transforms['val'], loader=pil_loader)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)

dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}
class_names = train_dataset.classes
print(class_names)

# 构建模型

In [None]:
# 构建模型(ResNet50)
# class ImageGuard(nn.Module):
#     def __init__(self):
#         super(ImageGuard, self).__init__()

#         # RESNET50
#         self.base_model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
#         self.base_model.fc = nn.Sequential(
#             nn.Linear(self.base_model.fc.in_features, 512),
#             nn.ReLU(),
#             nn.Dropout(0.5),
#             nn.Linear(512, 2),
#             # TODO: 目前只有两种分类，不使用Softmax
#             # nn.Softmax(dim=1)
#         )

#     def forward(self, x):
#         return self.base_model(x)

# 构建模型(EfficientNet)
class ImageGuard(nn.Module):
    def __init__(self, num_classes=4):  # 根据你的实际类别数修改
        super(ImageGuard, self).__init__()

        # 加载预训练EfficientNet
        self.base_model = EfficientNet.from_pretrained('efficientnet-b3')

        # 获取原始分类层的输入特征数
        in_features = self.base_model._fc.in_features

        # 替换全连接层
        self.base_model._fc = nn.Sequential(
            nn.Linear(in_features, 1024),
            nn.SiLU(),  # EfficientNet常用Swish激活（SiLU是PyTorch实现）
            nn.Dropout(0.3),
            nn.Linear(1024, num_classes)
        )

    def forward(self, x):
        return self.base_model(x)

In [None]:
model = ImageGuard()

In [None]:
# 交叉熵损失函数
criterion = nn.CrossEntropyLoss()

# 优化器
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)

# 学习率衰减
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

# 训练模型

In [None]:
# 训练模型
def train_model(model, criterion, optimizer, num_epochs):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        loop = tqdm(enumerate(train_loader),total=len(train_loader))

        for step, data in loop:
            images,labels = data
            images = images.to(device)
            labels = labels.to(device)

            with torch.set_grad_enabled(True):
                outputs = model(images)
                loss = criterion(outputs, labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            running_loss += loss.item() * images.size(0)

            loop.set_description(f'Epoch [{epoch + 1}/{num_epochs}]')
            loop.set_postfix(loss=running_loss / dataset_sizes['train'])

        validate_model(model)

    return model

In [None]:
# 评估性能
def validate_model(model):
    model.eval()
    corrects = 0

    for data in val_loader:
        images,labels = data
        images = images.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            outputs = model(images)
            preds = torch.argmax(outputs, 1)
            corrects += torch.sum(preds == labels.data)

    print(f"correct: {corrects/len(val_dataset)}")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"训练模式:{device}")
model = model.to(device)

In [None]:
# 训练
print("开始训练")
print("-" * 20)
model = train_model(model, criterion, optimizer, num_epochs=4)
if not os.path.exists('model'):
    os.makedirs('model')
torch.save(model.state_dict(), 'model/image_guard_v3.pth')
print("Model saved.")