In [None]:
from f2ai.common.collecy_fn import classify_collet_fn
from f2ai.models.sequential import SimpleClassify
import torch
from torch import nn
from torch.utils.data import DataLoader
from f2ai.featurestore import FeatureStore
from f2ai.dataset import GroupFixednbrSampler



if __name__ == "__main__":
    fs = FeatureStore("file:///Users/xuyizhou/Desktop/xyz_warehouse/gitlab/f2ai-credit-scoring")

    ds = fs.get_dataset(
        service="credit_scoring_v1",
        sampler=GroupFixednbrSampler(
            time_bucket="10 days",
            stride=1,
            group_ids=None,
            group_names=None,
            start="2020-08-01",
            end="2021-09-30",
        ),
    )
    features_cat = [  # catgorical features
        fea
        for fea in fs._get_feature_to_use(fs.services["credit_scoring_v1"])
        if fea not in fs._get_feature_to_use(fs.services["credit_scoring_v1"], True)
    ]
    cat_unique = fs.stats(
        fs.services["credit_scoring_v1"],
        fn="unique",
        group_key=[],
        start="2020-08-01",
        end="2021-09-30",
        features=features_cat,
    ).to_dict()
    cat_count = {key: len(cat_unique[key]) for key in cat_unique.keys()}
    cont_scalar_max = fs.stats(
        fs.services["credit_scoring_v1"], fn="max", group_key=[], start="2020-08-01", end="2021-09-30"
    ).to_dict()
    cont_scalar_min = fs.stats(
        fs.services["credit_scoring_v1"], fn="min", group_key=[], start="2020-08-01", end="2021-09-30"
    ).to_dict()
    cont_scalar = {key: [cont_scalar_min[key], cont_scalar_max[key]] for key in cont_scalar_min.keys()}

    i_ds = ds.to_pytorch()
    test_data_loader = DataLoader(  # `batch_siz`e and `drop_last`` do not matter now, `sampler`` set it to be None cause `test_data`` is a Iterator
        i_ds,
        collate_fn=lambda x: classify_collet_fn(
            x,
            cat_coder=cat_unique,
            cont_scalar=cont_scalar,
            label=fs._get_available_labels(fs.services["credit_scoring_v1"]),
        ),
        batch_size=4,
        drop_last=False,
        sampler=None,
    )

    model = SimpleClassify(
        cont_nbr=len(cont_scalar_max), cat_nbr=len(cat_count), emd_dim=4, max_types=max(cat_count.values())
    )
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # no need to change
    loss_fn = nn.BCELoss()  # loss function to train a classification model

    for epoch in range(10):  # assume 10 epoch
        print(f"epoch: {epoch} begin")
        for x, y in test_data_loader:
            pred_label = model(x)
            true_label = y
            loss = loss_fn(pred_label, true_label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"epoch: {epoch} done, loss: {loss}")
