# Train Test Split Construction

When going towards publication, a KFold cross validation on lab level could be better. For now, I will just choose to use 2 labs as test cases. For my POC, I will then also use these cells for later analysis in the Application notebooks.

In [28]:
import lance
import os

### Define paths

In [29]:
path_data_root = '/home/sam/SCI/cellenONE_project/datasets'

path_train = os.path.join(path_data_root, 'train_SCI.lance')
path_val = os.path.join(path_data_root, 'val_ChristineErwin_SCI.lance')
path_test = os.path.join(path_data_root, 'test_Backup_SCI.lance')

### Load data and set splitting rules

In [30]:
lds = lance.dataset(
    os.path.join(
        path_data_root,
        'cropped_cells_SCI_1'
    )
)

plates_train = [
    'Plate_02_BOGDAN_',
    'Plate_03_BOGDAN_',
    'Plate_04_BOGDAN',
    'Plate_05_KARL',
    'Plate_06_KARL',
    'Plate_07_BOGDAN',
    'Plate_08_BOGDAN',
    'Plate_09_BOGDAN',
    'Plate_10_AKOS',
]
plates_val = [
    'Plate_13_CHRISTINE',
    'Plate_14_ERWIN'
]
plates_test = [
    'Plate_01_Backup',
    'Plate_11_Backup',
    'Plate_12_Backup'
]

lds_batchreader_train = lds.to_batches(
    filter=f"plate_name IN ('{"', '".join(list(plates_train))}')"
)
lds_batchreader_val = lds.to_batches(
    filter=f"plate_name IN ('{"', '".join(list(plates_val))}')"
)
lds_batchreader_test = lds.to_batches(
    filter=f"plate_name IN ('{"', '".join(list(plates_test))}')"
)

### Write data to lance format

In [31]:
lance.write_dataset(
    lds_batchreader_train,
    path_train,
    schema=lds.schema
)
lance.write_dataset(
    lds_batchreader_val,
    path_val,
    schema=lds.schema
)
lance.write_dataset(
    lds_batchreader_test,
    path_test,
    schema=lds.schema
)

<lance.dataset.LanceDataset at 0x7f052bf39c50>

## Shapes sanity check

In [32]:
lds_train = lance.dataset(
    path_train,
)
lds_val = lance.dataset(
    path_val,
)
lds_test = lance.dataset(
    path_test,
)

In [33]:
train_df = lds_train.to_table(
    columns=['plate_name']
).to_pandas()

test_df = lds_test.to_table(
    columns=['plate_name']
).to_pandas()

val_df = lds_val.to_table(
    columns=['plate_name']
).to_pandas()

print(train_df.shape)
print(test_df.shape)
print(val_df.shape)

(2413, 1)
(700, 1)
(529, 1)
