# FHE Aging

## Setup

In [1]:
import pandas as pd 
import pyaging as pya

In [2]:
pya.data.download_example_data('blood_chemistry_example')

|-----> 🏗️ Starting download_example_data function
|-----------> Data found in pyaging_data/blood_chemistry_example.pkl
|-----> 🎉 Done! [0.0024s]


In [3]:
df = pd.read_pickle('pyaging_data/blood_chemistry_example.pkl')

In [4]:
adata = pya.preprocess.df_to_adata(df)

|-----> 🏗️ Starting df_to_adata function
|-----> ⚙️ Create anndata object started
|-----> ✅ Create anndata object finished [0.0056s]
|-----> ⚙️ Add metadata to anndata started
|-----------? No metadata provided. Leaving adata.obs empty
|-----> ⚠️ Add metadata to anndata finished [0.0007s]
|-----> ⚙️ Log data statistics started
|-----------> There are 30 observations
|-----------> There are 10 features
|-----------> Total missing values: 0
|-----------> Percentage of missing values: 0.00%
|-----> ✅ Log data statistics finished [0.0017s]
|-----> ⚙️ Impute missing values started
|-----------> No missing values found. No imputation necessary
|-----> ✅ Impute missing values finished [0.0018s]
|-----> 🎉 Done! [0.0154s]


In [5]:
import marshal
import math
import ntpath
import os
import types
from typing import Dict, List, Tuple
from urllib.request import urlretrieve

import anndata
import numpy as np
import pandas as pd
import torch
from anndata.experimental.pytorch import AnnLoader
from torch.utils.data import DataLoader, TensorDataset

from pyaging.logger import LoggerManager, main_tqdm, silence_logger
from pyaging.models import *
from pyaging.utils import download, load_clock_metadata, progress
from pyaging.predict._postprocessing import *
from pyaging.predict._preprocessing import *

@progress("Predict ages with model")
def predict_ages_with_model(
    adata: anndata.AnnData,
    model: pyagingModel,
    device: str,
    batch_size: int,
    logger,
    indent_level: int = 2,
) -> torch.Tensor:
    """
    Predict biological ages using a trained model and input data.

    This function takes a machine learning model and input data, and returns predictions made by the model.
    It's primarily used for estimating biological ages based on various biological markers. The function
    assumes that the model is already trained. A dataloader is used because of possible memory constraints
    for large datasets.

    Parameters
    ----------
    adata : anndata.AnnData
        The AnnData object containing the dataset. Its `.X` attribute is expected to be a matrix where rows
        correspond to samples and columns correspond to features.

    model : pyagingModel
        The pyagingModel of the aging clock of interest.

    device : str
        Device to move AnnData to during inference. Eithe 'cpu' or 'cuda'.

    batch_size : int
        Batch size for the AnnLoader object to predict age.

    logger : Logger
        A logger object for logging the progress or any relevant information during the prediction process.

    indent_level : int, optional
        The indentation level for logging messages, by default 2.

    Returns
    -------
    predictions : torch.Tensor
        An array of predicted ages or biological markers, as returned by the model.

    Notes
    -----
    Ensure that the data is preprocessed (e.g., scaled, normalized) as required by the model before
    passing it to this function. The model should be in evaluation mode if it's a type that has different
    behavior during training and inference (e.g., PyTorch models).

    The exact nature of the predictions (e.g., age, biological markers) depends on the model being used.

    Examples
    --------
    >>> model = load_pretrained_model()
    >>> predictions = predict_ages_with_model(model, "cpu", logger)
    >>> print(predictions[:5])
    [34.5, 29.3, 47.8, 50.1, 42.6]

    """

    # If there is a preprocessing step
    if model.preprocess_name is not None:
        logger.info(
            f"The preprocessing method is {model.preprocess_name}",
            indent_level=indent_level + 1,
        )
    else:
        logger.info("There is no preprocessing necessary", indent_level=indent_level + 1)

    # If there is a postprocessing step
    if model.postprocess_name is not None:
        logger.info(
            f"The postprocessing method is {model.postprocess_name}",
            indent_level=indent_level + 1,
        )
    else:
        logger.info("There is no postprocessing necessary", indent_level=indent_level + 1)

    # Create an AnnLoader
    use_cuda = torch.cuda.is_available()
    dataloader = AnnLoader(adata, batch_size=batch_size, use_cuda=use_cuda)

    # with torch.no_grad():
    #     for param in model.parameters():
    #         param.zero_()

    # Use the AnnLoader for batched prediction
    predictions = []
    with torch.inference_mode():
        for batch in main_tqdm(dataloader, indent_level=indent_level + 1, logger=logger):
            batch_pred = model(batch.obsm[f"X_{model.metadata['clock_name']}"])
            predictions.append(batch_pred)
    # Concatenate all batch predictions
    predictions = torch.cat(predictions)

    return predictions

In [6]:
pya.pred.predict_age_fhe(adata, predict_ages_with_model, 'PhenoAge')

|-----> 🏗️ Starting predict_age function
|-----> ⚙️ Set PyTorch device started
|-----------> Using device: cpu
|-----> ✅ Set PyTorch device finished [0.0025s]
|-----> 🕒 Processing clock: phenoage
|-----------> ⚙️ Load clock started
|-----------------> Data found in pyaging_data/phenoage.pt
Layer: base_model.linear.weight | Size: torch.Size([1, 10]) | Values : tensor([[-0.0336,  0.0095,  0.1953,  0.0954, -0.0120,  0.0268,  0.3306,  0.0019,
          0.0554,  0.0804]], dtype=torch.float64, grad_fn=<SliceBackward0>) 

Layer: base_model.linear.bias | Size: torch.Size([1]) | Values : tensor([-19.9067], dtype=torch.float64, grad_fn=<SliceBackward0>) 

|-----------> ✅ Load clock finished [0.0066s]
|-----------> ⚙️ Check features in adata started
|-----------------> All features are present in adata.var_names.
|-----------> ✅ Check features in adata finished [0.0008s]
|-----------> ⚙️ Predict ages with model started
|-----------------> There is no preprocessing necessary
|-----------------> Th

In [7]:
# pya.pred.predict_age(adata, predict_ages_with_model, 'PhenoAge')

In [8]:
adata.obs.head()

Unnamed: 0,phenoage
patient1,74.348798
patient2,67.372
patient3,74.789739
patient4,46.991769
patient5,44.559486


In [9]:
adata

AnnData object with n_obs × n_vars = 30 × 10
    obs: 'phenoage'
    var: 'percent_na'
    uns: 'phenoage_percent_na', 'phenoage_missing_features', 'phenoage_metadata'
    layers: 'X_original'

In [10]:
combined_df = pd.DataFrame({
    'phenoage': adata.obs["phenoage"],
    'chronological_age': df["age"]
})

# Display the first few rows
combined_df[:15]

Unnamed: 0,phenoage,chronological_age
patient1,74.348798,70.2
patient2,67.372,76.5
patient3,74.789739,66.4
patient4,46.991769,46.5
patient5,44.559486,42.3
patient6,72.50946,76.9
patient7,57.37705,55.1
patient8,31.779798,34.6
patient9,50.356509,47.3
patient10,67.696706,52.3


## Manual Torch Execution

In [11]:
class PhenoAge(pyagingModel):
    def __init__(self):
        super().__init__()

    def preprocess(self, x):
        return x

    def postprocess(self, x):
        """
        Applies a convertion from a CDF of the mortality score from a Gompertz
        distribution to phenotypic age.
        """
        # lambda
        l = torch.tensor(0.0192, device=x.device, dtype=x.dtype)
        mortality_score = 1 - torch.exp(-torch.exp(x) * (torch.exp(120 * l) - 1) / l)
        age = 141.50225 + torch.log(-0.00553 * torch.log(1 - mortality_score)) / 0.090165
        return age

In [13]:
device = "cpu"
weights_path = f"./pyaging_data/phenoage.pt"
clock = torch.load(weights_path, weights_only=False)

for name, param in clock.named_parameters():
    print(f"Layer: {name} | Size: {param.size()} | Values : {param[:2]} \n")

# Prepare clock for inference
clock.to(torch.float64)
clock.to(device)
clock.eval()

Layer: base_model.linear.weight | Size: torch.Size([1, 10]) | Values : tensor([[-0.0336,  0.0095,  0.1953,  0.0954, -0.0120,  0.0268,  0.3306,  0.0019,
          0.0554,  0.0804]], dtype=torch.float64, grad_fn=<SliceBackward0>) 

Layer: base_model.linear.bias | Size: torch.Size([1]) | Values : tensor([-19.9067], dtype=torch.float64, grad_fn=<SliceBackward0>) 



PhenoAge(
  (base_model): LinearModel(
    (linear): Linear(in_features=10, out_features=1, bias=True)
  )
)

In [14]:
adata

AnnData object with n_obs × n_vars = 30 × 10
    obs: 'phenoage'
    var: 'percent_na'
    uns: 'phenoage_percent_na', 'phenoage_missing_features', 'phenoage_metadata'
    layers: 'X_original'

In [15]:
clock.metadata['clock_name']

'phenoage'

In [16]:
adata.obs

Unnamed: 0,phenoage
patient1,74.348798
patient2,67.372
patient3,74.789739
patient4,46.991769
patient5,44.559486
patient6,72.50946
patient7,57.37705
patient8,31.779798
patient9,50.356509
patient10,67.696706


In [31]:
adata.obs.to_csv("out.csv")

In [17]:
dataset = df.iloc[:, 0:10]

In [18]:
dataset_np = dataset.to_numpy()
phenoages_np = np.array(adata.obs["phenoage"], dtype=np.float64)

In [19]:
dataset_torch = torch.tensor(dataset_np, dtype=torch.float64)
phenoages_torch = torch.tensor(phenoages_np, dtype=torch.float64)

In [20]:
with torch.inference_mode():
    pred = clock(dataset_torch)

In [21]:
pred

tensor([[74.3488],
        [67.3720],
        [74.7897],
        [46.9918],
        [44.5595],
        [72.5095],
        [57.3771],
        [31.7798],
        [50.3565],
        [67.6967],
        [62.6020],
        [41.7359],
        [82.2387],
        [56.6775],
        [46.4021],
        [63.7108],
        [84.7842],
        [87.1650],
        [90.2054],
        [62.2351],
        [25.2728],
        [55.2115],
        [69.7079],
        [49.1802],
        [45.2600],
        [35.3339],
        [81.8737],
        [64.5594],
        [79.2270],
        [58.7839]], dtype=torch.float64)

## Quantize Linear Regression Models for FHE

In [29]:
import numpy as np
from sklearn.linear_model import ElasticNet

manual_coefficients = np.array([
    -0.0336,  0.0095,  0.1953,  0.0954, -0.0120,  0.0268,  0.3306,  0.0019, 0.0554,  0.0804
])
manual_intercept = np.array([
    -19.9067
])
sklearn_model = ElasticNet(alpha=1.0, l1_ratio=0.5)
sklearn_model.n_features_in_ = len(manual_coefficients)
sklearn_model.coef_ = manual_coefficients
sklearn_model.intercept_ = manual_intercept

raw_pred = sklearn_model.predict(dataset_np)
pred_torch = torch.tensor(raw_pred, dtype=torch.float32)

phenoage_post = PhenoAge()
pred = phenoage_post.postprocess(pred_torch)
pred

tensor([74.3488, 67.3720, 74.7897, 46.9918, 44.5595, 72.5095, 57.3770, 31.7798,
        50.3565, 67.6967, 62.6020, 41.7359, 82.2387, 56.6775, 46.4021, 63.7108,
        84.7842, 87.1649, 90.2054, 62.2351, 25.2728, 55.2115, 69.7079, 49.1802,
        45.2600, 35.3339, 81.8737, 64.5594, 79.2271, 58.7839])

## (Attempt Torch) Quantize Model for FHE

In [35]:
import time
import numpy as np

def test_in_fhe(quantized_numpy_module, X_test, y_test, simulate=True):
    if not simulate:
        print("Generating key")
        start_key = time.time()
        quantized_numpy_module.fhe_circuit.keygen()
        end_key = time.time()
        print(f"Key generation finished in {end_key - start_key:.2f} seconds")

    fhe_mode = "simulate" if simulate else "execute"

    start_infer = time.time()
    predictions = quantized_numpy_module.forward(X_test, fhe=fhe_mode).argmax(1)
    end_infer = time.time()

    if not simulate:
        print(
            f"Inferences finished in {end_infer - start_infer:.2f} seconds "
            f"({(end_infer - start_infer)/len(X_test):.2f} seconds/sample)"
        )

    # Compute accuracy
    accuracy = np.mean(predictions == y_test) * 100
    print(
        "FHE " + ("(simulation) " * simulate) + f"accuracy: {accuracy:.2f}% on "
        f"{len(X_test)} examples."
    )
    return predictions

In [36]:
os.environ["SCIPY_ARRAY_API"] = "1"

from concrete.ml.torch.compile import compile_torch_model

# We need to unprune the model before compiling
# clock.unprune()

dataset_np_f32 = dataset_np.astype(np.float32)

# print("dataset_np_f32.dtype:", dataset_np_f32.dtype)

# dataset_f32 = dataset_torch.to(torch.float32)

clock = clock.to(torch.float32)

# quantized_numpy_module = compile_brevitas_qat_model(clock, dataset_torch)
quantized_module = compile_torch_model(
    clock, # our model
    dataset_np_f32, # a representative input-set to be used for both quantization and compilation
    n_bits=8,
    rounding_threshold_bits={"n_bits": 8, "method": "approximate"}
)

# prediction_simulated = test_in_fhe(quantized_numpy_module, dataset_torch, phenoages_torch, simulate=True)

  l = torch.tensor(0.0192, device=x.device, dtype=x.dtype)


In [37]:
dataset_np_f32.shape

(30, 10)

In [42]:
y_pred = quantized_module.forward(dataset_np_f32, fhe="simulate", debug=True)

RuntimeError: Values can only be constructed from arrays of signed and unsigned integers.

In [39]:
y_pred

array([[74.04157903],
       [66.08271222],
       [74.04157903],
       [45.82377855],
       [45.82377855],
       [70.18273452],
       [57.88266764],
       [29.66486717],
       [49.92380084],
       [66.08271222],
       [61.98268993],
       [41.72375626],
       [82.24162361],
       [53.78264535],
       [45.82377855],
       [66.08271222],
       [86.3416459 ],
       [86.3416459 ],
       [90.20049041],
       [61.98268993],
       [28.70015604],
       [53.78264535],
       [70.18273452],
       [49.92380084],
       [45.82377855],
       [33.52371168],
       [82.24162361],
       [61.98268993],
       [78.14160132],
       [57.88266764]])