In [1]:
%%capture
%pip install -U 'rockfish[labs]' -f 'https://packages.rockfish.ai'

In [2]:
import rockfish as rf
import rockfish.actions as ra
import rockfish.labs as rl

Please replace `YOUR_API_KEY` with the assigned API key string. Note that it should be without quotes.

For example, if the assigned API Key is `abcd1234`, you can do the following

```python
%env ROCKFISH_API_KEY=abcd1234
conn = rf.Connection.from_env()
```

If you do not have API Key, please reach out to support@rockfish.ai.


In [3]:
%env ROCKFISH_API_KEY=YOUR_API_KEY
conn = rf.Connection.from_env()

In [4]:
# download our example of timeseries data: finance.csv
!wget --no-clobber https://docs.rockfish.ai/tutorials/finance.csv

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0

I0000 00:00:1734042485.677230  376091 fork_posix.cc:75] Other threads are currently calling into gRPC, skipping fork() handlers


100 3363k  100 3363k    0     0  9850k      0 --:--:-- --:--:-- --:--:-- 9835k


In [5]:
dataset = rf.Dataset.from_csv("finance", "finance.csv")
dataset.to_pandas()

Unnamed: 0,customer,age,gender,merchant,category,amount,fraud,timestamp
0,C1093826151,4,M,M348934600,transportation,4.55,0,2023-01-01
1,C575345520,2,F,M348934600,transportation,76.67,0,2023-01-01
2,C1787537369,2,M,M1823072687,transportation,48.02,0,2023-01-01
3,C1732307957,5,F,M348934600,transportation,55.06,0,2023-01-01
4,C842799656,1,F,M348934600,transportation,25.62,0,2023-01-01
...,...,...,...,...,...,...,...,...
49995,C1971105040,3,M,M348934600,transportation,67.91,0,2023-01-20
49996,C51444479,3,M,M348934600,transportation,32.27,0,2023-01-20
49997,C1096642744,5,M,M1535107174,wellnessandbeauty,149.70,0,2023-01-20
49998,C1166683343,2,F,M1823072687,transportation,24.78,0,2023-01-20


Get valid merchant-category pairs present in the train dataset:


In [6]:
df = dataset.to_pandas()
merchant_to_category = {}
for mer, cat in zip(df["merchant"], df["category"]):
    valid_cats = merchant_to_category.get(mer, [])
    if cat not in valid_cats:
        valid_cats.append(cat)
    merchant_to_category[mer] = valid_cats

These will be used to confirm that the synthetic dataset also has valid merchant-category pairs.


### Join Dependent Fields


In [7]:
join_fields = ra.JoinFields(fields=["merchant", "category"])

### Train Model


In [8]:
config = ra.TrainTimeGAN.Config(
    encoder=ra.TrainTimeGAN.DatasetConfig(
        timestamp=ra.TrainTimeGAN.TimestampConfig(field="timestamp"),
        metadata=[
            ra.TrainTimeGAN.FieldConfig(field="age", type="categorical"),
            ra.TrainTimeGAN.FieldConfig(field="customer", type="session"),
        ],
        measurements=[
            ra.TrainTimeGAN.FieldConfig(
                field="merchant;category", type="categorical"
            ),
            ra.TrainTimeGAN.FieldConfig(field="amount", type="continuous"),
            ra.TrainTimeGAN.FieldConfig(field="fraud", type="categorical"),
        ],
    ),
    doppelganger=ra.TrainTimeGAN.DGConfig(
        epoch=10,
        epoch_checkpoint_freq=5,
        sample_len=2,
        batch_size=1255,
    ),
)
train = ra.TrainTimeGAN(config)

In [9]:
builder = rf.WorkflowBuilder()
builder.add_path(dataset, join_fields, train)
workflow = await builder.start(conn)
print(f"Workflow: {workflow.id()}")

Workflow: 1AK9n5JfvUXFQVN2MdSl3j


In [10]:
async for progress in workflow.progress().notebook():
    pass

I0000 00:00:1734042492.976953  376091 fork_posix.cc:75] Other threads are currently calling into gRPC, skipping fork() handlers


  0%|          | 0/10 [00:00<?, ?it/s]

### Generate Synthetic Data And Split Dependent Fields


In [11]:
model = await workflow.models().last()
model

Model(id='6f1615e9-b8d8-11ef-a2cc-1633bdde6ce2', labels={'workflow_id': '1AK9n5JfvUXFQVN2MdSl3j'}, create_time=datetime.datetime(2024, 12, 12, 22, 28, 38, tzinfo=datetime.timezone.utc), size_bytes=35405824)

In [12]:
config.doppelganger.sessions = 500
generate = ra.GenerateTimeGAN(config)
split_field = ra.SplitField(field="merchant;category")
save = ra.DatasetSave({"name": "synthetic"})

In [13]:
builder = rf.WorkflowBuilder()
builder.add_path(model, generate, split_field, save)
workflow = await builder.start(conn)
print(f"Workflow: {workflow.id()}")

Workflow: 1m7VwZJB12KB5IY3uKJTpM


In [14]:
async for log in workflow.logs():
    print(log)

2024-12-12T22:29:06Z generate-time-gan: INFO Downloading model with model_id='6f1615e9-b8d8-11ef-a2cc-1633bdde6ce2'...
2024-12-12T22:29:08Z generate-time-gan: INFO Generating 500 sessions...
2024-12-12T22:29:19Z dataset-save: INFO using field 'session_key' to concatenate tables
2024-12-12T22:29:20Z dataset-save: INFO Saved dataset '3jlPXPBPrTkPRpD0qV7nDE' with 4344 rows


In [15]:
syn = None
async for sds in workflow.datasets():
    syn = await sds.to_local(conn)
syn.to_pandas()

Unnamed: 0,timestamp,amount,age,fraud,merchant,category,session_key
0,2023-01-20 00:00:16.741,2336.902944,5,0,M692898500,health,0.0
1,2023-02-09 09:31:54.538,2036.797551,5,1,M1400236507,home,0.0
2,2023-01-19 11:02:51.252,2020.322661,2,0,M50039827,health,1.0
3,2023-01-21 09:09:48.010,1716.888488,2,0,M2011752106,hotelservices,1.0
4,2023-01-23 03:27:00.818,1936.406535,2,0,M50039827,health,1.0
...,...,...,...,...,...,...,...
4339,2024-03-11 20:28:33.309,2934.981299,4,0,M1313686961,contents,499.0
4340,2024-03-31 23:37:34.833,2921.413665,4,0,M348875670,hotelservices,499.0
4341,2024-04-20 05:45:24.237,2935.723120,4,0,M1313686961,contents,499.0
4342,2024-05-10 05:48:29.639,2920.525429,4,0,M547558035,fashion,499.0


### Evaluate Synthetic Dataset


Check if synthetic dataset has valid merchant-category pairs:


In [16]:
syn_df = syn.to_pandas()

In [17]:
for mer, cat in zip(syn_df["merchant"], syn_df["category"]):
    assert cat in merchant_to_category.get(mer)