# 1. Baseline

## Libraries

In [None]:
import os 
from pathlib import Path
import tqdm 
from easydict import EasyDict
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as vutils
from torchvision import transforms 

from data.ms1m import get_train_loader
from data.lfw import LFW
from backbone.arcfacenet import SEResNet_IR
from margin.ArcMarginProduct import ArcMarginProduct

from util.utils import save_checkpoint ,test


## Configuration

In [None]:
conf = EasyDict()
conf.train_root="./dataset/MS1M"
conf.lfw_root="/dataset/lfw_aligned_112"
conf.lfw_file_list ="./dataset/lfw_pair.txt"

conf.mode ="se_ir"
conf.depth=50
conf.margin_type = "Arcface"
conf.feature_dim =512
conf.scale_siz= 32.0
conf.batch_size =96
conf.lr =0.01
conf.milestones =[8,10,12]
conf.total_epoch = 10

conf.save_folder ="./saved"
conf.save_dir =os.path.join(conf.save_folder,conf.mode +"_" +str(conf.depth))
                            

## Data Loader

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))
])
trainloader,class_num = get_train_loader(conf)

In [None]:
import torch.utils
import torch.utils.data


lfw_dataset =LFW(conf.lfw_test_root,conf.lfw_file_list, transform=transform)
lfwloader = torch.utils.data.DataLoader(lfw_dataset,batch_size=128,num_workers=conf.num_workers)

# Model

In [None]:
model = SEResNet_IR(conf.depth, feature_dim=conf.feature_dim, mode=conf.mode).to(conf.device)
margin = ArcMarginProduct(conf.feature_dim ,class_num).to(conf.device)

In [None]:
crt = nn.CrossEntropyLoss()

In [None]:
optimizer =optim.SGD([
    {'params':model.parameters, 'weight_decay':5e-4},
    {'params':margin.parameters(),'weight_decay':5e-4}
],lr=conf.lr, momentum=0.9, nesterov=True)

In [None]:
def lr():
    for i in optimizer.param_groups:
        i['lr'] /=10

        print(optimizer)

## Train

In [None]:
best_acc=0
for i in range(1,conf.total_epoch+1):
    model.train()
    print('epochs{}/{}'.format(i,conf.total_epoch))

    if i == conf.milestones[0]:
        lr()
    if i == conf.milestones[1]:
        lr()
    if i == conf.milestones[2]:
        lr()

    for data in tqdm(trainloader):
        img,label = data[0],data[1]
        optimizer.zero_grad()

        logits =model(img)
        output =margin(logits,label)
        total_loss =crt(output,label)
        total_loss.backward()
        optimizer.step()
    
    model.eval()

    lfw_acc =test(conf,model,lfw_dataset,lfwloader)

    print("\nLFW :{:.4f} | train_loss :{:.4f}\n".format(lfw_acc,total_loss.item()))

    is_best =lfw_acc >best_acc
    best_acc=max(lfw_acc,best_acc)

    save_checkpoint({
        'epoch':epoch,
        'model_state_dict':model.state_dict(),
        'margin_state_dict':margin.state_dict(),
        'best_acc':best_acc
    },is_best,checkpoint=conf.save_dir)