In [1]:
import os
import random
import lmdb
import pickle
import torch
from torch.utils.data import Dataset

class MetaDataset(Dataset):
    def __init__(self, cluster_data_dir, k_shot, k_query):
        """
        Initialize MetaDataset, for meta-learning.
        
        Args:
            cluster_data_dir (str): directory path containing all cluster LMDB files.
            k_shot (int): number of samples in support set.
            k_query (int): number of samples in query set.
        """
        self.cluster_data_dir = cluster_data_dir
        self.k_shot = k_shot
        self.k_query = k_query

        # get all LMDB paths
        self.lmdb_paths = self._get_all_lmdb_paths()

        # map cluster_id to LMDB path
        self.cluster_id_to_lmdb_path = self._map_cluster_ids_to_lmdb_path()

        # get PDB IDs list for each cluster
        self.cluster_ids = list(self.cluster_id_to_lmdb_path.keys())
        self.cluster_id_to_pdb_ids = self._get_cluster_pdb_ids()

    def _get_all_lmdb_paths(self):
        """get all LMDB paths in cluster_data_dir."""
        lmdb_files = []
        for fname in os.listdir(self.cluster_data_dir):
            path = os.path.join(self.cluster_data_dir, fname)
            if os.path.isdir(path) and fname.endswith('.lmdb'):
                lmdb_files.append(path)
            elif os.path.isfile(path) and fname.endswith('.lmdb'):
                lmdb_files.append(path)
        return lmdb_files

    def _map_cluster_ids_to_lmdb_path(self):
        """map cluster_id to LMDB path."""
        cluster_id_to_lmdb_path = {}
        for lmdb_path in self.lmdb_paths:
            basename = os.path.basename(lmdb_path)
            # 假设命名规则为 cluster_{cluster_id}.lmdb
            if basename.startswith('cluster_') and basename.endswith('.lmdb'):
                cluster_id = basename[len('cluster_'):-len('.lmdb')]
                cluster_id_to_lmdb_path[cluster_id] = lmdb_path
        return cluster_id_to_lmdb_path

    def _get_cluster_pdb_ids(self):
        """get PDB IDs list for each cluster."""
        cluster_id_to_pdb_ids = {}
        for cluster_id, lmdb_path in self.cluster_id_to_lmdb_path.items():
            env = lmdb.open(lmdb_path, readonly=True, lock=False)
            pdb_ids = []
            with env.begin() as txn:
                cursor = txn.cursor()
                for key, _ in cursor:
                    pdb_id = key.decode()
                    pdb_ids.append(pdb_id)
            env.close()
            cluster_id_to_pdb_ids[cluster_id] = pdb_ids
        return cluster_id_to_pdb_ids

    def __len__(self):
        """return the number of clusters (tasks)."""
        return len(self.cluster_ids)

    def __getitem__(self, idx):
        """
        get a task (cluster), including support set and query set.

        Args:
            idx (int): index of the cluster.

        Returns:
            dict: a dictionary containing 'support_set' and 'query_set', with values as lists of data samples.
        """
        cluster_id = self.cluster_ids[idx]
        lmdb_path = self.cluster_id_to_lmdb_path[cluster_id]
        pdb_ids = self.cluster_id_to_pdb_ids[cluster_id]

        # ensure enough samples
        total_samples_needed = self.k_shot + self.k_query
        if len(pdb_ids) < total_samples_needed:
            raise ValueError(f"Not enough samples in cluster {cluster_id}")

        # randomly select support set and query set PDB IDs
        pdb_ids_sampled = random.sample(pdb_ids, total_samples_needed)
        support_pdb_ids = pdb_ids_sampled[:self.k_shot]
        query_pdb_ids = pdb_ids_sampled[self.k_shot:]

        # load support set data
        support_set = []
        env = lmdb.open(lmdb_path, readonly=True, lock=False)
        with env.begin() as txn:
            for pdb_id in support_pdb_ids:
                data = txn.get(pdb_id.encode())
                if data is not None:
                    item = pickle.loads(data)
                    protein_graph = item['protein_graph']
                    ligand_graph = item['ligand_graph']
                    kd_value = torch.tensor([item['kd_value']], dtype=torch.float)
                    support_set.append((protein_graph, ligand_graph, kd_value, pdb_id))
                else:
                    env.close()
                    raise KeyError(f"PDB ID {pdb_id} not found in LMDB.")

        # load query set data
        query_set = []
        with env.begin() as txn:
            for pdb_id in query_pdb_ids:
                data = txn.get(pdb_id.encode())
                if data is not None:
                    item = pickle.loads(data)
                    protein_graph = item['protein_graph']
                    ligand_graph = item['ligand_graph']
                    kd_value = torch.tensor([item['kd_value']], dtype=torch.float)
                    query_set.append((protein_graph, ligand_graph, kd_value, pdb_id))
                else:
                    env.close()
                    raise KeyError(f"PDB ID {pdb_id} not found in LMDB.")
        env.close()

        return {'support_set': support_set, 'query_set': query_set}


if __name__ == '__main__':
    dataset = MetaDataset(cluster_data_dir='./cluster_data', k_shot=10, k_query=15)
    print(len(dataset))
    print(dataset[0])


30
{'support_set': [(Data(x=[545, 1280], edge_index=[2, 5912], pos=[545, 3]), Data(x=[22, 9], edge_index=[2, 44], edge_attr=[44, 3], pos=[22, 3]), tensor([8.3400]), '3lj7'), (Data(x=[317, 1280], edge_index=[2, 3094], pos=[317, 3]), Data(x=[31, 9], edge_index=[2, 62], edge_attr=[62, 3], pos=[31, 3]), tensor([1.7000]), '6cex'), (Data(x=[426, 1280], edge_index=[2, 4122], pos=[426, 3]), Data(x=[98, 9], edge_index=[2, 204], edge_attr=[204, 3], pos=[98, 3]), tensor([8.0600]), '5ea5'), (Data(x=[387, 1280], edge_index=[2, 4050], pos=[387, 3]), Data(x=[72, 9], edge_index=[2, 150], edge_attr=[150, 3], pos=[72, 3]), tensor([6.3700]), '4r91'), (Data(x=[330, 1280], edge_index=[2, 3000], pos=[330, 3]), Data(x=[39, 9], edge_index=[2, 78], edge_attr=[78, 3], pos=[39, 3]), tensor([4.8200]), '4e6c'), (Data(x=[223, 1280], edge_index=[2, 2350], pos=[223, 3]), Data(x=[24, 9], edge_index=[2, 48], edge_attr=[48, 3], pos=[24, 3]), tensor([2.8500]), '3rxa'), (Data(x=[215, 1280], edge_index=[2, 2220], pos=[215,