In [2]:
%load_ext autoreload
%autoreload 2

import torch
from ppiformer.tasks.node import DDGPPIformer
from ppiformer.model.ppiformer import PPIformer
from ppiformer.utils.api import download_from_zenodo, predict_ddg, embed
from ppiformer.definitions import PPIFORMER_WEIGHTS_DIR, PPIFORMER_TEST_DATA_DIR

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
# Download the weights
download_from_zenodo('weights.zip')

Downloading to /Users/anton/dev/PPIformer/weights: 100%|██████████| 535M/535M [00:30<00:00, 17.5MiB/s] 
Extracting: 100%|██████████| 5/5 [00:03<00:00,  1.54files/s]


In [5]:
# Load the ensamble of fine-tuned models
device = 'cuda' if torch.cuda.is_available() else 'cpu'
models = [DDGPPIformer.load_from_checkpoint(PPIFORMER_WEIGHTS_DIR / f'ddg_regression/{i}.ckpt', map_location=torch.device('cpu')).eval() for i in range(3)]
models = [model.to(device) for model in models]

# Specify input
ppi_path = PPIFORMER_TEST_DATA_DIR / '1bui_A_C.pdb'  # PDB or PPIRef file (see https://ppiref.readthedocs.io/en/latest/extracting_ppis.html)
muts = ['SC16A', 'FC47A', 'SC16A,FC47A']  # List of single- or multi-point mutations

# Predict
ddg = predict_ddg(models, ppi_path, muts)
ddg

Process 29715 preparing data: 100%|██████████| 1/1 [00:00<00:00,  6.79it/s]


1 PPIs loaded: PPIInMemoryDataset(, n_muts=3)


tensor([-0.3708,  1.5188,  1.1482])

In [15]:
# Load the pre-trained model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = PPIformer.load_from_checkpoint(PPIFORMER_WEIGHTS_DIR / 'masked_modeling.ckpt', map_location=torch.device('cpu'))
model = model.to(device).eval()

# Specify input
ppi_path = PPIFORMER_TEST_DATA_DIR / '1bui_A_C.pdb'  # PDB or PPIRef file (see https://ppiref.readthedocs.io/en/latest/extracting_ppis.html)

# Embed (get the final type-0 features)
h = embed(model, ppi_path)
h.shape


  rank_zero_warn(
  rank_zero_warn(
Process 29715 preparing data: 100%|██████████| 1/1 [00:00<00:00,  7.44it/s]


1 PPIs loaded: PPIInMemoryDataset()


torch.Size([124, 128])