In [2]:
%%capture
%pip install -U 'rockfish[labs]' -f 'https://docs142.rockfish.ai/packages/index.html'

In [3]:
import rockfish as rf
import rockfish.actions as ra
import rockfish.labs

Please replace `YOUR_API_KEY` with the assigned API key string. Note that it should be without quotes.

For example, if the assigned API Key is `abcd1234`, you can do the following
```python
%env ROCKFISH_API_KEY=abcd1234
conn = rf.Connection.from_env()
```
If you do not have API Key, please reach out to support@rockfish.ai.

In [1]:
%env ROCKFISH_API_KEY=YOUR_API_KEY
conn = rf.Connection.from_env()

In [5]:
# download our example of tabular data: fall_detection.csv
!wget --no-clobber https://docs142.rockfish.ai/tutorials/fall_detection.csv

File ‘fall_detection.csv’ already there; not retrieving.



In [6]:
dataset = rf.Dataset.from_csv("fall_detection", "fall_detection.csv")
dataset.to_pandas()

Unnamed: 0,Age range of patient,Sex,Involvement of medication associated with fall risk,Whether a fall prevention protocol was implemented,Reason for incident,Whether a restraint prescription was given,Whether a physical therapy prescription was given,BBS Score,Body Temperature,Heart Rate,...,Chronic Heart Failure,Stroke,Frozen shoulder,Osteoarthritis of hip,Cerebellar Ataxia,Hemiparesis,GB Syndrome,Fall risk level,Ischemic Heart Disease,Cervical sopondylitis
0,60<70,M,Yes,Yes,Slip,No,Yes,41,97,80,...,No,No,No,No,No,No,No,High,No,No
1,30<40,F,Yes,Yes,Loss of balance,No,No,41,96,78,...,No,No,No,No,No,No,No,High,No,No
2,60<70,M,Yes,Yes,Mental confusion,No,Yes,43,98,81,...,No,No,No,No,No,No,No,High,No,No
3,80<90,M,Yes,Yes,Mental confusion,No,Yes,40,99,82,...,No,No,No,No,No,No,No,High,No,No
4,60<70,M,Yes,Yes,Loss of balance,No,Yes,40,96,90,...,No,No,No,No,No,No,No,High,No,No
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2577,50<60,M,No,No,Muscle weakness,No,No,44,97,67,...,No,No,No,No,No,No,No,Moderate,No,No
2578,80<90,F,No,No,Hypotension,No,No,43,97,96,...,No,No,No,No,No,No,No,Moderate,No,No
2579,30<40,F,No,No,Muscle weakness,No,No,43,96,76,...,No,No,No,No,No,No,No,Moderate,No,No
2580,60<70,M,No,Yes,Loss of balance,No,No,44,99,101,...,No,No,No,No,No,No,No,Moderate,No,No


In [7]:
# user can manually provide a list of categorical column names
categorical_fields = dataset.to_pandas().select_dtypes(include=["object"]).columns
config = {
    "encoder": {
        "metadata": [
            {"field": field, "type": "categorical"} for field in categorical_fields
        ]
        + [
            {"field": field, "type": "continuous"}
            for field in dataset.table.column_names
            if field not in categorical_fields
        ],
    },
    "tabular-gan": {
        "epochs": 100,
        "records": 2582,
    },
}
# create train action
train = ra.TrainTabGAN(config)

In [8]:
builder = rf.WorkflowBuilder()
builder.add_dataset(dataset)
builder.add_action(train, parents=[dataset])
workflow = await builder.start(conn)

print(f"Workflow: {workflow.id()}")

Workflow: 1izmeQThaSbyKJqpmiIVkT


In [9]:
async for progress in workflow.progress().notebook():
    pass

  0%|          | 0/100 [00:00<?, ?it/s]

In [10]:
model = await workflow.models().nth(0)
model

Model('4ac4674a-4096-11ef-8c4e-8a07ae1c625c')

### Update the generated records

In [16]:
config["tabular-gan"].update({"records": 5000})  # update n_records
generate = ra.GenerateTabGAN(config)
save = ra.DatasetSave({"name": "synthetic"})
builder = rf.WorkflowBuilder()
builder.add_model(model)
builder.add_action(generate, parents=[model])
builder.add_action(save, parents=[generate])
workflow = await builder.start(conn)
print(f"Workflow: {workflow.id()}")

Workflow: 1JQF6IfIJJ4dBr0Z5dmVz3


In [17]:
async for log in workflow.logs():
    print(log)

2024-07-12T21:33:41Z dataset-save: INFO Saved dataset '7w5W9HGR04i1RBFMRoJEo9' with 5000 rows
2024-07-12T21:33:40Z generate-tab-gan: INFO Generating 5000 records


In [18]:
syn = None
async for sds in workflow.datasets():
    syn = await sds.to_local(conn)
syn.to_pandas()

Unnamed: 0,Age range of patient,Sex,Involvement of medication associated with fall risk,Whether a fall prevention protocol was implemented,Reason for incident,Whether a restraint prescription was given,Whether a physical therapy prescription was given,BBS Score,Body Temperature,Heart Rate,...,Chronic Heart Failure,Stroke,Frozen shoulder,Osteoarthritis of hip,Cerebellar Ataxia,Hemiparesis,GB Syndrome,Fall risk level,Ischemic Heart Disease,Cervical sopondylitis
0,60<70,F,No,No,Mental confusion,No,No,39,97,95,...,No,No,No,No,No,No,No,High,No,No
1,70<80,M,Yes,Yes,Muscle weakness,No,No,54,97,76,...,No,No,No,No,No,No,No,High,No,No
2,70<80,M,Yes,Yes,Muscle weakness,No,No,37,96,72,...,No,No,No,No,No,No,No,High,No,No
3,20<30,F,Yes,Yes,Loss of balance,No,No,40,98,90,...,No,No,No,No,No,No,No,Low,No,No
4,80<90,M,Yes,Yes,Loss of balance,No,Yes,39,97,72,...,No,No,No,No,No,No,No,High,No,No
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4995,60<70,F,No,Yes,Hypotension,No,No,44,98,80,...,No,No,Yes,No,No,No,No,Moderate,No,No
4996,1<13,F,Yes,Yes,Mental confusion,No,Yes,43,97,105,...,No,No,No,No,No,No,No,High,No,No
4997,50<60,M,No,Yes,Muscle weakness,No,Yes,41,98,109,...,No,No,No,No,No,No,No,High,Yes,No
4998,40<50,F,Yes,Yes,Hypotension,No,No,39,96,91,...,No,No,No,No,No,No,No,High,No,No


### Generate large dataset
We recommend you to use our `SessionTarget` and please refer [here](https://docs142.rockfish.ai/data-gen.html#tabular-data) for details

In [13]:
record_target = ra.SessionTarget(target=20000)  # providing the target "records" value
save = ra.DatasetSave(name="target_synthetic", concat_tables=True)
builder = rf.WorkflowBuilder()
builder.add_model(model)
builder.add_action(generate, parents=[model, record_target])
builder.add_action(record_target, parents=[generate])
builder.add_action(save, parents=[generate])
workflow = await builder.start(conn)
print(f"Workflow: {workflow.id()}")

Workflow: 4z3Ev4H6LsnmhWhOJBf8kM


In [14]:
async for log in workflow.logs():
    print(log)

2024-07-12T21:32:55Z generate-tab-gan: INFO Generating 5000 records
2024-07-12T21:32:56Z session-target: INFO new=5000 total=5000 needs=15000
2024-07-12T21:32:56Z dataset-save: INFO Saved dataset '4QOVnz5mrVMum3KwPeyjc3' with 5000 rows
2024-07-12T21:32:56Z generate-tab-gan: INFO Generating 5000 records
2024-07-12T21:32:56Z session-target: INFO new=5000 total=10000 needs=10000
2024-07-12T21:32:57Z generate-tab-gan: INFO Generating 5000 records
2024-07-12T21:32:57Z session-target: INFO new=5000 total=15000 needs=5000
2024-07-12T21:32:57Z generate-tab-gan: INFO Generating 5000 records
2024-07-12T21:32:58Z session-target: INFO new=5000 total=20000 needs=0


In [15]:
syn_large = None
async for sds in workflow.datasets():
    syn_large = await sds.to_local(conn)
syn_large.to_pandas()

Unnamed: 0,Age range of patient,Sex,Involvement of medication associated with fall risk,Whether a fall prevention protocol was implemented,Reason for incident,Whether a restraint prescription was given,Whether a physical therapy prescription was given,BBS Score,Body Temperature,Heart Rate,...,Chronic Heart Failure,Stroke,Frozen shoulder,Osteoarthritis of hip,Cerebellar Ataxia,Hemiparesis,GB Syndrome,Fall risk level,Ischemic Heart Disease,Cervical sopondylitis
0,70<80,F,No,Yes,Muscle weakness,No,No,44,97,107,...,No,No,No,No,No,No,No,High,No,No
1,60<70,M,Yes,Yes,Loss of balance,No,No,35,98,88,...,No,No,No,No,No,No,No,Moderate,No,No
2,< 1,F,Yes,No,Mental confusion,No,No,38,97,92,...,No,No,No,No,No,No,No,High,No,No
3,60<70,M,No,Yes,Hypotension,No,No,41,97,76,...,No,No,No,No,No,No,No,High,No,No
4,60<70,F,Yes,Yes,Mental confusion,No,Yes,38,98,70,...,No,No,No,No,Yes,No,No,High,No,No
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
19995,40<50,F,No,Yes,Loss of balance,No,No,53,97,82,...,No,No,No,No,No,No,No,Moderate,No,No
19996,60<70,M,Yes,Yes,Mental confusion,No,Yes,44,98,71,...,No,No,No,Yes,No,No,No,High,No,No
19997,≥ 90,F,No,Yes,Loss of balance,No,No,39,97,91,...,No,No,No,No,No,No,No,Moderate,No,No
19998,< 1,F,Yes,Yes,Hypotension,No,Yes,44,98,109,...,Yes,No,No,No,No,No,No,Moderate,No,No
