In [1]:
!pip install scikit-learn pandas

import os
import pandas as pd
from sklearn.model_selection import train_test_split
import pickle



In [2]:
# Adjust paths if needed
name_map = pd.read_csv("name_mapping.csv")
surv = pd.read_csv("survival_info.csv")

print("name_mapping columns:", name_map.columns.tolist())
print("survival_info columns:", surv.columns.tolist())

name_mapping columns: ['Grade', 'BraTS_2017_subject_ID', 'BraTS_2018_subject_ID', 'TCGA_TCIA_subject_ID', 'BraTS_2019_subject_ID', 'BraTS_2020_subject_ID']
survival_info columns: ['Brats20ID', 'Age', 'Survival_days', 'Extent_of_Resection']


In [4]:
# Make sure we have a consistent ID column
# name_mapping usually has something like 'BraTS_2020_subject_ID'
name_map = name_map.rename(columns={"BraTS_2020_subject_ID": "Brats20ID"})

# survival_info.csv usually already has 'Brats20ID'
# Merge on Brats20ID to get Grade + Age + Survival, etc.
meta = pd.merge(name_map, surv, on="Brats20ID", how="left")

print("Merged shape:", meta.shape)
print(meta.head())
print("\nGrade value counts:")
print(meta["Grade"].value_counts())

Merged shape: (369, 9)
  Grade BraTS_2017_subject_ID BraTS_2018_subject_ID TCGA_TCIA_subject_ID  \
0   HGG   Brats17_CBICA_AAB_1   Brats18_CBICA_AAB_1                  NaN   
1   HGG   Brats17_CBICA_AAG_1   Brats18_CBICA_AAG_1                  NaN   
2   HGG   Brats17_CBICA_AAL_1   Brats18_CBICA_AAL_1                  NaN   
3   HGG   Brats17_CBICA_AAP_1   Brats18_CBICA_AAP_1                  NaN   
4   HGG   Brats17_CBICA_ABB_1   Brats18_CBICA_ABB_1                  NaN   

  BraTS_2019_subject_ID             Brats20ID     Age Survival_days  \
0   BraTS19_CBICA_AAB_1  BraTS20_Training_001  60.463           289   
1   BraTS19_CBICA_AAG_1  BraTS20_Training_002  52.263           616   
2   BraTS19_CBICA_AAL_1  BraTS20_Training_003  54.301           464   
3   BraTS19_CBICA_AAP_1  BraTS20_Training_004  39.068           788   
4   BraTS19_CBICA_ABB_1  BraTS20_Training_005  68.493           465   

  Extent_of_Resection  
0                 GTR  
1                 GTR  
2                 GTR

We’ll do this in two stages:

First: 85% train+val, 15% test

Second: within the 85%, split into 70% train, 15% val
So overall ≈ 70/15/15.

In [5]:
# First: train_val (85%) vs test (15%), stratified by Grade
train_val_df, test_df = train_test_split(
    meta,
    test_size=0.15,
    random_state=42,
    stratify=meta["Grade"]
)

# Second: within train_val, split into train (70%) and val (15%)
# Proportion for val relative to train_val = 0.15 / 0.85 ≈ 0.1765
val_ratio_within_trainval = 0.15 / 0.85

train_df, val_df = train_test_split(
    train_val_df,
    test_size=val_ratio_within_trainval,
    random_state=42,
    stratify=train_val_df["Grade"]
)

print("Total:", len(meta))
print("Train:", len(train_df))
print("Val:  ", len(val_df))
print("Test: ", len(test_df))

print("\nGrade counts (Train):")
print(train_df["Grade"].value_counts())
print("\nGrade counts (Val):")
print(val_df["Grade"].value_counts())
print("\nGrade counts (Test):")
print(test_df["Grade"].value_counts())


Total: 369
Train: 257
Val:   56
Test:  56

Grade counts (Train):
Grade
HGG    204
LGG     53
Name: count, dtype: int64

Grade counts (Val):
Grade
HGG    45
LGG    11
Name: count, dtype: int64

Grade counts (Test):
Grade
HGG    44
LGG    12
Name: count, dtype: int64


These plain text files are handy for our own scripts:

In [6]:
train_ids = train_df["Brats20ID"].tolist()
val_ids   = val_df["Brats20ID"].tolist()
test_ids  = test_df["Brats20ID"].tolist()

# Save as simple text files (one ID per line)
with open("train_ids.txt", "w") as f:
    f.write("\n".join(train_ids))

with open("val_ids.txt", "w") as f:
    f.write("\n".join(val_ids))

with open("test_ids.txt", "w") as f:
    f.write("\n".join(test_ids))

print("Saved train_ids.txt, val_ids.txt, test_ids.txt")


Saved train_ids.txt, val_ids.txt, test_ids.txt


In [7]:
# Build the split structure for nnU-Net
# We'll define a single fold (fold 0) with our train/val split
split_dict = {
    "train": train_ids,
    "val":   val_ids
}

splits = [split_dict]  # nnU-Net expects a list of folds

# Path to nnU-Net dataset folder
# Adjust this to match your actual nnUNet_raw structure
# Example: nnUNet_raw/Dataset999_BraTS2020/
dataset_id = "Dataset001_BraTS2020"  # <- change to your DatasetXXX name
nnunet_raw_path = "/content/drive/MyDrive/MEDICAL/InfoTheo_dataset/nnUNet_raw"  # or your path if different

dataset_folder = os.path.join(nnunet_raw_path, dataset_id)
os.makedirs(dataset_folder, exist_ok=True)

splits_file = os.path.join(dataset_folder, "splits_final.pkl")

with open(splits_file, "wb") as f:
    pickle.dump(splits, f)

print("Saved splits_final.pkl to:", splits_file)


Saved splits_final.pkl to: /content/drive/MyDrive/MEDICAL/InfoTheo_dataset/nnUNet_raw/Dataset001_BraTS2020/splits_final.pkl
