In [1]:
import torch
from torch.nn import Parameter

from lafomo.gp.variational.models import VariationalLFM
from lafomo.gp.variational.trainer import P53ConstrainedTrainer
from lafomo.gp.variational.options import VariationalOptions
from lafomo.utilities import save, load
from lafomo.plot.variational_plotters import Plotter

from matplotlib import pyplot as plt

In [2]:
from lafomo.data_loaders.datasets import TranscriptomicDataset
from os import path
import numpy as np
class SingleCellKidney(TranscriptomicDataset):
    """
    scRNA-seq dataset on the human kidney.
    Accession number: GSE131685
    Parameters:
        calc_moments: bool=False whether to use the raw unspliced/spliced or moments
    """
    def __init__(self, data_dir='../data/',
                 raw_data_dir='/Volumes/ultra/genomics/scRNA-seq/GSE131685_RAW_kidney/velocyto',
                 calc_moments=False):
        super().__init__()
        data_path = path.join(data_dir, 'kidney1.pt')
        if path.exists(data_path):
            data = torch.load(data_path)
            self.m_observed = data['m_observed']
            self.data = data['data']
            self.gene_names = data['gene_names']
        else:
            import scvelo as scv
            kidney1 = path.join(raw_data_dir, 'kidney1.loom')
            data = scv.read_loom(kidney1)
            scv.pp.filter_and_normalize(data, min_shared_counts=20, n_top_genes=2000)
            if calc_moments:
                scv.pp.moments(data, n_neighbors=30, n_pcs=30)
                u = data.layers['Mu']
                s = data.layers['Ms']
            else:
                u = data.layers['unspliced'].toarray()
                s = data.layers['spliced'].toarray()

            self.loom = data
            self.gene_names = self.loom.var.index
            self.data = np.concatenate([u, s], axis=1)
            num_cells = self.data.shape[0]
            self.data = torch.tensor(self.data.swapaxes(0, 1).reshape(4000, 1, num_cells))
            self.m_observed = self.data.permute(1, 0, 2)

            self.data = list(self.data)
            torch.save({
                'data': self.data,
                'm_observed': self.m_observed,
                'gene_names': self.gene_names
            }, data_path)
dataset = SingleCellKidney(calc_moments=True)

In [3]:
print(dataset.m_observed.shape)
print(dataset.gene_names.shape)
print(dataset[0].shape)

torch.Size([1, 4000, 8164])
(2000,)
torch.Size([1, 8164])


In [4]:
from lafomo.data_loaders import LFMDataset
from sortedcontainers import SortedList
class RNAVelocityLFM(VariationalLFM):
    def __init__(self, num_genes, num_latents, t_inducing, dataset: LFMDataset, options: VariationalOptions, **kwargs):
        super().__init__(num_genes*2, num_latents, t_inducing, dataset, options, **kwargs)
        self.transcription_rate = Parameter(torch.rand((num_genes, 1), dtype=torch.float64))
        self.splicing_rate = Parameter(torch.rand((num_genes, 1), dtype=torch.float64))
        self.decay_rate = Parameter(0.1 + torch.rand((num_genes, 1), dtype=torch.float64))
        self.num_cells = dataset[0][0].shape[0]
        ### Initialise random time assignments
        self.time_assignments = Parameter(torch.rand(self.num_cells), requires_grad=False)

    def odefunc(self, t, h):
        """h is of shape (num_samples, num_outputs, 1)"""
        if (self.nfe % 5) == 0:
            print(t)
        self.nfe += 1
        num_samples = h.shape[0]
        num_outputs = h.shape[1]
        h = h.view(num_samples, num_outputs//2, 2)
        u = h[:, :, 0].unsqueeze(-1)
        s = h[:, :, 1].unsqueeze(-1)
        du = self.transcription_rate - self.splicing_rate * u
        ds = self.splicing_rate * u - self.decay_rate * s

        # q_f = self.get_latents(t.reshape(-1))
        # # Reparameterisation trick
        # f = q_f.rsample([self.num_samples])  # (S, I, t)
        # f = self.G(f)  # (S, num_outputs, t)

        h_t = torch.cat([du, ds], dim=1)
        return h_t

    def G(self, f):
        """
        Parameters:
            f: (I, T)
        """
        return f

    def predict_f(self, t_predict):
        # Sample from the latent distribution
        q_f = self.get_latents(t_predict.reshape(-1))
        f = q_f.sample([500])  # (S, I, t)
        # This is a hack to wrap the latent function with the nonlinearity. Note we use the same variance.
        f = torch.mean(self.G(f), dim=0)[0]
        return torch.distributions.multivariate_normal.MultivariateNormal(f, scale_tril=q_f.scale_tril)

In [5]:
from torch.utils.data import DataLoader
import numpy as np
from lafomo.utilities import is_cuda
from lafomo.gp.variational.trainer import Trainer
class EMTrainer(Trainer):
    """
    Expectation-Maximisation Trainer

    Parameters
    ----------
    model: .
    optimizer:
    dataset: Dataset where t_observed (T,), m_observed (J, T).
    inducing timepoints.
    give_output: whether the trainer should give the first output (y_0) as initial value to the model `forward()`
    """
    def __init__(self, model: RNAVelocityLFM, optimizer: torch.optim.Optimizer, dataset, batch_size=1, give_output=False):
        super().__init__(model, optimizer, dataset, batch_size, give_output)
        # Initialise trajectory
        self.timepoint_choices = torch.linspace(0, 1, 100, requires_grad=False)
        initial_value = self.initial_value(None)
        self.previous_trajectory = self.model(self.timepoint_choices, initial_value, rtol=1e-3, atol=1e-4)

    def initial_value(self, y):
        initial_value = torch.zeros((self.batch_size, 1), dtype=torch.float64)
        initial_value = initial_value.cuda() if is_cuda() else initial_value
        if self.give_output:
            initial_value = y[0]
        return initial_value.repeat(self.model.num_samples, 1, 1)  # Add batch dimension for sampling

    def e_step(self, y):
        num_outputs = self.model.num_outputs
        # sorted_times, sort_indices = torch.sort(self.model.time_assignments, dim=0)
        # trajectory = self.model(sorted_times, self.initial_value(None), rtol=1e-2, atol=1e-3)

        # optimizer = torch.optim.LBFGS([model.time_assignments])
        u = self.previous_trajectory[:num_outputs//2]  # (2000, 100, 1)
        s = self.previous_trajectory[num_outputs//2:]  # (2000, 100, 1)
        trajectory = self.previous_trajectory # (4000, 100, 1)
        # y shape (4000, 8164, 1)
        num_times = trajectory.shape[1]
        index_to_time = torch.linspace(0, 1, num_times)
        for cell in range(self.model.num_cells):
            cell_time = self.model.time_assignments[cell]
            minimum_residual = 1e9
            minimum_i = -1
            for i in range(num_times): # loop through all times
                time = index_to_time[i]
                residual = (trajectory[:, i] - y[:, cell]) ** 2
                residual = residual.sum()
                if residual < minimum_residual:
                    minimum_residual = residual
                    minimum_i = i
            model.time_assignments[cell] = index_to_time[minimum_i]
        #     def closure():
        #         optimizer.zero_grad()
        #         print(trajectory.shape, y.shape) # (4000, 100, 1)
        #         # sum up the residuals between the true u, s and the trajectory u, s
        #         # the model.time_assignments variable contains for each cell which time it is assigned
        #         t_i = model.time_assignments[cell]
        #         # trajectory[]
        #
        #         loss = (trajectory - y) ** 2
        #         loss = loss.sum()
        #         loss.backward()
        #         print(loss)
        #         return loss
        #     optimizer.step(closure)

    def single_epoch(self, rtol, atol):
        epoch_loss = 0
        epoch_ll = 0
        epoch_kl = 0
        for i, data in enumerate(self.data_loader):
            self.optimizer.zero_grad()
            y = data.permute(0, 2, 1) # (O, C, 1)
            y = y.cuda() if is_cuda() else y
            ### E-step ###
            # assign timepoints $t_i$ to each cell by minimising its distance to the trajectory
            # self.e_step(y)
            print('estep done')
            ### M-step ###
            initial_value = self.initial_value(None)
            t_sorted, inv_indices = torch.unique(self.model.time_assignments, sorted=True, return_inverse=True)
            print(t_sorted, inv_indices)
            print(t_sorted.shape)
            output = self.model(t_sorted, initial_value, rtol=rtol, atol=atol)
            print('fiin')
            output = torch.squeeze(output)
            print(output.shape)
            # Calc loss and backprop gradients
            mult = 1
            if self.num_epochs <= 10:
                mult = self.num_epochs/10

            ll, kl = self.model.elbo(y, output, mult, data_index=i)
            total_loss = -ll + kl
            total_loss.backward()
            self.optimizer.step()
            epoch_loss += total_loss.item()
            epoch_ll += ll.item()
            epoch_kl += kl.item()
        return epoch_loss, epoch_ll, epoch_kl

    def train(self, epochs=20, report_interval=1, plot_interval=20, rtol=1e-5, atol=1e-6):
        losses = list()
        end_epoch = self.num_epochs+epochs
        plt.figure(figsize=(4, 2.3))

        for epoch in range(epochs):
            epoch_loss, epoch_ll, epoch_kl = self.single_epoch(rtol, atol)


            if (epoch % report_interval) == 0:
                print('Epoch %d/%d - Loss: %.2f (%.2f %.2f) λ: %.3f' % (
                    self.num_epochs + 1, end_epoch,
                    epoch_loss, -epoch_ll, epoch_kl,
                    self.model.kernel.lengthscale[0].item(),
                ), end='')
                self.print_extra()

            losses.append((-epoch_ll, epoch_kl))
            self.after_epoch()

            if (epoch % plot_interval) == 0:
                plt.plot(self.t_observed, output[0].cpu().detach().numpy(), label='epoch'+str(epoch))
            self.num_epochs += 1
        plt.legend()

        losses = np.array(losses)
        self.losses = np.concatenate([self.losses, losses], axis=0)

        return output

    def print_extra(self):
        print('')

    def after_epoch(self):
        pass


In [6]:
options = VariationalOptions(
    learn_inducing=False,
    num_samples=1,
    kernel_scale=False
)
num_cells = dataset[0].shape[1]
print(num_cells)
t_inducing = torch.linspace(0, 1, 10, dtype=torch.float64).reshape((-1, 1))
t_observed = torch.linspace(0, 12, num_cells).view(-1)
t_predict = torch.linspace(-1, 13, 80, dtype=torch.float64)
model = RNAVelocityLFM(2000, 1, t_inducing, dataset, options)

optimizer = torch.optim.Adam(model.parameters(), lr=0.05)
trainer = EMTrainer(model, optimizer, dataset, batch_size=4000)
plotter = Plotter(model, dataset.gene_names)

8164
tensor(0.)
torch.Size([50, 4000, 1])
tensor(1.0000e-06, dtype=torch.float64)
torch.Size([50, 4000, 1])
tensor(2.0000e-05, grad_fn=<AddBackward0>)
torch.Size([50, 4000, 1])
tensor(3.0000e-05, grad_fn=<AddBackward0>)
torch.Size([50, 4000, 1])
tensor(8.0000e-05, grad_fn=<AddBackward0>)
torch.Size([50, 4000, 1])
tensor(8.8889e-05, grad_fn=<AddBackward0>)
torch.Size([50, 4000, 1])
tensor(1.0000e-04, grad_fn=<AddBackward0>)
torch.Size([50, 4000, 1])
tensor(1.0000e-04, grad_fn=<AddBackward0>)
torch.Size([50, 4000, 1])
tensor(0.0003, grad_fn=<AddBackward0>)
torch.Size([50, 4000, 1])
tensor(0.0004, grad_fn=<AddBackward0>)
torch.Size([50, 4000, 1])
tensor(0.0009, grad_fn=<AddBackward0>)
torch.Size([50, 4000, 1])
tensor(0.0010, grad_fn=<AddBackward0>)
torch.Size([50, 4000, 1])
tensor(0.0011, grad_fn=<AddBackward0>)
torch.Size([50, 4000, 1])
tensor(0.0011, grad_fn=<AddBackward0>)
torch.Size([50, 4000, 1])
tensor(0.0031, grad_fn=<AddBackward0>)
torch.Size([50, 4000, 1])
tensor(0.0041, grad_fn=

  self.inducing_inputs = Parameter(torch.tensor(t_inducing), requires_grad=options.learn_inducing)
  self.initial_conditions = Parameter(torch.tensor(torch.zeros(self.num_outputs, 1)), requires_grad=True)


### Outputs prior to training:

In [7]:
rtol = 1e-3
atol = rtol/10

model_kwargs = {
    'rtol': rtol, 'atol': atol
}

# plotter.plot_outputs(t_predict, replicate=0, t_scatter=t_observed,y_scatter=dataset.m_observed, model_kwargs=model_kwargs);
# plotter.plot_latents(t_predict, ylim=(-1, 3), plot_barenco=True, plot_inducing=False)

In [None]:
tol = 5e-3
import time
start = time.time()

output = trainer.train(1, rtol=tol, atol=tol/10,
                       report_interval=5, plot_interval=5)
end = time.time()
print(end - start)


estep done
tensor([6.5565e-06, 6.5982e-05, 3.6025e-04,  ..., 9.9952e-01, 9.9959e-01,
        9.9988e-01]) tensor([3772, 3341, 7108,  ..., 2982, 3011, 5788])
torch.Size([8163])
tensor(6.5565e-06)
torch.Size([50, 4000, 1])
tensor(7.5565e-06, dtype=torch.float64)
torch.Size([50, 4000, 1])
tensor(2.6557e-05, grad_fn=<AddBackward0>)
torch.Size([50, 4000, 1])
tensor(3.6557e-05, grad_fn=<AddBackward0>)
torch.Size([50, 4000, 1])
tensor(8.6557e-05, grad_fn=<AddBackward0>)
torch.Size([50, 4000, 1])
tensor(9.5445e-05, grad_fn=<AddBackward0>)
torch.Size([50, 4000, 1])
tensor(0.0001, grad_fn=<AddBackward0>)
torch.Size([50, 4000, 1])
tensor(0.0001, grad_fn=<AddBackward0>)
torch.Size([50, 4000, 1])
tensor(0.0003, grad_fn=<AddBackward0>)
torch.Size([50, 4000, 1])
tensor(0.0004, grad_fn=<AddBackward0>)
torch.Size([50, 4000, 1])
tensor(0.0009, grad_fn=<AddBackward0>)
torch.Size([50, 4000, 1])
tensor(0.0010, grad_fn=<AddBackward0>)
torch.Size([50, 4000, 1])
tensor(0.0011, grad_fn=<AddBackward0>)
torch.Si

### Outputs after training

In [None]:
plotter.plot_losses(trainer, last_x=100)
plotter.plot_outputs(t_predict, replicate=0, ylim=(0, 3), t_scatter=t_observed,y_scatter=m_observed, model_kwargs=model_kwargs)
plotter.plot_latents(t_predict, ylim=(-2, 4), plot_barenco=True, plot_inducing=False)
plotter.plot_kinetics()
plotter.plot_convergence(trainer)

In [None]:
save(model, 'variational_linear')

In [None]:
do_load = False
if do_load:
    model = load('variational_linear', SingleLinearLFM, num_genes, num_tfs,
                 t_inducing, dataset, extra_points=2, fixed_variance=dataset.variance)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
    trainer = P53ConstrainedTrainer(model, optimizer, dataset)
print(do_load)