In [1]:
import torch
import torch.nn as nn
import itertools
import numpy as np
import matplotlib.pyplot as plt

In [74]:
class MCModel(nn.Module):
    """
    Markov chain approximation (homogeneous case)
    """

    def __init__(
        self,
        mu_init: float,
        sigma: float,
        a: float,
        z: float,
        dt: float,
        Nx: int,
    ) -> None:
        super(MCModel, self).__init__()
        self.mu = nn.Parameter(torch.tensor([mu_init]))

        self.sigma = sigma  # diffusion coeff (constant)
        self.a, self.z = a, z  # upper boundary & starting point
        self.Nx = Nx  # num of space steps
        dx = a / Nx
        self.dt, self.dx = dt, dx

        self.idx_z = int(round(z / dx))  # index of starting point
        self.init_dist = torch.zeros((1, self.Nx + 2))
        self.init_dist[0, self.idx_z] = 1

    def forward(self, T, s):
        """
        compute the probability of P(X[T]=s) with a exponential scaling
        where t is the first passage time
        by DYNAMIC PROGRAMMING
        s: value in [0, a]
        """
        m1 = self.mu * self.dt
        m2 = (self.mu * self.dt) ** 2 + self.sigma ** 2 * self.dt
        p1 = (m2 / self.dx ** 2 + m1 / self.dx) / 2
        p2 = (m2 / self.dx ** 2 - m1 / self.dx) / 2
        assert p1 + p2 < 1, "p+=%.5f, p0=%.5f, p-=%.5f" % (p1, 1 - p1 - p2, p2)
        probs = torch.cat((p2, 1 - p1 - p2, p1))
        indices = [[0, self.Nx + 1], [self.Nx, self.Nx + 1], [self.Nx + 1, self.Nx + 1]]
        values = torch.tensor([1, 1, 1])
        for i in range(1, self.Nx):
            indices.extend([[i, i - 1], [i, i], [i, i + 1]])
            values = torch.cat((values, probs))
        AdjMat = torch.sparse_coo_tensor(list(zip(*indices)), values, size=(self.Nx + 2, self.Nx + 2))
        idx_T = int(round(T / self.dt))
        idx_s = int(round(s / self.dx))
        r = torch.tensor(0)
        scaled_table = AdjMat.to_dense()[:, [idx_s]] / torch.exp(r)
        for t_step in range(idx_T - 2, -1, -1):
            b = torch.sum(torch.sparse.mm(AdjMat, scaled_table))
            r = r + torch.log(b)
            scaled_table = torch.sparse.mm(AdjMat, scaled_table) / b
        return torch.sparse.mm(self.init_dist, scaled_table) * torch.exp(r)

    def loss_fun(self, data):
        """
        compute the average negative log likelihood
        likelihood = product of P(X(Tk)=Ck)
        """
        logprob = 0
        for Tk, Ck in data:
            logprob -= torch.log(self.forward(Tk, Ck))
            if torch.isinf(logprob):
                raise ValueError("Infty detected, computation is stopped.")
        return logprob.squeeze() / len(data)


In [79]:
def train(model, data, num_epochs=100):
    loss_history, mu_history = [], []
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
    for epoch in range(1, num_epochs + 1):
        model.train()
        optimizer.zero_grad()
        loss = model.loss_fun(data)
        loss.backward()
        if epoch == 1 or epoch % 1 == 0:
            print("Epoch %d: Loss: %.5f; Parameters: %.5f" % (epoch - 1, loss, model.mu))
            loss_history.append(loss)
            mu_history.append(model.mu)
        mu_prev = model.mu.clone().detach().numpy()
        optimizer.step()
        if np.abs(mu_prev - model.mu.detach().numpy()) < 1e-4:
            break
    print("Optimization ends.")
    print("Epoch %d: Loss: %.5f; Parameters: %.5f" % (epoch, model.loss_fun(data), model.mu))
    loss_history.append(loss)
    mu_history.append(model.mu)
    return loss_history, mu_history


In [80]:
data = np.load('../data.npy')
data.shape

(100, 2)

In [81]:
ddm = MCModel(mu_init=1., sigma=1, a=4, z=1.5, dt=0.01, Nx=20)

In [82]:
train(ddm, data)

Epoch 0: Loss: 8.78511; Parameters: 1.00000
Epoch 1: Loss: 7.96819; Parameters: 0.68025
Epoch 2: Loss: 7.67520; Parameters: 0.48888
Epoch 3: Loss: 7.56954; Parameters: 0.37400
Epoch 4: Loss: 7.53138; Parameters: 0.30495
Epoch 5: Loss: 7.51755; Parameters: 0.26343
Epoch 6: Loss: 7.51257; Parameters: 0.23846
Epoch 7: Loss: 7.51077; Parameters: 0.22344
Epoch 8: Loss: 7.51012; Parameters: 0.21441
Epoch 9: Loss: 7.50989; Parameters: 0.20897
Epoch 10: Loss: 7.50978; Parameters: 0.20570
Epoch 11: Loss: 7.50976; Parameters: 0.20374
Epoch 12: Loss: 7.50974; Parameters: 0.20256
Epoch 13: Loss: 7.50975; Parameters: 0.20184
Epoch 14: Loss: 7.50976; Parameters: 0.20142
Epoch 15: Loss: 7.50976; Parameters: 0.20116
Epoch 16: Loss: 7.50974; Parameters: 0.20100
Optimization ends.
Epoch 17: Loss: 7.50974; Parameters: 0.20091
