In [20]:
import os
import random
import shutil

In [21]:
BASE_PATH = "/content/drive/MyDrive/Dataset"

IMG_BIND = f"{BASE_PATH}/images/binding"
IMG_NONBIND = f"{BASE_PATH}/images/nonbinding"

GRID_BIND = f"{BASE_PATH}/3d_grids/binding"
GRID_NONBIND = f"{BASE_PATH}/3d_grids/nonbinding"

In [22]:
pdb_ids = sorted([f.split("_")[0] for f in os.listdir(GRID_BIND) if f.endswith(".npy")])

print("Total PDB IDs:", len(pdb_ids))

Total PDB IDs: 59


In [23]:
random.seed(42)
random.shuffle(pdb_ids)

n = len(pdb_ids)

train_ids = pdb_ids[:int(0.7 * n)]
val_ids   = pdb_ids[int(0.7 * n):int(0.85 * n)]
test_ids  = pdb_ids[int(0.85 * n):]

print("Train:", len(train_ids))
print("Validation:", len(val_ids))
print("Test:", len(test_ids))


Train: 41
Validation: 9
Test: 9


In [24]:
for split in ["train", "val", "test"]:
    os.makedirs(f"{BASE_PATH}/{split}/2d/binding", exist_ok=True)
    os.makedirs(f"{BASE_PATH}/{split}/2d/nonbinding", exist_ok=True)
    os.makedirs(f"{BASE_PATH}/{split}/3d/binding", exist_ok=True)
    os.makedirs(f"{BASE_PATH}/{split}/3d/nonbinding", exist_ok=True)

print("Split folders created")


Split folders created


In [25]:
def copy_files(pdb_list, split):
    for pdb_id in pdb_list:
        # 3D grids
        shutil.copy(f"{GRID_BIND}/{pdb_id}_binding.npy",
                    f"{BASE_PATH}/{split}/3d/binding/")
        shutil.copy(f"{GRID_NONBIND}/{pdb_id}_nonbinding.npy",
                    f"{BASE_PATH}/{split}/3d/nonbinding/")

        # 2D images
        shutil.copy(f"{IMG_BIND}/{pdb_id}_binding.png",
                    f"{BASE_PATH}/{split}/2d/binding/")
        shutil.copy(f"{IMG_NONBIND}/{pdb_id}_nonbinding.png",
                    f"{BASE_PATH}/{split}/2d/nonbinding/")


In [26]:
copy_files(train_ids, "train")
copy_files(val_ids, "val")
copy_files(test_ids, "test")

print("✅ Dataset split completed successfully")


✅ Dataset split completed successfully


In [27]:
print("Train binding (3D):", len(os.listdir(f"{BASE_PATH}/train/3d/binding")))
print("Val binding (3D):", len(os.listdir(f"{BASE_PATH}/val/3d/binding")))
print("Test binding (3D):", len(os.listdir(f"{BASE_PATH}/test/3d/binding")))

Train binding (3D): 41
Val binding (3D): 9
Test binding (3D): 9
