<a href="https://colab.research.google.com/github/FionaXuDesign/FionaXuDesign/blob/main/Test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!python --version

Python 3.11.11


In [None]:
!pip install torch>=2.0.0 torchvision>=0.15.0 numpy>=1.21.6 Pillow>=9.4.0

In [None]:
import os
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import datasets, transforms

# ---------------------- 数据预处理 ----------------------
class FFT:
    def __init__(self, image_size):
        self.image_size = image_size

    def __call__(self, img):
        try:
            if not isinstance(img, Image.Image):
                img = transforms.ToPILImage()(img)
            img_gray = img.convert('L').resize((self.image_size, self.image_size))
            gray_np = np.array(img_gray, dtype=np.float32)

            F = np.fft.fft2(gray_np)
            Fshift = np.fft.fftshift(F)
            magnitude = np.log1p(np.abs(Fshift))

            gray_norm = (gray_np - gray_np.min()) / (gray_np.max() - gray_np.min() + 1e-8)
            mag_norm = (magnitude - magnitude.min()) / (magnitude.max() - magnitude.min() + 1e-8)

            two_channel = np.stack([gray_norm, mag_norm], axis=0)
            return torch.tensor(two_channel, dtype=torch.float32)

        except Exception as e:
            print(f"error: {str(e)}")
            return None

# ---------------------- 损坏数据集处理 ----------------------
class SafeImageFolder(datasets.ImageFolder):  # 这边是因为数据集有问题  用来处理损坏的数据集  后面修正数据集后可以删掉
    def __getitem__(self, index):
        try:
            path, _ = self.imgs[index]
            img = self.loader(path)
            if self.transform is not None:
                img = self.transform(img)
            return img, self.targets[index]
        except Exception as e:
            print(f"error file: {self.imgs[index][0]} - {str(e)}")
            return None

class FilteredDataset(Dataset):
    def __init__(self, original_dataset):
        self.original_dataset = original_dataset
        self.valid_indices = []

        for i in range(len(original_dataset)):
            data = original_dataset[i]
            if data is not None and data[0] is not None:
                self.valid_indices.append(i)

        if len(self.valid_indices) == 0:
            raise RuntimeError("empty")

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, i):
        return self.original_dataset[self.valid_indices[i]]

# ---------------------- 数据增强 ----------------------
def Build(image_size=128):
    train_transforms = transforms.Compose([
        transforms.Resize(int(image_size*1.2)),
        transforms.RandomCrop(image_size),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.Lambda(lambda x: x),
       FFT(image_size)
    ])

    val_transforms = transforms.Compose([
        transforms.Resize(image_size),
        transforms.Lambda(lambda x: x),
       FFT(image_size)
    ])

    return train_transforms, val_transforms

# ---------------------- 模型 ----------------------
class CNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(2, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.3)
        )
        self.fc = nn.Sequential(
            nn.Linear(128 * 16 * 16, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# ---------------------- 训练流程 ----------------------
def main():
    # 配置参数
    DATA_DIR = "dataset/513/" # 替换成自己的路径
    IMAGE_SIZE = 128
    BATCH_SIZE = 32
    EPOCHS = 10 # 50轮跑不动就跑个10轮 20轮的就行
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_trans, val_trans = Build(IMAGE_SIZE)
    full_dataset = SafeImageFolder(DATA_DIR, transform=train_trans)

    filtered_dataset = FilteredDataset(full_dataset)

    # 8 : 2 (train && val)
    train_size = int(0.8 * len(filtered_dataset))
    val_size = len(filtered_dataset) - train_size
    train_set, val_set = random_split(filtered_dataset, [train_size, val_size])

    train_loader = DataLoader(train_set, BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_set, BATCH_SIZE, shuffle=False)

    model = CNN(num_classes=len(full_dataset.classes)).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()

    best_val_acc = 0.0
    steps = 0

    for epoch in range(EPOCHS):
        model.train()
        train_loss, correct = 0.0, 0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()

        model.eval()
        val_loss, val_correct = 0.0, 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
                outputs = model(inputs)
                val_loss += criterion(outputs, labels).item() * inputs.size(0)
                _, preds = torch.max(outputs, 1)
                val_correct += (preds == labels).sum().item()

        train_loss = train_loss / len(train_loader.dataset)
        train_acc = correct / len(train_loader.dataset)
        val_loss = val_loss / len(val_loader.dataset)
        val_acc = val_correct / len(val_loader.dataset)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            steps = 0
            torch.save(model.state_dict(), "best_model.pth")
        else:
            steps += 1
            if steps >= 7:
                print(f"early stop")
                break

        print(f"Epoch {epoch+1}/{EPOCHS} | "
              f"Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")

if __name__ == "__main__":
    main()

FileNotFoundError: [Errno 2] No such file or directory: 'dataset/513/'