DUMMY DATA

In [109]:
import torch
import torch.nn as nn

batch_size = 5
feature_dim = 37
length_time = 50

data = torch.randn(batch_size, feature_dim, length_time)
time = torch.randint(0, 500, (1, length_time))
data_1 = data[0]


static = torch.randn(batch_size, 8)
static_1 = static[0]
    
sensor_mask = torch.zeros(batch_size, feature_dim, length_time)

sensor_mask[0, :10] = 1
sensor_mask[1, :20] = 1

SENSOR EMBEDDING LAYER

In [110]:


class SensorEmbeddingLayer(nn.Module):
    """Embedding layer for sensor features."""

    def __init__(self, num_sensor: int, time_length: int, dim_embedding: int):
        super(SensorEmbeddingLayer, self).__init__()
        self.num_sensor = num_sensor
        self.embedding_size = time_length
        self.dim_embedding = dim_embedding

        # Define the embedding layer
        self.sensor_embedding = nn.Linear(time_length, dim_embedding)

    def forward(self, sensor_matrix: torch.Tensor) -> torch.Tensor:
        """
        Apply sensor embedding to the input sensor matrix.
        
        Parameters:
        sensor_matrix (torch.Tensor): Input tensor of shape (num_sensor, embedding_size)
        
        Returns:
        torch.Tensor: Output tensor of shape (num_sensor, embedding_size, dim_embedding)
        """
        # Apply the embedding layer to each sensor in the matrix
        embedded_matrix = self.sensor_embedding(sensor_matrix)
        
        # Unsqueeze to add the time dimension
        embedded_matrix = embedded_matrix.unsqueeze(1)  # Shape: (num_sensor, 1, dim_embedding)

        # Expand along the time dimension
        embedded_matrix = embedded_matrix.expand(-1, self.embedding_size, -1)  # Shape: (num_sensor, time_length, dim_embedding)

        return embedded_matrix
        
        return embedded_matrix

# Example usage
num_sensor = 37
embedding_size = length_time
dim_embedding = 32

# Create the sensor embedding layer
sensor_embedding_layer = SensorEmbeddingLayer(num_sensor, embedding_size, dim_embedding)


# Get the embedded output
sensor_embedded_output = sensor_embedding_layer(data_1)

print("Embedded Output Shape:", sensor_embedded_output.shape)
print("Embedded Output:", sensor_embedded_output)

Embedded Output Shape: torch.Size([37, 50, 32])
Embedded Output: tensor([[[ 3.2309e-01, -1.1614e+00, -7.1861e-01,  ...,  3.6157e-01,
          -3.0024e-01, -3.0289e-01],
         [ 3.2309e-01, -1.1614e+00, -7.1861e-01,  ...,  3.6157e-01,
          -3.0024e-01, -3.0289e-01],
         [ 3.2309e-01, -1.1614e+00, -7.1861e-01,  ...,  3.6157e-01,
          -3.0024e-01, -3.0289e-01],
         ...,
         [ 3.2309e-01, -1.1614e+00, -7.1861e-01,  ...,  3.6157e-01,
          -3.0024e-01, -3.0289e-01],
         [ 3.2309e-01, -1.1614e+00, -7.1861e-01,  ...,  3.6157e-01,
          -3.0024e-01, -3.0289e-01],
         [ 3.2309e-01, -1.1614e+00, -7.1861e-01,  ...,  3.6157e-01,
          -3.0024e-01, -3.0289e-01]],

        [[ 2.2563e-01, -4.8985e-01, -5.7224e-02,  ..., -3.8700e-01,
           8.5718e-01,  3.2414e-01],
         [ 2.2563e-01, -4.8985e-01, -5.7224e-02,  ..., -3.8700e-01,
           8.5718e-01,  3.2414e-01],
         [ 2.2563e-01, -4.8985e-01, -5.7224e-02,  ..., -3.8700e-01,
           

TIME EMBEDDING LAYER

In [111]:
import math
from typing import Any, Optional

import torch
from torch import nn
from transformers import BigBirdConfig, MambaConfig

class TimeEmbeddingLayer(nn.Module):
    """Embedding layer for time features."""

    def __init__(self, embedding_size: int, is_time_delta: bool = False):
        super().__init__()
        self.embedding_size = embedding_size
        self.is_time_delta = is_time_delta

        self.w = nn.Parameter(torch.empty(1, self.embedding_size))
        self.phi = nn.Parameter(torch.empty(1, self.embedding_size))

        nn.init.xavier_uniform_(self.w)
        nn.init.xavier_uniform_(self.phi)

    def forward(self, time_stamps: torch.Tensor) -> Any:
        """Apply time embedding to the input time stamps."""
        if self.is_time_delta:
            # If the time_stamps represent time deltas, we calculate the deltas.
            # This is equivalent to the difference between consecutive elements.
            time_stamps = torch.cat(
                (time_stamps[:, 0:1] * 0, time_stamps[:, 1:] - time_stamps[:, :-1]),
                dim=-1,
            )
        time_stamps = time_stamps.float()
        time_stamps_expanded = time_stamps.unsqueeze(-1)
        next_input = time_stamps_expanded * self.w + self.phi

        return torch.sin(next_input)
    
# Example usage

time_layer = TimeEmbeddingLayer(dim_embedding, is_time_delta=False)

# Get the embedded output
time_embedded_output = time_layer(time)

print("Time Embedded Output Shape:", time_embedded_output.shape)

Time Embedded Output Shape: torch.Size([1, 50, 32])


STATIC EMBEDDING

In [112]:
class StaticEmbeddings(nn.Module):
    """Embedding layer for static features."""

    def __init__(self, input_dim: int, embedding_dim: int):
        super(StaticEmbeddings, self).__init__()
        self.input_dim = input_dim
        self.embedding_dim = embedding_dim

        # Define the embedding layer
        self.embedding_layer = nn.Linear(input_dim, embedding_dim)

    def forward(self, static_features: torch.Tensor) -> torch.Tensor:
        """
        Apply embedding to the input static features.
        
        Parameters:
        static_features (torch.Tensor): Input tensor of shape (batch_size, input_dim)
        
        Returns:
        torch.Tensor: Output tensor of shape (batch_size, embedding_dim)
        """
        # Apply the embedding layer
        embedded_features = self.embedding_layer(static_features)
        
        return embedded_features

# Example usage
input_dim = static.shape[1]
embedding_dim = 32  # Ensure this matches the desired embedding dimension

# Create the static embedding layer
static_embedding_layer = StaticEmbeddings(input_dim, embedding_dim)

# Get the embedded output
static_embedded_output = static_embedding_layer(static_1)

print("Static Embedded Output Shape:", static_embedded_output.shape)
print("Static Embedded Output:", static_embedded_output)

Static Embedded Output Shape: torch.Size([32])
Static Embedded Output: tensor([ 0.7477,  0.4308,  0.0703, -0.4131, -0.2980,  0.5652, -0.7332,  0.4662,
        -0.5202,  0.1866,  1.1226,  0.2759,  1.1025, -0.5448, -0.7484,  1.3939,
         0.1168,  0.7116,  1.4629, -0.7686, -0.2416, -0.3767,  1.1720,  0.5435,
         0.5205, -1.6172,  0.4486, -0.5138, -0.4031, -0.5404, -0.3838,  1.2161],
       grad_fn=<ViewBackward0>)


TOTAL EMBEDDING

In [None]:
class CombinedEmbeddings(nn.Module):
    """Combined embedding layer for sensor, time, and static features."""

    def __init__(self, num_sensor: int, time_lenght: int, dim_embedding: int, static_input_dim: int, is_time_delta: bool = False):
        super(CombinedEmbeddings, self).__init__()
        
        # Initialize sensor embedding layer
        self.sensor_embedding_layer = SensorEmbeddingLayer(num_sensor, time_lenght, dim_embedding)
        
        # Initialize time embedding layer
        self.time_embedding_layer = TimeEmbeddingLayer(dim_embedding, is_time_delta)
        
        # Initialize static embedding layer
        self.static_embedding_layer = StaticEmbeddings(static_input_dim, dim_embedding)

    def forward(self, sensor_matrix: torch.Tensor, time_stamps: torch.Tensor, static_features: torch.Tensor) -> torch.Tensor:
        """
        Apply combined embeddings to the input sensor matrix, time stamps, and static features.
        
        Parameters:
        sensor_matrix (torch.Tensor): Input tensor of shape (num_sensor, sensor_embedding_size)
        time_stamps (torch.Tensor): Input tensor of shape (1, time_embedding_size)
        static_features (torch.Tensor): Input tensor of shape (batch_size, static_input_dim)
        
        Returns:
        torch.Tensor: Concatenated output tensor of shape (batch_size, combined_embedding_dim)
        """
        # Apply sensor embedding
        sensor_embedded = self.sensor_embedding_layer(sensor_matrix) # output shape (num_sensor, time_lenght, dim_embedding)
        
        # Apply time embedding
        time_embedded = self.time_embedding_layer(time_stamps) # output shape (1, time_lenght, dim_embedding)
        print("Before expand",time_embedded.shape)
        time_embedded = time_embedded.expand(sensor_embedded.shape) # transform to shape (num_sensor, time_lenght, dim_embedding)
        print("After expand",time_embedded.shape)
        
        # Apply static embedding
        static_embedded = self.static_embedding_layer(static_features) # output shape (dim_embedding)
        static_embedded = static_embedded.unsqueeze(0).unsqueeze(0).expand(sensor_embedded.shape) # transform to shape (num_sensor, time_lenght, dim_embedding)
        
        # Combine all embeddings by adding them together
        combined_embedding = sensor_embedded + time_embedded + static_embedded
        
        return combined_embedding

# Example usage
combined_embedding_layer = CombinedEmbeddings(num_sensor, embedding_size, dim_embedding, input_dim, is_time_delta=False) 

# Get the combined embedded output
combined_embedded_output = combined_embedding_layer(data_1, time, static_1)

print("Combined Embedded Output Shape:", combined_embedded_output.shape)
print("Combined Embedded Output:", combined_embedded_output)

Combined Embedded Output Shape: torch.Size([37, 50, 32])
Combined Embedded Output: tensor([[[-1.9350e+00,  7.2320e-02, -3.9955e-01,  ...,  1.2593e+00,
           7.2569e-01,  2.6458e-01],
         [-5.2256e-01,  1.5626e-01, -1.4258e+00,  ...,  4.7073e-01,
           3.8697e-02, -1.5231e+00],
         [-1.9350e+00,  7.2320e-02, -3.9955e-01,  ...,  1.2593e+00,
           7.2569e-01,  2.6458e-01],
         ...,
         [-1.7623e+00,  1.1784e-01, -1.3839e+00,  ..., -9.0421e-02,
           1.0251e-01, -1.0962e+00],
         [-5.0239e-01, -1.6156e+00, -2.8773e-01,  ...,  1.6144e-01,
          -1.1364e+00, -1.8484e-01],
         [-1.4410e+00, -1.8103e+00, -1.2573e+00,  ...,  1.3058e+00,
           2.9971e-01,  2.8766e-01]],

        [[-1.8173e+00,  1.0718e+00, -5.9501e-01,  ...,  1.1334e+00,
           2.3187e+00, -6.2941e-01],
         [-4.0481e-01,  1.1557e+00, -1.6213e+00,  ...,  3.4479e-01,
           1.6317e+00, -2.4171e+00],
         [-1.8173e+00,  1.0718e+00, -5.9501e-01,  ...,  1.133

MAMBA MODEL

In [114]:
from mortality_part_preprocessing import MortalityDataset, PairedDataset, load_pad_separate
from torch.utils.data import DataLoader
import tqdm


train_batch_size = batch_size // 2  # we concatenate 2 batches together

train_collate_fn = PairedDataset.paired_collate_fn_truncate
val_test_collate_fn = MortalityDataset.non_pair_collate_fn_truncate

base_path = './P12data'

base_path_new = f"{base_path}/split_{1}"


train_pair, val_data, test_data = load_pad_separate(
    'physionet2012', base_path_new, 1
)

train_dataloader = DataLoader(train_pair, train_batch_size, shuffle=True, num_workers=16, collate_fn=train_collate_fn, pin_memory=True)
test_dataloader = DataLoader(test_data, batch_size, shuffle=True, num_workers=16, collate_fn=val_test_collate_fn, pin_memory=True)
val_dataloader = DataLoader(val_data, batch_size, shuffle=False, num_workers=16, collate_fn=val_test_collate_fn, pin_memory=True)

iterable_inner_dataloader = iter(train_dataloader) # make the train_dataloader iterable
test_batch = next(iterable_inner_dataloader) # iterate on the next object in a tuple
max_seq_length = test_batch[0].shape[2] # shape[2] = T
sensor_count = test_batch[0].shape[1] # shape[1] = F
static_size = test_batch[2].shape[1] # shape[1] = 8



Loading preprocessed datasets from ./processed_datasets
Loaded dataset from ./processed_datasets/physionet2012_1_pos.h5
Loaded dataset from ./processed_datasets/physionet2012_1_neg.h5
Loaded dataset from ./processed_datasets/physionet2012_1_val.h5
Loaded dataset from ./processed_datasets/physionet2012_1_test.h5


In [115]:
class MambaFoo(nn.Module):
    def __init__(self, num_sensor: int, time_lenght: int, dim_embedding: int, static_input_dim: int, is_time_delta: bool = False):
        super(MambaFoo, self).__init__()
        
        self.embedding = CombinedEmbeddings(
            num_sensor=num_sensor,
            time_lenght=time_lenght,
            dim_embedding=dim_embedding,
            static_input_dim=static_input_dim,
            is_time_delta=is_time_delta
        )

    def forward(self, sensor_matrix: torch.Tensor, time_stamps: torch.Tensor, static_features: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for MambaFoo.
        
        Parameters:
        - sensor_matrix: Input sensor data (batch_size, num_sensor, time_lenght)
        - time_stamps: Input time data (1, time_lenght)
        - static_features: Input static features (batch_size, static_input_dim)
        
        Returns:
        - torch.Tensor: Output embedding of shape (batch_size, num_sensor, time_lenght, dim_embedding)
        """
        # Compute combined embedding
        combined_embedding = self.embedding(sensor_matrix, time_stamps, static_features)
        
        # Optionally, add additional layers or computations here
        
        return combined_embedding


In [116]:
for batch in tqdm.tqdm(train_dataloader, total=len(train_dataloader)):
    data, times, static, labels, mask, delta = batch

    print(data.shape)

    for i in range(data.shape[0]):

        print(f"entering into patient {i}/{data.shape[0]}")
        
        patient = data[i,:,:]
        time_patient = time[i,:]
        static_patient = static[i,:]

        print(patient.shape)

        mambafoo = MambaFoo(
        num_sensor=patient.shape[0], 
        time_lenght=patient.shape[1], 
        dim_embedding=32, 
        static_input_dim=8, 
        is_time_delta=False
        )
        
        # Forward pass with sample inputs
        output = mambafoo(patient, time_patient, static_patient)

        print("MambaFoo Output Shape:", output.shape)
        print("MambaFoo Output:", output)

    "Continue this:"
    break

    




  0%|          | 0/2013 [00:00<?, ?it/s]

  0%|          | 0/2013 [00:00<?, ?it/s]

torch.Size([4, 37, 94])
entering into patient 0/4
torch.Size([37, 94])





RuntimeError: The expanded size of the tensor (94) must match the existing size (50) at non-singleton dimension 1.  Target sizes: [37, 94, 32].  Tensor sizes: [50, 32]