In [1]:
from msas_pytorch import msas
import torch

In [2]:
# let's define some statistics functions

statistics_functions = [torch.mean, torch.std, torch.median, torch.max, torch.min]

In [3]:
# and generated with random numbers simulated real and synthetic data

# both temporal and static, with both discrete and continuous features
# sample size = 100, number of static features = 5, number of temporal features = 5
# number of time steps = 10, discrete feature 1 has 10 categories, discrete feature 4 has 3 categories

static_with_discrete = torch.randn(100, 5)
static_with_discrete[:, 1] = torch.randint(0, 10, (100,))
static_with_discrete[:, 4] = torch.randint(0, 3, (100,))

static_without_discrete = torch.randn(100, 5)

temporal_with_discrete = torch.randn(100, 10, 5)
temporal_with_discrete[:, :, 1] = torch.randint(0, 10, (100, 10))
temporal_with_discrete[:, :, 4] = torch.randint(0, 3, (100, 10))

temporal_without_discrete = torch.randn(100, 10, 5)

synthetic_static_with_discrete = torch.randn(100, 5)
synthetic_static_with_discrete[:, 1] = torch.randint(0, 10, (100,))
synthetic_static_with_discrete[:, 4] = torch.randint(0, 3, (100,))

synthetic_static_without_discrete = torch.randn(100, 5)

synthetic_temporal_with_discrete = torch.randn(100, 10, 5)
synthetic_temporal_with_discrete[:, :, 1] = torch.randint(0, 10, (100, 10))
synthetic_temporal_with_discrete[:, :, 4] = torch.randint(0, 3, (100, 10))

synthetic_temporal_without_discrete = torch.randn(100, 10, 5)

In [4]:
# test without static data and without discrete temporal features and with different reductions
print(
    "mean reduction:",
    msas(
        temporal_without_discrete,
        synthetic_temporal_without_discrete,
        statistics_functions,
        reduction="mean",
    ),
)

print(
    "sum reduction:",
    msas(
        temporal_without_discrete,
        synthetic_temporal_without_discrete,
        statistics_functions,
        reduction="sum",
    ),
)

print(
    "no reduction:",
    msas(
        temporal_without_discrete,
        synthetic_temporal_without_discrete,
        statistics_functions,
        reduction=None,
    ),
)

mean reduction: tensor(0.8716)
sum reduction: tensor(21.7900)
no reduction: (tensor([[0.8500, 0.8700, 0.8600, 0.8500, 0.8600],
        [0.9000, 0.8800, 0.8800, 0.8800, 0.8600],
        [0.8400, 0.8500, 0.9300, 0.9300, 0.8500],
        [0.9200, 0.8800, 0.8500, 0.8100, 0.7800],
        [0.9200, 0.8900, 0.9000, 0.8800, 0.8700]]), None, None)


In [5]:
# test without static data and with discrete temporal features and with different reductions
print(
    "mean reduction:",
    msas(
        temporal_with_discrete,
        synthetic_temporal_with_discrete,
        statistics_functions,
        discrete_temporal_features_indices=torch.LongTensor([1, 4]),
        discrete_temporal_features_num_categories=torch.LongTensor([10, 3]),
        reduction="mean",
    ),
)

print(
    "sum reduction:",
    msas(
        temporal_with_discrete,
        synthetic_temporal_with_discrete,
        statistics_functions,
        discrete_temporal_features_indices=torch.LongTensor([1, 4]),
        discrete_temporal_features_num_categories=torch.LongTensor([10, 3]),
        reduction="sum",
    ),
)

print(
    "no reduction:",
    msas(
        temporal_with_discrete,
        synthetic_temporal_with_discrete,
        statistics_functions,
        discrete_temporal_features_indices=torch.LongTensor([1, 4]),
        discrete_temporal_features_num_categories=torch.LongTensor([10, 3]),
        reduction=None,
    ),
)

mean reduction: tensor(0.6623)
sum reduction: tensor(13.5695)
no reduction: (tensor([[0.8700, 0.9000, 0.9200],
        [0.9000, 0.9100, 0.8700],
        [0.9000, 0.8400, 0.9000],
        [0.9100, 0.9300, 0.9100],
        [0.8700, 0.9000, 0.8700]]), tensor([0.0692, 0.1004]), None)


In [6]:
# test with static data and temporal data without discrete features and with different reductions

print(
    "mean reduction:",
    msas(
        temporal_without_discrete,
        synthetic_temporal_without_discrete,
        statistics_functions,
        real_static_data=static_without_discrete,
        synthetic_static_data=synthetic_static_without_discrete,
        reduction="mean",
    ),
)

print(
    "sum reduction:",
    msas(
        temporal_without_discrete,
        synthetic_temporal_without_discrete,
        statistics_functions,
        real_static_data=static_without_discrete,
        synthetic_static_data=synthetic_static_without_discrete,
        reduction="sum",
    ),
)

print(
    "no reduction:",
    msas(
        temporal_without_discrete,
        synthetic_temporal_without_discrete,
        statistics_functions,
        real_static_data=static_without_discrete,
        synthetic_static_data=synthetic_static_without_discrete,
        reduction=None,
    ),
)

mean reduction: tensor(0.8858)
sum reduction: tensor(13.1450)
no reduction: (tensor([[0.8500, 0.8700, 0.8600, 0.8500, 0.8600],
        [0.9000, 0.8800, 0.8800, 0.8800, 0.8600],
        [0.8400, 0.8500, 0.9300, 0.9300, 0.8500],
        [0.9200, 0.8800, 0.8500, 0.8100, 0.7800],
        [0.9200, 0.8900, 0.9000, 0.8800, 0.8700]]), None, tensor([0.9400, 0.9100, 0.8400, 0.8800, 0.9300]))


In [7]:
# test with static data and temporal data with discrete features and with different reductions

print(
    "mean reduction:",
    msas(
        temporal_with_discrete,
        synthetic_temporal_with_discrete,
        statistics_functions,
        discrete_temporal_features_indices=torch.LongTensor([1, 4]),
        discrete_temporal_features_num_categories=torch.LongTensor([10, 3]),
        real_static_data=static_with_discrete,
        synthetic_static_data=synthetic_static_with_discrete,
        discrete_static_features_indices=torch.LongTensor([1, 4]),
        discrete_static_features_num_categories=torch.LongTensor([10, 3]),
        reduction="mean",
    ),
)

print(
    "sum reduction:",
    msas(
        temporal_with_discrete,
        synthetic_temporal_with_discrete,
        statistics_functions,
        discrete_temporal_features_indices=torch.LongTensor([1, 4]),
        discrete_temporal_features_num_categories=torch.LongTensor([10, 3]),
        real_static_data=static_with_discrete,
        synthetic_static_data=synthetic_static_with_discrete,
        discrete_static_features_indices=torch.LongTensor([1, 4]),
        discrete_static_features_num_categories=torch.LongTensor([10, 3]),
        reduction="sum",
    ),
)

print(
    "no reduction:",
    msas(
        temporal_with_discrete,
        synthetic_temporal_with_discrete,
        statistics_functions,
        discrete_temporal_features_indices=torch.LongTensor([1, 4]),
        discrete_temporal_features_num_categories=torch.LongTensor([10, 3]),
        real_static_data=static_with_discrete,
        synthetic_static_data=synthetic_static_with_discrete,
        discrete_static_features_indices=torch.LongTensor([1, 4]),
        discrete_static_features_num_categories=torch.LongTensor([10, 3]),
        reduction=None,
    ),
)

mean reduction: tensor(0.6196)
sum reduction: tensor(8.2269)
no reduction: (tensor([[0.8700, 0.9000, 0.9200],
        [0.9000, 0.9100, 0.8700],
        [0.9000, 0.8400, 0.9000],
        [0.9100, 0.9300, 0.9100],
        [0.8700, 0.9000, 0.8700]]), tensor([0.0692, 0.1004]), tensor([0.9300, 0.1017, 0.9400, 0.8500, 0.0626]))
