In [1]:
import torch
import torch.nn as nn
import torchvision
from torch.utils.tensorboard import SummaryWriter
from d2l import trainer

# 定义预训练模型的权重

In [2]:
pretrain_model = torchvision.models.get_model_weights("resnet50").DEFAULT

# 创建Flower数据集

In [3]:
transform = pretrain_model.transforms()
flower_dataset = torchvision.datasets.ImageFolder("../data/flowers/", transform)

In [4]:
flower_dataset.classes

['daisy', 'dandelion', 'rose', 'sunflower', 'tulip']

In [5]:
validation_rate = 0.2
tr_dataset, va_dataset = torch.utils.data.random_split(flower_dataset, [0.8, 0.2])
print(
    f"training dataset size: {len(tr_dataset)}, validation dataset size: {len(va_dataset)}"
)
tr_dataloader = torch.utils.data.DataLoader(tr_dataset, batch_size=32, shuffle=True)
va_dataloader = torch.utils.data.DataLoader(va_dataset, batch_size=32, shuffle=False)

training dataset size: 3454, validation dataset size: 863


# 模型微调

In [6]:
model = torchvision.models.resnet50(weights=pretrain_model)


# 将预训练模型的全连接层换成一个新的全连接层，参数重新初始化
model.fc = nn.Linear(model.fc.in_features, len(flower_dataset.classes))
nn.init.xavier_uniform_(model.fc.weight)

# 以于预训练的参数的学习率要，比常规的学习率低1个数量级左右
learning_rate = 1e-3
epoch = 10
num_device = 2
devices = [torch.device("cuda:{}".format(i)) for i in range(num_device)]

parallel_model = nn.DataParallel(model, device_ids=devices)
parallel_model.to(device=devices[0])
loss = nn.CrossEntropyLoss()

# 构建参数组
params_feature = [
    params
    for name, params in model.named_parameters()
    if name not in ["fc.weight", "fc.bias"]
]
# 为不同的参数组，设置不同的学习率
optimizer = torch.optim.SGD(
    [
        {"params": params_feature},
        {"params": model.fc.parameters(), "lr": learning_rate * 10},
    ],
    lr=learning_rate,
    weight_decay=0.001,
)

writer = SummaryWriter("logs/flower_finetune")

trainer.nn_train(
    tr_dataloader,
    va_dataloader,
    parallel_model,
    loss,
    optimizer,
    devices=devices,
    epoch=epoch,
    writer=writer,
)

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /home/yangyansheng/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:17<00:00, 5.71MB/s]


epoch 10: train_loss: 0.3028, train_acc:  0.9102, test_acc:  0.9340
