In [1]:
%%capture
%pip install -U 'rockfish[labs]' -f 'https://packages.rockfish.ai'

In [2]:
import rockfish as rf
import rockfish.actions as ra
import rockfish.labs as rl

Please replace `YOUR_API_KEY` with the assigned API key string. Note that it should be without quotes.

For example, if the assigned API Key is `abcd1234`, you can do the following

```python
%env ROCKFISH_API_KEY=abcd1234
conn = rf.Connection.from_env()
```

If you do not have API Key, please reach out to support@rockfish.ai.


In [3]:
%env ROCKFISH_API_KEY=YOUR_API_KEY
conn = rf.Connection.from_env()

In [4]:
# download our example of timeseries data: finance.csv
!wget --no-clobber https://docs.rockfish.ai/tutorials/finance.csv

File ‘finance.csv’ already there; not retrieving.



In [5]:
dataset = rf.Dataset.from_csv("finance", "finance.csv")
dataset.to_pandas()

Unnamed: 0,customer,age,gender,merchant,category,amount,fraud,timestamp
0,C1093826151,4,M,M348934600,transportation,4.55,0,2023-01-01
1,C575345520,2,F,M348934600,transportation,76.67,0,2023-01-01
2,C1787537369,2,M,M1823072687,transportation,48.02,0,2023-01-01
3,C1732307957,5,F,M348934600,transportation,55.06,0,2023-01-01
4,C842799656,1,F,M348934600,transportation,25.62,0,2023-01-01
...,...,...,...,...,...,...,...,...
49995,C1971105040,3,M,M348934600,transportation,67.91,0,2023-01-20
49996,C51444479,3,M,M348934600,transportation,32.27,0,2023-01-20
49997,C1096642744,5,M,M1535107174,wellnessandbeauty,149.70,0,2023-01-20
49998,C1166683343,2,F,M1823072687,transportation,24.78,0,2023-01-20


Get valid merchant-category pairs present in the train dataset:


In [6]:
df = dataset.to_pandas()
merchant_to_category = {}
for mer, cat in zip(df["merchant"], df["category"]):
    valid_cats = merchant_to_category.get(mer, [])
    if cat not in valid_cats:
        valid_cats.append(cat)
    merchant_to_category[mer] = valid_cats

These will be used to confirm that the synthetic dataset also has valid merchant-category pairs.


### Join Dependent Fields


In [7]:
join_fields = ra.JoinFields(fields=["merchant", "category"])

### Train Model


In [8]:
config = ra.TrainTimeGAN.Config(
    encoder=ra.TrainTimeGAN.DatasetConfig(
        timestamp=ra.TrainTimeGAN.TimestampConfig(field="timestamp"),
        metadata=[
            ra.TrainTimeGAN.FieldConfig(field="age", type="categorical"),
            ra.TrainTimeGAN.FieldConfig(field="customer", type="session"),
        ],
        measurements=[
            ra.TrainTimeGAN.FieldConfig(
                field="merchant;category", type="categorical"
            ),
            ra.TrainTimeGAN.FieldConfig(field="amount", type="continuous"),
            ra.TrainTimeGAN.FieldConfig(field="fraud", type="categorical"),
        ],
    ),
    doppelganger=ra.TrainTimeGAN.DGConfig(
        epoch=10,
        sample_len=2,
        batch_size=1255,
    ),
)
train = ra.TrainTimeGAN(config)

In [9]:
builder = rf.WorkflowBuilder()
builder.add_path(dataset, join_fields, train)
workflow = await builder.start(conn)
print(f"Workflow: {workflow.id()}")

Workflow: z3fhZRfcp1ZJ14tjESOUG


In [10]:
async for log in workflow.logs():
    print(log)

2026-01-12T20:23:08.250001Z dataset-load: INFO Downloading dataset '4H4panNJpqBuWInZJY3ftu'
2026-01-12T20:23:08.724083Z dataset-load: INFO Downloaded dataset '4H4panNJpqBuWInZJY3ftu' with 50000 rows
2026-01-12T20:23:11.540177Z train-time-gan: INFO Starting DG training job
2026-01-12T20:23:12.422564Z train-time-gan: INFO Epoch 1 completed.
2026-01-12T20:23:12.955888Z train-time-gan: INFO Epoch 2 completed.
2026-01-12T20:23:13.470983Z train-time-gan: INFO Epoch 3 completed.
2026-01-12T20:23:14.004411Z train-time-gan: INFO Epoch 4 completed.
2026-01-12T20:23:14.526326Z train-time-gan: INFO Epoch 5 completed.
2026-01-12T20:23:15.050524Z train-time-gan: INFO Epoch 6 completed.
2026-01-12T20:23:15.595468Z train-time-gan: INFO Epoch 7 completed.
2026-01-12T20:23:16.125186Z train-time-gan: INFO Epoch 8 completed.
2026-01-12T20:23:16.673300Z train-time-gan: INFO Epoch 9 completed.
2026-01-12T20:23:17.204818Z train-time-gan: INFO Epoch 10 completed.
2026-01-12T20:23:19.310086Z train-time-gan: IN

### Generate Synthetic Data And Split Dependent Fields


In [11]:
model = await workflow.models().last()
model

Model(id='886b91ed-eff4-11f0-ae9c-c6efc6b474c1', labels={'epoch': '10', 'is_checkpoint': 'false', 'job_id': '15J3OsYD5iZFZCnsp84TMO', 'model_type': 'time-gan', 'workflow_id': 'z3fhZRfcp1ZJ14tjESOUG'}, create_time=datetime.datetime(2026, 1, 12, 20, 23, 18, tzinfo=datetime.timezone.utc), size_bytes=18035200)

In [12]:
generate = ra.GenerateTimeGAN()
target = ra.SessionTarget(
    target=None
)  # user can specify the target session to generate. Default is None, which means generate the same number of sessions as the input dataset
split_field = ra.SplitField(field="merchant;category")
save = ra.DatasetSave(name="synthetic")

In [13]:
builder = rf.WorkflowBuilder()
builder.add_model(model)
builder.add_action(generate, parents=[model, target])
builder.add_action(split_field, parents=[generate])
builder.add_action(target, parents=[split_field])
builder.add_action(save, parents=[split_field])
workflow = await builder.start(conn)
print(f"Workflow: {workflow.id()}")

Workflow: 5i2jAId1Ppjr22xFkXnuhc


In [14]:
async for log in workflow.logs():
    print(log)

2026-01-12T20:23:21.882689Z generate-time-gan: INFO Downloading model with model_id='886b91ed-eff4-11f0-ae9c-c6efc6b474c1'...
2026-01-12T20:23:22.900030Z generate-time-gan: INFO Model version: 3
2026-01-12T20:23:22.912910Z generate-time-gan: INFO Generating 1000 sessions...
2026-01-12T20:23:23.860697Z session-target: INFO Grouping on: ['session_key']
2026-01-12T20:23:23.889803Z session-target: INFO new=1000 total=1000 needs=2765
2026-01-12T20:23:23.936208Z generate-time-gan: INFO Model found in cache
2026-01-12T20:23:23.967926Z generate-time-gan: INFO Model version: 3
2026-01-12T20:23:23.980160Z generate-time-gan: INFO Generating 1000 sessions...
2026-01-12T20:23:24.845688Z session-target: INFO Grouping on: ['session_key']
2026-01-12T20:23:24.873506Z session-target: INFO new=1000 total=2000 needs=1765
2026-01-12T20:23:24.918455Z generate-time-gan: INFO Model found in cache
2026-01-12T20:23:24.948678Z generate-time-gan: INFO Model version: 3
2026-01-12T20:23:24.960517Z generate-time-gan

In [15]:
syn = None
async for sds in workflow.datasets():
    syn = await sds.to_local(conn)
syn.to_pandas()

Unnamed: 0,age,amount,fraud,timestamp,merchant,category,session_key
0,4,3037.43,1,2023-01-19 23:59:18,M1535107174,wellnessandbeauty,0.0
1,4,3124.83,0,2023-01-25 10:29:16,M1873032707,hotelservices,0.0
2,4,3369.08,0,2023-01-05 07:24:15,M1313686961,contents,1.0
3,4,3473.42,1,2023-01-15 20:05:50,M3697346,leisure,1.0
4,3,1260.59,1,2023-01-03 08:38:33,M1913465890,health,2.0
...,...,...,...,...,...,...,...
8668,3,3683.71,0,2024-01-09 18:30:36,M3697346,leisure,762.0
8669,5,1089.30,0,2023-01-08 18:24:46,M480139044,health,763.0
8670,5,1048.54,1,2023-01-18 07:55:45,M3697346,leisure,763.0
8671,3,832.63,0,2023-01-03 13:55:39,M547558035,fashion,764.0


### Evaluate Synthetic Dataset


Check if synthetic dataset has valid merchant-category pairs:


In [16]:
syn_df = syn.to_pandas()

In [17]:
for mer, cat in zip(syn_df["merchant"], syn_df["category"]):
    assert cat in merchant_to_category.get(mer)