In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from efficientnet_pytorch import EfficientNet
import time
import numpy as np

# time recording start
start_time = time.time()

# Set the device to GPU if available, otherwise use CPU
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# torch.cuda.empty_cache()
device = torch.device("mps")

# 定义FPN模块
class FPN(nn.Module):
    def __init__(self, base_model):
        super(FPN, self).__init__()
        self.base_model = base_model
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.conv = nn.Conv2d(1280, 256, kernel_size=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.base_model.extract_features(x)
        x = self.pool(x)
        x = self.conv(x)
        x = self.relu(x)
        return x

# 数据预处理和加载
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load the combined dataset
dataset = ImageFolder(root='./MO_106/', transform=data_transforms)

# 划分数据集
total_size = len(dataset)
train_size = int(0.7 * total_size)
val_size = int(0.2 * total_size)
test_size = total_size - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(42))

# Training loop
num_epochs = 10
batch_size = 64
lr = 0.0001

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# 加载EfficientNet-B0作为基础模型
base_model = EfficientNet.from_pretrained('efficientnet-b0')
base_model.to(device)

# 创建FPN模型
fpn_model = FPN(base_model)
fpn_model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(fpn_model.parameters(), lr=lr)

# initialise ndarray to store the loss and accuracy in each epoch (on the training data)
train_loss = np.zeros(num_epochs)
train_accuracy = np.zeros(num_epochs)
val_loss = np.zeros(num_epochs)
val_accuracy = np.zeros(num_epochs)

for epoch in range(num_epochs):
    fpn_model.train()
    running_loss = 0.0
    correct_predictions = 0
    total_predictions = 0
    for images, labels in train_dataloader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = fpn_model(images)
        outputs = outputs.view(outputs.size(0), -1)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        _, predicted = torch.max(outputs.data, 1)
        correct_predictions += (predicted == labels).sum().item()
        total_predictions += labels.size(0)

    # 验证阶段
    fpn_model.eval()
    val_loss_epoch = 0.0
    correct_predictions_val = 0
    total_predictions_val = 0

    with torch.no_grad():
        for images, labels in val_dataloader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = fpn_model(images)
            outputs = outputs.view(outputs.size(0), -1)
            loss = criterion(outputs, labels)

            val_loss_epoch += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            correct_predictions_val += (predicted == labels).sum().item()
            total_predictions_val += labels.size(0)


    # Calculate average epoch loss and accuracy
    train_loss[epoch] = running_loss / len(train_dataset)
    train_accuracy[epoch] = correct_predictions / total_predictions
    val_loss[epoch] = val_loss_epoch / len(val_dataset)
    val_accuracy[epoch] = correct_predictions_val / total_predictions_val

    # Print training and validation statistics
    print(f'Epoch {epoch+1}/{num_epochs}: Train Loss: {train_loss[epoch]:.4f}, Train Accuracy: {train_accuracy[epoch]:.4f}, Val Loss: {val_loss[epoch]:.4f}, Val Accuracy: {val_accuracy[epoch]:.4f}')

end_time = time.time()
total_time = end_time - start_time
print(f"Training took {total_time:.2f} seconds.")

model_metrics = {
    "model_state_dict": fpn_model.state_dict(),
    "train_loss": train_loss,
    "train_accuracy": train_accuracy,
    "val_loss": val_loss,
    "val_accuracy": val_accuracy,
    'total_time': total_time
}

# 测试阶段
fpn_model.eval()
test_running_loss = 0.0
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_dataloader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = fpn_model(images)
        outputs = outputs.view(outputs.size(0), -1)
        loss = criterion(outputs, labels)

        test_running_loss += loss.item()

        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_loss = test_running_loss / len(test_dataloader)
test_accuracy = 100 * correct / total

# Save the model parameters and metrics to a file
ResultPath = "./results/"
results_path = ResultPath + f"efficientnetb0_fpn_epoch{num_epochs}_lr{lr}_bs{batch_size}.pt"

torch.save(model_metrics, results_path)



Loaded pretrained weights for efficientnet-b0
Epoch [1/1], Train Loss: 3.5295, Train Accuracy: 32.73%, Val Loss: 2.8787, Val Accuracy: 45.19%
Training finished.
Test Loss: 2.8215, Test Accuracy: 45.67%
