## FaceSegLite: Model Building

### **Imports**

In [58]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
import cv2
import os

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"device = {device}")

device = cpu


### **Dataset Loading**


First, we define utils function to be able to load the HuggingFace datasets correctly.

In [8]:
def load_yolo_format(file_path):
  """
  Loads a file in YOLO format.
  
  Args:
    - file_path: path to the file to load.
    
  Returns:
    - objects: list of objects in the image. Each object is represented as a
      tuple of (class_index, (x1, y1, ..., xn, yn)).
  """
  with open(file_path, 'r') as f:
    lines = f.readlines()

  objects = []
  for line in lines:
    data = line.strip().split()
    class_index = int(data[0])
    coordinates = list(map(float, data[1:]))
    objects.append((class_index, coordinates))

  return objects

def found_directory(file_path, folders):
    """
    Finds the directory in which the file is located.

    Args:
        - file_path: path to the file to load.
        - folders: list of folders to look into.
    
    Returns:
        - folder: name of the folder in which the file is located.
    """
    for i, folder in enumerate(folders):
        file_id = file_path.split('_')[0]
        folder_id = folder.split('--')[0]
        if file_id == folder_id:
            return folders[i]
    raise ValueError(f"Folder {folders} not found in {file_path}")

Define a custom torch Dataset.

In [54]:
class FslDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform

        self.mask_folder = os.path.join(root_dir, 'masks')
        self.mask_files = sorted(os.listdir(self.mask_folder))

        self.image_folder = os.path.join(root_dir, 'images')
        self.image_files = [os.path.join(self.image_folder, found_directory(file_path, os.listdir(self.image_folder)),
                file_path.replace('txt', 'jpg')) for file_path in self.mask_files]

    def __len__(self):
        return len(self.mask_files)


    def __getitem__(self, idx):
        # Get the image
        image = cv2.imread(self.image_files[idx])

        # Get the masks
        masks = load_yolo_format(os.path.join(self.mask_folder, self.mask_files[idx]))

        sample = {'image': image, 'masks': masks}

        if self.transform:
            sample = self.transform(sample)
        
        return sample


In [55]:
fsl_dataset = FslDataset(root_dir='../data/')

### **Dataset vizualisation**

In [22]:
def get_unique_color(index):
    """
    Generates a unique RGB color for a given class index.

    Args:
        - index: class index
    
    Returns:
        - color: tuple of (R, G, B) values
    """
    np.random.seed(index) # For reproducibility
    return tuple(np.random.rand(3))

def plot_image_with_mask(image, objects):
    """
    Plots an image with its segmentation mask.

    Args:
        - image_path: path to the image
        - objects: list of objects in the image. Each object is represented as a
          tuple of (class_index, (x1, y1, ..., xn, yn)).
    """
    # Read the image
    #image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Plot the image
    plt.imshow(image)

    # Plot the segmentation mask for each object with a unique color
    for index, (class_index, coordinates) in enumerate(objects):
        num_points = len(coordinates) // 2
        mask_points = np.array(coordinates).reshape((num_points, 2))
        mask_points *= np.array([image.shape[1], image.shape[0]])  # Convert normalized coordinates to pixels
        mask_points = mask_points.astype(int)

        # Draw a polygon for the segmentation mask
        color = get_unique_color(index)
        polygon = Polygon(mask_points, closed=True, edgecolor=color, facecolor='none', linewidth=2)
        plt.gca().add_patch(polygon)

        # Fill the polygon with a unique color
        plt.fill(mask_points[:, 0], mask_points[:, 1], color=color, alpha=0.3)

    plt.axis('off')
    plt.show()

In [None]:
dataset_path = "fsl_dataset/FSL_masks/19_Couple_Couple_19_12.txt"
image_1 = load_yolo_format(dataset_path)

image_path = "fsl_dataset/FSL_images/images/19--Couple/19_Couple_Couple_19_12.jpg"

plot_image_with_mask(image_path, image_1)