In [1]:
import torch

from EventStream.transformer import CondIndepModelForGenerativeSequenceModeling
from EventStream.transformer.config import StructuredTransformerConfig, StructuredEventProcessingMode
from EventStream.data.types import DataModality, PytorchBatch, TemporalityType
from EventStream.data.config import MeasurementConfig

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
TEST_DATA_TYPES_PER_GEN_MODE = {
    "single_label_classification": ["event_type"],
    "multi_label_classification": ["multi_label_col", "regression_col"],
    "multivariate_regression": ["regression_col"],
}
TEST_DATA_TYPES_IDXMAP = {
    "event_type": 1,
    "multi_label_col": 2,
    "regression_col": 3,
}
TEST_VOCAB_SIZES_BY_DATA_TYPE = {
    "event_type": 2,
    "multi_label_col": 3,
    "regression_col": 4,
}
TEST_VOCAB_OFFSETS_BY_DATA_TYPE = {
    "event_type": 1,
    "multi_label_col": 3,
    "regression_col": 6,
}
TEST_MEASUREMENTS_PER_DEP_GRAPH_LEVEL = [[], ["event_type"], ["multi_label_col", "regression_col"]]


default_config_kwargs = dict(
    dep_graph_attention_types=None,
    dep_graph_window_size=None,
    do_full_block_in_dep_graph_attention=None,
    do_full_block_in_seq_attention=None,
    measurements_per_generative_mode=TEST_DATA_TYPES_PER_GEN_MODE,
    vocab_sizes_by_measurement=TEST_VOCAB_SIZES_BY_DATA_TYPE,
    vocab_offsets_by_measurement=TEST_VOCAB_OFFSETS_BY_DATA_TYPE,
    measurements_idxmap=TEST_DATA_TYPES_IDXMAP,
    vocab_size=10,
    hidden_size=4,
    num_hidden_layers=5,
    head_dim=None,
    num_attention_heads=2,  # Needs to divide hidden_size.
    mean_log_inter_time=0,
    std_log_inter_time=1,
    use_cache=False,
    measurements_per_dep_graph_level=None,
    measurement_configs={
        "multi_label_col": MeasurementConfig(
            modality=DataModality.MULTI_LABEL_CLASSIFICATION,
            temporality=TemporalityType.DYNAMIC,
        ),
        "regression_col": MeasurementConfig(
            modality=DataModality.MULTIVARIATE_REGRESSION,
            temporality=TemporalityType.DYNAMIC,
            values_column="regression_val",
        ),
    },
)

In [7]:
BASE_BATCH = {
    "event_mask": torch.BoolTensor([[True, True, True, True], [False, True, True, True]]),
    "time_delta": torch.FloatTensor([[0, 2, 5, 1], [0, 3, 2, 1]]),
    "start_time": torch.FloatTensor([1.0, 1412.0]),
    "static_indices": torch.LongTensor([[1, 2, 3], [1, 3, 0]]),
    "static_measurement_indices": torch.LongTensor([[1, 2, 3], [1, 3, 0]]),
    "dynamic_values_mask": torch.BoolTensor(
        [
            [
                [False, False, False, False, False, False],
                [False, False, False, False, False, False],
                [False, False, False, True, True, True],
                [False, False, False, False, True, True],
            ],
            [
                [False, False, False, False, False, False],
                [False, False, False, False, False, False],
                [False, False, False, False, False, True],
                [False, False, False, False, True, True],
            ],
        ]
    ),
    "dynamic_measurement_indices": torch.LongTensor(
        [
            [
                [1, 0, 0, 0, 0, 0],
                [1, 2, 0, 0, 0, 0],
                [1, 2, 2, 3, 3, 3],
                [1, 2, 2, 2, 3, 3],
            ],
            [
                [1, 0, 0, 0, 0, 0],
                [1, 2, 0, 0, 0, 0],
                [1, 2, 2, 2, 2, 3],
                [1, 2, 2, 2, 3, 3],
            ],
        ]
    ),
    "dynamic_indices": torch.LongTensor(
        [
            [
                [1, 0, 0, 0, 0, 0],
                [2, 5, 0, 0, 0, 0],
                [2, 4, 5, 7, 8, 9],
                [2, 4, 5, 5, 8, 9],
            ],
            [
                [1, 0, 0, 0, 0, 0],
                [2, 5, 0, 0, 0, 0],
                [2, 4, 5, 4, 4, 9],
                [2, 4, 5, 5, 8, 9],
            ],
        ]
    ),
    "dynamic_values": torch.Tensor(
        [
            [
                [0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0],
                [0, 0, 0, 1.1, -1.1, 0.0],
                [0, 0, 0, 0, -3.1, 0.2],
            ],
            [
                [0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 1.4],
                [0, 0, 0, 0, -3.0, 1.2],
            ],
        ]
    ),
    # batch_size, n_chunks, n_neighbors, neighbor_len, hidden
    # n_chunks * model_chunk_len = model seq_len (?)
    "retreived_hidden_states": torch.randn(2, 2, 2, 8, 4),
}

In [23]:
config = StructuredTransformerConfig(
    **default_config_kwargs,
    structured_event_processing_mode=StructuredEventProcessingMode.CONDITIONALLY_INDEPENDENT,
    retreival_augmented=True,
    chunked_cross_attention_chunk_len=2,
)
assert config.retreival_layer_idx is not None
model = CondIndepModelForGenerativeSequenceModeling(config)

In [24]:
batch = PytorchBatch(**BASE_BATCH)

first_half_output = model.first_half_forward(batch)

In [25]:
second_half_output = model.second_half_forward(
    batch=batch,
    first_half_output=first_half_output,
)

  attn = F.softmax(attn.view(*attn.shape[:-2], -1)).view(attn.shape)


In [30]:
model

CondIndepModelForGenerativeSequenceModeling(
  (encoder): ConditionallyIndependentPointProcessTransformer(
    (input_layer): ConditionallyIndependentPointProcessInputLayer(
      (data_embedding_layer): DataEmbeddingLayer(
        (embed_layer): EmbeddingBag(10, 4, mode='sum', padding_idx=0)
      )
      (time_embedding_layer): TemporalPositionEncoding()
      (embedding_dropout): Dropout(p=0.1, inplace=False)
    )
    (h): ModuleList(
      (0-4): 5 x InnerBlock(
        (attn): InnerAttention(
          (attention): InnerSelfAttention(
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
            (k_proj): Linear(in_features=4, out_features=4, bias=False)
            (v_proj): Linear(in_features=4, out_features=4, bias=False)
            (q_proj): Linear(in_features=4, out_features=4, bias=False)
            (out_proj): Linear(in_features=4, out_features=4, bias=True)
          )
          (layer_norm): LayerNorm((