In [1]:
import torch
import pandas as pd
import pyaging as pya
import tensorflow as tf
from tensorflow.keras.models import load_model
import os

In [2]:
os.system("git clone https://github.com/rsinghlab/AltumAge")

0

In [3]:
AltumAge = load_model('AltumAge/example_dependencies/AltumAge.h5')  # Load your trained TensorFlow model

weights = {}
for layer in AltumAge.layers:
    weights[layer.name] = layer.get_weights()

# Instantiate your PyTorch model
model = pya.models.AltumAge()

# Function to copy weights from TensorFlow to PyTorch
def copy_weights(torch_layer, tf_weights, bn=False):
    with torch.no_grad():
        if bn:
            torch_layer.weight.data = torch.tensor(tf_weights[0]).float()
            torch_layer.bias.data = torch.tensor(tf_weights[1]).float()
            torch_layer.running_mean.data = torch.tensor(tf_weights[2]).float()
            torch_layer.running_var.data = torch.tensor(tf_weights[3]).float()
        else:
            torch_layer.weight.data = torch.tensor(tf_weights[0]).T.float()
            torch_layer.bias.data = torch.tensor(tf_weights[1]).float()

# Now copy the weights
# Note: Ensure the names match between TensorFlow and PyTorch layers
copy_weights(model.bn1, weights['batch_normalization_84'], bn=True)
copy_weights(model.linear1, weights['dense_84'])
copy_weights(model.bn2, weights['batch_normalization_85'], bn=True)
copy_weights(model.linear2, weights['dense_85'])
copy_weights(model.bn3, weights['batch_normalization_86'], bn=True)
copy_weights(model.linear3, weights['dense_86'])
copy_weights(model.bn4, weights['batch_normalization_87'], bn=True)
copy_weights(model.linear4, weights['dense_87'])
copy_weights(model.bn5, weights['batch_normalization_88'], bn=True)
copy_weights(model.linear5, weights['dense_88'])
copy_weights(model.bn6, weights['batch_normalization_89'], bn=True)
copy_weights(model.linear6, weights['dense_89'])

In [4]:
features = pd.read_pickle('AltumAge/example_dependencies/multi_platform_cpgs.pkl').tolist()

In [5]:
scaler = pd.read_pickle('AltumAge/example_dependencies/scaler.pkl')

In [17]:
reference_feature_values = list(scaler.center_)

In [18]:
weights_dict = {
    'reference_feature_values': reference_feature_values,
    'preprocessing': 'scale', 
    'preprocessing_helper': scaler,
    'postprocessing': None,
    'postprocessing_helper': None,
    'features': features,
    'weight_dict': model.state_dict(),
    'model_class': 'AltumAge',
}

metadata_dict = {
    'species': 'Homo sapiens',
    'data_type': 'methylation',
    'year': 2022,
    'implementation_approved_by_author(s)': '✅',
    'preprocessing': weights_dict['preprocessing'], 
    'postprocessing': weights_dict['postprocessing'], 
    'citation': "de Lima Camillo, Lucas Paulo, Louis R. Lapierre, and Ritambhara Singh. \"A pan-tissue DNA-methylation epigenetic clock based on deep learning.\" npj Aging 8.1 (2022): 4.",
    'doi': "https://doi.org/10.1038/s41514-022-00085-y",
    "notes": None,
}

In [19]:
torch.save(weights_dict, '../weights/altumage.pt')
torch.save(metadata_dict, '../metadata/altumage.pt')

In [20]:
os.system("rm -r AltumAge")

0