In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from efficientnet_pytorch import EfficientNet
from efficientnet_pytorch.utils import MemoryEfficientSwish
from efficientnet_pytorch.fpn import FPN

# Set the device to GPU if available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define transforms for data preprocessing
transform = transforms.Compose([
    transforms.Resize((456, 456)),  # Resize input images to match EfficientNet-B5 input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize input images
])

# Load the combined dataset
dataset = ImageFolder('./MO_106/', transform=transform)

# Split the dataset into training and validation sets
train_ratio = 0.8  # 80% for training, 20% for validation
train_size = int(train_ratio * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Training loop
num_epochs = 10
batch_size = 10
lr = 0.001

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

# Initialize the EfficientNet-B5 model
model = EfficientNet.from_pretrained('efficientnet-b5')

# Initialize the FPN model
fpn = FPN(in_channels_list=[40, 112, 320, 1280], out_channels=256, top_blocks=None)

# Replace the classifier with a new one for your specific number of classes
num_classes = len(dataset.classes)
model._fc = nn.Linear(fpn.out_channels, num_classes)

# Combine the EfficientNet backbone and the FPN
model.backbone = model
model.fpn = fpn

# Move the model to the device
model = model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

# Training loop
for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    running_loss = 0.0
    correct_predictions = 0
    total_predictions = 0

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

        _, predicted = torch.max(outputs.data, 1)
        correct_predictions += (predicted == labels).sum().item()
        total_predictions += labels.size(0)

    # Calculate average epoch loss and accuracy
    running_loss /= len(train_dataset)
    correct_predictions /= total_predictions

    # Print training statistics
    print(f'Epoch {epoch+1}/{num_epochs}: Train Loss: {running_loss:.4f}, Train Accuracy: {correct_predictions:.4f}')


ModuleNotFoundError: No module named 'efficientnet_pytorch.fpn'