In [None]:
import os
import xml.etree.ElementTree as ET 

import torch
import torchvision
from torch import optim
from torchvision import transforms, utils
from torchvision.transforms import InterpolationMode
from torch.utils.data import Dataset, DataLoader
import torchvision.models
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import math
import matplotlib.pyplot as plt
import skimage.io as skio
import skimage
from skimage.transform import resize
import scipy

from google.colab.patches import cv2_imshow
import pandas as pd

In [None]:
%%capture
if not os.path.exists('/content/ibug_300W_large_face_landmark_dataset'):
    !wget https://people.eecs.berkeley.edu/~zhecao/ibug_300W_large_face_landmark_dataset.zip
    !unzip 'ibug_300W_large_face_landmark_dataset.zip'    
    !rm -r 'ibug_300W_large_face_landmark_dataset.zip'

In [None]:
!wget https://people.eecs.berkeley.edu/~zhecao/ibug_300W_large_face_landmark_dataset.zip

In [None]:
!unzip 'ibug_300W_large_face_landmark_dataset.zip'  

In [None]:
!rm -r 'ibug_300W_large_face_landmark_dataset.zip'

In [None]:
tree = ET.parse('ibug_300W_large_face_landmark_dataset/labels_ibug_300W_train.xml')
root = tree.getroot()
root_dir = 'ibug_300W_large_face_landmark_dataset'

bboxes = [] # face bounding box used to crop the image
landmarks = [] # the facial keypoints/landmarks for the whole training dataset
img_filenames = [] # the image names for the whole dataset

for filename in root[2]:
	img_filenames.append(os.path.join(root_dir, filename.attrib['file']))
	box = filename[0].attrib
	# x, y for the top left corner of the box, w, h for box width and height
	bboxes.append([box['left'], box['top'], box['width'], box['height']]) 

	landmark = []
	for num in range(68):
		x_coordinate = int(filename[0][num].attrib['x'])
		y_coordinate = int(filename[0][num].attrib['y'])
		landmark.append([x_coordinate, y_coordinate])
	landmarks.append(landmark)

landmarks = np.array(landmarks).astype('float32')     
bboxes = np.array(bboxes).astype('float32') 

print(bboxes[5])

img = skio.imread(img_filenames[5])
print(img.shape)
landmark = landmarks[5]
bbox = bboxes[5].astype(np.int)
plt.imshow(img)
plt.plot([bbox[0], bbox[0] + bbox[2], bbox[0] + bbox[2], bbox[0]], [img.shape[0] - bbox[1], bbox[1], bbox[1] + bbox[3], bbox[1] + bbox[3]])
plt.plot(landmark[:, 0], landmark[:, 1], linestyle = "none", marker = ".", markersize = 3)
cropped_5 = torchvision.transforms.functional.crop(torch.from_numpy(img).permute((2, 0, 1)), bbox[1], bbox[0], bbox[3], bbox[2])
print(cropped_5.numpy().shape)
skio.imshow(cropped_5.permute((1, 2, 0)).numpy())

In [None]:
# Adjust the topleft corner and size of the bounding box to let it include as 
# many feature points as possible. 
bboxes_adjusted = bboxes.copy().astype(np.int)
mask = bboxes_adjusted >= 0
bboxes_adjusted[:, 0] = (bboxes_adjusted[:, 0] - bboxes_adjusted[:, 2] * 0.1).astype(np.int)
bboxes_adjusted[:, 1] = (bboxes_adjusted[:, 1] - bboxes_adjusted[:, 3] * 0.1).astype(np.int)
bboxes_adjusted[:, 2] = (bboxes_adjusted[:, 2] * 1.2).astype(np.int)
bboxes_adjusted[:, 3] = (bboxes_adjusted[:, 3] * 1.2).astype(np.int)
# print(np.sum(bboxes_adjusted[:, 3] < 0))
# img = skio.imread(img_filenames[139])
# landmark = landmarks[139]
# bbox = bboxes[139].astype(np.int)
# print(bboxes[139])
# # plt.imshow(img)
# # plt.plot([bbox[0], bbox[0] + bbox[2], bbox[0] + bbox[2], bbox[0]], [bbox[1], bbox[1], bbox[1] + bbox[3], bbox[1] + bbox[3]])
# # plt.plot(landmark[:, 0], landmark[:, 1], linestyle = "none", marker = ".", markersize = 3)

# cropped_139 = torchvision.transforms.functional.crop(torch.from_numpy(img).permute((2, 0, 1)), bbox[1], bbox[0], bbox[3], bbox[2])
# print(cropped_139.permute((1, 2, 0)).numpy().shape)
# skio.imshow(cropped_139.permute((1, 2, 0)).numpy())
# plt.plot(landmark[:, 0] - bbox[0], landmark[:, 1] - bbox[1], linestyle = "none", marker = ".", markersize = 3)

In [None]:
# Adjust landmark pixel location after resizing the cropped image to 224*224. 
landmarks_adjusted = landmarks.copy()
landmarks_adjusted[:, :, 0] = (landmarks_adjusted[:, :, 0] - bboxes_adjusted[:, 0].reshape((bboxes_adjusted.shape[0], 1))) / bboxes_adjusted[:, 2].reshape((bboxes_adjusted.shape[0], 1)) * 224
landmarks_adjusted[:, :, 1] = (landmarks_adjusted[:, :, 1] - bboxes_adjusted[:, 1].reshape((bboxes_adjusted.shape[0], 1))) / bboxes_adjusted[:, 3].reshape((bboxes_adjusted.shape[0], 1)) * 224

# images = []
# for i in range(len(img_filenames)):
#   if i % 100 == 0:
#     print(i)
#   img = skio.imread(img_filenames[i])
#   bbox = bboxes_adjusted[i]
#   if len(img.shape) == 3:
#     print("crop image with rgb channel")
#     img = torchvision.transforms.functional.crop(torch.from_numpy(img).permute((2, 0, 1)), 
#                                                         bbox[1], bbox[0], bbox[3], bbox[2]).permute((1, 2, 0))
#     print(img.shape)
#   if len(img.shape) == 2:
#     print("crop black and white image")
#     img = np.dstack((img, img, img))
#     img = torchvision.transforms.functional.crop(torch.from_numpy(img).permute((2, 0, 1)), 
#                                                         bbox[1], bbox[0], bbox[3], bbox[2]).permute((1, 2, 0))
#   images.append(resize(img, (224, 224), anti_aliasing=True))
  
#   if i <= 7:
#     plt.imshow(images[i])
#     plt.plot(landmarks_adjusted[i][:, 0], landmarks_adjusted[i][:, 1], linestyle = "none", marker = ".")
#     plt.show()
#   if i > 7:
#     break


In [None]:
num_data = landmarks_adjusted.shape[0]
num_training = 6000
num_validation = 666

class TrainingDataset(Dataset):
  """
  Training Dataset. 
  """
  def __init__(self):
      """
      Args:
          root_dir (string): Directory with all the images and their feature
              locations.
          transform (callable, optional): Optional transform to be applied
              on a sample.
      """
      # if display:
      #     height, width = images[0].shape[:2]
      #     for i in range(4):
      #         plt.imshow(images[i])
      #         plt.plot(feature_points[i][-6, 0] * width,
      #             feature_points[i][-6, 1] * height,
      #             linestyle = "none", marker = ".")
      #         plt.show()

      self.feature_points = np.array(landmarks_adjusted[:num_training]).astype(np.float32)
      self.images_name = img_filenames[:num_training]
      self.bboxes_adjusted = bboxes_adjusted[:num_training]


  def __len__(self):
    return self.feature_points.shape[0]


  def __getitem__(self, idx):
    if torch.is_tensor(idx):
      idx = idx.tolist()
    img_name = self.images_name[idx]
    img = skio.imread(img_name)
    
    bbox = self.bboxes_adjusted[idx]
    if len(img.shape) == 2:
      img = np.dstack((img, img, img))
    img = torchvision.transforms.functional.crop(torch.from_numpy(img).permute((2, 0, 1)), 
                                                bbox[1], bbox[0], bbox[3], bbox[2]).permute((1, 2, 0))
    img = skimage.color.rgb2gray(img.numpy())
    img = resize(img, (224, 224), anti_aliasing=True).astype(np.float32)
    feature = self.feature_points[idx].astype(np.float32)
    sample = {'image': img, 'feature': feature}

    return sample

In [None]:
class ValidationDataset(Dataset):
  """Validation Dataset."""

  def __init__(self):
    """
    Args:
        root_dir (string): Directory with all the images and their feature
            locations.
        transform (callable, optional): Optional transform to be applied
            on a sample.
    """

    self.feature_points = np.array(landmarks_adjusted[-num_validation:]).astype(np.float32)
    self.images_name = img_filenames[-num_validation:]
    self.bboxes_adjusted = bboxes_adjusted[-num_validation:]


  def __len__(self):
      return self.feature_points.shape[0]


  def __getitem__(self, idx):
    if torch.is_tensor(idx):
      idx = idx.tolist()
    img_name = self.images_name[idx]
    img = skio.imread(img_name)
    
    bbox = self.bboxes_adjusted[idx]
    if len(img.shape) == 2:
      img = np.dstack((img, img, img))
    img = torchvision.transforms.functional.crop(torch.from_numpy(img).permute((2, 0, 1)), 
                                                bbox[1], bbox[0], bbox[3], bbox[2]).permute((1, 2, 0))
    img = skimage.color.rgb2gray(img.numpy())
    img = resize(img, (224, 224), anti_aliasing=True).astype(np.float32)
    feature = self.feature_points[idx].astype(np.float32)
    sample = {'image': img, 'feature': feature}

    return sample

In [None]:
def training_dataloader(batch_size = 8):
    """
    Returns the training dataset and a training dataloader. 
    """
    training_data = TrainingDataset()
    return training_data, \
        DataLoader(training_data, batch_size = batch_size, shuffle=True)

def validation_dataloader(batch_size = num_validation):
    """
    Returns the training dataset and a training dataloader. 
    """
    validation_data = ValidationDataset()
    return validation_data, \
        DataLoader(validation_data, batch_size = batch_size)

In [None]:
# Visualize training dataset and corresponding facial keypoints. 
training_data, training_loader = training_dataloader()
validation_data, validation_loader = validation_dataloader()
for i, batch in enumerate(training_loader):
  images = batch["image"]
  features = batch["feature"]
  print(features.shape)
  print(images.shape)
  for j in range(images.shape[0]):
    plt.imshow(images[j].numpy())
    plt.plot(features[j].numpy()[:, 0], features[j].numpy()[:, 1], linestyle = "none", marker = ".")
    plt.show()
    if j > 2:
      break
  break

In [None]:
def rotate(x, y, angles):
    """
    Rotate the image and find corresponding rotated feature points. 
    For each angle in angles, rotate all images in batch x with that angle. 
    """
    x_rotated_lst = []
    y_rotated_lst = []
    for angle in angles:
        rotation_matrix = np.array([[np.cos(-angle / 180 * np.pi),
            np.sin(-angle / 180 * np.pi)],
            [-np.sin(-angle / 180 * np.pi),
            np.cos(-angle / 180 * np.pi)]]).astype(np.float32)
        x_rotated = torchvision.transforms.functional.affine(x, angle = angle,
            translate = [0, 0], scale = 1, shear = [0, 0], fill = -0.5)
        y_rotated = y.clone()
        y_rotated[:, :, 0] = y_rotated[:, :, 0] - width // 2
        y_rotated[:, :, 1] = y_rotated[:, :, 1] - height // 2
        y_rotated = torch.matmul(y_rotated, torch.from_numpy(rotation_matrix.T))
        y_rotated[:, :, 0] = (y_rotated[:, :, 0] + width // 2)
        y_rotated[:, :, 1] = (y_rotated[:, :, 1] + height // 2)
        x_rotated_lst.append(x_rotated)
        y_rotated_lst.append(y_rotated)
    return x_rotated_lst, y_rotated_lst

def shift_vertical(x, y, pixels):
    """
    Shift the image vertically and find corresponding shifted feature points. 
    For each pixel value in pixels, shift all images in batch x with that pixel. 
    """
    x_shifted_lst = []
    y_shifted_lst = []
    for pixel in pixels:
        x_shifted_vertical = torchvision.transforms.functional.affine(x,
            angle = 0, translate = [0, pixel], scale = 1,
            shear = [0, 0], fill = -0.5)
        y_shifted_vertical = y.clone()
        y_shifted_vertical[:, :, 1] = y_shifted_vertical[:, :, 1] + \
            pixel
        x_shifted_lst.append(x_shifted_vertical)
        y_shifted_lst.append(y_shifted_vertical)
    return x_shifted_lst, y_shifted_lst

def shift_horizontal(x, y, pixels):
    """
    Shift the image horizontally and find corresponding shifted feature points. 
    For each pixel value in pixels, shift all images in batch x with that pixel. 
    """
    x_shifted_lst = []
    y_shifted_lst = []
    for pixel in pixels:
        x_shifted_horizontal = torchvision.transforms.functional.affine(x,
            angle = 0, translate = [pixel, 0], scale = 1,
            shear = [0, 0], fill = -0.5)
        y_shifted_horizontal = y.clone()
        y_shifted_horizontal[:, :, 0] = y_shifted_horizontal[:, :, 0] + \
            pixel
        x_shifted_lst.append(x_shifted_horizontal)
        y_shifted_lst.append(y_shifted_horizontal)
    return x_shifted_lst, y_shifted_lst


In [None]:
class ResNet18(nn.Module):
  """
  Self-implemented ResNet18. 
  """
  def __init__(self):
    super(ResNet18, self).__init__()
    self.conv1 = nn.Conv2d(1, 64, 7, 2, (3, 3))
    self.bn1 = nn.BatchNorm2d(64)

    self.conv2_1 = nn.Conv2d(64, 64, 3, (1, 1))
    self.bn2_1 = nn.BatchNorm2d(64)
    self.conv2_2 = nn.Conv2d(64, 64, 1)
    self.bn2_2 = nn.BatchNorm2d(64)
    self.conv2_3 = nn.Conv2d(64, 64, 3, (1, 1))
    self.bn2_3 = nn.BatchNorm2d(64)
    self.conv2_4 = nn.Conv2d(64, 64, 1)
    self.bn2_4 = nn.BatchNorm2d(64)

    self.conv3_1 = nn.Conv2d(64, 128, 3, 2, (1, 1))
    self.bn3_1 = nn.BatchNorm2d(128)
    self.conv3_2 = nn.Conv2d(128, 128, 1)
    self.bn3_2 = nn.BatchNorm2d(128)
    self.conv3_3 = nn.Conv2d(128, 128, 3, (1, 1))
    self.bn3_3 = nn.BatchNorm(128)
    self.conv3_4 = nn.Conv2d(128, 128, 1)
    self.bn3_4 = nn.BatchNorm2d(128)
    self.conv3_id = nn.Conv2d(64, 128, 1, 2)
    self.bn3_id = nn.BatchNorm2d(128)

    self.conv4_1 = nn.Conv2d(128, 256, 3, 2, (1, 1))
    self.bn4_1 = nn.BatchNorm2d(256)
    self.conv4_2 = nn.Conv2d(256, 256, 1)
    self.bn4_2 = nn.BatchNorm2d(256)
    self.conv3_3 = nn.Conv2d(256, 256, 3, (1, 1))
    self.bn4_3 = nn.BatchNorm(256)
    self.conv3_4 = nn.Conv2d(256, 256, 1)
    self.bn4_4 = nn.BatchNorm2d(256)
    self.conv4_id = nn.Conv2d(128, 256, 1, 2)
    self.bn4_id = nn.BatchNorm2d(256)

    self.conv5_1 = nn.Conv2d(256, 512, 3, 2, (1, 1))
    self.bn5_1 = nn.BatchNorm2d(512)
    self.conv5_2 = nn.Conv2d(512, 512, 1)
    self.bn5_2 = nn.BatchNorm2d(512)
    self.conv5_3 = nn.Conv2d(512, 512, 3, (1, 1))
    self.bn5_3 = nn.BatchNorm(512)
    self.conv5_4 = nn.Conv2d(512, 512, 1)
    self.bn5_4 = nn.BatchNorm2d(512)
    self.conv5_id = nn.Conv2d(64, 128, 1, 2)
    self.bn5_id = nn.BatchNorm2d(128)

    self.fc1 = nn.Linear(512, 136)

  def forward(x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = F.relu(x)
    x = F.max_pool2d(x, (3, 3), 2, 1)

    # conv2 block
    identity = x
    x = self.conv2_1(x)
    x = self.bn2_1(x)
    x = F.relu(x)
    x = self.conv2_2(x)
    x = sel.bn2_2(x)
    x = x + identity
    x = F.relu(x)
    identity = x
    x = self.conv2_3(x)
    x = self.bn2_3(x)
    x = F.relu(x)
    x = self.conv2_4(x)
    x = sel.bn2_4(x)
    x = x + identity
    x = F.relu(x)

    # conv3 block
    identity = self.conv3_id(x)
    identity = self.bn3_id(identity)
    x = self.conv3_1(x)
    x = self.bn3_1(x)
    x = F.relu(x)
    x = self.conv3_2(x)
    x = sel.bn3_2(x)
    x = x + identity
    x = F.relu(x)
    identity = x
    x = self.conv3_3(x)
    x = self.bn3_3(x)
    x = F.relu(x)
    x = self.conv3_4(x)
    x = sel.bn3_4(x)
    x = x + identity
    x = F.relu(x)

    # conv4 block
    identity = self.conv4_id(x)
    identity = self.bn4_id(identity)
    x = self.conv4_1(x)
    x = self.bn4_1(x)
    x = F.relu(x)
    x = self.conv4_2(x)
    x = sel.bn4_2(x)
    x = x + identity
    x = F.relu(x)
    identity = x
    x = self.conv4_3(x)
    x = self.bn4_3(x)
    x = F.relu(x)
    x = self.conv4_4(x)
    x = sel.bn4_4(x)
    x = x + identity
    x = F.relu(x)

    # conv5 block
    identity = self.conv5_id(x)
    identity = self.bn5_id(identity)
    x = self.conv5_1(x)
    x = self.bn5_1(x)
    x = F.relu(x)
    x = self.conv5_2(x)
    x = sel.bn5_2(x)
    x = x + identity
    x = F.relu(x)
    identity = x
    x = self.conv5_3(x)
    x = self.bn5_3(x)
    x = F.relu(x)
    x = self.conv5_4(x)
    x = sel.bn5_4(x)
    x = x + identity
    x = F.relu(x)

    x = F.adaptive_avg_pool2d(x, (1, 1))
    x = self.fc1(x)
    
    return x

In [None]:
class ResNet(nn.Module):
  """
  Slightly modified ResNet18 from torchvision.models. 
  """
  def __init__(self):
    super(ResNet, self).__init__()
    self.resnet18 = torchvision.models.resnet18(pretrained=False)
    self.resnet18.conv1 = nn.Conv2d(1, 64, kernel_size = 7, stride = 2, padding = (3, 3))
    self.resnet18.fc = nn.Linear(512, 136)


  def forward(self, images):
    features = self.resnet18(images)
    return features

In [None]:
# Initialize hyperparameters. 
batch_size = 64
rotation_per_batch = 1
shift_vertical_per_batch = 1
shift_horizontal_per_batch = 1
np.random.seed(1234)

training_data, training_loader = training_dataloader(batch_size)
validation_data, validation_loader = validation_dataloader(batch_size)

# num_epoch = 10
training_losses = []
validation_losses = []


height, width = training_data[0]["image"].shape[0], \
    training_data[0]["image"].shape[1]
rotation_angles = np.random.rand(len(training_data) // batch_size + 1,
    rotation_per_batch) * 30 - 15
pixels_vertical = np.random.randint(-10, 10,
    (len(training_data) // batch_size + 1, shift_vertical_per_batch))
pixels_horizontal = np.random.randint(-10, 10,
    (len(training_data) // batch_size + 1, shift_vertical_per_batch))

model = ResNet().to("cuda")
optimizer = optim.Adam(model.parameters(), lr = 3e-4)
criterion = nn.MSELoss()


In [None]:
# Train on ResNet18 for 10 epochs. 
num_epoch = 10
for epoch in range(num_epoch):
    print(f"Start training epoch {epoch}")
    loss_epoch = []
    for i, batch in enumerate(training_loader):
        print("Start training batch ", i)
        x = batch["image"]
        y = batch["feature"]

        angles = rotation_angles[i]
        x_rotated_lst, y_rotated_lst = rotate(x, y, angles)
      
        # plt.imshow(x_rotated_lst[0][0].numpy(), cmap = "gray")
        # print("begin plot")
        # print(y_rotated_lst[0][0].numpy()[:, 0])
        # print(y_rotated_lst[0][0].numpy()[:, 1])
        # plt.plot(y_rotated_lst[0][0].numpy()[:, 0], y_rotated_lst[0][0].numpy()[:, 1],
        #     linestyle = "none", marker = ".", markersize = 3, color="b")
        # print("end plot")
        # plt.show()
        # break

        pixels = pixels_vertical[i]
        x_shifted_vertical_lst, y_shifted_vertical_lst = shift_vertical(x, y, pixels)

        # plt.imshow(x_shifted_vertical_lst[0][0].numpy(), cmap = "gray")
        # print("begin plot")
        # plt.plot(y_shifted_vertical_lst[0][0].numpy()[:, 0], y_shifted_vertical_lst[0][0].numpy()[:, 1],
        #     linestyle = "none", marker = ".", markersize = 3, color="b")
        # print("end plot")
        # plt.show()
        # break


        pixels = pixels_horizontal[i]
        x_shifted_horizontal_lst, y_shifted_horizontal_lst = shift_horizontal(x, y, pixels)
        
        # plt.imshow(x_shifted_horizontal_lst[0][0].numpy(), cmap = "gray")
        # print("begin plot")
        # plt.plot(y_shifted_horizontal_lst[0][0].numpy()[:, 0], y_shifted_horizontal_lst[0][0].numpy()[:, 1],
        #     linestyle = "none", marker = ".", markersize = 3, color="b")
        # print("end plot")
        # plt.show()
        # break

        x = torch.cat(tuple(x_rotated_lst + x_shifted_vertical_lst + x_shifted_horizontal_lst + [x]),
            axis = 0)
        y = (torch.cat(tuple(y_rotated_lst + y_shifted_vertical_lst + y_shifted_horizontal_lst + [y]),
            axis = 0) / 224).to("cuda")
        x = x.unsqueeze(1)

        output = model(x.to("cuda"))
        loss = criterion(torch.flatten(y, 1), output).float()
        loss_epoch.append(loss.item())
        print(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 10 == 0:
          model.eval()
          for j, batch in enumerate(validation_loader):
            x = batch["image"]
            y = batch["feature"]
            output = model(x.unsqueeze(1).to("cuda"))
            for k in range(4):
              actual_index = -num_validation + batch_size * j + k
              plt.imshow(skio.imread(img_filenames[actual_index]), cmap = "gray")
              plt.plot(y[k].numpy()[:, 0] / 224 * bboxes_adjusted[actual_index][2] + bboxes_adjusted[actual_index][0],
                y[k].numpy()[:, 1] / 224 * bboxes_adjusted[actual_index][3] + bboxes_adjusted[actual_index][1], 
                linestyle = "none", marker = ".", markersize = 12, color = 'g')
              plt.plot(output[k].detach().cpu().numpy()[::2] * bboxes_adjusted[actual_index][2] + bboxes_adjusted[actual_index][0],
                output[k].detach().cpu().numpy()[1::2] * bboxes_adjusted[actual_index][3] + bboxes_adjusted[actual_index][1],
                linestyle = "none", marker = ".", markersize = 12, color = 'r')
              plt.show()
            break
          model.train()
    training_losses.append(np.mean(loss_epoch))
    model.eval()
    validation_loss_epoch = 0
    for i, batch in enumerate(validation_loader):
        validation_x = batch["image"].unsqueeze(1).to("cuda")
        validation_y = batch["feature"].to("cuda") / 224
        validation_output = model(validation_x)
        validation_loss = criterion(torch.flatten(validation_y, 1),
            validation_output).float()
        validation_loss_epoch += validation_loss.item() * validation_x.shape[0]
    validation_loss_epoch = validation_loss_epoch / num_validation
    print(validation_loss_epoch)
    validation_losses.append(validation_loss_epoch)
    model.train()

In [None]:
# Run this cell if want to save model for future use. 
torch.save(model, "resnet18_no_pretrain.pt")

In [None]:
model.eval()
# Visualize prediction on validation set. 
for i, batch in enumerate(validation_loader):
  print("Start validating batch ", i)
  x = batch["image"]
  y = batch["feature"]
  output = model(x.unsqueeze(1).to("cuda"))
  if i == 1:
    for j in range(4):
      actual_index = -num_validation + batch_size * i + j
      plt.imshow(skio.imread(img_filenames[-num_validation + batch_size * i + j]), cmap = "gray")
      plt.plot(y[j].numpy()[:, 0] / 224 * bboxes_adjusted[actual_index][2] + bboxes_adjusted[actual_index][0],
            y[j].numpy()[:, 1] / 224 * bboxes_adjusted[actual_index][3] + bboxes_adjusted[actual_index][1], 
            linestyle = "none", marker = ".", markersize = 12, color = 'g')
      plt.plot(output[j].detach().cpu().numpy()[::2] * bboxes_adjusted[actual_index][2] + bboxes_adjusted[actual_index][0],
            output[i].detach().cpu().numpy()[1::2] * bboxes_adjusted[actual_index][3] + bboxes_adjusted[actual_index][1],
            linestyle = "none", marker = ".", markersize = 12, color = 'r')
      plt.show()
    break

In [None]:
# Plot validation and training losses. 
plt.plot(np.arange(len(training_losses)), training_losses, label = "Training loss")
plt.plot(np.arange(len(validation_losses)), validation_losses, label = "Validation loss")
plt.legend()
plt.yscale("log")
plt.show()

In [None]:
# Test on additional self-uploaded images. 
img_names = ["img_1.jpeg", "img_2.jpeg", "img_3.jpeg"]
bounding_boxes = np.array([[33, 42, 130, 132], [40, 53, 113, 116], [360, 238, 358, 363]])
for i in range(len(img_names)):
  img = skio.imread(img_names[i])
  bbox = bounding_boxes[i]
  x = torchvision.transforms.functional.crop(torch.from_numpy(img).permute((2, 0, 1)), 
                                              bbox[1], bbox[0], bbox[3], bbox[2]).permute((1, 2, 0))
  x = skimage.color.rgb2gray(x)
  x = torch.from_numpy(resize(x, (224, 224), anti_aliasing=True).astype(np.float32))
  output = model(x.unsqueeze(0).unsqueeze(0).to("cuda"))
  plt.imshow(img)
  plt.plot(output[0].detach().cpu().numpy()[::2] * bbox[2] + bbox[0],
          output[0].detach().cpu().numpy()[1::2] * bbox[3] + bbox[1],
          linestyle = "none", marker = ".", markersize = 12, color = 'r')
  plt.show()

In [None]:
test_tree = ET.parse('ibug_300W_large_face_landmark_dataset/labels_ibug_300W_test_parsed.xml')
test_root = test_tree.getroot()
root_dir = 'ibug_300W_large_face_landmark_dataset'

test_bboxes = [] # face bounding box used to crop the image
test_img_filenames = [] # the image names for the whole dataset

for filename in test_root[2]:
  test_img_filenames.append(os.path.join(root_dir, filename.attrib['file']))
  box = filename[0].attrib
  # x, y for the top left corner of the box, w, h for box width and height
  test_bboxes.append([box['left'], box['top'], box['width'], box['height']]) 
  
test_bboxes = np.array(test_bboxes).astype('float32') 

test_bboxes_adjusted = test_bboxes.copy().astype(np.int)
test_bboxes_adjusted[:, 0] = test_bboxes_adjusted[:, 0] - test_bboxes_adjusted[:, 2] * 0.1
test_bboxes_adjusted[:, 1] = test_bboxes_adjusted[:, 1] - test_bboxes_adjusted[:, 3] * 0.1
test_bboxes_adjusted[:, 2] = test_bboxes_adjusted[:, 2] * 1.2
test_bboxes_adjusted[:, 3] = test_bboxes_adjusted[:, 3] * 1.2

In [None]:
print(test_bboxes.shape)

In [None]:
num_test = test_bboxes_adjusted.shape[0]
class TestDataset(Dataset):
  """Test Dataset."""

  def __init__(self):
    """
    Args:
        root_dir (string): Directory with all the images and their feature
            locations.
        transform (callable, optional): Optional transform to be applied
            on a sample.
    """

    self.images_name = test_img_filenames
    self.bboxes_adjusted = test_bboxes_adjusted


  def __len__(self):
      return self.bboxes_adjusted.shape[0]


  def __getitem__(self, idx):
    if torch.is_tensor(idx):
      idx = idx.tolist()
    img_name = self.images_name[idx]
    img = skio.imread(img_name)
    
    bbox = self.bboxes_adjusted[idx]
    if len(img.shape) == 2:
      img = np.dstack((img, img, img))
    img = torchvision.transforms.functional.crop(torch.from_numpy(img).permute((2, 0, 1)), 
                                                bbox[1], bbox[0], bbox[3], bbox[2]).permute((1, 2, 0))
    img = skimage.color.rgb2gray(img.numpy())
    img = resize(img, (224, 224), anti_aliasing=True).astype(np.float32)
    sample = {'image': img}

    return sample

In [None]:
def test_dataloader(batch_size = num_test):
    """
    Returns a test dataset and test dataloader. 
    """
    test_data = TestDataset()
    return test_data, \
        DataLoader(test_data, batch_size = batch_size)

In [None]:
# Visualize predictions on test dataset. 
model.eval()

test_batch_size = 64

test_data, test_loader = test_dataloader(test_batch_size)
result = torch.Tensor()

for i, batch in enumerate(test_loader):
  print("Start testing batch ", i)
  x = batch["image"]
  output = model(x.unsqueeze(1).to("cuda"))
  actual_x_index = output.detach().cpu().numpy()[:, ::2] * \
    test_bboxes_adjusted[i * test_batch_size:min((i + 1) * test_batch_size, num_test), [2]] + \
    test_bboxes_adjusted[i * test_batch_size:min((i + 1) * test_batch_size, num_test), [0]]
  actual_y_index = output.detach().cpu().numpy()[:, 1::2] * \
    test_bboxes_adjusted[i * test_batch_size:min((i + 1) * test_batch_size, num_test), [3]] + \
    test_bboxes_adjusted[i * test_batch_size:min((i + 1) * test_batch_size, num_test), [1]]
  if i == 1:
    for j in range(40,44):
      actual_index = test_batch_size * i + j
      plt.imshow(skio.imread(test_img_filenames[actual_index]), cmap = "gray")
      plt.plot(actual_x_index[j], actual_y_index[j],
            linestyle = "none", marker = ".", markersize = 12, color = 'r')
      plt.show()
  result = torch.cat((result, 
                     torch.from_numpy(np.array([actual_x_index, actual_y_index])).permute(1, 2, 0).flatten(1)))
  print(result.shape)

In [None]:
# Flatten results to store in .csv file. 
result = result.flatten()

In [None]:
# Save predictions to .csv file. Submit to Kaggle afterwards. 
df = pd.DataFrame({"Id":np.arange(len(result)), "Predicted": result.cpu().numpy()})
df.to_csv("results_catherine_gai.csv", index = False)