In [1]:
from msas_pytorch import msas
import torch
from scipy.spatial.distance import euclidean

In [2]:
# let's define some statistics functions

statistics_functions = [torch.mean, torch.std, torch.median, torch.max, torch.min]

In [3]:
# and generated with random numbers simulated real and synthetic data

# both temporal and static, with both discrete and continuous features
# sample size = 100, number of static features = 5, number of temporal features = 5
# number of time steps = 10, discrete feature 1 has 10 categories, discrete feature 4 has 3 categories

static_with_discrete = torch.randn(1000, 5)
static_with_discrete[:, 1] = torch.randint(0, 10, (1000,))
static_with_discrete[:, 4] = torch.randint(0, 3, (1000,))

static_without_discrete = torch.randn(1000, 5)

temporal_with_discrete = torch.randn(1000, 10, 5)
temporal_with_discrete[:, :, 1] = torch.randint(0, 10, (1000, 10))
temporal_with_discrete[:, :, 4] = torch.randint(0, 3, (1000, 10))

temporal_without_discrete = torch.randn(1000, 10, 5)

synthetic_static_with_discrete = torch.randn(1000, 5)
synthetic_static_with_discrete[:, 1] = torch.randint(0, 10, (1000,))
synthetic_static_with_discrete[:, 4] = torch.randint(0, 3, (1000,))

synthetic_static_without_discrete = torch.randn(1000, 5)

synthetic_temporal_with_discrete = torch.randn(1000, 10, 5)
synthetic_temporal_with_discrete[:, :, 1] = torch.randint(0, 10, (1000, 10))
synthetic_temporal_with_discrete[:, :, 4] = torch.randint(0, 3, (1000, 10))

synthetic_temporal_without_discrete = torch.randn(1000, 10, 5)

In [4]:
# test without static data and without discrete temporal features and with different reductions
print(
    "mean reduction:",
    msas(
        temporal_without_discrete,
        synthetic_temporal_without_discrete,
        statistics_functions,
        reduction="mean",
    ),
)

print(
    "sum reduction:",
    msas(
        temporal_without_discrete,
        synthetic_temporal_without_discrete,
        statistics_functions,
        reduction="sum",
    ),
)

print(
    "no reduction:",
    msas(
        temporal_without_discrete,
        synthetic_temporal_without_discrete,
        statistics_functions,
        reduction=None,
    ),
)

mean reduction: tensor(0.9606)
sum reduction: tensor(24.0140)
no reduction: (tensor([[0.9790, 0.9610, 0.9510, 0.9450, 0.9660],
        [0.9610, 0.9670, 0.9650, 0.9440, 0.9540],
        [0.9670, 0.9620, 0.9620, 0.9320, 0.9550],
        [0.9610, 0.9680, 0.9650, 0.9560, 0.9690],
        [0.9640, 0.9610, 0.9690, 0.9660, 0.9640]]), None, None)


In [5]:
# test without static data and with discrete temporal features and with different reductions
print(
    "mean reduction:",
    msas(
        temporal_with_discrete,
        synthetic_temporal_with_discrete,
        statistics_functions,
        discrete_temporal_features_indices=torch.LongTensor([1, 4]),
        discrete_temporal_features_num_categories=torch.LongTensor([10, 3]),
        reduction="mean",
    ),
)

print(
    "sum reduction:",
    msas(
        temporal_with_discrete,
        synthetic_temporal_with_discrete,
        statistics_functions,
        discrete_temporal_features_indices=torch.LongTensor([1, 4]),
        discrete_temporal_features_num_categories=torch.LongTensor([10, 3]),
        reduction="sum",
    ),
)

print(
    "no reduction:",
    msas(
        temporal_with_discrete,
        synthetic_temporal_with_discrete,
        statistics_functions,
        discrete_temporal_features_indices=torch.LongTensor([1, 4]),
        discrete_temporal_features_num_categories=torch.LongTensor([10, 3]),
        reduction=None,
    ),
)

mean reduction: tensor(0.9603)
sum reduction: tensor(16.3217)
no reduction: (tensor([[0.9640, 0.9720, 0.9590],
        [0.9730, 0.9650, 0.9490],
        [0.9610, 0.9410, 0.9750],
        [0.9610, 0.9540, 0.9490],
        [0.9580, 0.9590, 0.9590]]), tensor([0.9604, 0.9623]), None)


In [6]:
# test with static data and temporal data without discrete features and with different reductions

print(
    "mean reduction:",
    msas(
        temporal_without_discrete,
        synthetic_temporal_without_discrete,
        statistics_functions,
        real_static_data=static_without_discrete,
        synthetic_static_data=synthetic_static_without_discrete,
        reduction="mean",
    ),
)

print(
    "sum reduction:",
    msas(
        temporal_without_discrete,
        synthetic_temporal_without_discrete,
        statistics_functions,
        real_static_data=static_without_discrete,
        synthetic_static_data=synthetic_static_without_discrete,
        reduction="sum",
    ),
)

print(
    "no reduction:",
    msas(
        temporal_without_discrete,
        synthetic_temporal_without_discrete,
        statistics_functions,
        real_static_data=static_without_discrete,
        synthetic_static_data=synthetic_static_without_discrete,
        reduction=None,
    ),
)

mean reduction: tensor(0.9643)
sum reduction: tensor(14.4270)
no reduction: (tensor([[0.9790, 0.9610, 0.9510, 0.9450, 0.9660],
        [0.9610, 0.9670, 0.9650, 0.9440, 0.9540],
        [0.9670, 0.9620, 0.9620, 0.9320, 0.9550],
        [0.9610, 0.9680, 0.9650, 0.9560, 0.9690],
        [0.9640, 0.9610, 0.9690, 0.9660, 0.9640]]), None, tensor([0.9560, 0.9670, 0.9740, 0.9760, 0.9670]))


In [7]:
# test with static data and temporal data with discrete features and with different reductions

print(
    "mean reduction:",
    msas(
        temporal_with_discrete,
        synthetic_temporal_with_discrete,
        statistics_functions,
        discrete_temporal_features_indices=torch.LongTensor([1, 4]),
        discrete_temporal_features_num_categories=torch.LongTensor([10, 3]),
        real_static_data=static_with_discrete,
        synthetic_static_data=synthetic_static_with_discrete,
        discrete_static_features_indices=torch.LongTensor([1, 4]),
        discrete_static_features_num_categories=torch.LongTensor([10, 3]),
        reduction="mean",
    ),
)

print(
    "sum reduction:",
    msas(
        temporal_with_discrete,
        synthetic_temporal_with_discrete,
        statistics_functions,
        discrete_temporal_features_indices=torch.LongTensor([1, 4]),
        discrete_temporal_features_num_categories=torch.LongTensor([10, 3]),
        real_static_data=static_with_discrete,
        synthetic_static_data=synthetic_static_with_discrete,
        discrete_static_features_indices=torch.LongTensor([1, 4]),
        discrete_static_features_num_categories=torch.LongTensor([10, 3]),
        reduction="sum",
    ),
)

print(
    "no reduction:",
    msas(
        temporal_with_discrete,
        synthetic_temporal_with_discrete,
        statistics_functions,
        discrete_temporal_features_indices=torch.LongTensor([1, 4]),
        discrete_temporal_features_num_categories=torch.LongTensor([10, 3]),
        real_static_data=static_with_discrete,
        synthetic_static_data=synthetic_static_with_discrete,
        discrete_static_features_indices=torch.LongTensor([1, 4]),
        discrete_static_features_num_categories=torch.LongTensor([10, 3]),
        reduction=None,
    ),
)

mean reduction: tensor(0.9583)
sum reduction: tensor(10.5515)
no reduction: (tensor([[0.9640, 0.9720, 0.9590],
        [0.9730, 0.9650, 0.9490],
        [0.9610, 0.9410, 0.9750],
        [0.9610, 0.9540, 0.9490],
        [0.9580, 0.9590, 0.9590]]), tensor([0.9604, 0.9623]), tensor([0.9420, 0.9640, 0.9550, 0.9740, 0.9463]))
