# Bottleneck Transformer Tutorial

This notebook demonstrates how to use the `BottleneckTransformer` model for multimodal healthcare data fusion in PyHealth.

**Overview:**
- Initialize BottleneckTransformer with multi-modality data
- Demonstrate modality-specific pre-fusion vs multimodal bottleneck fusion
- Highlight architecture hyperparameters `bottlenecks_n` and `fusion_startidx`
- Inspect forward passes and probability mappings

## 1. Environment Setup

In [None]:
import torch
import warnings
warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

## 2. Data Preparation
We use PyHealth's `create_sample_dataset` to generate a lightweight multimodal dataset. You can substitute this with `MIMIC3Dataset`, `MIMIC4Dataset` or `OMOPDataset` for real-world scenarios.

In [None]:
from pyhealth.datasets import create_sample_dataset

samples = [
    {
        "patient_id": "patient-0",
        "visit_id": "visit-0",
        "conditions": ["A", "B", "C"],
        "procedures": ["X", "Y"],
        "labs": [1.0, 2.0, 3.0],
        "label": 1,
    },
    {
        "patient_id": "patient-1",
        "visit_id": "visit-0",
        "conditions": ["D", "E"],
        "procedures": ["Y"],
        "labs": [4.0, 5.0, 6.0],
        "label": 0,
    },
]

input_schema = {
    "conditions": "sequence",
    "procedures": "sequence",
    "labs": "tensor",
}
output_schema = {"label": "binary"}

dataset = create_sample_dataset(
    samples=samples,
    input_schema=input_schema,
    output_schema=output_schema,
    dataset_name="test",
)

## 3. Dataloader Setup
We use PyHealth's automatic `get_dataloader` utility which converts the structured processed fields into batches.

In [None]:
from pyhealth.datasets import get_dataloader

train_loader = get_dataloader(dataset, batch_size=2, shuffle=True)

## 4. Initialize Bottleneck Transformer
The model initializes modality-specific transformer paths and limits the dense attention flow to bottleneck tokens specifically. 

- `fusion_startidx` parameter decides which layer cross-attention over bottlenecks activates. Lower means earlier fusion.
- `bottlenecks_n` regulates how many tokens represent the capacity of the bottleneck.

In [None]:
from pyhealth.models import BottleneckTransformer

model = BottleneckTransformer(
    dataset=dataset,
    embedding_dim=128,
    bottlenecks_n=4,
    fusion_startidx=1,
    num_layers=3,
    heads=4
).to(device)

print("Model modalities:", model.feature_keys)
print(model)

## 5. Forward Pass
Perform a simple mapping to inspect outputs. PyHealth models produce unified dicts returning `loss`, probability spaces `y_prob`, and predictions `logit`.

In [None]:
data_batch = next(iter(train_loader))
outputs = model(**data_batch)

for k, v in outputs.items():
    try:
        print(f"{k}: {v.shape}")
    except AttributeError:
        print(f"{k}: {v}")

print("\nForward pass successful!")