In [36]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt


class Weibull:
    def __init__(self, b, k, eps=1e-8):
        self.b = tf.constant(b, dtype=tf.float32)
        self.k = tf.constant(k, dtype=tf.float32)
        self.eps = eps


    def prob(self, x):
        x = tf.clip_by_value(x, self.eps, tf.reduce_max(x))
        return self.b * self.k * tf.math.pow(x, self.k - 1) \
            * tf.math.exp(-self.b * tf.math.pow(x, self.k))


    def log_prob(self, x):
        x = tf.clip_by_value(x, self.eps, tf.reduce_max(x))
        return tf.math.log(self.b) + tf.math.log(self.k) \
            + (self.k - 1) * tf.math.log(x) - self.b * tf.math.pow(x, self.k)
    

    def log_survival(self, x):
        x = tf.clip_by_value(x, self.eps, tf.reduce_max(x))
        return -self.b * tf.math.pow(x, self.k)
    

    def sample(self, size=()):            
        u = tf.random.uniform(minval=0,
                              maxval=1,
                              shape=self.b.shape + size,
                              dtype=tf.float32)
        y = -1 / self.b * tf.math.log(1 - u)
        y = tf.clip_by_value(y, self.eps, tf.reduce_max(y))
        return tf.math.pow(y, 1 / self.k).numpy()
    

class Gamma:
    def __init__(self, alpha, beta, eps=1e-8):
        self.alpha = tf.constant(alpha, dtype=tf.float32)
        self.beta = tf.constant(beta, dtype=tf.float32)
        self.eps = eps


    def prob(self, x):
        x = tf.clip_by_value(x, self.eps, tf.reduce_max(x))
        return tf.math.pow(x, self.alpha - 1) \
            * tf.math.pow(self.beta, self.alpha) \
            * tf.math.exp(-self.beta * x) \
            / tf.math.exp(tf.math.lgamma(self.alpha))


    def log_prob(self, x):
        x = tf.clip_by_value(x, self.eps, tf.reduce_max(x))
        self.alpha = tf.clip_by_value(self.alpha, self.eps,
                                      tf.reduce_max(self.alpha))
        self.beta = tf.clip_by_value(self.beta, self.eps,
                                     tf.reduce_max(self.beta))
        return (self.alpha - 1) * tf.math.log(x) \
            + self.alpha * tf.math.log(self.beta) - self.beta * x \
            - tf.math.lgamma(self.alpha)

    
    def survival(self, x):
        rhs = self.beta * x
        rhs = tf.clip_by_value(rhs, self.eps, tf.reduce_max(rhs))
        self.alpha = tf.clip_by_value(self.alpha, self.eps,
                                      tf.reduce_max(self.alpha))
        self.beta = tf.clip_by_value(self.beta, self.eps,
                                     tf.reduce_max(self.beta))
        return 1 - tf.math.igamma(self.alpha, rhs) \
            / tf.math.exp(tf.math.lgamma(self.alpha))


    def log_survival(self, x):
        y = self.survival(x)
        y = tf.clip_by_value(y, self.eps, tf.reduce_max(y))
        return tf.math.log(y)


    def sample(self, size=()):            
        return tf.random.gamma(size, self.alpha, self.beta).numpy()


class Model:
    def __init__(self, context_size=32, dist=Weibull):
        self.context_size = context_size
        self.encoder = keras.layers.GRU(context_size, return_sequences=True)
        self.decoder = keras.layers.Dense(2, activation="softplus")
        self.optimizer = keras.optimizers.Adam(learning_rate=0.01)
        self.dist = dist
        self.dist_params = {"b": [], "k": []}


    def get_context(self, inter_times):
        tau = tf.expand_dims(inter_times, axis=-1)
        log_tau = tf.math.log(tf.clip_by_value(tau, 1e-8, tf.reduce_max(tau)))
        input = tf.concat([tau, log_tau], axis=-1)
        output = self.encoder(input)
        context = tf.pad(output[:, :-1, :], [[0, 0], [1, 0], [0, 0]])
        return context


    def get_inter_times_distribution(self, context):
        if self.dist == Weibull or self.dist == Gamma:
            params = self.decoder(context)
            b = params[..., 0]
            k = params[..., 1]
            self.dist_params["b"].append(b)
            self.dist_params["k"].append(k)
            return self.dist(b, k)
        else:
            assert False and "Distribution not supported"


    def nll_loss(self, inter_times, seq_lengths):
        context = self.get_context(inter_times)
        inter_times_dist = self.get_inter_times_distribution(context)

        log_pdf = inter_times_dist.log_prob(inter_times)
        log_surv = inter_times_dist.log_survival(inter_times)

        # construit un masque pour ne sélectionner que les éléments
        # nécessaires dans chaque liste
        mask = np.cumsum(np.ones_like(log_pdf), axis=-1) \
            <= np.expand_dims(seq_lengths, axis=-1)
        log_like = tf.reduce_sum(log_pdf * mask, axis=-1)
        
        # idx est une liste de la forme [(a1, b1), (a2, b2), ...]
        # gather_nd sélectionne les éléments correspondant à ces indices
        # (ligne et colonne)
        idx = list(zip(range(len(seq_lengths)), seq_lengths))
        log_surv_last = tf.gather_nd(log_surv, idx)
        log_like += log_surv_last

        return -log_like
    

    @property
    def weights(self):
        return self.encoder.trainable_weights + self.decoder.trainable_weights
        
    
    def fit(self, epochs, inter_times, seq_lengths, t_end):
        for epoch in range(epochs + 1):
            with tf.GradientTape() as tape:
                loss = tf.reduce_mean(self.nll_loss(inter_times,
                                                    seq_lengths)) / t_end
            grads = tape.gradient(loss, self.weights)
            self.optimizer.apply_gradients(zip(grads, self.weights))

            if epoch % 10 == 0:
                print(f"Loss at epoch {epoch}: {loss:.2f}")


    def sample(self, batch_size, t_end):
        inter_times = np.empty((batch_size, 0))
        next_context = tf.zeros(shape=(batch_size, 1, 32))
        generated = False

        while not generated:
            dist = self.get_inter_times_distribution(next_context)
            next_inter_times = dist.sample()
            inter_times = tf.concat([inter_times, next_inter_times], axis=-1)
            tau = tf.expand_dims(next_inter_times, axis=-1)
            log_tau = tf.math.log(
                tf.clip_by_value(tau, 1e-8, tf.reduce_max(tau)))
            input = tf.concat([tau, log_tau], axis=-1)
            next_context = self.encoder(input)

            generated = np.sum(inter_times, axis=-1).min() >= t_end

        return np.cumsum(inter_times, axis=-1)
    

    def next(self, inter_times, seq_lengths, num_samples):
        new = np.empty((inter_times.shape[0], 0))
        # on ne part pas d'un vecteur contexte nul
        idx = list(zip(range(len(seq_lengths)), seq_lengths))
        last_inter_times = tf.expand_dims(
            tf.gather_nd(inter_times, idx), axis=1)
        tau = tf.expand_dims(last_inter_times, axis=-1)
        log_tau = tf.math.log(
            tf.clip_by_value(tau, 1e-8, tf.reduce_max(tau)))
        input = tf.concat([tau, log_tau], axis=-1)
        print(input.shape)
        next_context = self.encoder(input)

        # pour chaque itération, on tire les nouveaux temps d'attente
        for _ in range(num_samples):
            dist = self.get_inter_times_distribution(next_context)
            next_inter_times = dist.sample()
            new = tf.concat([new, next_inter_times], axis=-1)
            tau = tf.expand_dims(next_inter_times, axis=-1)
            log_tau = tf.math.log(
                tf.clip_by_value(tau, 1e-8, tf.reduce_max(tau)))
            input = tf.concat([tau, log_tau], axis=-1)
            next_context = self.encoder(input)

        return np.cumsum(new, axis=-1)

In [37]:
b = 0.7
k = 2.3
w = Weibull(b, k)
inter_times = w.sample((100, 100))
arrival_times = np.cumsum(inter_times, axis=-1)
eps = 1e-8
t_end = np.min(arrival_times[..., -1]) - eps
seq_lengths = np.sum(arrival_times < t_end, axis=-1)

epochs = 10
model = Model(context_size=32)
model.fit(epochs, inter_times, seq_lengths, t_end)

Loss at epoch 0: 1.64
Loss at epoch 10: 0.68
(100, 1, 2)


array([[4.11580980e-01, 1.51200825e+00, 2.13835335e+00, 3.35456252e+00,
        3.91814297e+00, 5.78070694e+00, 6.17968574e+00, 7.22091392e+00,
        8.00644413e+00, 8.84785530e+00],
       [1.63302350e+00, 2.36698151e+00, 2.97481209e+00, 3.18545449e+00,
        3.67059273e+00, 5.71365494e+00, 6.02063951e+00, 6.98537692e+00,
        7.02490250e+00, 8.17051033e+00],
       [1.47099328e+00, 2.35758436e+00, 7.79037941e+00, 1.05844907e+01,
        1.11750525e+01, 1.20013640e+01, 1.33428822e+01, 1.44788488e+01,
        1.50298898e+01, 1.60761143e+01],
       [3.34782577e+00, 7.20162988e+00, 7.51563340e+00, 7.54314824e+00,
        8.10601046e+00, 9.33593061e+00, 1.03005629e+01, 1.26062914e+01,
        1.34750257e+01, 1.40898189e+01],
       [2.87698150e-01, 5.44338357e+00, 6.81887650e+00, 7.43805885e+00,
        8.68394732e+00, 8.96231180e+00, 1.02197780e+01, 1.43299388e+01,
        1.63265395e+01, 1.80044925e+01],
       [2.36351347e+00, 2.57432739e+00, 3.48320983e+00, 6.67129157e+00,
   

In [41]:
preds = model.next(inter_times, seq_lengths, 10)

(100, 1, 2)


In [43]:
inter_times[0].shape, preds[0].shape

((100,), (10,))

In [44]:
inter_times[0, -1], preds[0, 0]

(1.1883255, 1.3173764944076538)