<a href="https://colab.research.google.com/github/ClaretWheel1481/ImageGuard-dev/blob/main/ImageGuard.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 导入所需库

In [None]:
import os
import requests
from concurrent.futures import ThreadPoolExecutor
from urllib.parse import urlparse
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from PIL import ImageFile, Image
from tqdm import tqdm
from torch.utils.data import DataLoader
ImageFile.LOAD_TRUNCATED_IMAGES = True

# 多线程下载数据集图片

In [None]:
# 创建文件夹
def create_directory(path):
    if not os.path.exists(path):
        os.makedirs(path)

In [None]:
def download_image(url, save_dir):
    try:
        response = requests.get(url, timeout=10)
        response.raise_for_status()
        img_name = os.path.basename(urlparse(url).path)
        img_path = os.path.join(save_dir, img_name)
        with open(img_path, 'wb') as f:
            f.write(response.content)
        print(f"Downloaded {img_name}")
    except Exception as e:
        print(f"Failed to download {url}: {e}")

In [None]:
def download_dataset(urls_file, save_dir):
    create_directory(save_dir)
    with open(urls_file, 'r') as f:
        urls = f.read().splitlines()
    with ThreadPoolExecutor(max_workers=128) as executor:
        for url in urls:
            executor.submit(download_image, url, save_dir)

In [21]:
def download_all_datasets():
    datasets = ['neutral', 'porn']
    for dataset in datasets:
        urls_file = f'drive/MyDrive/数据集/data/train/{dataset}/urls_{dataset}.txt'
        save_dir = f'dataset/{dataset}'
        download_dataset(urls_file, save_dir)
        # TODO: 检测URL是否可正常访问，不可用则在txt中移除

In [6]:
download_all_datasets()

[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
Downloaded tvnjofsceub01.jpg
Downloaded tug16q90vnb31.jpg
Downloaded tsrhakcg5j931.jpg
Downloaded tvwlzq6pw9241.jpg
Downloaded tugawl7lcbv11.jpg
Downloaded tmmpw7mkte431.jpgDownloaded ttykf4ca6jt11.jpg
Downloaded tv34y472bfm31.jpgDownloaded tvf5cm4jg1x31.jpg
Downloaded triwko0n0kb31.jpg
Downloaded trlasm4frdv21.jpg

Downloaded twcz84oziij31.jpg
Downloaded tw5ksru9dkyy.jpg

Downloaded tvpznsipgwz31.jpg
Downloaded trf195ld58z01.jpg
Downloaded tu4oncpo8rjz.jpg
Downloaded trxf3vuierr31.jpg
Downloaded tvwaeh3r40x11.jpg
Downloaded tvtaox4yiz831.jpg
Downloaded tt6nihzobln31.jpg
Downloaded tvynqgjm52541.jpg
Downloaded tx62rajv6kp31.jpg
Downloaded twedllu29qq31.jpg
Downloaded tv9t3bxgzbl01.jpg
Downloaded tscrzwdq6yn31.jpgFailed to download https://i.redd.it/tx9ig8maiqv01.jpg: 404 Client Error: Not Found for url: https://i.redd.it/tx9ig8maiqv01.jpg

Failed to download https://i.redd.it/txc4ab7tdqb31.jpg: 404 Client Error: Not Found for url: https://i.redd

# 加载数据集

In [7]:
# 自定义ImageFolder类以跳过无效图像
class CustomImageFolder(datasets.ImageFolder):
    def __getitem__(self, index):
        while True:
            try:
                sample, target = super(CustomImageFolder, self).__getitem__(index)
                if sample is not None:
                    return sample, target
            except Exception as e:
                print(f"Failed to load image {self.imgs[index][0]}: {e}")
                index = (index + 1) % len(self.imgs)

In [8]:
# 加载图片数据方法
def pil_loader(path):
    try:
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')
    except Exception as e:
        print(f"Failed to load image {path}: {e}")
        return None

In [22]:
# TODO: 检测所有图像是否可正常打开

In [9]:
# 数据预处理
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(299),
        transforms.RandomResizedCrop(299, scale=(0.8, 1.0), ratio=(3/4, 4/3)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]),
    'val': transforms.Compose([
        transforms.Resize(299),
        transforms.CenterCrop(299),
        transforms.ToTensor(),
    ]),
}

In [10]:
# 加载数据集
train_dataset = CustomImageFolder('dataset', transform=data_transforms['train'], loader=pil_loader)
val_dataset = CustomImageFolder('drive/MyDrive/数据集/data/validation', transform=data_transforms['val'], loader=pil_loader)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}
class_names = train_dataset.classes
print(class_names)

['neutral', 'porn']


# 构建模型

In [14]:
# 构建模型
class ImageGuard(nn.Module):
    def __init__(self):
        super(ImageGuard, self).__init__()

        # RESNET50
        self.base_model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        self.base_model.fc = nn.Sequential(
            nn.Linear(self.base_model.fc.in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 2),
            # TODO: 目前只有两种分类，不使用Softmax
            # nn.Softmax(dim=1)
        )

    def forward(self, x):
        return self.base_model(x)

In [15]:
model = ImageGuard()

In [16]:
# 交叉熵损失函数
criterion = nn.CrossEntropyLoss()

# 优化器
# TODO: Adam优化器相对更稳定，RMSprop需要控制学习率，未来调优
optimizer = optim.Adam(model.parameters(), lr=0.002)
# optimizer = optim.RMSprop(model.parameters(), lr=0.002, weight_decay=1e-5)

# 训练模型

In [17]:
# 训练模型
def train_model(model, criterion, optimizer, num_epochs):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        loop = tqdm(enumerate(train_loader),total=len(train_loader))

        for step, data in loop:
            images,labels = data
            images = images.to(device)
            labels = labels.to(device)

            with torch.set_grad_enabled(True):
                outputs = model(images)
                loss = criterion(outputs, labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            running_loss += loss.item() * images.size(0)

            loop.set_description(f'Epoch [{epoch + 1}/{num_epochs}]')
            loop.set_postfix(loss=running_loss / dataset_sizes['train'])

        validate_model(model)

    return model

In [18]:
# 评估性能
def validate_model(model):
    model.eval()
    corrects = 0

    for data in val_loader:
        images,labels = data
        images = images.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            outputs = model(images)
            preds = torch.argmax(outputs, 1)
            corrects += torch.sum(preds == labels.data)

    print(f"correct: {corrects/len(val_dataset)}")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"训练模式:{device}")
model = model.to(device)

In [20]:
# 训练
print("开始训练")
print("-" * 20)
model = train_model(model, criterion, optimizer, num_epochs=3)
if not os.path.exists('model'):
    os.makedirs('model')
torch.save(model.state_dict(), 'model/image_guard_v1.pth')
print("Model saved.")

开始训练
--------------------


Epoch [1/3]:  40%|████      | 141/351 [13:07<17:58,  5.14s/it, loss=0.0484]

Failed to load image dataset/neutral/Untitled-1.jpg: cannot identify image file <_io.BufferedReader name='dataset/neutral/Untitled-1.jpg'>
Failed to load image dataset/neutral/Untitled-1.jpg: Unexpected type <class 'NoneType'>


Epoch [1/3]:  63%|██████▎   | 222/351 [20:37<12:21,  5.75s/it, loss=0.0716]

Failed to load image dataset/neutral/lk9wdklq6dz.jpg: cannot identify image file <_io.BufferedReader name='dataset/neutral/lk9wdklq6dz.jpg'>
Failed to load image dataset/neutral/lk9wdklq6dz.jpg: Unexpected type <class 'NoneType'>


Epoch [1/3]:  67%|██████▋   | 235/351 [21:51<11:16,  5.83s/it, loss=0.0737]

Failed to load image dataset/neutral/wetkoala.jpg: cannot identify image file <_io.BufferedReader name='dataset/neutral/wetkoala.jpg'>
Failed to load image dataset/neutral/wetkoala.jpg: Unexpected type <class 'NoneType'>


Epoch [1/3]:  85%|████████▌ | 300/351 [27:57<04:49,  5.67s/it, loss=0.087]

Failed to load image dataset/neutral/atacama04.jpg: cannot identify image file <_io.BufferedReader name='dataset/neutral/atacama04.jpg'>
Failed to load image dataset/neutral/atacama04.jpg: Unexpected type <class 'NoneType'>


Epoch [1/3]: 100%|██████████| 351/351 [32:37<00:00,  5.58s/it, loss=0.0981]


correct: 0.5


Epoch [2/3]:  34%|███▎      | 118/351 [11:05<21:37,  5.57s/it, loss=0.0248]

Failed to load image dataset/neutral/atacama04.jpg: cannot identify image file <_io.BufferedReader name='dataset/neutral/atacama04.jpg'>
Failed to load image dataset/neutral/atacama04.jpg: Unexpected type <class 'NoneType'>
Failed to load image dataset/neutral/lk9wdklq6dz.jpg: cannot identify image file <_io.BufferedReader name='dataset/neutral/lk9wdklq6dz.jpg'>
Failed to load image dataset/neutral/lk9wdklq6dz.jpg: Unexpected type <class 'NoneType'>


Epoch [2/3]:  72%|███████▏  | 252/351 [23:31<10:18,  6.24s/it, loss=0.0545]

Failed to load image dataset/neutral/wetkoala.jpg: cannot identify image file <_io.BufferedReader name='dataset/neutral/wetkoala.jpg'>
Failed to load image dataset/neutral/wetkoala.jpg: Unexpected type <class 'NoneType'>


Epoch [2/3]:  83%|████████▎ | 293/351 [27:14<05:20,  5.52s/it, loss=0.0617]

Failed to load image dataset/neutral/Untitled-1.jpg: cannot identify image file <_io.BufferedReader name='dataset/neutral/Untitled-1.jpg'>
Failed to load image dataset/neutral/Untitled-1.jpg: Unexpected type <class 'NoneType'>


Epoch [2/3]: 100%|██████████| 351/351 [32:28<00:00,  5.55s/it, loss=0.0747]


correct: 0.5


Epoch [3/3]:  19%|█▉        | 68/351 [06:25<26:06,  5.54s/it, loss=0.0156]

Failed to load image dataset/neutral/Untitled-1.jpg: cannot identify image file <_io.BufferedReader name='dataset/neutral/Untitled-1.jpg'>
Failed to load image dataset/neutral/Untitled-1.jpg: Unexpected type <class 'NoneType'>


Epoch [3/3]:  32%|███▏      | 114/351 [10:36<20:27,  5.18s/it, loss=0.025]

Failed to load image dataset/neutral/lk9wdklq6dz.jpg: cannot identify image file <_io.BufferedReader name='dataset/neutral/lk9wdklq6dz.jpg'>
Failed to load image dataset/neutral/lk9wdklq6dz.jpg: Unexpected type <class 'NoneType'>


Epoch [3/3]:  57%|█████▋    | 199/351 [18:23<13:32,  5.34s/it, loss=0.0419]

Failed to load image dataset/neutral/wetkoala.jpg: cannot identify image file <_io.BufferedReader name='dataset/neutral/wetkoala.jpg'>
Failed to load image dataset/neutral/wetkoala.jpg: Unexpected type <class 'NoneType'>


Epoch [3/3]:  70%|███████   | 246/351 [22:41<09:03,  5.18s/it, loss=0.0527]

Failed to load image dataset/neutral/atacama04.jpg: cannot identify image file <_io.BufferedReader name='dataset/neutral/atacama04.jpg'>
Failed to load image dataset/neutral/atacama04.jpg: Unexpected type <class 'NoneType'>


Epoch [3/3]: 100%|██████████| 351/351 [32:20<00:00,  5.53s/it, loss=0.0728]


correct: 0.5
Model saved.
