In [2]:
import torch
import numpy as np
from torch import nn

def masked_mean_pooling(datatensor, mask):
    """
    Computes the masked mean pooling of the input tensor along the specified dimension.

    This function calculates the average of the values in the input tensor `datatensor` 
    while ignoring the values at positions where the `mask` tensor is zero. It is 
    particularly useful for handling sequences of varying lengths where padding is 
    applied.

    Args:
        datatensor (torch.Tensor): The input tensor of shape (batch_size, lenght_time, num_features).
        mask (torch.Tensor): A binary mask tensor of shape (batch_size, sequence_length) 
                             where 1 indicates valid data points and 0 indicates padding.

    Returns:
        torch.Tensor: A tensor of shape (batch_size, feature_dim) containing the masked 
                      mean pooled values for each sequence in the batch.
    """
 
 
    # eliminate all values learned from nonexistant timepoints
    mask_expanded = mask.unsqueeze(-1).expand(datatensor.size()).float() # Takes the mask tensor, adds an extra dimension at the end,
    # expands it to match the size of datatensor, and converts it to a floating-point tensor.
    data_summed = torch.sum(datatensor * mask_expanded, dim=1)

    # find out number of existing timepoints
    data_counts = mask_expanded.sum(1)
    data_counts = torch.clamp(data_counts, min=1e-9)  # put on min clamp

    # Calculate average:
    averaged = data_summed / (data_counts)

    return averaged

batch_size = 5
feature_dim = 10
length_time = 50

datatensor = torch.randn(batch_size, length_time, feature_dim)

print(datatensor.size())

mask = torch.zeros(batch_size, length_time)

mask[0, :10] = 1
mask[1, :20] = 1

print(mask.size())

print(masked_mean_pooling(datatensor, mask).size())
print(masked_mean_pooling(datatensor, mask))

torch.Size([5, 50, 10])
torch.Size([5, 50])
torch.Size([5, 10])
tensor([[-0.5688, -0.3338, -0.1661, -0.0249, -0.0462, -0.4219,  0.1485,  0.0539,
          0.3579,  0.0553],
        [-0.0784,  0.2849, -0.1677,  0.2220, -0.1155,  0.1715,  0.1597,  0.2521,
          0.0316,  0.0373],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000]])


In [None]:
class PositionalEncodingTF(nn.Module):
    """
    Based on the SEFT positional encoding implementation
    """

    def __init__(self, d_model, max_len=500):
        super(PositionalEncodingTF, self).__init__()
        self.max_len = max_len
        self.d_model = d_model
        self._num_timescales = d_model // 2

    def getPE(self, P_time):
        B = P_time.shape[1]

        P_time = P_time.float()

        # create a timescale of all times from 0-1
        timescales = self.max_len ** np.linspace(0, 1, self._num_timescales)

        # make a tensor to hold the time embeddings
        times = torch.Tensor(P_time.cpu()).unsqueeze(2)

        # scale the timepoints according to the 0-1 scale
        scaled_time = times / torch.Tensor(timescales[None, None, :])
        # Use a 32-D embedding to represent a single time point
        pe = torch.cat(
            [torch.sin(scaled_time), torch.cos(scaled_time)], axis=-1
        )  # T x B x d_model
        pe = pe.type(torch.FloatTensor)

        return pe

    def forward(self, P_time):
        pe = self.getPE(P_time)
        return pe
    
time = torch.randint(0, 500, (1, 500))
print(time.size())

pos_layer = PositionalEncodingTF(32, 500)
embeddings = pos_layer(time)
print(embeddings.size())

print(embeddings[0, 1, :])

torch.Size([1, 500])
torch.Size([1, 500, 32])
tensor([-0.8218,  0.8080,  0.2857, -0.0298, -0.7708, -0.9994,  0.8192, -0.9973,
        -0.0212,  0.8682,  0.9821,  0.7911,  0.5671,  0.3880,  0.2603,  0.1731,
         0.5698,  0.5892,  0.9583,  0.9996, -0.6371, -0.0343,  0.5735,  0.0738,
        -0.9998, -0.4962,  0.1886,  0.6117,  0.8236,  0.9216,  0.9655,  0.9849])


In [19]:
from x_transformers import Encoder


def masked_max_pooling(datatensor, mask):
    """
    Adapted from HuggingFace's Sentence Transformers:
    https://github.com/UKPLab/sentence-transformers/
    Calculate masked average for final dimension of tensor
    """
    # eliminate all values learned from nonexistant timepoints
    mask_expanded = mask.unsqueeze(-1).expand(datatensor.size()).float()

    datatensor[mask_expanded == 0] = -1e9  # Set padding tokens to large negative value
    maxed = torch.max(datatensor, 1)[0]

    return maxed

class EncoderClassifierRegular(nn.Module):

    def __init__(
        self,
        device="cpu",
        pooling="mean",
        num_classes=2,
        sensors_count=37,
        static_count=8,
        layers=1,
        heads=1,
        dropout=0.2,
        attn_dropout=0.2,
        **kwargs
    ):
        super().__init__()

        self.pooling = pooling
        self.device = device
        self.sensors_count = sensors_count
        self.static_count = static_count

        self.sensor_axis_dim_in = 2 * self.sensors_count

        self.sensor_axis_dim = self.sensor_axis_dim_in
        if self.sensor_axis_dim % 2 != 0:
            self.sensor_axis_dim += 1

        self.static_out = self.static_count + 4

        self.attn_layers = Encoder(
            dim=self.sensor_axis_dim,
            depth=layers,
            heads=heads,
            attn_dropout=attn_dropout,
            ff_dropout=dropout,
        )

        #This embedding is used for the 37 time series of the pysionet
        self.sensor_embedding = nn.Linear(self.sensor_axis_dim_in, self.sensor_axis_dim)

        #Static is used for the rest of the constant variables, eg. Age.
        self.static_embedding = nn.Linear(self.static_count, self.static_out)
        self.nonlinear_merger = nn.Linear(
            self.sensor_axis_dim + self.static_out,
            self.sensor_axis_dim + self.static_out,
        )
        self.classifier = nn.Linear(
            self.sensor_axis_dim + self.static_out, num_classes
        )

        self.pos_encoder = PositionalEncodingTF(self.sensor_axis_dim)

    def forward(self, x, static, time, sensor_mask, **kwargs):

        x_time = torch.clone(x)  # (N, T)
        x_time = torch.permute(x_time, (0, 2, 1))  # (N, T)
        mask = (
            torch.count_nonzero(x_time, dim=2)
        ) > 0  # mask for sum of all sensors for each person/at each timepoint

        # add indication for missing sensor values
        x_sensor_mask = torch.clone(sensor_mask)  # (N, F, T)
        x_sensor_mask = torch.permute(x_sensor_mask, (0, 2, 1))  # (N, T, F)
        x_time = torch.cat([x_time, x_sensor_mask], axis=2)  # (N, T, 2F) #Binary

        # make sensor embeddings
        x_time = self.sensor_embedding(x_time)  # (N, T)

        # add positional encodings
        pe = self.pos_encoder(time).to(self.device)  # taken from RAINDROP, (N, T, pe)
        x_time = torch.add(x_time, pe)  # (N, T, F) (N, F)

        # run  attention
        x_time = self.attn_layers(x_time, mask=mask)

        if self.pooling == "mean":
            x_time = masked_mean_pooling(x_time, mask)
        elif self.pooling == "median":
            x_time = torch.median(x_time, dim=1)[0]
        elif self.pooling == "sum":
            x_time = torch.sum(x_time, dim=1)  # sum on time
        elif self.pooling == "max":
            x_time = masked_max_pooling(x_time, mask)

        # concatenate poolingated attented tensors
        static = self.static_embedding(static)
        x_merged = torch.cat((x_time, static), axis=1)

        nonlinear_merged = self.nonlinear_merger(x_merged).relu()

        # classify!
        return self.classifier(nonlinear_merged)
    
# Create dummy data
#x = torch.randn(batch_size, length_time, feature_dim)
x = torch.randn(batch_size, feature_dim, length_time)
time = torch.randint(0, 500, (1, 50))


static = torch.randn(batch_size, 8)
    
sensor_mask = torch.zeros(batch_size, feature_dim, length_time)

mask[0, :10] = 1
mask[1, :20] = 1


    # Create an instance of the EncoderClassifierRegular
model = EncoderClassifierRegular(device="cpu", pooling="mean", num_classes=2, sensors_count=feature_dim, static_count=8)

    # Run the model with the dummy data
output = model(x, static, time, sensor_mask)
print(output.size())
print(output)



torch.Size([5, 2])
tensor([[ 0.0454, -0.1720],
        [ 0.0908, -0.1668],
        [ 0.0839, -0.0448],
        [ 0.0705, -0.1464],
        [ 0.1508, -0.2261]], grad_fn=<AddmmBackward0>)
