In [1]:
import os

In [2]:
def iter_files_with_ext(root_dir, ext):
    """
    Iterate recursively on all files in root dir, ending with <ext>.
    :param root_dir: Directory to list.
    :param ext: Extension of the files.
    :return: absolute path of all files.
    """
    for path, sub_dirs, files in os.walk(root_dir):
        for name in files:
            if not name.startswith(".") and name.endswith(ext):
                abs_path = os.path.abspath(os.path.join(path, name))
                if os.path.isfile(abs_path):
                    yield os.path.join(path, name)

In [3]:
train_files = list(iter_files_with_ext("../data/train/", ".jpg"))

In [5]:
def get_uid_rot(filename):
    basename = os.path.basename(filename)
    basename, _ = os.path.splitext(basename)
    uid, rot = basename.split("_")
    return uid, int(rot)

In [6]:
get_uid_rot(train_files[0])

('3c6acfceb552', 9)

In [12]:
uids = list(set([get_uid_rot(fn)[0] for fn in train_files]))

In [13]:
import numpy as np

In [14]:
permuted_idx = np.random.permutation(len(uids))

In [15]:
train_p = 0.8
n_train = int(train_p * len(uids))

In [16]:
train_idx = permuted_idx[:n_train]
val_idx = permuted_idx[n_train:]

In [17]:
train_uids = set([uids[idx] for idx in train_idx])
val_uids = set([uids[idx] for idx in val_idx])

In [18]:
val_uids.intersection(train_uids)

set()

In [19]:
train_filenames = [fn for fn in train_files if get_uid_rot(fn)[0] in train_uids]
val_filenames = [fn for fn in train_files if get_uid_rot(fn)[0] in val_uids]

In [20]:
len(train_filenames)

4064

In [21]:
len(val_filenames)

1024

In [22]:
import json

In [30]:
def get_basename(fn):
    basename = os.path.basename(fn)
    basename, _ = os.path.splitext(basename)
    return basename

In [31]:
get_basename(train_filenames[0])

'3c6acfceb552_09'

In [32]:
with open("../data/train.json", "w") as jfile:
    json.dump([get_basename(fn) for fn in train_filenames], jfile, indent=2)

In [33]:
with open("../data/val.json", "w") as jfile:
    json.dump([get_basename(fn) for fn in val_filenames], jfile, indent=2)

In [34]:
!head -10 ../data/train.json

[
  "3c6acfceb552_09",
  "eb91b1c659a0_06",
  "5df60cf7cab2_01",
  "cafee4122080_15",
  "00087a6bd4dc_13",
  "898339fab87a_04",
  "dd70a0a51e3b_02",
  "a56f923399ca_02",
  "d1a3af34e674_12",


In [35]:
!head -10 ../data/val.json

[
  "9dfaeb835626_02",
  "a7b9e343cf6b_13",
  "70b6a79565fe_02",
  "cf65b1c5e147_06",
  "6d375bc2ece1_03",
  "ed8472086df8_14",
  "0cdf5b5d0ce1_01",
  "1e89e1af42e7_07",
  "3c54e71fd2c9_04",
