In [None]:
%%capture
%pip install -U 'rockfish[labs]' -f 'https://docs142.rockfish.ai/packages/index.html'

In [None]:
import rockfish as rf
import rockfish.actions as ra
import rockfish.labs as rl

Please replace `YOUR_API_KEY` with the assigned API key string. Note that it should be without quotes.

For example, if the assigned API Key is `abcd1234`, you can do the following
```python
%env ROCKFISH_API_KEY=abcd1234
conn = rf.Connection.from_env()
```
If you do not have API Key, please reach out to support@rockfish.ai.

In [None]:
%env ROCKFISH_API_KEY=YOUR_API_KEY
conn = rf.Connection.from_env()

In [None]:
# download our example of tabular data: fall_detection.csv
!wget --no-clobber https://docs142.rockfish.ai/tutorials/fall_detection.csv

In [None]:
dataset = rf.Dataset.from_csv("fall_detection", "fall_detection.csv")
dataset.to_pandas()

In [None]:
# user can manually provide a list of categorical column names
categorical_fields = (
    dataset.to_pandas().select_dtypes(include=["object"]).columns
)  
config = {
    "encoder": {
        "metadata": [
            {"field": field, "type": "categorical"} for field in categorical_fields
        ]
        + [
            {"field": field, "type": "continuous"}
            for field in dataset.table.column_names
            if field not in categorical_fields
        ],
    },
    "tabular-gan": {
        "epochs": 100,
        "records": 2582,
    },
}
# create train action
train = ra.TrainTabGAN(config)

In [None]:
builder = rf.WorkflowBuilder()
builder.add_dataset(dataset)
builder.add_action(train, parents=[dataset])
workflow = await builder.start(conn)

print(f"Workflow: {workflow.id()}")

In [None]:
async for progress in workflow.progress().notebook():
    pass

In [None]:
model = await workflow.models().nth(0)
model

In [None]:
generate = ra.GenerateTabGAN(config)
save = ra.DatasetSave({"name": "synthetic"})
builder = rf.WorkflowBuilder()
builder.add_model(model)
builder.add_action(generate, parents=[model])
builder.add_action(save, parents=[generate])
workflow = await builder.start(conn)
print(f"Workflow: {workflow.id()}")

In [None]:
async for log in workflow.logs():
    print(log)

In [None]:
syn = None
async for sds in workflow.datasets():
    syn = await sds.to_local(conn)
syn.to_pandas()

### Evaluation

**1. categorical columns**

In [None]:
for col in ["Age range of patient", "Sex"]:
    source_agg = rf.metrics.count_all(dataset, col, nlargest=10)
    syn_agg = rf.metrics.count_all(syn, col, nlargest=10)
    rl.vis.plot_bar([source_agg, syn_agg], col, f"{col}_count")

**2. numerical columns**

In [None]:
for col in ["BBS Score", "Body Temperature"]:
    rl.vis.plot_kde([dataset, syn], col)

**3. correlation between numerical columns**

In [None]:
col1 = "SBP"
col2 = "DBP"
rl.vis.plot_correlation([dataset, syn], col1, col2, alpha=0.5)

**4. correlation heatmap between several numerical columns**

In [None]:
n_cols = ["Body Temperature", "SBP", "BBS Score", "DBP", "Heart Rate"]
rl.vis.plot_correlation_heatmap([dataset, syn], n_cols, annot=True, fmt=".2f")