In [1]:
%%capture
%pip install -U 'rockfish[labs]' -f 'https://packages.rockfish.ai'

In [2]:
import rockfish as rf
import rockfish.actions as ra
import rockfish.labs as rl

Please replace `YOUR_API_KEY` with the assigned API key string. Note that it should be without quotes.

For example, if the assigned API Key is `abcd1234`, you can do the following

```python
%env ROCKFISH_API_KEY=abcd1234
conn = rf.Connection.from_env()
```

If you do not have API Key, please reach out to support@rockfish.ai.


In [10]:
%env ROCKFISH_API_KEY=YOUR_API_KEY
conn = rf.Connection.from_env()
conn = rf.Connection.from_config("staging")

env: ROCKFISH_API_KEY=YOUR_API_KEY


In [4]:
# download our example of timeseries data: finance.csv
!wget --no-clobber https://docs.rockfish.ai/tutorials/finance.csv

--2025-02-06 18:34:08--  https://docs.rockfish.ai/tutorials/finance.csv
Resolving docs.rockfish.ai (docs.rockfish.ai)... 

I0000 00:00:1738895648.022114 19789770 fork_posix.cc:77] Other threads are currently calling into gRPC, skipping fork() handlers


2600:9000:2146:2200:3:1cb5:3480:93a1, 2600:9000:2146:8600:3:1cb5:3480:93a1, 2600:9000:2146:5000:3:1cb5:3480:93a1, ...
Connecting to docs.rockfish.ai (docs.rockfish.ai)|2600:9000:2146:2200:3:1cb5:3480:93a1|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3444556 (3.3M) [text/csv]
Saving to: ‘finance.csv’


2025-02-06 18:34:08 (16.0 MB/s) - ‘finance.csv’ saved [3444556/3444556]



In [11]:
dataset = rf.Dataset.from_csv("finance", "finance.csv")
dataset.to_pandas()

Unnamed: 0,customer,age,gender,merchant,category,amount,fraud,timestamp
0,C1093826151,4,M,M348934600,transportation,4.55,0,2023-01-01
1,C575345520,2,F,M348934600,transportation,76.67,0,2023-01-01
2,C1787537369,2,M,M1823072687,transportation,48.02,0,2023-01-01
3,C1732307957,5,F,M348934600,transportation,55.06,0,2023-01-01
4,C842799656,1,F,M348934600,transportation,25.62,0,2023-01-01
...,...,...,...,...,...,...,...,...
49995,C1971105040,3,M,M348934600,transportation,67.91,0,2023-01-20
49996,C51444479,3,M,M348934600,transportation,32.27,0,2023-01-20
49997,C1096642744,5,M,M1535107174,wellnessandbeauty,149.70,0,2023-01-20
49998,C1166683343,2,F,M1823072687,transportation,24.78,0,2023-01-20


Get valid merchant-category pairs present in the train dataset:


In [12]:
df = dataset.to_pandas()
merchant_to_category = {}
for mer, cat in zip(df["merchant"], df["category"]):
    valid_cats = merchant_to_category.get(mer, [])
    if cat not in valid_cats:
        valid_cats.append(cat)
    merchant_to_category[mer] = valid_cats

These will be used to confirm that the synthetic dataset also has valid merchant-category pairs.


### Join Dependent Fields


In [13]:
join_fields = ra.JoinFields(fields=["merchant", "category"])

### Train Model


In [14]:
config = ra.TrainTimeGAN.Config(
    encoder=ra.TrainTimeGAN.DatasetConfig(
        timestamp=ra.TrainTimeGAN.TimestampConfig(field="timestamp"),
        metadata=[
            ra.TrainTimeGAN.FieldConfig(field="age", type="categorical"),
            ra.TrainTimeGAN.FieldConfig(field="customer", type="session"),
        ],
        measurements=[
            ra.TrainTimeGAN.FieldConfig(
                field="merchant;category", type="categorical"
            ),
            ra.TrainTimeGAN.FieldConfig(field="amount", type="continuous"),
            ra.TrainTimeGAN.FieldConfig(field="fraud", type="categorical"),
        ],
    ),
    doppelganger=ra.TrainTimeGAN.DGConfig(
        epoch=10,
        epoch_checkpoint_freq=5,
        sample_len=2,
        batch_size=1255,
    ),
)
train = ra.TrainTimeGAN(config)

In [15]:
builder = rf.WorkflowBuilder()
builder.add_path(dataset, join_fields, train)
workflow = await builder.start(conn)
print(f"Workflow: {workflow.id()}")

Workflow: 1FQWRivF7EkRvt3wHqojH5


In [16]:
async for log in workflow.logs():
    print(log)

2025-02-07T02:34:51Z dataset-load: INFO Downloading dataset '580AroRT3pFgGjSxb6pffV'
2025-02-07T02:34:51Z dataset-load: INFO Downloaded dataset '580AroRT3pFgGjSxb6pffV' with 50000 rows
2025-02-07T02:34:57Z train-time-gan: INFO Starting DG training job
2025-02-07T02:34:58Z train-time-gan: INFO Epoch 1 completed.
2025-02-07T02:34:58Z train-time-gan: INFO Epoch 2 completed.
2025-02-07T02:34:58Z train-time-gan: INFO Epoch 3 completed.
2025-02-07T02:34:59Z train-time-gan: INFO Epoch 4 completed.
2025-02-07T02:34:59Z train-time-gan: INFO Epoch 5 completed.
2025-02-07T02:35:00Z train-time-gan: INFO Epoch 6 completed.
2025-02-07T02:35:00Z train-time-gan: INFO Epoch 7 completed.
2025-02-07T02:35:01Z train-time-gan: INFO Epoch 8 completed.
2025-02-07T02:35:01Z train-time-gan: INFO Epoch 9 completed.
2025-02-07T02:35:02Z train-time-gan: INFO Epoch 10 completed.
2025-02-07T02:35:05Z train-time-gan: INFO Training completed. Uploaded model 2357df37-e4fc-11ef-a569-c275ba02a948


### Generate Synthetic Data And Split Dependent Fields


In [17]:
model = await workflow.models().last()
model

Model(id='2357df37-e4fc-11ef-a569-c275ba02a948', labels={'workflow_id': '1FQWRivF7EkRvt3wHqojH5'}, create_time=datetime.datetime(2025, 2, 7, 2, 35, 4, tzinfo=datetime.timezone.utc), size_bytes=35407360)

In [22]:
generate = ra.GenerateTimeGAN()
target = ra.SessionTarget(
    target=None
)  # user can specify the target session to generate. Default is None, which means generate the same number of sessions as the input dataset
split_field = ra.SplitField(field="merchant;category")
save = ra.DatasetSave(name="synthetic")

In [31]:
builder = rf.WorkflowBuilder()
builder.add_model(model)
builder.add_action(generate, parents=[model, target])
builder.add_action(split_field, parents=[generate])
builder.add_action(target, parents=[split_field])
builder.add_action(save, parents=[split_field])
workflow = await builder.start(conn)
print(f"Workflow: {workflow.id()}")

Workflow: 4Kl8ix4lT847r2UN5mF1vq


In [32]:
async for log in workflow.logs():
    print(log)

2025-02-07T03:20:17Z session-target: INFO Grouping on: ['session_key']
2025-02-07T03:20:17Z session-target: INFO new=1000 total=1000 needs=2765
2025-02-07T03:20:17Z dataset-save: INFO using field 'session_key' to concatenate tables
2025-02-07T03:20:18Z session-target: INFO Grouping on: ['session_key']
2025-02-07T03:20:17Z dataset-save: INFO Saved dataset 'l5EsNDBZSyczcF8bu2kKp' with 3467 rows
2025-02-07T03:20:14Z generate-time-gan: INFO Downloading model with model_id='2357df37-e4fc-11ef-a569-c275ba02a948'...
2025-02-07T03:20:18Z session-target: INFO new=1000 total=2000 needs=1765
2025-02-07T03:20:16Z generate-time-gan: INFO Generating 1000 sessions...
2025-02-07T03:20:18Z dataset-save: INFO using field 'session_key' to concatenate tables
2025-02-07T03:20:17Z generate-time-gan: INFO Model found in cache
2025-02-07T03:20:17Z generate-time-gan: INFO Generating 1000 sessions...
2025-02-07T03:20:18Z generate-time-gan: INFO Model found in cache
2025-02-07T03:20:18Z generate-time-gan: INFO G

In [33]:
syn = None
async for sds in workflow.datasets():
    syn = await sds.to_local(conn)
syn.to_pandas()

Unnamed: 0,timestamp,amount,age,fraud,merchant,category,session_key
0,2023-01-03 08:36:38.355,1191.104658,4,0,M1823072687,transportation,1000.0
1,2023-01-21 13:24:46.500,1097.819749,4,0,M1842530320,tech,1000.0
2,2023-01-19 23:55:30.924,1064.472810,4,0,M1294758098,leisure,1001.0
3,2023-01-23 12:01:37.665,1274.412752,4,1,M1294758098,leisure,1001.0
4,2023-01-27 01:26:33.155,1112.339081,4,0,M1294758098,leisure,1001.0
...,...,...,...,...,...,...,...
13267,2023-01-01 20:07:26.623,1004.975746,3,0,M1313686961,contents,997.0
13268,2023-01-04 06:54:26.417,1891.922095,3,0,M348875670,hotelservices,998.0
13269,2023-01-08 04:58:02.557,1815.738493,3,1,M1400236507,home,998.0
13270,2023-01-02 17:10:48.277,447.251147,2,0,M1352454843,hotelservices,999.0


### Evaluate Synthetic Dataset


Check if synthetic dataset has valid merchant-category pairs:


In [34]:
syn_df = syn.to_pandas()

In [35]:
for mer, cat in zip(syn_df["merchant"], syn_df["category"]):
    assert cat in merchant_to_category.get(mer)