In [1]:
#------------------------------------------------#
#   进行训练前需要利用这个文件生成cls_train.txt
#------------------------------------------------#
import os

In [2]:
if __name__ == "__main__":
    #---------------------#
    #   训练集所在的路径
    #---------------------#
    datasets_path   = "/root/autodl-nas/data/facenet/datasets"

    types_name      = os.listdir(datasets_path)
    types_name      = sorted(types_name)

    list_file = open('cls_train.txt', 'w')
    for cls_id, type_name in enumerate(types_name):
        photos_path = os.path.join(datasets_path, type_name)
        if not os.path.isdir(photos_path):
            continue
        photos_name = os.listdir(photos_path)

        for photo_name in photos_name:
            list_file.write(str(cls_id) + ";" + '%s'%(os.path.join(os.path.abspath(datasets_path), type_name, photo_name)))
            list_file.write('\n')
    list_file.close()

In [None]:
loss            = triplet_loss()
#----------------------#
#   记录Loss
#----------------------#
if local_rank == 0:
    loss_history = LossHistory(save_dir, model, input_shape=input_shape)
else:
    loss_history = None

#------------------------------------------------------------------#
#   torch 1.2不支持amp，建议使用torch 1.7.1及以上正确使用fp16
#   因此torch1.2这里显示"could not be resolve"
#------------------------------------------------------------------#
if fp16:
    from torch.cuda.amp import GradScaler as GradScaler
    scaler = GradScaler()
else:
    scaler = None

model_train     = model.train()
#----------------------------#
#   多卡同步Bn
#----------------------------#
if sync_bn and ngpus_per_node > 1 and distributed:
    model_train = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_train)
elif sync_bn:
    print("Sync_bn is not support in one gpu or not distributed.")

if Cuda:
    if distributed:
        #----------------------------#
        #   多卡平行运行
        #----------------------------#
        model_train = model_train.cuda(local_rank)
        model_train = torch.nn.parallel.DistributedDataParallel(model_train, device_ids=[local_rank], find_unused_parameters=True)
    else:
        model_train = torch.nn.DataParallel(model)
        cudnn.benchmark = True
        model_train = model_train.cuda()

In [None]:
#---------------------------------#
#   LFW估计
#---------------------------------#
LFW_loader = torch.utils.data.DataLoader(
    LFWDataset(dir=lfw_dir_path, pairs_path=lfw_pairs_path, image_size=input_shape), batch_size=32, shuffle=False) if lfw_eval_flag else None

#-------------------------------------------------------#
#   0.01用于验证，0.99用于训练
#-------------------------------------------------------#
val_split = 0.01
with open(annotation_path,"r") as f:
    lines = f.readlines()
np.random.seed(10101)
np.random.shuffle(lines)
np.random.seed(None)
num_val = int(len(lines)*val_split)
num_train = len(lines) - num_val

show_config(
    num_classes = num_classes, backbone = backbone, model_path = model_path, input_shape = input_shape, \
    Init_Epoch = Init_Epoch, Epoch = Epoch, batch_size = batch_size, \
    Init_lr = Init_lr, Min_lr = Min_lr, optimizer_type = optimizer_type, momentum = momentum, lr_decay_type = lr_decay_type, \
    save_period = save_period, save_dir = save_dir, num_workers = num_workers, num_train = num_train, num_val = num_val
)

if True:
    if batch_size % 3 != 0:
        raise ValueError("Batch_size must be the multiple of 3.")
    #-------------------------------------------------------------------#
    #   判断当前batch_size，自适应调整学习率
    #-------------------------------------------------------------------#
    nbs             = 64
    lr_limit_max    = 1e-3 if optimizer_type == 'adam' else 1e-1
    lr_limit_min    = 3e-4 if optimizer_type == 'adam' else 5e-4
    Init_lr_fit     = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max)
    Min_lr_fit      = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2)

    #---------------------------------------#
    #   根据optimizer_type选择优化器
    #---------------------------------------#
    optimizer = {
        'adam'  : optim.Adam(model.parameters(), Init_lr_fit, betas = (momentum, 0.999), weight_decay = weight_decay),
        'sgd'   : optim.SGD(model.parameters(), Init_lr_fit, momentum=momentum, nesterov=True, weight_decay = weight_decay)
    }[optimizer_type]

    #---------------------------------------#
    #   获得学习率下降的公式
    #---------------------------------------#
    lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, Epoch)

    #---------------------------------------#
    #   判断每一个世代的长度
    #---------------------------------------#
    epoch_step      = num_train // batch_size
    epoch_step_val  = num_val // batch_size

    if epoch_step == 0 or epoch_step_val == 0:
        raise ValueError("数据集过小，无法继续进行训练，请扩充数据集。")

    #---------------------------------------#
    #   构建数据集加载器。
    #---------------------------------------#
    train_dataset   = FacenetDataset(input_shape, lines[:num_train], num_classes, random = True)
    val_dataset     = FacenetDataset(input_shape, lines[num_train:], num_classes, random = False)

    if distributed:
        train_sampler   = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True,)
        val_sampler     = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False,)
        batch_size      = batch_size // ngpus_per_node
        shuffle         = False
    else:
        train_sampler   = None
        val_sampler     = None
        shuffle         = True

    gen             = DataLoader(train_dataset, shuffle=shuffle, batch_size=batch_size//3, num_workers=num_workers, pin_memory=True,
                            drop_last=True, collate_fn=dataset_collate, sampler=train_sampler)
    gen_val         = DataLoader(val_dataset, shuffle=shuffle, batch_size=batch_size//3, num_workers=num_workers, pin_memory=True,
                            drop_last=True, collate_fn=dataset_collate, sampler=val_sampler)

    for epoch in range(Init_Epoch, Epoch):
        if distributed:
            train_sampler.set_epoch(epoch)

        set_optimizer_lr(optimizer, lr_scheduler_func, epoch)

        fit_one_epoch(model_train, model, loss_history, loss, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, Cuda, LFW_loader, batch_size//3, lfw_eval_flag, fp16, scaler, save_period, save_dir, local_rank)

    if local_rank == 0:
        loss_history.writer.close()