In [15]:
from covidxpert.menpo import extract_menpo_points

from glob import glob
import os
from multiprocessing import Pool, cpu_count
import shutil
import numpy as np

import menpo.io as mio
from menpo.image import Image
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split

In [29]:
path_masks = glob('all_masks/masks/*')

base_path = "all_masks/images/{}.png"
points_path = "all_masks/processed/all/{}.pts"
processed_image_path = "all_masks/processed/all/{}.jpg"

os.makedirs("all_masks/processed/all", exist_ok=True)
os.makedirs("all_masks/processed/train", exist_ok=True)
os.makedirs("all_masks/processed/test", exist_ok=True)

def job_extract_menpo_points(mask_path: str):
    base_name = os.path.basename(mask_path).split('.')[0].split('_mask')[0]
    image_path = base_path.format(base_name)
    point_path = points_path.format(base_name)
    proc_path = processed_image_path.format(base_name)
    if not os.path.exists(proc_path):
        extract_menpo_points(mask_path=mask_path, image_path=image_path, 
                             save_image_path=proc_path, save_points_path=point_path)
        
def split_dataset(test_perc:float=0.2, random_state:int=42):
    img_list = glob('all_masks/processed/all/*.pts')
    path_all = "all_masks/processed"
    train, test = train_test_split(img_list, test_size=test_perc, random_state=random_state)
    for partition, points_set in zip(("train", "test"), (train, test)):
        for points_path in points_set:
            base_name = os.path.basename(points_path).split('.')[0]
            
            src_image_path = f"{path_all}/all/{base_name}.jpg"
            src_point_path = f"{path_all}/all/{base_name}.pts"
            
            dst_image_path = f"{path_all}/{partition}/{base_name}.jpg"
            dst_point_path = f"{path_all}/{partition}/{base_name}.pts"
            
            shutil.copy2(src_image_path, dst_image_path)
            shutil.copy2(src_point_path, dst_point_path)
        

In [25]:
with Pool(cpu_count()) as p:
    p.map(job_extract_menpo_points, path_masks)
    p.close()
    p.join()

In [30]:
split_dataset()

In [35]:
len(glob('all_masks/processed/all/*')), len(glob('all_masks/processed/train/*')), len(glob('all_masks/processed/test/*'))

(1404, 1122, 282)