In [5]:
import torch
import torchcde
import torch.nn as nn

class CDEFunc(nn.Module):
    def __init__(self, hidden_dim, input_channels):
        super().__init__()
        self.linear = nn.Linear(hidden_dim, hidden_dim * input_channels)

    def forward(self, t, z):  # Adjusted to accept time 't'
        out = self.linear(z)
        return out.view(z.size(0), z.size(1), -1)


class CRTP_CDE_no_context(nn.Module):
    def __init__(self, input_channels, hidden_dim, num_activities, dropout=0.2):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_activities = num_activities

        self.initial = nn.Linear(input_channels, hidden_dim)
        self.func = CDEFunc(hidden_dim, input_channels)
        self.dropout = nn.Dropout(dropout)

        self.activity_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_activities)
        )
        self.rrt_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, data, ts):
        coeffs = torchcde.linear_interpolation_coeffs(data)
        X = torchcde.LinearInterpolation(coeffs)
        z0 = self.initial(X.evaluate(ts[0]))
        z_t = torchcde.cdeint(X=X, z0=z0, func=self.func, t=ts)
        z_T = z_t[:, -1, :]  # final hidden state
        z_T = self.dropout(z_T)
        activity_logits = self.activity_head(z_T)
        rrt = self.rrt_head(z_T).squeeze(-1)
        return activity_logits, rrt


# Run test
if __name__ == "__main__":
    torch.manual_seed(0)
    batch_size, sequence_len, input_dim, hidden_dim, num_activities = 2, 12, 6, 32, 10
    data = torch.randn(batch_size, sequence_len, input_dim)
    ts = torch.linspace(0, 1, sequence_len)

    model = CRTP_CDE_no_context(input_channels=input_dim, hidden_dim=hidden_dim, num_activities=num_activities)
    activity_logits, rrt = model(data, ts)

print("Activity logits (softmax probs for first instance):")
print(torch.softmax(activity_logits[0], dim=-1))
print("\nRemaining Runtime Predictions:")
print(rrt.detach())



Activity logits (softmax probs for first instance):
tensor([0.0643, 0.1190, 0.1084, 0.0970, 0.0879, 0.1257, 0.1183, 0.0857, 0.1062,
        0.0875], grad_fn=<SoftmaxBackward0>)

Remaining Runtime Predictions:
tensor([0.1469, 0.3285])
