In [1]:
import pyro
from pyro.nn import PyroSample, PyroModule
from pyro.infer import SVI, Trace_ELBO, autoguide
import torch
from torch.nn.functional import softplus
from sklearn.metrics import mean_squared_error
import random
import matplotlib.pyplot as plt
import seaborn as sns
import anndata as ann

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
data = ann.read_h5ad("/mnt/storage/thien/projectdata/GSE194122_openproblems_neurips2021_cite_BMMC_processed.h5ad")

  utils.warn_names_duplicates("var")


In [5]:
print(data)

AnnData object with n_obs × n_vars = 90261 × 14087
    obs: 'GEX_n_genes_by_counts', 'GEX_pct_counts_mt', 'GEX_size_factors', 'GEX_phase', 'ADT_n_antibodies_by_counts', 'ADT_total_counts', 'ADT_iso_count', 'cell_type', 'batch', 'ADT_pseudotime_order', 'GEX_pseudotime_order', 'Samplename', 'Site', 'DonorNumber', 'Modality', 'VendorLot', 'DonorID', 'DonorAge', 'DonorBMI', 'DonorBloodType', 'DonorRace', 'Ethnicity', 'DonorGender', 'QCMeds', 'DonorSmoker', 'is_train'
    var: 'feature_types', 'gene_id'
    uns: 'dataset_id', 'genome', 'organism'
    obsm: 'ADT_X_pca', 'ADT_X_umap', 'ADT_isotype_controls', 'GEX_X_pca', 'GEX_X_umap'
    layers: 'counts'


In [6]:
# simulate data
n_obs = 100
n_features = 20
n_factors = 5

torch.manual_seed(2024)
Z_in = torch.randn(n_obs, n_factors) 
W_in = torch.randn(n_features, n_factors) 

# create observated values from the simulated factor and weight matrix with some random noise
Y = torch.matmul(Z_in, W_in.t()) + 0.2 * torch.randn(n_obs, n_features)
print(Y.shape)
print(Y)

torch.Size([100, 20])
tensor([[ 5.3621e-01, -9.5076e-01, -8.2234e-01,  ..., -2.3100e+00,
          1.5223e+00,  4.2384e-01],
        [-8.2410e-01, -2.6438e+00,  5.0057e-01,  ..., -1.6188e+00,
          1.4766e+00, -1.3624e+00],
        [ 3.5991e+00,  8.5626e+00, -1.4502e+00,  ...,  9.4482e-01,
         -1.3726e+00,  1.0173e+00],
        ...,
        [-7.1978e-01, -2.0009e+00,  1.3592e-01,  ..., -6.8602e-01,
          6.1233e-01,  4.7834e-02],
        [-8.2973e-04, -2.9130e+00,  4.6990e-01,  ..., -9.3622e-01,
          9.6277e-02,  7.1257e-01],
        [-9.2778e-01,  2.7502e+00, -1.1326e+00,  ...,  5.6296e+00,
         -1.5591e+00,  2.7615e+00]])


In [7]:
class FA(PyroModule):
    def __init__(self, Y, K):
        """
        Args:
            Y: Tensor (Samples x Features)
            K: Number of Latent Factors
        """
        super().__init__()
        pyro.clear_param_store()
        
        # data
        self.Y = Y
        self.K = K
        
        self.num_samples = self.Y.shape[0]
        self.num_features = self.Y.shape[1]
        
        self.sample_plate = pyro.plate("sample", self.num_samples)
        self.feature_plate = pyro.plate("feature", self.num_features)
        self.latent_factor_plate = pyro.plate("latent factors", self.K)
        
        
    def model(self):
        """
        how to generate a matrix
        """
        with self.latent_factor_plate:
            with self.feature_plate:
                # sample weight matrix with Normal prior distribution
                W = pyro.sample("W", pyro.distributions.Normal(0., 1.))                
                
            with self.sample_plate:
                # sample factor matrix with Normal prior distribution
                Z = pyro.sample("Z", pyro.distributions.Normal(0., 1.))
        
        # estimate for Y
        Y_hat = torch.matmul(Z, W.t())
        
        with pyro.plate("feature_", self.Y.shape[1]), pyro.plate("sample_", self.Y.shape[0]):
            # masking the NA values such that they are not considered in the distributions
            obs_mask = torch.ones_like(self.Y, dtype=torch.bool)
            if data is not None:
                obs_mask = torch.logical_not(torch.isnan(self.Y))
            with pyro.poutine.mask(mask=obs_mask):
                if data is not None:
                    # a valid value for the NAs has to be defined even though these samples will be ignored later
                    self.Y = torch.nan_to_num(self.Y, nan=0) 
            
                    # sample scale parameter for each feature-sample pair with LogNormal prior (has to be positive)
                    scale = pyro.sample("scale", pyro.distributions.LogNormal(0., 1.))
                    # compare sampled estimation to the true observation Y
                    pyro.sample("obs", pyro.distributions.Normal(Y_hat, scale), obs=self.Y)


    def train(self):
        # set training parameters
        optimizer = pyro.optim.Adam({"lr": 0.02})
        elbo = Trace_ELBO()
        guide = autoguide.AutoDelta(self.model)
        
        # initialize stochastic variational inference
        svi = SVI(
            model = self.model,
            guide = guide,
            optim = optimizer,
            loss = elbo
        )
        
        num_iterations = 2000
        train_loss = []
        for j in range(num_iterations):
            # calculate the loss and take a gradient step
            loss = svi.step()

            train_loss.append(loss/self.Y.shape[0])
            if j % 200 == 0:
                print("[iteration %04d] loss: %.4f" % (j + 1, loss / self.Y.shape[0]))
        
        # Obtain maximum a posteriori estimates for W and Z
        map_estimates = guide(Y)
        
        return train_loss, map_estimates

In [8]:
FA_model = FA(Y,5)
losses, estimates = FA_model.train()



[iteration 0001] loss: 114.7042




[iteration 0201] loss: 13.8319
[iteration 0401] loss: 11.0294
[iteration 0601] loss: 10.4479
[iteration 0801] loss: 10.2793
[iteration 1001] loss: 10.2267
[iteration 1201] loss: 10.1964
[iteration 1401] loss: 10.1988
[iteration 1601] loss: 10.2002
[iteration 1801] loss: 10.1879


In [10]:
estimates

{'W': tensor([[ 1.4652,  1.6283,  1.0000, -1.2970, -1.1173],
         [-1.8084,  2.3123,  2.7544,  0.5089, -3.6005],
         [ 2.4228, -0.5210, -0.9575, -0.5250, -0.1896],
         [ 0.3043,  0.9080, -1.1091,  2.1785, -1.7784],
         [-0.7594, -1.4002, -1.6800,  0.4682, -1.3413],
         [-0.8269, -0.0611, -0.1735,  1.0052,  0.2949],
         [ 2.6573, -0.0076, -2.2350,  1.7551,  0.9884],
         [ 2.0689, -0.5384,  0.1624,  1.0281, -0.1691],
         [ 0.1544, -2.2978, -1.8604,  0.0416, -1.7097],
         [ 0.0787, -0.6553,  1.6258,  0.6562,  0.4507],
         [-0.9854, -0.5038,  0.0189,  0.7944, -1.1674],
         [ 0.8928, -1.6178, -1.4007,  2.5214, -0.5550],
         [-1.1880,  1.7042, -0.3272, -2.5008,  0.5494],
         [-0.1548,  0.0899, -2.9778, -2.8621, -0.1637],
         [ 0.9466,  2.2355, -1.7860, -1.9431, -2.9382],
         [-0.5621,  2.3763,  0.3188,  0.3071, -0.7570],
         [-0.4888, -0.6190, -1.2797,  0.5718, -1.2596],
         [-2.7939, -2.2837, -0.1397,  1.283