In [1]:
import os
import numpy as np
import pandas as pd

In [6]:
df_agg = pd.read_csv("RS_data_agg.csv")
col_tuples = [tuple(c.split('_')) if "ROI" in c else (c, '') for c in df_agg.columns]
df_agg.columns = pd.MultiIndex.from_tuples(col_tuples)
df_agg

Unnamed: 0_level_0,Diagnosis,ROI1-ROI1,ROI1-ROI1,ROI1-ROI2,ROI1-ROI2,ROI1-ROI3,ROI1-ROI3,ROI1-ROI4,ROI1-ROI4,ROI1-ROI5,...,ROI262-ROI263,ROI262-ROI263,ROI262-ROI264,ROI262-ROI264,ROI263-ROI263,ROI263-ROI263,ROI263-ROI264,ROI263-ROI264,ROI264-ROI264,ROI264-ROI264
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,var,mean,var,mean,var,mean,var,mean,...,mean,var,mean,var,mean,var,mean,var,mean,var
0,Control,1.0,0.0,0.111342,0.02778,0.132046,0.038377,0.097095,0.034371,0.06648,...,-0.011571,0.037473,-0.050117,0.027456,0.99932,0.00068,0.279532,0.050156,0.99932,0.00068
1,PTSD,1.0,0.0,0.105944,0.026793,0.13553,0.036369,0.095276,0.03433,0.06439,...,-0.023395,0.033395,-0.042671,0.025482,1.0,0.0,0.257095,0.047299,1.0,0.0


In [7]:
col_filter = df_agg.columns.get_level_values(0).to_series().str.contains("ROI")

In [12]:
age_range = (10, 100)
site_choices = ["A", "B", "C", "D", "E", "F"]
n_samples_control = 20
n_samples_ptsd = 10
subj_id_mul = 0
ex_df = None
for diag, diag_df in df_agg.groupby("Diagnosis"):
    n_samples = n_samples_control if diag == 'Control' else n_samples_ptsd
    ROI_agg = diag_df[df_agg.columns[col_filter]]
    ROI_means = ROI_agg.xs("mean", axis=1, level=1, drop_level=True).to_numpy().flatten()
    ROI_vars = ROI_agg.xs("var", axis=1, level=1, drop_level=True).to_numpy().flatten()
    n_ROIs = len(ROI_means)
    ROI_samples = np.random.normal(loc=ROI_means, scale=np.sqrt(ROI_vars), size=(n_samples, n_ROIs))
    col_names = df_agg.columns[col_filter].get_level_values(0).to_series().unique()
    diag_ex_df = pd.DataFrame(ROI_samples, columns=col_names)
    subj_ids = np.random.permutation(n_samples) + n_samples*subj_id_mul
    ages = np.random.randint(age_range[0], age_range[1], n_samples)
    sexes = np.random.choice(["M","F"], n_samples)
    sites = np.random.choice(site_choices, n_samples)
    diag_ex_df.insert(0, 'Site', sites)
    diag_ex_df.insert(1, 'SubjectID', subj_ids)
    diag_ex_df.insert(2, 'Diagnosis', diag)
    diag_ex_df.insert(3, 'Age', ages)
    diag_ex_df.insert(4, 'Sex', sexes)
    if ex_df is None:
        ex_df = diag_ex_df
    else:
        ex_df = pd.concat([ex_df, diag_ex_df], ignore_index=True)
    
    subj_id_mul += 1

ex_df

Unnamed: 0,Site,SubjectID,Diagnosis,Age,Sex,ROI1-ROI1,ROI1-ROI2,ROI1-ROI3,ROI1-ROI4,ROI1-ROI5,...,ROI261-ROI261,ROI261-ROI262,ROI261-ROI263,ROI261-ROI264,ROI262-ROI262,ROI262-ROI263,ROI262-ROI264,ROI263-ROI263,ROI263-ROI264,ROI264-ROI264
0,A,4,Control,40,M,1.0,0.005119,0.260709,0.276597,0.625206,...,1.046411,0.27187,0.225481,0.178783,0.931993,0.074302,-0.105372,0.98825,0.339961,1.008516
1,B,18,Control,34,F,1.0,0.131317,0.143047,0.015091,0.033315,...,0.968696,0.106633,-0.119147,-0.054522,1.009274,-0.136024,-0.003818,1.013211,0.355602,0.989907
2,D,14,Control,83,F,1.0,0.249349,-0.05701,0.09402,0.330839,...,0.987045,0.147393,0.433067,-0.232037,1.008879,0.068258,0.220003,0.971478,0.295053,1.035334
3,A,17,Control,25,F,1.0,0.093199,0.041888,0.069637,0.626811,...,0.981276,0.152478,0.123645,-0.218406,1.022811,0.101622,-0.292032,0.978878,0.309185,0.975821
4,C,8,Control,12,F,1.0,0.26458,0.352261,0.102381,0.030955,...,1.041081,0.299346,-0.178397,-0.082434,1.017162,0.153038,0.088272,0.959921,0.337613,1.057162
5,F,1,Control,96,M,1.0,0.282019,-0.241174,0.307522,-0.142424,...,0.996413,0.37872,-0.056402,-0.020524,0.9995,-0.018768,0.141428,0.964183,-0.090229,1.023601
6,E,0,Control,13,M,1.0,-0.093097,0.055134,-0.091561,0.0764,...,1.017964,0.120304,-0.052222,-0.105,0.961371,0.409509,-0.1061,0.985815,0.182784,1.049073
7,F,19,Control,89,M,1.0,0.221442,-0.065511,0.2364,-0.178139,...,0.981786,0.136412,-0.104923,-0.014906,1.004423,0.04309,0.115828,0.996634,0.16667,1.007271
8,C,9,Control,85,M,1.0,-0.067811,0.185978,0.296588,0.071748,...,0.977913,0.850497,-0.082248,0.201185,0.994865,0.042586,-0.161116,0.985619,0.252947,1.031861
9,D,2,Control,56,M,1.0,-0.229742,-0.009312,-0.266079,0.118861,...,1.011041,0.022692,-0.080945,0.148103,1.006507,-0.190492,-0.102166,0.982757,0.356311,1.022845


In [13]:
ex_df.to_csv("RS_data_example.csv", index=False)