In [41]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [42]:
def weigths_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): # data type
        nn.init.xavier_uniform(m.weight.data)
        nn.init.constant_(m.bias, 0.1)

In [43]:
def compute_accuracy(prob_cls, gt_cls): # gt = ground-truth
    # return tensor
    prob_cls = torch.squeeze(prob_cls)
    gt_cls = torch.squeeze(gt_cls)

    # torch.ge tensor 비교 >= true or false
    mask = torch.ge(gt_cls,0)
    valid_gt_cls = torch.masked_select(gt_cls, mask)
    valid_prob_cls = torch.masked_select(prob_cls,mask)

    #computer predicted accuracy
    size = min(valid_gt_cls.size()[0],valid_prob_cls.size()[0])
    prob_ones = torch.ge(valid_prob_cls, 0.6).float()
    right_ones = torch.eq(prob_ones, valid_gt_cls).float()

    return torch.div(torch.mul(torch.sum(right_ones), float(1,0)),float(size))

In [44]:
class LossFn:
    def __init__(self, cls_factor, box_factor, landmark_factor=1):
        # loss function
        self.cls_factor = cls_factor
        self.box_factor = box_factor
        self.land_factor = landmark_factor
        self.loss_cls = nn.BCELoss()
        self.loss_box = nn.MSELoss()
        self.loss_landmark = nn.MSELoss()

    def cls_loss(self, gt_label, pred_label):
        pred_label = torch.squeeze(pred_label)
        gt_label = torch.squeeze(gt_label)
        
        # only use binary for classification which labels 0 and 1
        mask = torch.ge(gt_label, 0)
        valid_gt_label = torch.masked_select(gt_label, mask)
        valid_pred_label = torch.masked_select(pred_label, mask)
        return self.loss_cls(valid_pred_label, vaild_gt_label) * self.cls_factor

    def box_loss(self, gt_label, gt_offset, pred_offset):
        # make ground truth mask 
        pred_offset = torch.squeeze(pred_offset)
        gt_offset = torch.squeeze(gt_offset)
        gt_label = torch.squeeze(gt_label)
        
        #only use positivie sampels
        unmask = torch.eq(gt_labels,0)
        mask = torch.eq(unmask, 0)
        # conver mask to dim index
        choose_ix = torch.nonzero(mask.data)
        choose_ix = torch.squeeze(choose_ix)

        #only valid element can effect the loss
        valid_gt_offset = gt_offset[choose_ix, :]
        valid_pred_offset = pred_offset[choose_ix, :]
        return self.loss_box(valid_pred_offset, valid_gt_offset) * self.box_factor

    def landmark_loss(self, gt_label, gt_landmark, pred_landmark):
        pred_landmark = torch.squeeze(pred_landmark)
        gt_landmark = torch.squeeze(gt_landmark)
        gt_label = torch.squeeze(gt_label)

        #only CelebA data been used in landmark regression
        mask = torch.eq(gt_label, -2)

        choose_ix = torch.nonzero(mask.data)
        choose_ix = torch.squeeze(choose_ix)

        valid_gt_landmark = gt_landmark[choose_ix, :]
        valid_pred_landmark = pred_landmark[choose_ix, :]
        return self.loss(valid_gt_landmark, valid_pred_landmark) * self.land_factor




In [50]:
class PNet(nn.Module):
    def __init__(self, is_train=False, use_cuda=True):
        super(PNet, self).__init__()
        self.is_train = is_train
        self.use_cuda = use_cuda

        self.pre_layer = nn.Sequential(
            nn.Conv2d(3,10,kernel_size=3,stride=1),
            nn.PReLU(),
            nn.MaxPool2d(kernel_size=2,stride=2),
            nn.Conv2d(10,16,kernel_size=3,stride=1),
            nn.PReLU(),
            nn.Conv2d(16,32,kernel_size=3,stride=1),
            nn.PReLU()
        )
        #detection
        self.conv4_1 = nn.Conv2d(32,2,kernel_size=1,stride=1)
        self.conv4_2 = nn.Conv2d(32,4,kernel_size=1,stride=1)
        self.conv4_3 = nn.Conv2d(32,10,kernel_size=1,stride=1)

        self.apply(weigths_init)

    def forward(self,x):
        x=self.pre_layer(x)
        label = F.sigmoid(self.conv4_1(x))
        offset = self.conv4_2(x)

        return label, offset

In [49]:

class RNet(nn.Module):
    def __init__(self,is_train=False, use_cuda=True):
        super(RNet,self).__init__()
        self.is_train = is_train
        self.use_cuda = use_cuda

        self.pre_layer=nn.Sequential(
            nn.Conv2d(3,28,kernel_size=3,stride=1),
            nn.PReLU(),
            nn.MaxPool2d(kernel_size=3,stride=2),
            nn.Conv2d(28,48,kernel_size=3,stride=1),
            nn.PReLU(),
            nn.MaxPool2d(kernel_size=3,stride=2),
            nn.Conv2d(48,64,kernel_size=2,stride=1),
            nn.PReLU()
        )

        self.conv4=nn.Linear(64*3*3,128)
        self.prelu4 = nn.PReLU()
        # face classification
        self.conv5_1 = nn.Linear(128,2)
        # bounding box regression
        self.conv5_2 = nn.Linear(128,4)
        # landmark localization
        self.conv5_3 = nn.Linear(128,10)

        self.apply(weigths_init)

    def forward(self,x):
        x = self.pre_layer(x)
        x = x.view(-1, 64 * 3 *3)
        x = self.conv4(x)
        x = self.prelu4(x)
        # det
        det = torch.sigmoid(self.conv5_1(x))
        box = self.conv5_2(x)

        return det, box
    

In [47]:
class ONet(nn.Module):
    def __init__(self, is_train=False, use_cuda=True):
        super(ONet, self).__init__()
        self.is_train = is_train
        self.use_cuda = use_cuda

        self.pre_layer = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size = 3, stirde = 1),
            nn.PReLU(),
            nn.MaxPool2d(kernel_size = 3, stride =2),
            nn.Conv2d(32, 64, kernel_size = 3, stirde = 1),
            nn.PReLU(),
            nn.MaxPool2d(kernel_size = 3, stride =2),
            nn.Conv2d(64, 64, kernel_size = 3, stirde = 1),
            nn.PReLU(),
            nn.MaxPool2d(kernel_size = 1, stride =2),
            nn.Conv2d(64, 128, kernel_size=3, stride =1),
            nn.PReLU()
        )
        self.conv5 = nn.Linear(128*3*3, 256)
        self.prelu5 = nn.PReLU()
        # face classification
        self.conv6_1 = nn.Linear(256,2)
        # bounding box regression
        self.conv6_2 = nn.Linear(256,4)
        # landmark localization
        self.conv6_3 = nn.Lineasr(256,10)

        self.apply(weigths_init)

    def forward(self, x):
        x = self.pre_layer(x)
        x = view(-1, 128*3*3)
        x = self.conv5(x)
        x = self.prelu5(x)
        # detection
        det = torch.sigmoid(self.conv6_1(x))
        box = self.conv6_2(x)
        landmark = self.conv6_3(x)

        return det,box, landmark
