In [5]:
%load_ext autoreload 
%autoreload 2

from pathlib import Path
import torchdrug
from torchdrug import datasets, core, tasks, transforms
import torch

from plaid.datasets import CATHShardedDataModule
from plaid.compression.hourglass_vq import HourglassVQLightningModule
from torchdrug import layers
from plaid.esmfold.misc import batch_encode_sequences
from plaid.esmfold import esmfold_v1

TORCH_DRUG_DATASETS_PATH = Path("/homefs/home/lux70/storage/data/torchdrug")
device = torch.device("cuda")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [25]:
compression_model_id = "kyytc8i9"
ckpt_dir = Path("/homefs/home/lux70/storage/plaid/checkpoints/hourglass_vq")

ckpt_path = ckpt_dir / compression_model_id / "last.ckpt"
hourglass = HourglassVQLightningModule.load_from_checkpoint(ckpt_path)

using quantizer tanh


In [23]:
class ModelWrapper(torch.nn.Module):
    def __init__(
        self,
        compression_model_id = "kyytc8i9",
        ckpt_dir = Path("/homefs/home/lux70/storage/plaid/checkpoints/hourglass_vq"),
        esmfold=None,
        num_mlp_layers=2,
        num_classes=None,
    ):
        super().__init__()
        self.compression_model_id = compression_model_id
        self.ckpt_dir = ckpt_dir
        

        if esmfold is None:
            esmfold = esmfold_v1()

        hourglass = HourglassVQLightningModule.load_from_checkpoint(ckpt_dir / compression_model_id / "last.ckpt")
        output_dim = 1024 // hourglass.enc.downproj_factor

        mlp = layers.MLP(
            output_dim,
            [output_dim] * (num_mlp_layers - 1),
            batch_norm=False,
        )
        self.esmfold = esmfold.requires_grad_(False)
        self.hourglass = hourglass.requires_grad_(False)
        self.output_dim = output_dim
        self.mlp = mlp.to(device).requires_grad_(True)
        
        if num_classes is not None:
            # assume classification
            self.out_layer = torch.nn.Linear(output_dim, num_classes)
        else:
            # assume regression
            self.out_layer = torch.nn.Linear(output_dim, 1)
        

    def forward(self, sequences):
        with torch.no_grad():
            latent = self.esmfold.infer_embedding(sequences)['s'].to(self.device)
            compressed = self.hourglass(latent, mask=None, infer_only=True)  # (N, L, C)
            compressed = torch.from_numpy(compressed).to(self.device)  # no grad passed
        return self.out_layer(self.mlp(compressed.mean(dim=1))), compressed

#esmfold = esmfold_v1()
model = ModelWrapper(esmfold=esmfold)

using quantizer tanh


# Beta Lactamase

In [6]:
truncate_transform = transforms.TruncateProtein(max_length=350, random=False)
dataset = datasets.BetaLactamase(TORCH_DRUG_DATASETS_PATH, transform=truncate_transform, lazy=True)
train_set, valid_set, test_set = dataset.split()

19:56:10   Extracting /homefs/home/lux70/storage/data/torchdrug/beta_lactamase.tar.gz to /homefs/home/lux70/storage/data/torchdrug


Constructing proteins from sequences: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5198/5198 [00:00<00:00, 12790.06it/s]


In [24]:
model = model.to(device)
model(dataset.sequences[:10])

tensor([[-0.0998],
        [-0.1001],
        [-0.0999],
        [-0.1002],
        [-0.0999],
        [-0.1000],
        [-0.0998],
        [-0.0999],
        [-0.0998],
        [-0.1000]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [30]:

from torchdrug import datasets, transforms, tasks, core
import torch

truncate_transform = transforms.TruncateProtein(max_length=200, random=False)
protein_view_transform = transforms.ProteinView(view="residue")
transform = transforms.Compose([truncate_transform, protein_view_transform])

dataset = datasets.BetaLactamase("~/storage/data/torchdrug", atom_feature=None, bond_feature=None, residue_feature="default", transform=transform)
train_set, valid_set, test_set = dataset.split()
# task = tasks.PropertyPrediction(model, task=dataset.tasks,
#                                 criterion="mse", metric=("mae", "rmse", "spearmanr"),
#                                 normalization=False, num_mlp_layer=2)

00:59:45   Extracting /homefs/home/lux70/storage/data/torchdrug/beta_lactamase.tar.gz to /homefs/home/lux70/storage/data/torchdrug


Constructing proteins from sequences: 100%|████████████████████████████████████████████████████████████████████| 5198/5198 [00:08<00:00, 645.76it/s]


In [40]:
train_set[0]

{'graph': Protein(num_atom=0, num_bond=0, num_residue=200),
 'scaled_effect1': 0.9426838159561157}

In [42]:
batch = train_set[0]
batch['graph'].residue_type

tensor([14,  2,  7, 11, 15, 16, 17,  4,  1,  8,  7,  3, 16, 16,  1,  1, 16,  6,
         8,  3,  4, 16,  1, 15,  3, 13,  5,  8,  4, 12,  4, 12, 10,  1, 13, 10,
        11,  8,  0,  1, 17,  4,  0, 18,  7, 13,  8, 10,  8,  9,  2,  0, 12,  7,
         8, 13,  2, 16, 17,  3, 13, 13, 17, 16,  3, 14, 14,  2,  5, 16, 12,  4,
         8,  8,  6,  0,  1,  4,  8,  2, 17,  4, 10,  1,  0, 11, 13, 11,  8,  0,
        17, 17,  7, 15, 18,  2, 11,  9, 10,  8,  4, 13, 18,  2,  3,  4,  5, 13,
        12, 15,  8,  5, 10,  0, 14,  5,  4, 17, 13,  8,  6,  2,  1,  1,  7,  5,
        14,  2, 10,  9,  5,  1,  1,  9,  8,  7,  8,  5,  5,  7,  0,  0,  3, 12,
        13,  8,  5,  1, 16,  8, 15,  9, 14,  0, 10, 15,  4,  5, 17,  8, 10, 17,
        19, 13,  3, 13,  8,  9, 13,  1,  7,  3,  9, 10, 13, 17, 10,  5,  5, 14,
         3,  1,  1, 14,  1,  5,  5,  8, 17, 12,  8,  8,  5,  0, 13,  8,  8,  5,
         8,  1])

In [33]:
batch['graph'].__dict__.keys()

dict_keys(['_meta_contexts', 'meta_dict', '_edge_list', '_edge_weight', 'num_atom', 'num_bond', 'num_relation', 'atom_type', 'formal_charge', 'explicit_hs', 'chiral_tag', 'radical_electrons', 'atom_map', 'bond_type', 'bond_stereo', 'stereo_atoms', 'num_residue', 'view', 'atom_name', 'atom2residue', 'is_hetero_atom', 'occupancy', 'b_factor', 'residue_type', 'residue_feature', 'residue_number', 'insertion_code', 'chain_id'])

In [39]:
batch['graph'].residue_type.max()

tensor(19)

# 

In [85]:
fluorescence_ds = datasets.Fluorescence(TORCH_DRUG_DATASETS_PATH, transform=transform,lazy=True)

19:39:21   Extracting /homefs/home/lux70/storage/data/torchdrug/fluorescence.tar.gz to /homefs/home/lux70/storage/data/torchdrug


Constructing proteins from sequences: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 54025/54025 [00:00<00:00, 253494.00it/s]


In [86]:
soluability_ds = datasets.Solubility(TORCH_DRUG_DATASETS_PATH, transform=transform,lazy=True)

19:39:27   Downloading https://miladeepgraphlearningproteindata.s3.us-east-2.amazonaws.com/peerdata/solubility.tar.gz to /homefs/home/lux70/storage/data/torchdrug/solubility.tar.gz
19:39:31   Extracting /homefs/home/lux70/storage/data/torchdrug/solubility.tar.gz to /homefs/home/lux70/storage/data/torchdrug


Constructing proteins from sequences: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 71419/71419 [00:00<00:00, 153910.69it/s]


In [87]:
binary_loc_ds = datasets.BinaryLocalization(TORCH_DRUG_DATASETS_PATH, transform=transform,lazy=True)

19:39:41   Downloading https://miladeepgraphlearningproteindata.s3.us-east-2.amazonaws.com/peerdata/subcellular_localization_2.tar.gz to /homefs/home/lux70/storage/data/torchdrug/subcellular_localization_2.tar.gz
19:39:42   Extracting /homefs/home/lux70/storage/data/torchdrug/subcellular_localization_2.tar.gz to /homefs/home/lux70/storage/data/torchdrug


Constructing proteins from sequences: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8662/8662 [00:00<00:00, 8802.92it/s]


In [88]:
sub_loc_ds = datasets.SubcellularLocalization(TORCH_DRUG_DATASETS_PATH, transform=transform,lazy=True)

19:39:47   Downloading https://miladeepgraphlearningproteindata.s3.us-east-2.amazonaws.com/peerdata/subcellular_localization.tar.gz to /homefs/home/lux70/storage/data/torchdrug/subcellular_localization.tar.gz
19:39:48   Extracting /homefs/home/lux70/storage/data/torchdrug/subcellular_localization.tar.gz to /homefs/home/lux70/storage/data/torchdrug


Constructing proteins from sequences: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14004/14004 [00:00<00:00, 31486.60it/s]


In [89]:
ec_ds = datasets.EnzymeCommission(TORCH_DRUG_DATASETS_PATH, transform=transform,lazy=True)

19:39:50   Downloading https://zenodo.org/record/6622158/files/EnzymeCommission.zip to /homefs/home/lux70/storage/data/torchdrug/EnzymeCommission.zip
19:40:48   Extracting /homefs/home/lux70/storage/data/torchdrug/EnzymeCommission.zip to /homefs/home/lux70/storage/data/torchdrug
19:40:51   Extracting /homefs/home/lux70/storage/data/torchdrug/EnzymeCommission/train.zip to /homefs/home/lux70/storage/data/torchdrug/EnzymeCommission
19:41:54   Extracting /homefs/home/lux70/storage/data/torchdrug/EnzymeCommission/valid.zip to /homefs/home/lux70/storage/data/torchdrug/EnzymeCommission
19:42:01   Extracting /homefs/home/lux70/storage/data/torchdrug/EnzymeCommission/test.zip to /homefs/home/lux70/storage/data/torchdrug/EnzymeCommission


Constructing proteins from pdbs: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19198/19198 [00:00<00:00, 134829.13it/s]
Dumping to /homefs/home/lux70/storage/data/torchdrug/EnzymeCommission/enzyme_commission.pkl.gz: 100%|███████████████████████████████████████████████████| 19198/19198 [00:00<00:00, 72040.30it/s]


In [22]:
go_ds = datasets.GeneOntology(TORCH_DRUG_DATASETS_PATH, transform=transform,lazy=True)

NameError: name 'transform' is not defined

In [None]:
fold_ds = datasets.Fold(TORCH_DRUG_DATASETS_PATH, transform=transform,lazy=True)

In [None]:
ss_ds = datasets.SecondaryStructure(TORCH_DRUG_DATASETS_PATH, transform=transform,lazy=True)

In [None]:
pn_ds = datasets.ProteinNet(TORCH_DRUG_DATASETS_PATH, transform=transform,lazy=True)

In [None]:
ppi_ds = datasets.HumanPPI(TORCH_DRUG_DATASETS_PATH, transform=transform,lazy=True)

In [None]:
yeast_ds = datasets.YeastPPI(TORCH_DRUG_DATASETS_PATH, transform=transform,lazy=True)

In [None]:
ppi_aff_ds = datasets.PIAffinity(TORCH_DRUG_DATASETS_PATH, transform=transform,lazy=True)