In [197]:
import cv2
import json
import numpy as np
import sys
sys.path.append('../')
from spiga.models.spiga import SPIGA
import torch.nn.functional as F

# Load image and bbox
image = cv2.imread("../assets/colab/image_sportsfan.jpg")
with open('../assets/colab/bbox_sportsfan.json') as jsonfile:
    bbox = json.load(jsonfile)['bbox']

In [198]:
from spiga.models.spiga import SPIGA

spiga = SPIGA()

In [199]:
import os
import pkg_resources
import copy
import torch
import numpy as np

# Paths
weights_path_dft = pkg_resources.resource_filename('spiga', 'models/weights')

import spiga.inference.pretreatment as pretreat
from spiga.models.spiga import SPIGA
from spiga.inference.config import ModelConfig


class SPIGAFramework:

    def __init__(self, model_cfg: ModelConfig(), gpus=[0], load3DM=True):

        # Parameters
        self.model_cfg = model_cfg
        self.gpus = gpus

        # Pretreatment initialization
        self.transforms = pretreat.get_transformers(self.model_cfg)

        # SPIGA model
        self.model_inputs = ['image', "model3d", "cam_matrix"]
        self.model = SPIGA(num_landmarks=model_cfg.dataset.num_landmarks,
                           num_edges=model_cfg.dataset.num_edges)

        # Load weights and set model
        weights_path = self.model_cfg.model_weights_path
        if weights_path is None:
            weights_path = weights_path_dft

        # if self.model_cfg.load_model_url:
        #     model_state_dict = torch.hub.load_state_dict_from_url(self.model_cfg.model_weights_url,
        #                                                           model_dir=weights_path,
        #                                                           file_name=self.model_cfg.model_weights)
        # else:
        #     weights_file = os.path.join(weights_path, self.model_cfg.model_weights)
        #     model_state_dict = torch.load(weights_file)

        # self.model.load_state_dict(model_state_dict)
        # self.model = self.model.cuda(gpus[0])
        self.model.eval()
        print('SPIGA model loaded!')

        # Load 3D model and camera intrinsic matrix
        if load3DM:
            loader_3DM = pretreat.AddModel3D(model_cfg.dataset.ldm_ids,
                                             ftmap_size=model_cfg.ftmap_size,
                                             focal_ratio=model_cfg.focal_ratio,
                                             totensor=True)
            params_3DM = self._data2device(loader_3DM())
            self.model3d = params_3DM['model3d']
            self.cam_matrix = params_3DM['cam_matrix']

    def inference(self, image, bboxes):
        batch_crops, crop_bboxes = self.pretreat(image, bboxes)
        outputs = self.net_forward(batch_crops)
        features = self.postreatment(outputs, crop_bboxes, bboxes)
        return features

    def pretreat(self, image, bboxes):
        crop_bboxes = []
        crop_images = []
        for bbox in bboxes:
            sample = {'image': copy.deepcopy(image),
                      'bbox': copy.deepcopy(bbox)}
            sample_crop = self.transforms(sample)
            crop_bboxes.append(sample_crop['bbox'])
            crop_images.append(sample_crop['image'])

        # Images to tensor and device
        batch_images = torch.tensor(np.array(crop_images), dtype=torch.float)
        # batch_images = self._data2device(batch_images)
        # Batch 3D model and camera intrinsic matrix
        batch_model3D = self.model3d.unsqueeze(0).repeat(len(bboxes), 1, 1)
        batch_cam_matrix = self.cam_matrix.unsqueeze(0).repeat(len(bboxes), 1, 1)

        # SPIGA inputs
        model_inputs = [batch_images, batch_model3D, batch_cam_matrix]
        return model_inputs, crop_bboxes

    def net_forward(self, inputs):
        outputs = self.model(inputs)
        return outputs

    def postreatment(self, output, crop_bboxes, bboxes):
        features = {}
        crop_bboxes = np.array(crop_bboxes)
        bboxes = np.array(bboxes)

        if 'Landmarks' in output.keys():
            landmarks = output['Landmarks'][-1].cpu().detach().numpy()
            landmarks = landmarks.transpose((1, 0, 2))
            landmarks = landmarks*self.model_cfg.image_size
            landmarks_norm = (landmarks - crop_bboxes[:, 0:2]) / crop_bboxes[:, 2:4]
            landmarks_out = (landmarks_norm * bboxes[:, 2:4]) + bboxes[:, 0:2]
            landmarks_out = landmarks_out.transpose((1, 0, 2))
            features['landmarks'] = landmarks_out.tolist()

        # Pose output
        if 'Pose' in output.keys():
            pose = output['Pose'].cpu().detach().numpy()
            features['headpose'] = pose.tolist()

        return features

    def select_inputs(self, batch):
        inputs = []
        for ft_name in self.model_inputs:
            data = batch[ft_name]
            inputs.append(self._data2device(data.type(torch.float)))
        return inputs

    def _data2device(self, data):
        if isinstance(data, list):
            data_var = data
            for data_id, v_data in enumerate(data):
                data_var[data_id] = self._data2device(v_data)
        if isinstance(data, dict):
            data_var = data
            for k, v in data.items():
                data[k] = self._data2device(v)
        else:
            with torch.no_grad():
                data_var = data.cuda(device=self.gpus[0], non_blocking=True)
        return data_var


In [200]:
from spiga.inference.config import ModelConfig

dataset = 'wflw'
processor = SPIGAFramework(ModelConfig(dataset))
# features = processor.inference(image, [bbox])

SPIGA model loaded!


In [201]:
model_inputs, crop_bboxes = processor.pretreat(image, [bbox])
[batch_images, batch_model3D, batch_cam_matrix] = model_inputs

# batch_images = batch_images.expand(4, -1, -1, -1)
# batch_model3D = batch_model3D.expand(4, -1, -1, -1)
# batch_cam_matrix = batch_cam_matrix.expand(4, -1, -1, -1)
# model_inputs = [batch_images, batch_model3D, batch_cam_matrix]

In [202]:
batch_images.device

device(type='cpu')

In [203]:
print(batch_images.shape)
print(batch_model3D.shape)
print(batch_cam_matrix.shape)
print(crop_bboxes[0])

torch.Size([1, 3, 256, 256])
torch.Size([1, 98, 3])
torch.Size([1, 3, 3])
[ 48.          50.2230171  160.         155.55396581]


In [204]:
spiga = SPIGA(num_landmarks=98, num_edges=15)
# pts_proj, features = model.backbone_forward(model_inputs)
imgs = model_inputs[0]
model3d = model_inputs[1]
cam_matrix = model_inputs[2]

print(imgs.shape)
# duplicate it to make batch data
imgs = imgs.expand(4,-1,-1,-1)
print(imgs.shape)

# HourGlass Forward
features = spiga.visual_cnn(imgs)

torch.Size([1, 3, 256, 256])
torch.Size([4, 3, 256, 256])


In [205]:
features['VisualField'][0].shape

torch.Size([4, 256, 64, 64])

In [206]:
import spiga.models.gnn.pose_proj as pproj

pose_raw = features['HGcore'][-1]
B, L, _, _ = pose_raw.shape
pose = pose_raw.reshape(B, L)
pose = spiga.pose_fc(pose)
features['Pose'] = pose.clone()

euler = pose[:, 0:3]
trl = pose[:, 3:]
rot = pproj.euler_to_rotation_matrix(euler)

model3d = model3d.to('cpu')
cam_matrix = cam_matrix.to('cpu')

pts_proj = pproj.projectPoints(model3d, rot, trl, cam_matrix)
pts_proj = pts_proj / spiga.visual_res

In [207]:
print(pts_proj[0].shape)
print(features.keys())

torch.Size([98, 2])
dict_keys(['VisualField', 'HGcore', 'Pose'])


In [208]:
visual_field = features['VisualField'][-1]
print(visual_field.shape)

torch.Size([4, 256, 64, 64])


In [209]:
embedded_ft = spiga.extract_embedded(pts_proj, visual_field, 1)



In [210]:
embedded_ft.shape

torch.Size([4, 98, 512])

In [211]:
step = 0
B, L, _ = pts_proj.shape  # Pts_proj range:[0,1]
centers = pts_proj + 0.5 / spiga.visual_res  # BxLx2
centers = centers.reshape(B * L, 2)  # B*Lx2
theta_trl = (-1 + centers * 2).unsqueeze(-1)  # BxLx2x1
theta_s = spiga.theta_S[step]  # 2x2
theta_s = theta_s.repeat(B * L, 1, 1)  # B*Lx2x2
theta = torch.cat((theta_s, theta_trl), -1)  # B*Lx2x3

In [212]:
# Generate crop grid
B, C, _, _ = visual_field.shape
grid = torch.nn.functional.affine_grid(theta, (B * L, C, spiga.kwindow, spiga.kwindow))
grid = grid.reshape(B, L, spiga.kwindow, spiga.kwindow, 2)
grid = grid.reshape(B, L, spiga.kwindow * spiga.kwindow, 2)

# Crop windows
crops = torch.nn.functional.grid_sample(visual_field, grid, padding_mode="border")  # BxCxLxK*K
crops = crops.transpose(1, 2)  # BxLxCxK*K
crop_tmp = crops.clone()
crops = crops.reshape(B * L, C, spiga.kwindow, spiga.kwindow)

crops.shape

torch.Size([392, 256, 7, 7])

In [217]:
crop_tmp = crop_tmp.reshape(B, L, C, spiga.kwindow, spiga.kwindow)
crop_tmp.shape

torch.Size([4, 98, 256, 7, 7])

In [215]:
visual_field.shape

torch.Size([4, 256, 64, 64])

In [218]:
crop_tmp[0][0]

tensor([[[3.5762, 3.5762, 3.5762,  ..., 3.5762, 3.5762, 3.5762],
         [3.5762, 3.5762, 3.5762,  ..., 3.5762, 3.5762, 3.5762],
         [3.5762, 3.5762, 3.5762,  ..., 3.5762, 3.5762, 3.5762],
         ...,
         [3.5762, 3.5762, 3.5762,  ..., 3.5762, 3.5762, 3.5762],
         [3.5762, 3.5762, 3.5762,  ..., 3.5762, 3.5762, 3.5762],
         [3.5762, 3.5762, 3.5762,  ..., 3.5762, 3.5762, 3.5762]],

        [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.8879, 0.8879, 0.8879,  ..., 0.8879, 0.8879, 0.8879],
         [0.8879, 0.8879, 0.8879,  ..., 0.8879, 0.8879, 0.8879],
         [0.8879, 0.8879, 0.8879,  ..., 0.8879, 0.8879, 0.

In [176]:
# Flatten features
visual_ft = spiga.conv_window[step](crops)
print(visual_ft.shape)
_, Cout, _, _ = visual_ft.shape
visual_ft = visual_ft.reshape(B, L, Cout)
print(visual_ft.shape)

torch.Size([392, 512, 1, 1])
torch.Size([4, 98, 512])


In [124]:
# Params compute only once
gat_prob = []
features['Landmarks'] = []
for step in range(spiga.steps):
    # Features generation
    embedded_ft = spiga.extract_embedded(pts_proj, visual_field, step)

    # GAT inference
    offset, gat_prob = spiga.gcn[step](embedded_ft, gat_prob)
    offset = F.hardtanh(offset)

    # Update coordinates
    pts_proj = pts_proj + spiga.offset_ratio[step] * offset
    features['Landmarks'].append(pts_proj.clone())
    break

features['GATProb'] = gat_prob



In [133]:
print(features.keys())
features['Landmarks'][0].shape

dict_keys(['VisualField', 'HGcore', 'Pose', 'Landmarks', 'GATProb'])


torch.Size([4, 98, 2])

In [177]:
shape_ft = spiga.calculate_distances(pts_proj)
shape_ft = spiga.shape_encoder[step](shape_ft)
# Addition
embedded_ft = visual_ft + shape_ft
print(embedded_ft.shape)

torch.Size([4, 98, 512])


In [185]:
gat_prob[0].shape

torch.Size([4, 4, 98, 98])

In [None]:
# GAT inference
offset, gat_prob = spiga.gcn[step](embedded_ft, gat_prob)
offset = F.hardtanh(offset)

# Update coordinates
pts_proj = pts_proj + spiga.offset_ratio[step] * offset
features['Landmarks'].append(pts_proj.clone())

features['GATProb'] = gat_prob