In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from model import ResNet18, MoCo
from data_preparation import download_mini_imagenet, get_mini_imagenet_data, get_cifar100_data
from training_finetuning import train_moco, finetune_and_validate


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_dir = "./data"
log_dir = "./runs/experiment"

# 下载和预处理数据
download_mini_imagenet(data_dir)
mini_imagenet_loader = get_mini_imagenet_data(data_dir, batch_size=128)
cifar100_train_loader, cifar100_val_loader = get_cifar100_data(data_dir, batch_size=128)

# 超参数设置
learning_rates = [0.001, 0.01]
batch_sizes = [64, 128]
best_params = None
best_accuracy = 0

# 超参数调优
for lr in learning_rates:
    for batch_size in batch_sizes:
        # 更新数据加载器的batch size
        mini_imagenet_loader = get_mini_imagenet_data(data_dir, batch_size=batch_size)
        cifar100_train_loader, cifar100_val_loader = get_cifar100_data(data_dir, batch_size=batch_size)

        # 初始化模型和优化器
        moco_model = MoCo(base_encoder=ResNet18).to(device)
        moco_optimizer = optim.Adam(moco_model.parameters(), lr=lr)

        # 训练MoCo模型
        for epoch in range(1, 5):  # 简化为少量轮次
            train_moco(moco_model, mini_imagenet_loader, nn.CrossEntropyLoss(), moco_optimizer, device, epoch, log_interval=10, writer=None)

        # 微调并验证ResNet-18模型
        resnet18_model = ResNet18(pretrained=True).to(device)
        resnet18_optimizer = optim.Adam(resnet18_model.parameters(), lr=lr)

        for epoch in range(1, 5):  # 简化为少量轮次
            finetune_and_validate(resnet18_model, cifar100_train_loader, cifar100_val_loader, nn.CrossEntropyLoss(), resnet18_optimizer, device, epoch, log_interval=10, writer=None, phase='finetune')
            accuracy = finetune_and_validate(resnet18_model, cifar100_train_loader, cifar100_val_loader, nn.CrossEntropyLoss(), resnet18_optimizer, device, epoch, log_interval=10, writer=None, phase='validate')

            if accuracy > best_accuracy:
                best_accuracy = accuracy
                best_params = (lr, batch_size)

# 使用最优超参数重新训练模型并可视化
best_lr, best_batch_size = best_params
mini_imagenet_loader = get_mini_imagenet_data(data_dir, batch_size=best_batch_size)
cifar100_train_loader, cifar100_val_loader = get_cifar100_data(data_dir, batch_size=best_batch_size)

moco_model = MoCo(base_encoder=ResNet18).to(device)
moco_optimizer = optim.Adam(moco_model.parameters(), lr=best_lr)
moco_writer = SummaryWriter(log_dir + "/moco")

# 训练MoCo模型
for epoch in range(1, 10):  # 使用更多的轮次进行完整训练
    train_moco(moco_model, mini_imagenet_loader, nn.CrossEntropyLoss(), moco_optimizer, device, epoch, log_interval=10, writer=moco_writer)

resnet18_model = ResNet18(pretrained=True).to(device)
resnet18_optimizer = optim.Adam(resnet18_model.parameters(), lr=best_lr)
resnet_writer = SummaryWriter(log_dir + "/resnet18")

# 微调并验证ResNet-18模型
for epoch in range(1, 10):  # 使用更多的轮次进行完整训练
    finetune_and_validate(resnet18_model, cifar100_train_loader, cifar100_val_loader, nn.CrossEntropyLoss(), resnet18_optimizer, device, epoch, log_interval=10, writer=resnet_writer, phase='finetune')
    finetune_and_validate(resnet18_model, cifar100_train_loader, cifar100_val_loader, nn.CrossEntropyLoss(), resnet18_optimizer, device, epoch, log_interval=10, writer=resnet_writer, phase='validate')

moco_writer.close()
resnet_writer.close()
