# Prerequisites

In [1]:
import sys

sys.path.append("..")

In [2]:
%load_ext autoreload
%autoreload 2

import json
import os

import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm

from dataset3d import BNSet, BNSetMasks, get_dloader_mask
from util3d import show_volume

In [3]:
data_dir = "../data/bugNIST_DATA"

# Train/Val split

In [4]:
dset = BNSetMasks(data_dir, "train")

train_size = 0.6
random_state = 1337

train_indices, valtest_indices = train_test_split(
    np.arange(len(dset)),
    train_size=train_size,
    stratify=dset.labels,
    random_state=random_state,
)
train_indices.size, valtest_indices.size, train_indices.size / (
    train_indices.size + valtest_indices.size
)

(5492, 3662, 0.5999563032554075)

In [5]:
valtest_split = valtest_indices.size // 2
val_indices, test_indices = (
    valtest_indices[:valtest_split],
    valtest_indices[valtest_split:],
)
val_indices.size, test_indices.size

(1831, 1831)

In [6]:
# np.unique(dset.labels[val_indices], return_counts=True)[1], np.unique(dset.labels[test_indices], return_counts=True)[1], np.unique(dset.labels[valtest_indices], return_counts=True)[1]

In [7]:
split = {
    "train": train_indices.tolist(),
    "val": val_indices.tolist(),
    "test": test_indices.tolist(),
}

In [11]:
with open(f'{data_dir}/single_bugs_split.json', 'w') as fp:
    json.dump(split, fp)

In [14]:
with open(f'{data_dir}/single_bugs_split.json', 'r') as fp:
    out = json.load(fp)

In [15]:
out

{'train': [2819,
  4455,
  3018,
  8475,
  6909,
  5716,
  5981,
  4693,
  6931,
  6421,
  4067,
  5446,
  7717,
  5760,
  994,
  11,
  5245,
  5631,
  5215,
  3713,
  5954,
  7892,
  3898,
  5867,
  6833,
  7604,
  8381,
  8627,
  8723,
  4104,
  5888,
  7402,
  8947,
  3322,
  7041,
  1840,
  4367,
  4069,
  263,
  6896,
  3553,
  144,
  6066,
  5121,
  450,
  4095,
  7915,
  2327,
  3299,
  58,
  1347,
  8391,
  8964,
  4180,
  1592,
  6698,
  5026,
  3216,
  7458,
  5566,
  1273,
  3820,
  7640,
  2452,
  8427,
  3057,
  2656,
  3486,
  7879,
  5258,
  2277,
  4168,
  5311,
  2100,
  8829,
  4821,
  6455,
  3099,
  3501,
  7238,
  3273,
  2417,
  3935,
  2198,
  4211,
  7906,
  3422,
  2054,
  486,
  2133,
  8056,
  561,
  8459,
  126,
  6212,
  7575,
  1653,
  540,
  7975,
  3289,
  7898,
  2677,
  204,
  4306,
  3402,
  7095,
  873,
  6305,
  7420,
  4130,
  7372,
  5693,
  875,
  1746,
  4922,
  8923,
  1920,
  5846,
  6607,
  3561,
  1725,
  5812,
  4471,
  8743,
  2577,
  1809