In [2]:
import os
import torch.utils.data
from torch.nn import DataParallel
from datetime import datetime


from backbone.cbam import CBAMResNet, CBAMResNet_ae
from margin.ArcMarginProduct import ArcMarginProduct
from margin.MultiMarginProduct import MultiMarginProduct
from margin.CosineMarginProduct import CosineMarginProduct
from margin.InnerProduct import InnerProduct
from utils.logging import init_log
from dataset.casia_webface import CASIAWebFace
from dataset.lfw import LFW
from torch.optim import lr_scheduler
import torch.optim as optim

import time
from eval_lfw import evaluation_10_fold, getFeatureFromTorch
import numpy as np
import torchvision.transforms as transforms
import argparse

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable


In [3]:
import torch
from PIL import Image
import matplotlib.pyplot as plt

loader = transforms.Compose([
    transforms.ToTensor()])  

unloader = transforms.ToPILImage()

In [4]:
train_root = '../CASIA-WebFace' 
train_file_list = 'names_2000.txt'

lfw_test_root = '../Siamese_lfw_pytorch-master/lfw' 
lfw_file_list = '../pairs.txt'

backbone =  'Res50_IR' 
margin_type = 'Softmax'
feature_dim = 512   
scale_size =  32  
batch_size =  64   
total_epoch = 40   

save_freq =  300  
test_freq =  300  
resume =  False  
net_path =  '' 
margin_path =  ''   
save_dir =  './model'  
model_pre =  '' 

In [5]:
class printer(nn.Module):
    def forward(self, input):
        print(input.size())
        return input

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)


class UnFlatten(nn.Module):
    def forward(self, input, size=256):
        return input.view(input.size(0), size, 5, 5)

class VAE(nn.Module):
    def __init__(self, image_channels=3, h_dim=1024, z_dim=128):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(image_channels, 32, kernel_size=4, stride=2),
            #printer(),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            #printer(),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2),
            #printer(),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2),
            #printer(),
            nn.ReLU(),
            Flatten()
        ).to(device)
        
        self.fc1 = nn.Linear(h_dim, z_dim).to(device)
        self.fc2 = nn.Linear(h_dim, z_dim).to(device)
        self.fc3 = nn.Linear(z_dim, h_dim).to(device)
        
        self.decoder = nn.Sequential(
            UnFlatten(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2),
            #printer(),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2),
            #printer(),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, output_padding = 1),
            #printer(),
            nn.ReLU(),
            nn.ConvTranspose2d(32, image_channels, kernel_size=4, stride=2),
            #printer(),
            nn.Sigmoid()
        ).to(device)
        
    def reparameterize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        # return torch.normal(mu, std)
        esp = torch.randn(*mu.size()).to(device)
        z = mu + std * esp
        return z
    
    def bottleneck(self, h):
        mu, logvar = self.fc1(h), self.fc2(h)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar
        
    def representation(self, x):
        return self.bottleneck(self.encoder(x))[0]

    def forward(self, x):
        h = self.encoder(x) # b 6400
        #print(h.shape)
        z, mu, logvar = self.bottleneck(h)#b 32, b 32, b 32
        z = self.fc3(z)#b 6400
        return self.decoder(h), mu, logvar, h

In [6]:
device = torch.device('cuda')
#device = torch.device('cpu')

save_dir = os.path.join(save_dir, model_pre + backbone.upper() + '_' + datetime.now().strftime('%Y%m%d_%H%M%S'))
if os.path.exists(save_dir):
    raise NameError('model dir exists!')
os.makedirs(save_dir)
logging = init_log(save_dir)
_print = logging.info

# dataset loader
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((112,112)),
    transforms.ToTensor()  # range [0.0, 1. -> [-1.0,1.0]
])
# validation dataset
trainset = CASIAWebFace(train_root, train_file_list, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=512,
                                        shuffle=True, num_workers=4, drop_last=False)


vae = VAE(image_channels=3, h_dim=6400, z_dim=32).to(device)
vae_optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)

dataset size:  96968 / 2000


In [7]:
def adjust_learning_rate(optimizer):
    
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr'] * 0.5

In [8]:
def loss_fn(recon_x, x, mu, logvar):

    BCE = F.mse_loss(recon_x, x, size_average=False).to(device)
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()).to(device)

    return BCE + KLD, BCE, KLD

In [9]:
def imshow(tensor, title=None):
    image = tensor.cpu().clone()  # we clone the tensor to not do changes on it
    image = image.squeeze(0)  # remove the fake batch dimension
    image = unloader(image)
    plt.imshow(image)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated

In [10]:
vae_file_name = 'vae'

In [10]:
# epochs_ae = 50
# max_mse = 9999

# for epoch in range(epochs_ae):
#     vae.train()
#     loss_epoch = []
#     if (epoch+1)%10 == 0:
#         adjust_learning_rate(vae_optimizer)
        
#     for idx, (images, _) in enumerate(trainloader):
        
#         recon_images, mu, logvar, h = vae(images.to(device))
#         #print(images[0])
        
        
#         #print(recon_images)
#         loss, bce, kld = loss_fn(recon_images, images.to(device), mu, logvar)
#         vae_optimizer.zero_grad()
#         loss.backward()
#         vae_optimizer.step()
#         loss_epoch.append(bce.cpu().detach().numpy()/512)
#         if (idx)%100 == 0:
#             imshow(images[0].cpu().detach(), title=None)
#             imshow(recon_images[0].cpu().detach(), title=None)
#             to_print = "Epoch[{}/{}] Loss: {:.3f} {:.3f} {:.6f}".format(epoch+1, 
#                                 epochs_ae, loss.item()/512, bce.item()/512, kld.item()/512)
#             print(to_print)
        
            
            
#     cur_mse = np.mean(np.array(loss_epoch))
#     print("Epoch[{}/{}] Loss: {:.6f}".format(epoch+1, 
#                                 epochs_ae, cur_mse))
#     if cur_mse < max_mse:
#             max_mse = cur_mse
#             state = {
#                 'net': vae.state_dict(),
#                 'optimizer': vae_optimizer.state_dict(),
#                 'epoch': epoch
#             }
#             torch.save(state, vae_file_name)
#             print('\n------------ Save best model ------------\n')

In [11]:
checkpoint = torch.load(vae_file_name)
save_epoch = checkpoint['epoch']
print("last saved model is in epoch {}".format(save_epoch))
vae.load_state_dict(checkpoint['net'])
vae_optimizer.load_state_dict(checkpoint['optimizer'])
vae = vae.to(device)
vae.train()

last saved model is in epoch 49


VAE(
  (encoder): Sequential(
    (0): Conv2d(3, 32, kernel_size=(4, 4), stride=(2, 2))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
    (3): ReLU()
    (4): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2))
    (5): ReLU()
    (6): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2))
    (7): ReLU()
    (8): Flatten()
  )
  (fc1): Linear(in_features=6400, out_features=32, bias=True)
  (fc2): Linear(in_features=6400, out_features=32, bias=True)
  (fc3): Linear(in_features=32, out_features=6400, bias=True)
  (decoder): Sequential(
    (0): UnFlatten()
    (1): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2))
    (2): ReLU()
    (3): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2))
    (4): ReLU()
    (5): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), output_padding=(1, 1))
    (6): ReLU()
    (7): ConvTranspose2d(32, 3, kernel_size=(4, 4), stride=(2, 2))
    (8): Sigmoid()
  )
)

In [None]:
device = torch.device('cuda')
#device = torch.device('cpu')
# log init
save_dir = os.path.join(save_dir, model_pre + backbone.upper() + '_' + datetime.now().strftime('%Y%m%d_%H%M%S'))
if os.path.exists(save_dir):
    raise NameError('model dir exists!')
os.makedirs(save_dir)
logging = init_log(save_dir)
_print = logging.info

# dataset loader
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((112,112)),
    transforms.ToTensor()
])
# validation dataset
trainset = CASIAWebFace(train_root, train_file_list, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=96,
                                          shuffle=True, num_workers=4, drop_last=False)
# test dataset
lfwdataset = LFW(lfw_test_root, lfw_file_list, transform=transform)
lfwloader = torch.utils.data.DataLoader(lfwdataset, batch_size=96,
                                        shuffle=False, num_workers=4, drop_last=False)



if backbone == 'MobileFace':
    net = MobileFaceNet()
elif backbone == 'Res50_IR':
    net = CBAMResNet(50, feature_dim=feature_dim, mode='ir')
elif backbone == 'SERes50_IR':
    net = CBAMResNet(50, feature_dim=args.feature_dim, mode='ir_se')
elif backbone == 'Res100_IR':
    net = CBAMResNet(100, feature_dim=args.feature_dim, mode='ir')
elif backbone == 'SERes100_IR':
    net = CBAMResNet(100, feature_dim=args.feature_dim, mode='ir_se')
else:
    print(backbone, ' is not available!')

net = CBAMResNet_ae(device, vae, 50, feature_dim=feature_dim, mode='ir').to(device)

if margin_type == 'ArcFace':
    margin = ArcMarginProduct(feature_dim, trainset.class_nums, s=scale_size)
elif margin_type == 'MultiMargin':
    margin = MultiMarginProduct(args.feature_dim, trainset.class_nums, s=args.scale_size)
elif margin_type == 'CosFace':
    margin = CosineMarginProduct(args.feature_dim, trainset.class_nums, s=args.scale_size)
elif margin_type == 'Softmax':
    margin = InnerProduct(feature_dim, trainset.class_nums)
else:
    print(margin_type, 'is not available!')

if resume:
    print('resume the model parameters from: ', net_path, margin_path)
    net.load_state_dict(torch.load(net_path)['net_state_dict'])
    margin.load_state_dict(torch.load(margin_path)['net_state_dict'])

# define optimizers for different layer
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer_ft = optim.SGD([
    {'params': net.parameters(), 'weight_decay': 5e-4},
    {'params': margin.parameters(), 'weight_decay': 5e-4}
], lr=0.1, momentum=0.9, nesterov=True)

exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer_ft, milestones=[8, 13, 18, 23, 28, 33, 38], gamma=0.2)

net = net.to(device)
margin = margin.to(device)


best_lfw_acc = 0.0
best_lfw_iters = 0

total_iters = 0

for epoch in range(1, total_epoch + 1):
    exp_lr_scheduler.step()
    # train model
    print('Train Epoch: {}/{} ...'.format(epoch, total_epoch))
    net.train()

    since = time.time()
    for data in trainloader:
        img, label = data[0].to(device), data[1].to(device)
        #print(img.shape)
        optimizer_ft.zero_grad()

        raw_logits = net(img)
        output = margin(raw_logits, label)
        total_loss = criterion(output, label)
        total_loss.backward()
        optimizer_ft.step()

        total_iters += 1
        # print train information
        if total_iters % 100 == 0:
            
            # current training accuracy
            _, predict = torch.max(output.data, 1)
            total = label.size(0)
            correct = (np.array(predict.cpu()) == np.array(label.data.cpu())).sum()
            time_cur = (time.time() - since) / 100
            since = time.time()


            print("Iters: {:0>6d}/[{:0>2d}], loss: {:.4f}, train_accuracy: {:.4f}, time: {:.2f} s/iter, learning rate: {}".format(total_iters, epoch, total_loss.item(), correct/total, time_cur, exp_lr_scheduler.get_lr()[0]))

        # save model
        if total_iters % save_freq == 0:
            msg = 'Saving checkpoint: {}'.format(total_iters)
            print(msg)
            
            net_state_dict = net.state_dict()
            margin_state_dict = margin.state_dict()
            if not os.path.exists(save_dir):
                os.mkdir(save_dir)
            torch.save({
                'iters': total_iters,
                'net_state_dict': net_state_dict},
                os.path.join(save_dir, 'Iter_%06d_net.ckpt' % total_iters))
            torch.save({
                'iters': total_iters,
                'net_state_dict': margin_state_dict},
                os.path.join(save_dir, 'Iter_%06d_margin.ckpt' % total_iters))

        # test accuracy
        if total_iters % test_freq == 0:
            with torch.no_grad(): 

                # test model on lfw
                net.eval()
                getFeatureFromTorch('./result/cur_lfw_result.mat', net, device, lfwdataset, lfwloader)
                lfw_accs = evaluation_10_fold('./result/cur_lfw_result.mat')
                print('LFW Ave Accuracy: {:.4f}'.format(np.mean(lfw_accs) * 100))
                if best_lfw_acc <= np.mean(lfw_accs) * 100:
                    best_lfw_acc = np.mean(lfw_accs) * 100
                    best_lfw_iters = total_iters


        net.train()

print('Finally Best Accuracy: LFW: {:.4f} in iters: {}'.format(best_lfw_acc, best_lfw_iters))
print('finishing training')

In [None]:
%run dataset/casia_webface.py