In [2]:
import pandas as pd
import SimpleITK as sitk
from pathlib import Path
from tqdm.notebook import tqdm

In [3]:
def assign_folds(df, num_folds, random_state=None):
    fold_size = len(df) // num_folds
    orig_len = len(df)
    folds = []
    for fold_id in range(num_folds - 1):
        fold_frac = fold_size / (orig_len - fold_id * fold_size)
        fold = df.groupby(["probCOVID", "probSevere", "sex"]).sample(frac=fold_frac, random_state=random_state)
        df = df.drop(index=fold.index)
        fold["cv"] = fold_id
        folds.append(fold)
    df["cv"] = num_folds - 1
    folds.append(df)
    return pd.concat(folds).sort_index()

In [4]:
def balance(df):
    sev_only = df[df.probSevere == 1]
    r = int(len(df) / len(sev_only) - 1)
    return pd.concat([df] + [sev_only] * r).sort_index()

In [5]:
def get_age_sex(patient):
    img = sitk.ReadImage(str(dataroot/"data"/"mha"/f"{patient}.mha"))
    return img.GetMetaData("PatientAge"), img.GetMetaData("PatientSex")

In [6]:
dataroot = Path("~/data/stoic/data").expanduser()

In [None]:
ref = pd.read_csv(dataroot/"metadata"/"reference.csv", index_col="PatientID")

ages = []
sexs = []
# takes about an hour
for patient in tqdm(ref.index):
    age, sex = get_age_sex(patient)
    ages.append(age)
    sexs.append(sex)

ref["age"] = ages
ref["sex"] = sexs

ref.to_csv("ref.csv", index=True)

In [6]:
ref = pd.read_csv("ref.csv", index_col="PatientID")

In [9]:
val_ratio = 0.1
g = ref.groupby(["probCOVID", "probSevere", "age", "sex"])

val = g.sample(frac=val_ratio, replace=False)
val["set"] = "val"

train = ref.drop(index=val.index)
train["set"] = "train"

split = pd.concat([train, val])
assert len(split) == len(ref)
split.to_csv("split_meta.csv", index=True)

Balance probSevere

In [30]:
split = pd.read_csv("split_meta.csv", index_col="PatientID")

In [52]:
train8 = assign_folds(split[split.set == "train"], num_folds=8)
val8 = split[split.set == "val"].copy()
val8["cv"] = -1
split8 = pd.concat([train8, val8]).sort_index()

split8.to_csv("split_meta8.csv", index=True)

In [59]:
trainb8 = split8[(split8.set == "train") & (split8.probCOVID == 1)]
trainb8 = balance(trainb8)
valb8 = split8[split8.set == "val"]
balance8 = pd.concat([trainb8, valb8]).sort_index()
balance8.to_csv("split_meta_sevbalance8.csv", index=True)

8-fold cross-validation split without extra test set (for submission).

In [21]:
cv8 = balance(assign_folds(ref[ref.probCOVID == 1], num_folds=8))
cv8["set"] = "train"
cv8.to_csv("split_sev_cv8.csv", index=True)

8fold cross-validation split but with probCOVID==0 cases

In [8]:
cv8 = pd.read_csv("split_sev_cv8.csv")

In [18]:
noinf8 = assign_folds(ref[ref.probCOVID == 0], num_folds=8).reset_index()
noinf8["set"] = "train"
pd.concat([cv8, noinf8]).to_csv("split_all_cv8.csv", index=False)

## 5-fold

In [13]:
cv5 = balance(assign_folds(ref[ref.probCOVID == 1], num_folds=5, random_state=1055))
cv5["set"] = "train"
cv5.to_csv("split_sev_cv5.csv", index=True)

In [7]:
cv5_inf1 = balance(assign_folds(ref[ref.probCOVID == 1], num_folds=5, random_state=1055))
cv5_inf0 = assign_folds(ref[ref.probCOVID == 0], num_folds=5, random_state=1055)
cv5 = pd.concat([cv5_inf0, cv5_inf1])
cv5["set"] = "train"
cv5.to_csv("sevbal_cv5.csv", index=True)
cv5["set"] = ["val" if cv == 0 else "train" for cv in cv5["cv"]]
# remove duplicates from validation set
cv5 = pd.concat([cv5[cv5["set"] == "train"], cv5.loc[cv5[cv5["set"] == "val"].index.unique()]])
cv5.to_csv("sevbal_cv5_val=0.csv", index=True)