# **Export Blendshape Deformation Model**

In [3]:
import torch
from torch import nn
import json
import numpy as np


class BlendShape(nn.Module):
    """ Module to deform head model based on blendshapes for PnP solve

    Attributes:
      data (int): json data containing 3d points for face and blendshapes
      bs_list (list): keys for json dictionary
      face (torch.tensor): tensor containing 3d points for face
      bs_tensor (torch.tensor): tensor containing 3d points for blendshapes
      landmark_ids (torch.tensor): ids used for PnP solve
      lock_mask (torch.tensor): mask for which ids consider blend shapes
      lock_eyes_nose (bool): whether or not to use lock mask
    """
    def __init__(self, blendshapes='./data/bs_points_a.json'):
        """ Initiate class

          Args:
            blendshapes (string): json file containing blendshapes
        """

        super(BlendShape, self).__init__()

        with open(blendshapes) as json_file:
            self.data = json.load(json_file)

        self.bs_list = ['BS.Mesh'] + [f'BS.Mesh{num}' for num in range(1, 51)]

        keys = self.data['default'].keys()
        self.register_buffer('face', torch.tensor(
            np.array([self.data['default'][k] for k in keys ])))

        bs_list = []
        for key in self.bs_list:
            bs_list.append([self.data['blend_shapes'][key][k] for k in keys])

        self.register_buffer('bs_tensor',
                             torch.tensor(np.array(bs_list)))

        self.register_buffer('landmark_ids',
                             torch.tensor([18,2,24,33,36,42]))

        self.register_buffer('lock_mask',
                             torch.tensor([0.,1.,0.,0.,1.,1.]))
        
        self.lock_eyes_nose = True


    def forward(self, y_hat):
        """ Forward pass

          Args:
            y_hat (torch.tensor): predicted blendshape values

          Returns:
            torch.tensor: face points deformed by blend shapes
        """
        
        selected_bs_points = torch.index_select(self.bs_tensor,
                                                1,self.landmark_ids)
        
        selected_face_points = torch.index_select(self.face,
                                                  0,self.landmark_ids)
        
        y_hat_blend_weighted = selected_bs_points * y_hat[:, None, None]

        if self.lock_eyes_nose:
          final_blend = y_hat_blend_weighted.sum(dim=0)*self.lock_mask[:,None]
        else:
          final_blend = y_hat_blend_weighted.sum(dim=0)

        y_hat_face = selected_face_points + final_blend

        return y_hat_face



In [4]:

bs_jit = torch.jit.script(BlendShape().cuda())
bs_jit.save('blendshape_model.ptc')
