In [1]:
%%capture
%pip install -U 'rockfish[labs]' -f 'https://docs142.rockfish.ai/packages/index.html'

In [2]:
import rockfish as rf
import rockfish.actions as ra

Please replace `YOUR_API_KEY` with the assigned API key string. Note that it should be without quotes.

For example, if the assigned API Key is `abcd1234`, you can do the following
```python
%env ROCKFISH_API_KEY=abcd1234
conn = rf.Connection.from_env()
```
If you do not have API Key, please reach out to support@rockfish.ai.

In [3]:
%env ROCKFISH_API_KEY=YOUR_API_KEY
conn = rf.Connection.from_env()

env: ROCKFISH_API_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpYXQiOjE3MTIyNDM1OTEsImlzcyI6ImFwaSIsIm5iZiI6MTcxMjI0MzU5MSwidG9rZW5faWQiOiI0WEZBVEp1QWc2VGtudkdSSVZ0UHllIiwidXNlcl9pZCI6IjRVazVITDVra3lkQ0JmU0loMUhpQVcifQ.WH2uUN4wJ-SjCJiNWervlnMCiJCnG-xJoPDEgFMM-Ak


In [4]:
# download our example of tabular data: spotify-2023-short.csv
!wget --no-clobber https://docs142.rockfish.ai/tutorials/spotify-2023-short.csv

File ‘spotify-2023-short.csv’ already there; not retrieving.



In [5]:
dataset = rf.Dataset.from_csv("Spotify", "spotify-2023-short.csv")
dataset.to_pandas()

Unnamed: 0,released_year,released_month,released_day,in_spotify_playlists,bpm,key,mode
0,2023,7,14,553,125,B,Major
1,2023,3,23,1474,92,C#,Major
2,2023,6,30,1397,138,F,Major
3,2019,8,23,7858,170,A,Major
4,2023,5,18,3133,144,A,Minor
...,...,...,...,...,...,...,...
95,2023,5,12,2175,143,D#,Major
96,2023,3,17,2000,100,F#,Minor
97,2022,12,9,2839,143,F,Major
98,2011,1,1,20333,112,C#,Minor


In [6]:
cat_fields = "released_year released_month released_day key mode".split()
con_fields = "in_spotify_playlists bpm".split()
config = {
    "encoder": {
        "metadata": [{"field": col, "type": "categorical"} for col in cat_fields]
        + [{"field": col, "type": "continuous"} for col in con_fields]
    },
    "rtf": {
        "mode": "tabular",
        "num_bootstrap": 2,
        "tabular": {
            "epochs": 1,
            "transformer": {"gpt2_config": {"layer": 1, "head": 1, "embed": 1}},
        },
    },
}
# create the train action
train = ra.TrainTransformer(config)

In [7]:
builder = rf.WorkflowBuilder()
builder.add_dataset(dataset)
builder.add_action(train, parents=[dataset])
workflow = await builder.start(conn)

print(f"Workflow: {workflow.id()}")

Workflow: 3RVkdfGHqzbEqAjRVdTF6A


In [8]:
async for log in workflow.logs():
    print(log)

2024-07-12T21:40:36Z dataset-load: INFO Loading dataset '4uUy4WY379VS07O0hRqPtX' with 100 rows
2024-07-12T21:40:36Z train-transformer: INFO Start training...
2024-07-12T21:40:38Z train-transformer: INFO Epoch 1 completed.
2024-07-12T21:40:50Z train-transformer: INFO Training completed. The Model ID is 6725031c-4097-11ef-8c4e-8a07ae1c625c


In [9]:
model = await workflow.models().last()
model

Model('6725031c-4097-11ef-8c4e-8a07ae1c625c')

### Update the generated records

In [13]:
config["rtf"].update({"records": 1000})  # update the generated records
generate = ra.GenerateTransformer(config)
save = ra.DatasetSave({"name": "SyntheticData"})
builder = rf.WorkflowBuilder()
builder.add_model(model)
builder.add_action(generate, parents=[model])
builder.add_action(save, parents=[generate])
workflow = await builder.start(conn)
print(f"Workflow: {workflow.id()}")

Workflow: 1XQvv5y9xo9nHgseLCndwd


In [14]:
syn = None
async for sds in workflow.datasets():
    syn = await sds.to_local(conn)
syn.to_pandas()

Unnamed: 0,released_year,released_month,released_day,key,mode,in_spotify_playlists,bpm
0,2017,8,9,C#,Major,3587,174
1,2074,2,26,G,Major,4138,180
2,2024,7,25,D#,Minor,803,74
3,2020,9,1,D,Minor,4874,57
4,2023,0,36,F#,Major,7892,177
...,...,...,...,...,...,...,...
995,2029,2,31,B,Minor,9532,141
996,2021,8,21,C#,Minor,7802,190
997,2020,3,27,F#,Major,2882,60
998,2907,8,5,F,Minor,5611,150


### Generate large dataset
We recommend you to use our `SessionTarget` and please refer [here](https://docs142.rockfish.ai/data-gen.html#tabular-data) for details

In [15]:
record_target = ra.SessionTarget(target=5000)  # providing the target "records" value
save = ra.DatasetSave(name="target_synthetic", concat_tables=True)
builder = rf.WorkflowBuilder()
builder.add_model(model)
builder.add_action(generate, parents=[model, record_target])
builder.add_action(record_target, parents=[generate])
builder.add_action(save, parents=[generate])
workflow = await builder.start(conn)
print(f"Workflow: {workflow.id()}")

Workflow: 1z9uifGipOxveNRcwACBfJ


In [16]:
async for log in workflow.logs():
    print(log)

2024-07-12T21:40:59Z generate-transformer: INFO Starting download of Model 6725031c-4097-11ef-8c4e-8a07ae1c625c
2024-07-12T21:41:04Z generate-transformer: INFO Finished download of Model 6725031c-4097-11ef-8c4e-8a07ae1c625c
2024-07-12T21:41:04Z generate-transformer: INFO Start generating samples...
2024-07-12T21:41:07Z session-target: INFO Grouping on: ['released_year', 'released_month', 'released_day', 'key', 'mode', 'in_spotify_playlists', 'bpm']
2024-07-12T21:41:07Z session-target: INFO new=1000 total=1000 needs=4000
2024-07-12T21:41:07Z generate-transformer: INFO Finish generating samples...
2024-07-12T21:41:07Z generate-transformer: INFO Starting download of Model 6725031c-4097-11ef-8c4e-8a07ae1c625c
2024-07-12T21:41:07Z dataset-save: INFO Saved dataset '4x1QDYiRhpX7cSNdfQhsE4' with 1000 rows
2024-07-12T21:41:12Z generate-transformer: INFO Finished download of Model 6725031c-4097-11ef-8c4e-8a07ae1c625c
2024-07-12T21:41:12Z generate-transformer: INFO Start generating samples...
202

In [17]:
syn_large = None
async for sds in workflow.datasets():
    syn_large = await sds.to_local(conn)
syn_large.to_pandas()

Unnamed: 0,released_year,released_month,released_day,key,mode,in_spotify_playlists,bpm
0,2018,9,23,F,Minor,12483,172
1,2077,0,29,F,Minor,20933,120
2,2023,1,17,,Minor,10174,24
3,2073,5,26,B,Minor,7896,70
4,2020,4,21,D,Major,2946,168
...,...,...,...,...,...,...,...
4995,2076,19,3,G,Major,16846,112
4996,2027,11,1,G,Major,43467,95
4997,2023,10,33,G#,Major,5801,183
4998,2016,0,15,A#,Minor,13130,100
