In [1]:
%%capture
%pip install -U 'rockfish[labs]' -f 'https://docs.rockfish.ai/packages/index.html'

In [2]:
import rockfish as rf
import rockfish.actions as ra
import rockfish.labs as rl

Please replace `YOUR_API_KEY` with the assigned API key string. Note that it should be without quotes.

For example, if the assigned API Key is `abcd1234`, you can do the following

```python
%env ROCKFISH_API_KEY=abcd1234
conn = rf.Connection.from_env()
```

If you do not have API Key, please reach out to support@rockfish.ai.


In [11]:
%env ROCKFISH_API_KEY=YOUR_API_KEY
conn = rf.Connection.from_env()

env: ROCKFISH_API_KEY=YOUR_API_KEY


In [4]:
# download our example of timeseries data: finance.csv
!wget --no-clobber https://docs.rockfish.ai/tutorials/finance.csv

--2024-07-19 01:07:19--  https://docs142.rockfish.ai/tutorials/finance.csv
Resolving docs142.rockfish.ai (docs142.rockfish.ai)... 65.8.161.27, 65.8.161.100, 65.8.161.81, ...
Connecting to docs142.rockfish.ai (docs142.rockfish.ai)|65.8.161.27|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3444556 (3.3M) [text/csv]
Saving to: ‘finance.csv’


2024-07-19 01:07:22 (1.36 MB/s) - ‘finance.csv’ saved [3444556/3444556]



In [5]:
dataset = rf.Dataset.from_csv("finance", "finance.csv")
dataset.to_pandas()

Unnamed: 0,customer,age,gender,merchant,category,amount,fraud,timestamp
0,C1093826151,4,M,M348934600,transportation,4.55,0,2023-01-01
1,C575345520,2,F,M348934600,transportation,76.67,0,2023-01-01
2,C1787537369,2,M,M1823072687,transportation,48.02,0,2023-01-01
3,C1732307957,5,F,M348934600,transportation,55.06,0,2023-01-01
4,C842799656,1,F,M348934600,transportation,25.62,0,2023-01-01
...,...,...,...,...,...,...,...,...
49995,C1971105040,3,M,M348934600,transportation,67.91,0,2023-01-20
49996,C51444479,3,M,M348934600,transportation,32.27,0,2023-01-20
49997,C1096642744,5,M,M1535107174,wellnessandbeauty,149.70,0,2023-01-20
49998,C1166683343,2,F,M1823072687,transportation,24.78,0,2023-01-20


Get valid merchant-category pairs present in the train dataset:


In [6]:
df = dataset.to_pandas()
merchant_to_category = {}
for mer, cat in zip(df["merchant"], df["category"]):
    valid_cats = merchant_to_category.get(mer, [])
    if cat not in valid_cats:
        valid_cats.append(cat)
    merchant_to_category[mer] = valid_cats

These will be used to confirm that the synthetic dataset also has valid merchant-category pairs.


### Join Dependent Fields


In [7]:
join_fields = ra.JoinFields(fields=["merchant", "category"])

### Train Model


In [8]:
config = ra.TrainTimeGAN.Config(
    encoder=ra.TrainTimeGAN.DatasetConfig(
        timestamp=ra.TrainTimeGAN.TimestampConfig(field="timestamp"),
        metadata=[
            ra.TrainTimeGAN.FieldConfig(field="age", type="categorical"),
            ra.TrainTimeGAN.FieldConfig(field="customer", type="session"),
        ],
        measurements=[
            ra.TrainTimeGAN.FieldConfig(
                field="merchant;category", type="categorical"
            ),
            ra.TrainTimeGAN.FieldConfig(field="amount", type="continuous"),
            ra.TrainTimeGAN.FieldConfig(field="fraud", type="categorical"),
        ],
    ),
    doppelganger=ra.TrainTimeGAN.DGConfig(
        epoch=10,
        epoch_checkpoint_freq=5,
        sample_len=2,
        batch_size=1255,
    ),
)
train = ra.TrainTimeGAN(config)

In [9]:
builder = rf.WorkflowBuilder()
builder.add_path(dataset, join_fields, train)
workflow = await builder.start(conn)
print(f"Workflow: {workflow.id()}")

Workflow: 32aRIW3N7lZxFbQyXIYNXU


In [10]:
async for progress in workflow.progress().notebook():
    pass

CancelledError: 

### Generate Synthetic Data And Split Dependent Fields


In [None]:
model = await workflow.models().last()
model

Model('b682a9ff-3970-11ef-ad55-fe3a1ae943e1')

In [None]:
config.doppelganger.sessions = 500
generate = ra.GenerateTimeGAN(config)
split_field = ra.SplitField(field="merchant;category")
save = ra.DatasetSave({"name": "synthetic"})

In [None]:
builder = rf.WorkflowBuilder()
builder.add_path(model, generate, split_field, save)
workflow = await builder.start(conn)
print(f"Workflow: {workflow.id()}")

Workflow: 7HFFGptQ33tIXyO5ZBZWiI


In [None]:
async for log in workflow.logs():
    print(log)

2024-07-03T19:16:20Z generate-time-gan: INFO Downloading model with model_id='b682a9ff-3970-11ef-ad55-fe3a1ae943e1'...
2024-07-03T19:16:25Z generate-time-gan: INFO Generating 500 sessions...
2024-07-03T19:16:26Z dataset-save: INFO Saved dataset '1k7klM7r6sdaAaQaB6zr9m' with 4417 rows


In [None]:
syn = None
async for sds in workflow.datasets():
    syn = await sds.to_local(conn)
syn.to_pandas()

Unnamed: 0,timestamp,amount,age,fraud,session_key,merchant,category
0,2023-01-04 20:30:31.498,882.658654,5,1,0.0,M349281107,fashion
1,2023-01-07 15:41:35.952,497.092635,4,0,1.0,M692898500,health
2,2023-01-10 08:43:39.314,327.908537,4,0,1.0,M677738360,contents
3,2023-01-13 08:51:01.340,471.801285,4,0,1.0,M2122776122,home
4,2023-01-16 03:29:16.891,331.751762,4,0,1.0,M677738360,contents
...,...,...,...,...,...,...,...
4412,2024-09-11 19:50:20.016,2065.117430,2,0,497.0,M1053599405,health
4413,2024-09-27 21:15:21.388,2634.786640,2,0,497.0,M348934600,transportation
4414,2024-10-14 21:30:26.068,2114.590609,2,0,497.0,M97925176,wellnessandbeauty
4415,2023-01-21 23:11:45.710,7246.849439,3,0,498.0,M1198415165,wellnessandbeauty


### Evaluate Synthetic Dataset


Check if synthetic dataset has valid merchant-category pairs:


In [None]:
syn_df = syn.to_pandas()

In [None]:
for mer, cat in zip(syn_df["merchant"], syn_df["category"]):
    assert cat in merchant_to_category.get(mer)