In [3]:
# ===============================================================
# DFTpy:  Mount Drive / clone repo / build CSVs only
# ===============================================================

# ① Mount Drive -------------------------------------------------
from google.colab import drive
!rm -rf /content/drive
drive.mount('/content/drive')

# ② Install deps ------------------------------------------------
!pip -q install --upgrade pip
!pip -q install torch torchvision torchaudio
!pip -q install numpy pandas pymatgen scikit-learn h5py tqdm joblib

# ③ Clone fresh DFTpy ------------------------------------------- -------------------------------------------
import os, random
from pathlib import Path
import pandas as pd

# 切换到工作目录并克隆仓库（如已存在则跳过）
os.chdir('/content')
if not Path('DFT').exists():
    !git clone -q https://github.com/ChenHongBo0420/DFT.git

# ③ 收集所有样本目录 --------------------------------------------
base = Path('/content/DFT/database')
all_dirs    = set()
energy_dirs = set()
dos_dirs    = set()

for pos in base.rglob('POSCAR'):
    root = pos.parent.parent if pos.parent.name.upper() == 'POSCAR' else pos.parent
    all_dirs.add(root.as_posix())
    # 判断完整数据
    if (root/'energy').is_file() and (root/'forces').is_file() and (root/'stress').is_file():
        energy_dirs.add(root.as_posix())
    if (root/'dos').is_file() and (root/'VB_CB').is_file():
        dos_dirs.add(root.as_posix())

# ④ 划分 Train/Val/Test ------------------------------------------

def split(lst, ratio=0.7):
    lst = list(lst)
    random.shuffle(lst)
    n = len(lst)
    n_tr  = max(int(n * ratio), 1)
    n_val = max(int(n * 0.15), 1)
    if n_tr + n_val >= n:
        n_tr = max(n - 2, 1)
        n_val = 1
    return lst[:n_tr], lst[n_tr:n_tr + n_val]

tr_all, val_all = split(all_dirs)
tr_en,  val_en  = split(energy_dirs) if energy_dirs else ([], [])
tr_dos, val_dos = split(dos_dirs)    if dos_dirs    else ([], [])
# 测试集为剩余所有样本
test_all = sorted(all_dirs - set(tr_all) - set(val_all))

# ⑤ 写 CSV 并打印样本数 ------------------------------------------
csv_dir = Path('/content/drive/MyDrive/DFT_CSVs')
csv_dir.mkdir(exist_ok=True)

pd.DataFrame({'files':    tr_all}).to_csv(csv_dir/'Train_all.csv',   index=False)
pd.DataFrame({'files':    val_all}).to_csv(csv_dir/'Val_all.csv',     index=False)
pd.DataFrame({'files':    tr_en }).to_csv(csv_dir/'Train_energy.csv',index=False)
pd.DataFrame({'files':    val_en }).to_csv(csv_dir/'Val_energy.csv',  index=False)
pd.DataFrame({'files':    tr_dos}).to_csv(csv_dir/'Train_dos.csv',   index=False)
pd.DataFrame({'files':    val_dos}).to_csv(csv_dir/'Val_dos.csv',     index=False)
pd.DataFrame({'file_loc_test': test_all}).to_csv(csv_dir/'predict.csv', index=False)

print(f"Samples for CHG:     Train={len(tr_all)}  Val={len(val_all)}  Test={len(test_all)}")
print(f"Samples for Energy:  Train={len(tr_en)}  Val={len(val_en)}  Test={len(set(test_all)&energy_dirs)}")
print(f"Samples for DOS:     Train={len(tr_dos)}  Val={len(val_dos)}  Test={len(set(test_all)&dos_dirs)}")

Mounted at /content/drive
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m47.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m26.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m173.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m181.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m42.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m37.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m73.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m76.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━