In [27]:
import lmdb
import pickle
import torch
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Batch

class PDBBindDataset(Dataset):
    def __init__(self, lmdb_path="./../Testset/pdbbind.lmdb"):
        self.env = lmdb.open(lmdb_path, readonly=True, lock=False)
        self.txn = self.env.begin()
        self.data_dict = self._load_data()

    def _load_data(self):
        data_dict = {}
        for key, value in self.txn.cursor():
            cluster_data = pickle.loads(value)
            for pdb_id, pdb_data in cluster_data.items():
                data_dict[pdb_id] = pdb_data
        return data_dict

    def __len__(self):
        return len(self.data_dict)

    def __getitem__(self, idx):
        pdb_id = list(self.data_dict.keys())[idx]
        item = self.data_dict[pdb_id]
        
        protein_graph = item['protein_graph']
        ligand_graph = item['ligand_graph']
        kd_value = torch.tensor([item['kd_value']], dtype=torch.float)
        
        return protein_graph, ligand_graph, kd_value, pdb_id

    def __del__(self):
        self.env.close()

def collate_fn(batch):
    protein_graphs, ligand_graphs, kd_values, pdb_ids = zip(*batch)
    
    batched_protein_graphs = Batch.from_data_list(protein_graphs)
    batched_ligand_graphs = Batch.from_data_list(ligand_graphs)
    batched_kd_values = torch.cat(kd_values)
    
    return batched_protein_graphs, batched_ligand_graphs, batched_kd_values, pdb_ids

# usage 
# lmdb_path = "./../Testset/pdbbind.lmdb"
dataset = PDBBindDataset()
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

In [28]:
# example: a tuple contains protein_graph/molecule_graph/ke_value/pdb_id
# batch refer to which graph the node belongs to
# ptr
'''
protein_graph:DataBatch(
    x=[batch_num_nodes,in_features],
    edge_index=[2,num_edges],
    pos=[batch_num_nodes,3],
    batch=[batch_num_nodes],
    ptr=[batch_size+1])
molecule_graph:DataBatch(
    x=[batch_num_nodes,in_features],
    edge_index=[2,num_edges],
    edge_attr=[num_edges,num_edge_features],
    pos=[batch_num_nodes,3],
    batch=[batch_num_nodes],
    ptr=[batch_size+1])
kd_value:tensor([batch_size])
pdb_id:[batch_size]
'''
example=list(dataloader)[0]

In [29]:
example

(DataBatch(x=[1427, 1284], edge_index=[2, 13640], pos=[1427, 3], batch=[1427], ptr=[5]),
 DataBatch(x=[319, 9], edge_index=[2, 656], edge_attr=[656, 3], pos=[319, 3], batch=[319], ptr=[5]),
 tensor([4.3000, 3.3000, 6.6300, 4.4000]),
 ('5ivz', '5mt0', '4xrq', '3itu'))