In [1]:
%%capture
%pip install scikit-learn
%pip install -U 'rockfish[labs]' -f 'https://docs.rockfish.ai/packages/index.html'

In [2]:
import rockfish as rf
import rockfish.actions as ra
import rockfish.labs as rl
import pandas as pd
from sklearn.model_selection import train_test_split

Please replace `YOUR_API_KEY` with the assigned API key string. Note that it should be without quotes.

For example, if the assigned API Key is `abcd1234`, you can do the following

```python
%env ROCKFISH_API_KEY=abcd1234
conn = rf.Connection.from_env()
```

If you do not have API Key, please reach out to support@rockfish.ai.


In [1]:
%env ROCKFISH_API_KEY=YOUR_API_KEY
conn = rf.Connection.from_env()

In [4]:
# download our example of tabular data: fall_detection.csv
!wget --no-clobber https://docs.rockfish.ai/tutorials/fall_detection.csv

File ‘fall_detection.csv’ already there; not retrieving.



In [5]:
# split into train and test dataset
df = pd.read_csv("fall_detection.csv")
train_split, test_split = train_test_split(
    df, test_size=0.5, shuffle=True, random_state=1
)

# reset and drop original indices for both splits
train_split = train_split.reset_index(drop=True)
test_split = test_split.reset_index(drop=True)

In [6]:
train_dataset = rf.Dataset.from_pandas("fall_detection_train", train_split)
train_dataset.to_pandas()

Unnamed: 0,Age range of patient,Sex,Involvement of medication associated with fall risk,Whether a fall prevention protocol was implemented,Reason for incident,Whether a restraint prescription was given,Whether a physical therapy prescription was given,BBS Score,Body Temperature,Heart Rate,...,Chronic Heart Failure,Stroke,Frozen shoulder,Osteoarthritis of hip,Cerebellar Ataxia,Hemiparesis,GB Syndrome,Fall risk level,Ischemic Heart Disease,Cervical sopondylitis
0,70<80,F,Yes,Yes,Muscle weakness,No,Yes,41,97,99,...,No,No,No,No,No,No,No,High,No,No
1,60<70,F,Yes,Yes,Mental confusion,No,Yes,40,96,94,...,No,No,No,No,No,No,No,High,No,No
2,70<80,M,Yes,Yes,Loss of balance,No,No,38,96,81,...,No,No,No,No,No,No,No,High,No,No
3,30<40,F,No,No,Hypotension,No,No,48,98,78,...,No,No,No,No,No,No,No,Low,No,No
4,60<70,M,Yes,Yes,Muscle weakness,No,Yes,39,97,103,...,No,No,No,No,No,No,No,High,No,No
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1286,80<90,M,No,Yes,Loss of balance,No,No,39,97,77,...,No,No,No,No,No,No,No,High,No,No
1287,60<70,F,No,Yes,Loss of balance,No,Yes,41,97,71,...,No,No,No,No,No,No,No,High,Yes,No
1288,60<70,F,No,Yes,Loss of balance,No,No,40,96,78,...,No,No,No,No,No,No,No,High,No,No
1289,1<13,F,No,Yes,Slip,No,Yes,39,97,98,...,No,No,No,No,No,No,No,High,No,No


In [7]:
test_dataset = rf.Dataset.from_pandas("fall_detection_test", test_split)
test_dataset.to_pandas()

Unnamed: 0,Age range of patient,Sex,Involvement of medication associated with fall risk,Whether a fall prevention protocol was implemented,Reason for incident,Whether a restraint prescription was given,Whether a physical therapy prescription was given,BBS Score,Body Temperature,Heart Rate,...,Chronic Heart Failure,Stroke,Frozen shoulder,Osteoarthritis of hip,Cerebellar Ataxia,Hemiparesis,GB Syndrome,Fall risk level,Ischemic Heart Disease,Cervical sopondylitis
0,70<80,F,No,Yes,Muscle weakness,No,No,38,98,74,...,No,No,No,No,No,No,No,High,No,No
1,30<40,M,No,Yes,Loss of balance,No,No,39,97,89,...,No,No,No,No,No,No,No,High,No,No
2,20<30,F,No,No,Loss of balance,No,No,43,97,93,...,No,No,No,No,No,No,No,Moderate,No,No
3,40<50,F,No,Yes,Muscle weakness,No,Yes,41,95,88,...,No,No,No,No,No,No,No,High,No,No
4,30<40,F,No,No,Hypotension,No,No,49,98,76,...,No,No,No,No,No,No,No,Low,No,No
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1286,80<90,F,No,Yes,Muscle weakness,No,Yes,43,97,89,...,No,No,No,No,No,No,No,Moderate,No,No
1287,30<40,M,Yes,Yes,Hypotension,No,No,39,96,98,...,Yes,No,No,No,No,No,No,High,No,No
1288,20<30,F,No,Yes,Hypotension,No,No,42,95,79,...,No,No,No,No,No,No,No,High,No,No
1289,80<90,F,No,No,Slip,No,No,42,99,75,...,No,No,No,No,No,No,No,High,Yes,No


In [8]:
# user can manually provide a list of categorical column names
categorical_fields = (
    train_dataset.to_pandas().select_dtypes(include=["object"]).columns
)
config = {
    "encoder": {
        "metadata": [
            {"field": field, "type": "categorical"}
            for field in categorical_fields
        ]
        + [
            {"field": field, "type": "continuous"}
            for field in train_dataset.table.column_names
            if field not in categorical_fields
        ],
    },
    "tabular-gan": {
        "epochs": 100,
        "records": len(train_split),
    },
}
# create train action
train = ra.TrainTabGAN(config)

In [9]:
builder = rf.WorkflowBuilder()
builder.add_dataset(train_dataset)
builder.add_action(train, parents=[train_dataset])
workflow = await builder.start(conn)

print(f"Workflow: {workflow.id()}")

Workflow: 59VZADHz6mAuTQdEdLU2tD


In [10]:
async for progress in workflow.progress().notebook():
    pass

  0%|          | 0/100 [00:00<?, ?it/s]

In [11]:
model = await workflow.models().nth(0)
model

Model('6a90cb4d-4495-11ef-a6e0-625013d9c08b')

In [12]:
generate = ra.GenerateTabGAN(config)
save = ra.DatasetSave({"name": "synthetic"})
builder = rf.WorkflowBuilder()
builder.add_model(model)
builder.add_action(generate, parents=[model])
builder.add_action(save, parents=[generate])
workflow = await builder.start(conn)
print(f"Workflow: {workflow.id()}")

Workflow: 1wy5fAYKIvsErJEYRqlrLl


In [13]:
syn = None
async for sds in workflow.datasets():
    syn = await sds.to_local(conn)
syn.to_pandas()

Unnamed: 0,Age range of patient,Sex,Involvement of medication associated with fall risk,Whether a fall prevention protocol was implemented,Reason for incident,Whether a restraint prescription was given,Whether a physical therapy prescription was given,BBS Score,Body Temperature,Heart Rate,...,Chronic Heart Failure,Stroke,Frozen shoulder,Osteoarthritis of hip,Cerebellar Ataxia,Hemiparesis,GB Syndrome,Fall risk level,Ischemic Heart Disease,Cervical sopondylitis
0,20<30,F,Yes,Yes,Mental confusion,No,No,49,98,91,...,No,No,No,No,No,No,No,High,No,No
1,60<70,M,Yes,Yes,Loss of balance,No,Yes,52,94,84,...,No,No,No,No,No,No,No,High,No,No
2,13<20,M,Yes,No,Muscle weakness,No,No,38,98,85,...,No,No,No,No,No,No,No,Moderate,No,No
3,50<60,F,No,Yes,Hypotension,No,No,44,94,57,...,No,No,No,No,No,No,No,High,No,No
4,20<30,M,Yes,No,Mental confusion,No,No,44,94,79,...,No,No,No,No,No,No,No,High,No,No
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1286,60<70,F,No,Yes,Loss of balance,No,No,44,96,74,...,No,No,No,No,No,No,No,High,No,No
1287,70<80,M,No,Yes,Mental confusion,No,No,44,98,79,...,No,No,No,No,No,No,No,Low,No,No
1288,30<40,F,No,Yes,Loss of balance,No,Yes,38,97,78,...,No,No,No,No,No,No,No,High,No,No
1289,60<70,F,Yes,Yes,Loss of balance,No,No,37,98,93,...,No,No,No,No,No,No,No,High,No,No


### DCR Score

The Distance to Closest Record (DCR) score quantifies privacy risk by checking how similar records in the synthetic
dataset are w.r.t. the source dataset.

It does so by measuring the similarity between the DCR distributions between the two dataset pairs - (source, synthetic)
and (source, test). The more similar these two DCR distributions are, the more "private" the synthetic data.

Note that the test dataset should be sampled from the same distribution as the source dataset, and should not be used to
train your synthetic data generator.

The DCR score is a value between 0 and positive infinity. It can be interpreted using the following Likert scale for
quality:

1. Low: [0 - 0.75)
2. Medium: [0.75 - 1.0)
3. High: [1.0, positive infinity)


In [14]:
score = rl.metrics.distance_to_closest_record_score(
    train_dataset=train_dataset, test_dataset=test_dataset, syn=syn
)

In [15]:
score

7.978583353694598