In [1]:
import torch
from torchvision import datasets, transforms, models
import time

# Define the modified AlexNet model
class ModifiedAlexNet(torch.nn.Module):
    def __init__(self):
        super(ModifiedAlexNet, self).__init__()
        # Load the AlexNet model without pretrained weights, and modify the first convolutional layer
        self.alexnet = models.alexnet(pretrained=False, num_classes=1000)  # Load AlexNet without pretrained weights, num_classes=1000 is required but won't affect our modifications
        self.alexnet.features[0] = torch.nn.Conv2d(1, 64, kernel_size=11, stride=4, padding=2)  # Modify the first convolutional layer to fit MNIST's single-channel input
        self.alexnet.classifier = torch.nn.Sequential(
            torch.nn.Linear(256 * 6 * 6, 4096),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(4096, 4096),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(4096, 10)  # Modify the last fully connected layer to fit MNIST's 10 classes
        )
    
    def forward(self, x):
        x = self.alexnet.features(x)
        x = x.view(x.size(0), 256 * 6 * 6)
        x = self.alexnet.classifier(x)
        return torch.nn.functional.log_softmax(x, dim=1)

# 修改训练函数以适应修改后的AlexNet模型
def train_model(workers, epochs, batch_size, target_loss, threshold):
    torch.manual_seed(123)
    
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # AlexNet需要224x224的输入尺寸
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    train_set = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=workers, drop_last=True)
    
    model = ModifiedAlexNet()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    
    start_time = time.time()
    last_loss = float('inf')
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        avg_loss = running_loss / len(train_loader)
        print(f"Epoch {epoch+1} - Loss: {avg_loss}")
        
        if avg_loss <= target_loss or abs(last_loss - avg_loss) < threshold:
            end_time = time.time()
            total_time = end_time - start_time
            return avg_loss, total_time
        
        last_loss = avg_loss
    
    end_time = time.time()
    total_time = end_time - start_time
    
    return avg_loss, total_time

# 超参数网格搜索函数
def grid_search(target_loss=0.0001, threshold=0.00001):
    best_time = float('inf')
    best_hyperparams = {}
    
    for workers in range(1, 9):
        epochs = 100
        for batch_size in [32, 64, 128]:
            print(f"Workers: {workers}, Epochs: {epochs}, Batch Size: {batch_size}")
            
            loss, time_taken = train_model(workers, epochs, batch_size, target_loss, threshold)
            
            if loss <= target_loss and time_taken < best_time:
                best_time = time_taken
                best_hyperparams = {
                    'Workers': workers,
                    'Epochs': epochs,
                    'Batch Size': batch_size
                }
                
            print(f"Loss: {loss}, Time Taken: {time_taken} seconds\n")
    
    print("Best Hyperparameters:")
    print(best_hyperparams)
    print(f"Time Taken to Reach Target Loss {target_loss}: {best_time} seconds")

# 执行超参数搜索
grid_search(target_loss=0.0001)

Workers: 1, Epochs: 100, Batch Size: 32




Epoch 1 - Loss: 0.6164121730191323
Epoch 2 - Loss: 0.06207950351567318
Epoch 3 - Loss: 0.041385338415042494
Epoch 4 - Loss: 0.030764536195838202
Epoch 5 - Loss: 0.02393719928660236
Epoch 6 - Loss: 0.018890861949896982
Epoch 7 - Loss: 0.014299899279247135
Epoch 8 - Loss: 0.012308043715290842
Epoch 9 - Loss: 0.009279171468408821
Epoch 10 - Loss: 0.007601210348690438
Epoch 11 - Loss: 0.006208370193472062
Epoch 12 - Loss: 0.004149542587747495
Epoch 13 - Loss: 0.004549469416832699
Epoch 14 - Loss: 0.002552504439832448
Epoch 15 - Loss: 0.001315340267587133
Epoch 16 - Loss: 0.001518116956266528
Epoch 17 - Loss: 0.002639307639389749
Epoch 18 - Loss: 0.0033121159193518356
Epoch 19 - Loss: 0.003224943468127361
Epoch 20 - Loss: 0.0025440578684798264
Epoch 21 - Loss: 0.0038233062551118876
Epoch 22 - Loss: 0.0016017602504160323
Epoch 23 - Loss: 0.0002771036236012988
Epoch 24 - Loss: 0.00014229738826560922
Epoch 25 - Loss: 8.309920018693688e-05
Loss: 8.309920018693688e-05, Time Taken: 2340.651429176