In [None]:
rm -rf ./data/.ipynb_checkpoints


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import os
import requests
from duckduckgo_search import DDGS
from PIL import Image
import glob
import shutil
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True  # ✅ 允许加载损坏图片

#设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#下载图片（每个类别对应一个文件夹）
def download_images(person_name, save_dir="./data/faces", num_images=20):
    class_dir = os.path.join(save_dir, person_name.replace(" ", "_"))
    os.makedirs(class_dir, exist_ok=True)

    with DDGS() as ddgs:
        results = ddgs.images(person_name, max_results=num_images)

    downloaded = 0
    for i, result in enumerate(results):
        if 'image' not in result or not result['image'].startswith('http'):
            continue

        image_url = result["image"]
        try:
            response = requests.get(image_url, stream=True, timeout=5)
            if response.status_code == 200 and "image" in response.headers.get('content-type', ''):
                file_path = os.path.join(class_dir, f"{i}.jpg")
                with open(file_path, "wb") as file:
                    for chunk in response.iter_content(1024):
                        file.write(chunk)

                # ✅ 下载后检查文件是否损坏
                if os.path.getsize(file_path) < 10240:  # 小于 10KB 说明可能损坏
                    print(f"⚠️ 删除损坏的图片: {file_path}")
                    os.remove(file_path)
                else:
                    downloaded += 1
        except Exception:
            continue

    if downloaded == 0:
        raise FileNotFoundError("No images downloaded. Check your network or DuckDuckGo settings.")

    print(f"✅ 下载 {downloaded} 张图片到 {class_dir}")


#确保 `ImageFolder` 的数据结构正确
def fix_dataset_structure(data_root="./data/faces"):
    """
    - 确保 `data_root` 目录下至少有一个类别文件夹。
    - 如果 `default/` 目录里直接存放图片，移动到 `default/images/` 作为类别。
    - 删除 `.ipynb_checkpoints/` 目录。
    """
    # 删除 `.ipynb_checkpoints/`
    checkpoints = os.path.join(data_root, ".ipynb_checkpoints")
    if os.path.exists(checkpoints):
        shutil.rmtree(checkpoints)
        print("删除 .ipynb_checkpoints")

    valid_folders = [d for d in os.listdir(data_root) if os.path.isdir(os.path.join(data_root, d)) and not d.startswith('.')]

    # 如果 `data_root` 下没有类别文件夹
    if len(valid_folders) == 0:
        print("⚠️ 没有类别文件夹，检查是否有图片需要移动...")
        default_path = os.path.join(data_root, "default")
        os.makedirs(default_path, exist_ok=True)

        # 识别 `data_root` 里的所有图片
        for file in os.listdir(data_root):
            if file.endswith(("jpg", "jpeg", "png")):
                shutil.move(os.path.join(data_root, file), os.path.join(default_path, file))

        valid_folders.append("default")

    print(f"数据集类别文件夹: {valid_folders}")

#预处理图片
def preprocess_images(image_dir):
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    files = glob.glob(image_dir + "/*.*")  # 处理所有格式
    for file in files:
        try:
            with Image.open(file) as img:
                img = img.convert("RGB")  # 转换为 RGB 模式
                img_transformed = transform(img)  # 变换为 Tensor

                # 重新转换回 PIL 再保存
                img_transformed = transforms.ToPILImage()(img_transformed)
                img_transformed.save(file.replace(".jpg", "_processed.jpg"))

        except (OSError, IOError):
            print(f"❌ 跳过损坏的图片: {file}")
            os.remove(file)  #删除损坏的图片

    print("✅ All images preprocessed!")

#训练数据集加载
data_root = "./data/faces"
batch_size = 128
image_size = 64

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

def is_valid_file(path):
    return path.lower().endswith(('jpg', 'jpeg', 'png'))

# 修正数据结构
fix_dataset_structure(data_root)

# 重新检查并加载数据
dataset = dsets.ImageFolder(root=data_root, transform=transform, is_valid_file=is_valid_file)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
print("✅ 数据集加载成功！类别:", dataset.classes)

#定义 DCGAN 生成器 & 判别器
class Generator(nn.Module):
    def __init__(self, nz, ngf, nc):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

class Discriminator(nn.Module):
    def __init__(self, nc, ndf):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

# 7️⃣ 训练 DCGAN
def train_dcgan(epochs=100):
    netG = Generator(100, 64, 3).to(device)
    netD = Discriminator(3, 64).to(device)

    criterion = nn.BCELoss()
    optimizerG = optim.Adam(netG.parameters(), lr=0.0001, betas=(0.5, 0.999))
    optimizerD = optim.Adam(netD.parameters(), lr=0.0001, betas=(0.5, 0.999))

    for epoch in range(epochs):
        for i, (data, _) in enumerate(dataloader):
            real_data = data.to(device)
            noise = torch.randn(real_data.size(0), 100, 1, 1, device=device)
            fake_data = netG(noise)

        print(f"Epoch [{epoch}/{epochs}] completed.")

# 示例：下载、预处理并训练
person_name = "Trump"
download_images(person_name, num_images=50)
train_dcgan(epochs=50)




✅ 数据集类别文件夹: ['default', 'Elon_Musk']
✅ 数据集加载成功！类别: ['Elon_Musk', 'default']
✅ 下载 49 张图片到 ./data/faces/Trump
Epoch [0/50] completed.
Epoch [1/50] completed.
Epoch [2/50] completed.
Epoch [3/50] completed.
Epoch [4/50] completed.
Epoch [5/50] completed.
Epoch [6/50] completed.
Epoch [7/50] completed.
Epoch [8/50] completed.
Epoch [9/50] completed.
Epoch [10/50] completed.
Epoch [11/50] completed.
Epoch [12/50] completed.
Epoch [13/50] completed.
Epoch [14/50] completed.
Epoch [15/50] completed.
Epoch [16/50] completed.
Epoch [17/50] completed.
Epoch [18/50] completed.
Epoch [19/50] completed.
Epoch [20/50] completed.
Epoch [21/50] completed.
Epoch [22/50] completed.
Epoch [23/50] completed.
Epoch [24/50] completed.
Epoch [25/50] completed.
Epoch [26/50] completed.
Epoch [27/50] completed.
Epoch [28/50] completed.
Epoch [29/50] completed.
Epoch [30/50] completed.
Epoch [31/50] completed.
Epoch [32/50] completed.
Epoch [33/50] completed.
Epoch [34/50] completed.
Epoch [35/50] completed.
Ep

In [None]:
torch.save(netG.state_dict(), "generator1.pth")  # 训练时保存


In [None]:
import torch
import torchvision.utils as vutils

# 载入训练好的模型
netG = Generator(100, 64, 3).to(device)
netG.load_state_dict(torch.load("generator1.pth"))  # 假设你保存了模型

# 生成假人脸
noise = torch.randn(64, 100, 1, 1, device=device)  # 64 个噪声样本
fake_images = netG(noise).detach().cpu()

# 保存或显示
vutils.save_image(fake_images, "generated_faces2.png", normalize=True)


  netG.load_state_dict(torch.load("generator1.pth"))  # 假设你保存了模型
