# <center> 工程实现-MetaScore

## 1. 数据加载

首先，训练所需要的数据全部来自于pdbbind并且已被处理为图格式，可以直接加载:

In [1]:
import torch
data_path = './dataset/pdbbind/v2020_train_dict.pt'
data = torch.load(data_path,weights_only=False)

看一下数据组织格式：所有数据是一个大的嵌套字典，它的键是`pdb_id`，值是一个字典，其键`prot`,`lig`,`label`分别对应了蛋白质图、配体图和亲和力数据:

In [2]:
data['10gs']

{'prot': Data(x=[83, 41], edge_index=[2, 2014], edge_attr=[2014, 5], pos=[83, 24, 3]),
 'lig': Data(x=[33, 41], edge_index=[2, 68], edge_attr=[68, 10], pos=[33, 3]),
 'label': '6.4'}


与传统机器学习不同，元学习训练的数据单元是任务而非单个数据点，因此需要对pdbbind中的蛋白-配体对数据进行任务划分，在MetaScore中表现为对蛋白质结构相似度进行层次聚类，以聚类得到的簇为单个任务进行训练。对应的，`DataLoader`的写法也相应的改变：

- `DataLoader`一次从所有任务中采样`batch_size`个任务
- 每个任务应当采样`num_classes_per_task`个类别，即选取不同的聚类用于构建一个任务
- 每个聚类中随机采样`k-shot`个样本用于构建支持集，`q-query`个样本用于构建查询集
- 将一个任务中所有聚类中被选中的支持集样本和查询集样本聚合后，即可视为`meta-batch`中的一个任务
- 考虑到显存问题，无法再将`meta-batch`中的任务也做一遍聚合然后一次算完（显存会爆炸），只能在一个批次中逐任务累计梯度



聚类的结果已经拿到，加载一下:

In [3]:
import pandas as pd
cluster_path = './dataset/cluster/clustering_results.csv'
cluster_df = pd.read_csv(cluster_path)
cluster_df.head()

Unnamed: 0,PDB_ID,Cluster
0,2z7i,257
1,3i7g,333
2,2v88,197
3,3wav,150
4,6qlu,0


而后我们根据聚类的结果来构建一个新的数据字典，在这个字典中，键应当改为`Cluster`，值应为列表，每个列表的元素代表位于该聚类下的一个数据点，以原数据中字典的格式存储（在原数据字典的基础上额外增加了`pdb_id`作为数据标识符）:

In [4]:
import tqdm
dict_with_cluster = {}
for pdb_id, value in tqdm.tqdm(data.items()):
    value['pdb_id'] = pdb_id
    row = cluster_df[cluster_df['PDB_ID'] == pdb_id]
    cluster_id = row['Cluster'].values[0]
    if f'class_{cluster_id}' not in dict_with_cluster:
        dict_with_cluster[f'class_{cluster_id}'] = []
    dict_with_cluster[f'class_{cluster_id}'].append(value)
dict_with_cluster['class_1']

100%|██████████| 19145/19145 [00:27<00:00, 693.98it/s]


[{'prot': Data(x=[38, 41], edge_index=[2, 794], edge_attr=[794, 5], pos=[38, 24, 3]),
  'lig': Data(x=[35, 41], edge_index=[2, 72], edge_attr=[72, 10], pos=[35, 3]),
  'label': '5.7',
  'pdb_id': '4z0u'},
 {'prot': Data(x=[83, 41], edge_index=[2, 2264], edge_attr=[2264, 5], pos=[83, 24, 3]),
  'lig': Data(x=[21, 41], edge_index=[2, 46], edge_attr=[46, 10], pos=[21, 3]),
  'label': '6.0',
  'pdb_id': '3dxm'}]

在此结果基础上，考虑划分元学习的`meta-train`和`meta-test`集，在`MetaScore`中，考虑到降低实现难度，该集的划分以每个聚类中数据点的数量划分，具体的，不少于17个数据点的聚类会被分类为训练集，不少于7个数据点的聚类分类为测试集，其他聚类暂时丢弃，或考虑用于模型`zeroshot`性能测试（这部分聚类数量很多但是数据量很小，可以丢弃而不造成严重影响）。为此可以写一个`dict_with_split`，存储每个聚类属于的集合:

In [5]:
import numpy as np
train_set_min_size = 17
val_set_min_size = 7
dict_with_split = {'train': [], 'val': [], 'test': []}
for class_idx, value in dict_with_cluster.items():
    data_count = len(value)
    if data_count >= train_set_min_size:
        dict_with_split['train'].append(class_idx)
    elif data_count >= val_set_min_size and data_count < train_set_min_size:
        dict_with_split['val'].append(class_idx)
    else:
        dict_with_split['test'].append(class_idx)
print(f"Train: {len(dict_with_split['train'])} classes, Val: {len(dict_with_split['val'])} classes, Test: {len(dict_with_split['test'])} classes")
dataset_size_dict = {
    "train": {key: len(dict_with_cluster[key]) for key in dict_with_split["train"]},
    "val": {key: len(dict_with_cluster[key]) for key in dict_with_split["val"]},
    "test": {key: len(dict_with_cluster[key]) for key in dict_with_split["test"]},
}
data_length = {name: np.sum(np.array(list(dataset_size_dict[name].values()))) for name in dict_with_split.keys()}
print("data length of each split", data_length)

Train: 186 classes, Val: 153 classes, Test: 808 classes
data length of each split {'train': 15810, 'val': 1574, 'test': 1761}


到此为止完成了数据单元（单个聚类）的数据集划分，接下来要考虑如何使用数据单元来构建`meta-batch`，这涉及到:
- 任务构建逻辑，如何在所有聚类中选择`batch_size`个聚类构建一个任务
- 聚类中的样本采样逻辑，如何在一个聚类中选取样本构建支持集和查询集
这两点可以通过一个`TaskSampler`类来完成:

In [6]:
class TaskSampler:
    """
    Samples tasks based on specified rules.
    Can select classes uniformly or based on class size using softmax probability.
    Samples support/query items uniformly from within selected classes.
    """
    def __init__(self,
                 dataset_size_dict: dict,
                 train_num_classes_per_set: int,
                 val_num_classes_per_set: int,
                 train_num_support: int,
                 val_num_support: int,
                 train_num_query: int,
                 val_num_query: int,
                 sampling_rule: str = "uniform"): # Added sampling_rule parameter
        """
        Initializes the TaskSampler.

        Args:
            dataset_size_dict: Dictionary mapping split names ('train', 'val') to dictionaries
                               of {class_name: class_size}.
            train_num_classes_per_set: Number of classes per task in the training set.
            val_num_classes_per_set: Number of classes per task in the validation set.
            train_num_support: Number of support examples per class in the training set.
            val_num_support: Number of support examples per class in the validation set.
            train_num_query: Number of query examples per class in the training set.
            val_num_query: Number of query examples per class in the validation set.
            sampling_rule: Rule for sampling classes ('uniform' or 'softmax'). Defaults to 'uniform'.
        """
        self.dataset_size_dict = dataset_size_dict
        self.sampling_rule = sampling_rule # Store the rule, for now only uniform sampling is implemented
        # determine the number of clusters to sample for each task
        self.num_classes_per_set = {
            "train": train_num_classes_per_set,
            "val": val_num_classes_per_set,
        }
        # determine the number of support samples to sample for each class
        self.num_support = {
            "train": train_num_support,
            "val": val_num_support,
        }
        # determine the number of query samples to sample for each class
        self.num_query = {
            "train": train_num_query,
            "val": val_num_query,
        }
        # check if the sampling rule is valid
        if sampling_rule not in ["uniform", "softmax"]:
            raise ValueError("Invalid sampling_rule. Must be 'uniform' or 'softmax'.")

    def sample_task(self, dataset_name: str, seed: int) -> dict:
        """
        Samples a task based on the configured sampling rule.

        Args:
            dataset_name: 'train' or 'val'.
            seed: Random seed for reproducibility.

        Returns:
            Dictionary with sampled class names and their support/query indices.
            Example: {'class_A': {'support_indices': [...], 'query_indices': [...]}, ...}
        """
        if dataset_name not in ["train", "val"]:
            raise ValueError("Invalid dataset_name. Must be 'train' or 'val'.")
        rng = np.random.RandomState(seed)
        num_classes = self.num_classes_per_set[dataset_name]
        num_support = self.num_support[dataset_name]
        num_query = self.num_query[dataset_name]
        num_samples_per_class = num_support + num_query
        available_classes_dict = self.dataset_size_dict[dataset_name]
        available_class_names = list(available_classes_dict.keys())
        if len(available_class_names) < num_classes:
            raise ValueError(
                f"Not enough classes in {dataset_name} split ({len(available_class_names)}) "
                f"to sample {num_classes} classes."
            )
        # Select classes based on the sampling rule
        if self.sampling_rule == "uniform":
            selected_classes = rng.choice(
                available_class_names,
                size=num_classes,
                replace=False,
            )
        elif self.sampling_rule == "softmax":
            raise NotImplementedError("Softmax sampling rule not implemented")
        else:
            raise ValueError(f"Unknown sampling rule: {self.sampling_rule}")
        task_info = {}
        for class_name in selected_classes:
            class_size = available_classes_dict[class_name] # Use the dict directly
            if class_size < num_samples_per_class:
                 raise ValueError(
                    f"Class '{class_name}' in '{dataset_name}' has only {class_size} samples, "
                    f"but {num_samples_per_class} are required for support+query."
                )
            # Sample indices from 0 to class_size - 1
            selected_indices = rng.choice(
                class_size,
                size=num_samples_per_class,
                replace=False # Should not sample the same item twice for one task
            )
            task_info[class_name] = {
                'support_indices': selected_indices[:num_support],
                'query_indices': selected_indices[num_support : num_support + num_query]
            }
        return task_info

task_sampler = TaskSampler(
    dataset_size_dict=dataset_size_dict,
    train_num_classes_per_set=5,
    val_num_classes_per_set=16,
    train_num_support=5,
    val_num_support=3,
    train_num_query=5,
    val_num_query=3,
    sampling_rule="uniform"
)

这是一个非常简单的采样实现（也是`MAML`论文中的实现），基于这个`TaskSampler`尝试采样一批任务:

In [7]:
sampled_train_task_info = task_sampler.sample_task(dataset_name='train', seed=1227);sampled_val_task_info = task_sampler.sample_task(dataset_name='val', seed=1227)
train_support_set_data = [];train_query_set_data = [];val_support_set_data = [];val_query_set_data = []
for sampled_train_class_name in sampled_train_task_info.keys():
    train_class_data_list = dict_with_cluster[sampled_train_class_name]
    support_indices = sampled_train_task_info[sampled_train_class_name]['support_indices']
    query_indices = sampled_train_task_info[sampled_train_class_name]['query_indices']
    per_class_support_data = [train_class_data_list[i] for i in support_indices]
    train_support_set_data.append(per_class_support_data)
    per_class_query_data = [train_class_data_list[i] for i in query_indices]
    train_query_set_data.append(per_class_query_data)
for sampled_val_class_name in sampled_val_task_info.keys():
    val_class_data_list = dict_with_cluster[sampled_val_class_name]
    support_indices = sampled_val_task_info[sampled_val_class_name]['support_indices']
    query_indices = sampled_val_task_info[sampled_val_class_name]['query_indices']
    per_class_support_data = [val_class_data_list[i] for i in support_indices]
    val_support_set_data.append(per_class_support_data)
    per_class_query_data = [val_class_data_list[i] for i in query_indices]
    val_query_set_data.append(per_class_query_data)

# thus a task is sampled, data is organized as follows:
# (5,5,16,16), refer to the number of classes sampled 
len(train_support_set_data),len(train_query_set_data),len(val_support_set_data),len(val_query_set_data)
# (5,5,3,3), refer to the number of support and query samples sampled for each class
len(train_support_set_data[0]),len(train_query_set_data[0]),len(val_support_set_data[0]),len(val_query_set_data[0])

(5, 5, 3, 3)

目前采样得到的任务组织比较混乱，考虑用字典使任务变得更加清晰，可以定义一个`transform`函数实现这一点，同时，该函数可以用于定义`__getitem__`方法，即获得一个数据单元，在此我们通过复制任务来得到一个`meta-batch`方便后续演示:

In [8]:
def transform(data):
    """Transforms the raw data list from get_set into separate lists."""
    pdb_ids = []
    labels = []
    prots = []
    ligs = []
    for class_data_list in data:
        for per_pdb_data_dict in class_data_list:
            pdb_ids.append(per_pdb_data_dict['pdb_id'])
            labels.append(float(per_pdb_data_dict['label']))
            prots.append(per_pdb_data_dict['prot'])
            ligs.append(per_pdb_data_dict['lig'])
    return {
        'pdb_ids': pdb_ids,
        'labels': labels,
        'prots': prots,
        'ligs': ligs,
    }
# (25,25,48,48), 25 = 5*5, 48 = 16*3
# train_support_set_data['pdb_ids'].__len__(),train_query_set_data['pdb_ids'].__len__(),val_support_set_data['pdb_ids'].__len__(),val_query_set_data['pdb_ids'].__len__()
train_support_set_data = transform(train_support_set_data);train_query_set_data = transform(train_query_set_data)
val_support_set_data = transform(val_support_set_data);val_query_set_data = transform(val_query_set_data)
train_task = (train_support_set_data,train_query_set_data,1227); val_task = (val_support_set_data,val_query_set_data,1227)
# BATCH_SIZE = 4, simply repeat the task for 4 times to get a batch
train_batch = [train_task for _ in range(4)];val_batch = [val_task for _ in range(4)]

为了让数据格式更加工整，考虑写一个自定义的`task_collate_fn`来实现这一点，最后整合好的`meta-batch`是一个长度为`batch_size`的列表，列表的每个元素是一个字典:

In [9]:
from torch_geometric.data import Batch
def task_collate_fn(batch):
    """
    Collate function for creating task batches with graph data.
    Processes and collates support and query sets into batches.
    Input `batch`: A list of tuples, where each tuple is (support_set_dict, query_set_dict, seed)
                    returned by FewShotLearningDatasetParallel.__getitem__.
    """
    all_tasks = []
    
    for task in batch:
        # Unpack the task tuple
        support_set_dict, query_set_dict, seed = task

        # Process Support Set
        support_ligs = support_set_dict['ligs'] # List of ligand graph objects
        support_prots = support_set_dict['prots'] # List of protein graph objects
        support_labels = support_set_dict['labels']
        support_pdb_ids = support_set_dict['pdb_ids']
        
        query_ligs = query_set_dict['ligs'] # List of ligand graph objects
        query_prots = query_set_dict['prots'] # List of protein graph objects
        query_labels = query_set_dict['labels']
        query_pdb_ids = query_set_dict['pdb_ids']
        
        task_data = {
            "support": {
                "prots": Batch.from_data_list(support_prots),
                "ligs": Batch.from_data_list(support_ligs),
                "labels": torch.tensor(support_labels, dtype=torch.float).reshape(-1),
                "pdb_ids": support_pdb_ids,
            },
            "query": {
                "prots": Batch.from_data_list(query_prots),
                "ligs": Batch.from_data_list(query_ligs),
                "labels": torch.tensor(query_labels, dtype=torch.float).reshape(-1),
                "pdb_ids": query_pdb_ids,
            },
            "seed": seed,
        }
        all_tasks.append(task_data)
    return all_tasks

# successfully get a meta-batch for training and validation
train_batch = task_collate_fn(train_batch);val_batch = task_collate_fn(val_batch)

## 2. 基础模型的定义

该部分定义了神经网络的具体架构，实现与传统机器学习并无不同，后续考虑收录一些基于`MAML`设计模型的技巧在里面，目前相关的工作有:

- [MetaMolGen: A Neural Graph Motif Generation Model for De Novo Molecular Design](https://arxiv.org/abs/2504.15587)
- [Meta-MGNN: Few-Shot Graph Learning for Molecular Property Prediction](https://arxiv.org/abs/2102.07916)
  
该部分工作持续收录中~

此处我们就用预先写好的`./models/gatedgcn.py`中定义好的权重以及参数`./models/gatedgcn.pt`进行模型初始化，并且简单测试一下基础模型是否能够工作:

In [68]:
from models.gatedgcn import GenScore_GGCN, mdn_loss_fn
import torch
ft = False
state_dict = torch.load('./models/gatedgcn.pt',weights_only=False)
meta_model_state_dict = state_dict['network'] # meta_model_state_dict, with a 'regressor.' prefix
# Correctly strip the prefix, ensuring the new key is not empty
model_state_dict = {
    new_key: v 
    for k, v in meta_model_state_dict.items() 
    if k.startswith('regressor.') and (new_key := k[len('regressor.'):])
}
base_model = GenScore_GGCN()
if ft is False:
    print('load meta-model')
    base_model.load_state_dict(model_state_dict)
else:
    print('load non-meta-model')
    base_model.load_state_dict(torch.load('./models/gatedgcn_nonmeta.pth',weights_only=False)['model_state_dict'])
support_train_prot_eg = train_batch[0]['support']['prots']
support_train_lig_eg = train_batch[0]['support']['ligs']
pi, sigma, mu, dist, atom_types, bond_types, C_batch =base_model.net_forward(support_train_lig_eg,support_train_prot_eg)
print(f"Predicted Gaussian Mixture Weights: {pi.shape}\nPredicted Gaussian Mixture Standard Deviations: {sigma.shape}\nPredicted Gaussian Mixture Means: {mu.shape}\nPrecalculated Closest Distance: {dist.shape}\nPredicted Atom Types: {atom_types.shape}\nPredicted Bond Types: {bond_types.shape}\nBatch Index: {C_batch.shape}")
total_loss, mdn_loss, affi_loss, atom_loss, bond_loss, y, batch = \
    base_model.forward(train_batch[0]['support'])
print(f"Total Loss: {total_loss}\nMDN Loss: {mdn_loss}\nAffinity Pearson Correlation: {affi_loss}\nAtom Loss: {atom_loss}\nBond Loss: {bond_loss}\nPredicted Scores: {y}\nBatch Index: {batch}")


load meta-model
Predicted Gaussian Mixture Weights: torch.Size([64953, 10])
Predicted Gaussian Mixture Standard Deviations: torch.Size([64953, 10])
Predicted Gaussian Mixture Means: torch.Size([64953, 10])
Precalculated Closest Distance: torch.Size([64953, 1])
Predicted Atom Types: torch.Size([763, 17])
Predicted Bond Types: torch.Size([1638, 4])
Batch Index: torch.Size([64953])
Total Loss: 0.9454381367967772
MDN Loss: 0.9453015327453613
Affinity Pearson Correlation: 0.6828477382659912
Atom Loss: 0.0032243274617940187
Bond Loss: 0.13338115811347961
Predicted Scores: tensor([208.2487, 230.9851, 212.2064, 135.1499, 214.3612, 107.6525, 157.2823,
         51.0062, 168.1118, 168.4160,  78.0566,  50.8879,  69.3069,  93.9391,
        208.8820, 115.5764, 169.4495, 150.0120, 112.4466, 166.9707,  55.4113,
        129.6583, 181.6872,  73.9615, 121.3966], dtype=torch.float64)
Batch Index: tensor([ 0,  0,  0,  ..., 24, 24, 24])


Excellent! 接下来就要上元学习的核心逻辑了，MetaScore采用的元学习方法是基于优化的`MAML`及其变体`MAML++`（后续考虑继续做变体，比如`Task Based Attention Mechanism`等），其介绍可以详细看同目录下的`MetaLearning.md`，下面主要强调其工程上的实现（主要基于`higher`库，比较方便，所有已经写好的并非用于元学习的模型通过`higher`的帮助均可以轻松改造为元模型）。

## 3. 元学习框架设计

为了将元学习框架套用到`GenScore_GGCN`模型上，我们需要额外定义一个超类`MAMLRegressor`: