# Head Estimation, Pose Conversion & Visualization

## 1. Load and Process Keypoints JSON

In [None]:
indices = [0, 1, 2, 3, 4, 5, 6, 11, 12]

def access_elements(data, indices):
    return [data[i] for i in indices]

In [None]:
import json

# Load person_keypoints_train2017.json and extract poses which have all of their body keypoints
'''
"keypoints": [
                        "nose",
                        "left_eye",
                        "right_eye",
                        "left_ear",
                        "right_ear",
                        
                        "left_shoulder",
                        "right_shoulder",
                        "left_elbow",
                        "right_elbow",
                        "left_wrist",
                        "right_wrist",

                        "left_hip",
                        "right_hip",
                        "left_knee",
                        "right_knee",
                        "left_ankle",
                        "right_ankle"
                    ],
'''

def extract_poses(file_path, scale=2):
    with open(file_path, 'r') as f:
        data = json.load(f)

    poses_list = []

    for annotation in data['annotations']:
        visibility = access_elements(annotation['keypoints'][2::3], indices)

        if 0 not in visibility:
            keypoints_x = access_elements(annotation['keypoints'][0::3], indices)
            keypoints_y = access_elements(annotation['keypoints'][1::3], indices)

            max_x = max(keypoints_x)
            min_x = min(keypoints_x)
            max_y = max(keypoints_y)
            min_y = min(keypoints_y)

            height = max_y - min_y
            width = max_x - min_x

            if height > width:
                keypoints_x = [((x - min_x) / height - 0.5) * scale for x in keypoints_x]
                keypoints_y = [((y - min_y) / height - 0.5) * scale for y in keypoints_y]
            else:
                keypoints_x = [((x - min_x) / width - 0.5) * scale for x in keypoints_x]
                keypoints_y = [((y - min_y) / width - 0.5) * scale for y in keypoints_y]

            poses_list.append((keypoints_x[5:] + keypoints_y[5:], keypoints_x[:5] + keypoints_y[:5]))

    return poses_list

# File path to the JSON file
file_path = 'person_keypoints_train2017.json'

# Extract keypoints
extracted_poses = extract_poses(file_path)

# Print the result
result_len = len(extracted_poses)
print(extracted_poses[:5])
print(f"Number of extracted poses: {result_len}")

split = 0.8

# Save the result
with open("train.json", "w") as train:
    json.dump(extracted_poses[:int(result_len * split)], train)

with open("test.json", "w") as test:
    json.dump(extracted_poses[int(result_len * split):], test)

## 2. Load Libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset

print(f"MPS Support: {torch.backends.mps.is_built()}")
print(f"MPS Availability: {torch.backends.mps.is_available()}")

In [None]:
device = None

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f'Device: {device}')

## 3. Dataloader and Model Definition

### Dataloader

In [None]:
split = 0.8

class TrainDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        train_path = data_dir + "/train.json"

        with open(train_path, 'r') as f:
            data = json.load(f)
            data_len = len(data)
            self.data = data[:int(data_len * split)]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return torch.tensor(self.data[index][0]), torch.tensor(self.data[index][1])


class ValDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        train_path = data_dir + "/train.json"

        with open(train_path, 'r') as f:
            data = json.load(f)
            data_len = len(data)
            self.data = data[int(data_len * split):]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return torch.tensor(self.data[index][0]), torch.tensor(self.data[index][1])
    
    
class TestDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        test_path = data_dir + "/test.json"

        with open(test_path, 'r') as f:
            self.data = json.load(f)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return torch.tensor(self.data[index][0]), torch.tensor(self.data[index][1])

In [None]:
import json

train_dataset = TrainDataset(data_dir='.')
val_dataset = ValDataset(data_dir='.')
test_dataset = TestDataset(data_dir='.')

print(len(train_dataset))
print(len(val_dataset))
print(len(test_dataset))

### Model

In [None]:
# Define the model
class HeadEstimator(nn.Module):
    def __init__(self):
        super(HeadEstimator, self).__init__()
        self.fc1 = nn.Linear(8, 64)
        self.fc2 = nn.Linear(64, 512)
        self.fc3 = nn.Linear(512, 512)
        self.fc4 = nn.Linear(512, 10)
        self.dropout = nn.Dropout(0.2)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(512)
        self.bn3 = nn.BatchNorm1d(512)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight.data)
                nn.init.zeros_(m.bias.data)
                

    def forward(self, x):
        x = torch.relu(self.bn1(self.fc1(x)))
        x = self.dropout(x)
        x = torch.relu(self.bn2(self.fc2(x)))
        x = self.dropout(x)
        x = torch.relu(self.bn3(self.fc3(x)))
        out = self.fc4(x)

        return out

### Train and Val Function

In [None]:
def train_model(trainloader, model, criterion, optimizer, scheduler, device):
    model.train()
    total_loss = 0

    for batch, (inputs, targets) in tqdm(enumerate(trainloader), total=len(trainloader)):

        inputs = inputs.to(device)
        targets = targets.to(device)
        #criterion = criterion.to(device)

        optimizer.zero_grad()
        preds = model(inputs)
        loss = criterion(preds, targets)

        loss.backward()
        optimizer.step()
        scheduler.step()

        total_loss += loss.cpu().item()

    return total_loss / (batch + 1)

In [None]:
def val_model(valloader, model, criterion, device):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for batch, (inputs, targets) in tqdm(enumerate(valloader), total = len(valloader)):

            inputs = inputs.to(device)
            targets = targets.to(device)
            #criterion = criterion.to(device)

            preds = model(inputs)
            loss = criterion(preds, targets)
            
            total_loss += loss.cpu().item()

    return total_loss/(batch+1)

## 4. Training Model

In [None]:
batch_size = 64
learning_rate = 5e-4
epochs = 50

trainLoader = DataLoader(train_dataset, batch_size = batch_size, shuffle=True)
valLoader = DataLoader(val_dataset, batch_size = batch_size, shuffle=True)

model = HeadEstimator()
model = model.to(device)

def loss(preds, targets):
    mse = torch.mean((preds - targets) ** 2)

    eye_dist = torch.mean((preds[:, 1:3] - targets[:, 1:3]) ** 2 
                            + (preds[:, 6:8] - targets[:, 6:8]) ** 2)
    
    ear_dist = torch.mean((preds[:, 3:5] - targets[:, 3:5]) ** 2 
                            + (preds[:, 8:10] - targets[:, 8:10]) ** 2)

    return mse + eye_dist + ear_dist

#criterion = nn.MSELoss()
criterion = loss
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = StepLR(optimizer, step_size=1000, gamma=0.5)

history = {'train_loss':[], 'val_loss':[]}

print(model)

In [None]:
print("Training...")

for epoch in range(epochs):
    train_loss = train_model(trainLoader, model, criterion, optimizer, scheduler, device)
    val_loss = val_model(valLoader, model, criterion, device)

    print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)

    if (epoch+1) % 5 == 0 or epoch == epochs-1:
        print("Saving checkpoint...")
        torch.save(model.state_dict(), f'./model.pth')

print('Finished Training')

In [None]:
plt.plot(range(epochs), history['train_loss'], label='Train Loss', color='red')
plt.plot(range(epochs), history['val_loss'], label='Val Loss', color='blue')

plt.title('Loss history')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend()

plt.show

## 5. Test Model

In [None]:
def test_model(testloader, model, criterion, device):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for batch, (inputs, targets) in tqdm(enumerate(testloader), total = len(testloader)):

            inputs = inputs.to(device)
            targets = targets.to(device)

            preds = model(inputs)
            loss = criterion(preds, targets)
            
            total_loss += loss.cpu().item()

    return total_loss/(batch+1)

In [None]:
testLoader = DataLoader(test_dataset, batch_size = batch_size, shuffle=False)

print("Testing...")
test_loss = test_model(testLoader, model, criterion, device)

print(f"Test Loss: {test_loss:.4f}")

## 6. Inference and Visualization Functions

In [None]:
def inference(model, device, x):
    model.eval()
    x = x.to(device)

    with torch.no_grad():
        pred = model(x)
        
    return pred[:10]

In [None]:
def normalize_pose(pose, scale=2):
    
    keypoints_x = [point[0] for point in pose]
    keypoints_y = [point[1] for point in pose]

    min_x = min(keypoints_x)
    max_x = max(keypoints_x)
    min_y = min(keypoints_y)
    max_y = max(keypoints_y)
    
    height = max_y - min_y
    width = max_x - min_x

    if height > width:
        keypoints_x = [((x - min_x) / height - 0.5) * scale for x in keypoints_x]
        keypoints_y = [((y - min_y) / height - 0.5) * scale for y in keypoints_y]
        original_scale = height
    else:
        keypoints_x = [((x - min_x) / width - 0.5) * scale for x in keypoints_x]
        keypoints_y = [((y - min_y) / width - 0.5) * scale for y in keypoints_y]
        original_scale = width

    
    print('keypoints_x = "', keypoints_x, '"')
    print('keypoints_y = "', keypoints_y, '"\n')

    pose = [(x, y) for x, y in zip(keypoints_x, keypoints_y)]

    return pose, min_x, min_y, original_scale, original_scale

def denormalize_pose(pose, min_x, min_y, width, height, scale=2):
    pose = [((x / scale + 0.5) * width + min_x, (y / scale + 0.5) * height + min_y) for (x, y) in pose]

    return pose

def resize_pose(pose, resize_width=200, resize_height=200, padding=10, scale=2):
    pose, _, _, _, _ = normalize_pose(pose, scale)
    pose = [((x / scale + 0.5) * resize_width + padding, (y / scale + 0.5) * resize_height + padding) for (x, y) in pose]

    return pose

In [None]:
'''
coco (original):
0 : nose
1 2 : eyes
3 4 : ears
5 7 9 : left arm
6 8 10: right arm
11 13 15: left leg
12 14 16: right leg

mpii (ours):
13 14 15 : left arm
12 11 10 : right arm
3 4 5 : left leg
2 1 0 : right leg
7 : neck
9 8 : head

coco (ours):
0 : nose
1 : neck
2 3 4 : right arm
5 6 7 : left arm
8 9 10 : right leg
11 12 13 : left leg
14 15 : eyes
16 17 : ears
'''

def convert_pose(model, device, pose):
    # re-order keypoints to match neural network input
    reordered_pose = [pose[13], pose[12], pose[14], pose[11], pose[15], pose[10], pose[3], pose[2], pose[4], pose[1], pose[5], pose[0]]

    upper_body = access_elements(reordered_pose, [0, 1, 6, 7]) + [((pose[8][0] + pose[9][0]*2)/3, (pose[8][1] + pose[9][1]*2)/3)]
    upper_body, min_x, min_y, width, height = normalize_pose(upper_body)

    input_x = [point[0] for point in upper_body[:4]]
    input_y = [point[1] for point in upper_body[:4]]

    input_pose = torch.tensor([input_x + input_y])

    # head inference
    head = inference(model, device, input_pose)[0].cpu().tolist()
    head_points = [(x, y) for x, y in zip(head[:5], head[5:])]

    # denormalize head keypoints
    head_points = denormalize_pose(head_points, min_x, min_y, width, height)

    # insert head keypoints to the pose
    reordered_pose = [head_points[0], pose[7], pose[12], pose[11], pose[10], pose[13], pose[14], pose[15], pose[2], pose[1], pose[0], pose[3], pose[4], pose[5],
                      head_points[1], head_points[2], head_points[3], head_points[4]]

    return reordered_pose

In [None]:
import cv2
import matplotlib.pyplot as plt


mpii_link_pairs = [[0, 1], [1, 2], [2, 6], 
              [3, 6], [3, 4], [4, 5], 
              [6, 7], [7,12], [11, 12], 
              [10, 11], [7, 13], [13, 14],
              [14, 15],[7, 8],[8, 9]]

mpii_link_color = [(0, 0, 255), (0, 0, 255), (0, 0, 255),
              (0, 255, 0), (0, 255, 0), (0, 255, 0),
              (0, 255, 255), (0, 0, 255), (0, 0, 255),
              (0, 0, 255), (0, 255, 0), (0, 255, 0),
              (0, 255, 0), (0, 255, 255), (0, 255, 255)]

mpii_point_color = [(255,0,0),(0,255,0),(0,0,255), 
               (128,0,0), (0,128,0), (0,0,128),
               (255, 255, 0),(0,255,255),(255, 0, 255),
               (128,128,0),(0, 128, 128),(128,0,128),
               (128,255,0),(128,128,128),(255,128,0),
               (255,0,128),(255,255,255)]


coco_link_pairs = [[0, 1], [1, 2], [2, 3], 
              [3, 4], [1, 5], [5, 6], 
              [6, 7], [1, 8], [8, 9], 
              [9, 10], [1, 11], [11, 12],
              [12, 13],[0, 14],[14, 16], [0, 15], [15, 17]]

coco_link_color = [(0, 0, 255), (0, 0, 255), (0, 0, 255),
              (0, 255, 0), (0, 255, 0), (0, 255, 0),
              (0, 255, 255), (0, 0, 255), (0, 0, 255),
              (0, 0, 255), (0, 255, 0), (0, 255, 0),
              (0, 255, 0), (0, 255, 255), (0, 255, 255), (128, 128, 0), (128, 0, 128)]

coco_point_color = [(255,0,0),(0,255,0),(0,0,255), 
               (128,0,0), (0,128,0), (0,0,128),
               (255, 255, 0),(0,255,255),(255, 0, 255),
               (128,128,0),(0, 128, 128),(128,0,128),
               (128,255,0),(128,128,128),(255,128,0),
               (255,0,128),(255,255,255), (128, 128, 0), (128, 0, 128)]

'''
mpii (ours):
2 1 0 : right leg
3 4 5 : left leg
12 11 10 : right arm
13 14 15 : left arm
7 : neck
9 8 : head

coco (ours):
0 : nose
1 : neck
2 3 4 : right arm
5 6 7 : left arm
8 9 10 : right leg
11 12 13 : left leg
14 15 : eyes
16 17 : ears
'''

def vis_pose(image_path, pose, link_pairs, link_color, point_color):
    image = cv2.imread(image_path)

    pose = [(int(x), int(y)) for (x, y) in pose]

    for idx, pair in enumerate(link_pairs):
        if pose[pair[0]] != (0, 0) and pose[pair[1]] != (0, 0):
            cv2.line(image, pose[pair[0]], pose[pair[1]], link_color[idx], 2)

    for idx, point in enumerate(pose):
        if point != (0, 0):
            cv2.putText(image, str(idx), point, cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
            cv2.circle(image, point, 5, point_color[idx], thickness=-1)

    cv2.imshow("image", image)
    cv2.moveWindow("image", 0, 0)
    cv2.waitKey(0)
    cv2.destroyAllWindows()
    '''
    plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    plt.axis('off')
    plt.show()
    '''


def vis_pose_mpii(image_path, pose):
    vis_pose(image_path, pose, mpii_link_pairs, mpii_link_color, mpii_point_color)


def vis_pose_coco(image_path, pose):
    vis_pose(image_path, pose, coco_link_pairs, coco_link_color, coco_point_color)

In [None]:
import os

def vis_data(model, device, data, image_dir_path='./affordance_data/data'):
    data = data.split(' ')
    image_path = os.path.join(image_dir_path, data[0])
    pose_data = data[1:-1]
    pose_data = [eval(x) for x in pose_data]

    pose = []
    for i in range(0, len(pose_data), 2):
        pose.append((pose_data[i], pose_data[i+1]))

    vis_pose_mpii(image_path, pose)

    pose = convert_pose(model, device, pose)
    vis_pose_coco(image_path, pose)
    vis_pose_coco(image_path, resize_pose(pose))

## 7. Pose Visualization

In [None]:
# Load Checkpoint
model = HeadEstimator()
model.load_state_dict(torch.load(f'./model.pth'))
model = model.to(device)

In [None]:
train_data_path = './affordance_data/trainlist.txt'
train_data = []
with open(train_data_path, 'r') as f:
    train_data = list(f.readlines())

for data in train_data[1000:]:
    vis_data(model, device, data)
    # input()