# <center> MetaScore

## 1. 数据加载

In [1]:
import os
import tqdm
import hydra
import higher
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.data import Batch

# predefined model
from models.gatedgcn import GenScore_GGCN

import warnings
warnings.filterwarnings('ignore')

In [2]:
data_path = './dataset/pdbbind/v2020_train_dict.pt'
data = torch.load(data_path,weights_only=False)

看一下数据组织格式：所有数据是一个大的嵌套字典，它的键是`pdb_id`，值是一个字典，其键`prot`,`lig`,`label`分别对应了蛋白质图、配体图和亲和力数据:

In [3]:
data['10gs']

{'prot': Data(x=[83, 41], edge_index=[2, 2014], edge_attr=[2014, 5], pos=[83, 24, 3]),
 'lig': Data(x=[33, 41], edge_index=[2, 68], edge_attr=[68, 10], pos=[33, 3]),
 'label': np.str_('6.4')}


与传统机器学习不同，元学习训练的数据单元是任务而非单个数据点，因此需要对pdbbind中的蛋白-配体对数据进行任务划分，在MetaScore中表现为对蛋白质结构相似度进行层次聚类，以聚类得到的簇为单个任务进行训练。对应的，`DataLoader`的写法也相应的改变：

- `DataLoader`一次从所有任务中采样`batch_size`个任务
- 每个任务应当采样`num_classes_per_task`个类别，即选取不同的聚类用于构建一个任务
- 每个聚类中随机采样`k-shot`个样本用于构建支持集，`q-query`个样本用于构建查询集
- 将一个任务中所有聚类中被选中的支持集样本和查询集样本聚合后，即可视为`meta-batch`中的一个任务
- 考虑到显存问题，无法再将`meta-batch`中的任务也做一遍聚合然后一次算完（显存会爆炸），只能在一个批次中逐任务累计梯度



聚类的结果已经拿到，加载一下:

In [4]:
cluster_path = './dataset/cluster/clustering_results.csv'
cluster_df = pd.read_csv(cluster_path)
cluster_df.head()

Unnamed: 0,PDB_ID,Cluster
0,2z7i,257
1,3i7g,333
2,2v88,197
3,3wav,150
4,6qlu,0


而后我们根据聚类的结果来构建一个新的数据字典，在这个字典中，键应当改为`Cluster`，值应为列表，每个列表的元素代表位于该聚类下的一个数据点，以原数据中字典的格式存储（在原数据字典的基础上额外增加了`pdb_id`作为数据标识符）:

In [5]:
dict_with_cluster = {}
for pdb_id, value in tqdm.tqdm(data.items()):
    value['pdb_id'] = pdb_id
    row = cluster_df[cluster_df['PDB_ID'] == pdb_id]
    cluster_id = row['Cluster'].values[0]
    if f'class_{cluster_id}' not in dict_with_cluster:
        dict_with_cluster[f'class_{cluster_id}'] = []
    dict_with_cluster[f'class_{cluster_id}'].append(value)
dict_with_cluster['class_1']

100%|██████████| 19145/19145 [00:09<00:00, 1936.29it/s]


[{'prot': Data(x=[38, 41], edge_index=[2, 794], edge_attr=[794, 5], pos=[38, 24, 3]),
  'lig': Data(x=[35, 41], edge_index=[2, 72], edge_attr=[72, 10], pos=[35, 3]),
  'label': np.str_('5.7'),
  'pdb_id': np.str_('4z0u')},
 {'prot': Data(x=[83, 41], edge_index=[2, 2264], edge_attr=[2264, 5], pos=[83, 24, 3]),
  'lig': Data(x=[21, 41], edge_index=[2, 46], edge_attr=[46, 10], pos=[21, 3]),
  'label': np.str_('6.0'),
  'pdb_id': np.str_('3dxm')}]

在此结果基础上，考虑划分元学习的`meta-train`和`meta-test`集，在`MetaScore`中，考虑到降低实现难度，该集的划分以每个聚类中数据点的数量划分，具体的，不少于17个数据点的聚类会被分类为训练集，不少于7个数据点的聚类分类为测试集，其他聚类暂时丢弃，或考虑用于模型`zeroshot`性能测试（这部分聚类数量很多但是数据量很小，可以丢弃而不造成严重影响）。为此可以写一个`dict_with_split`，存储每个聚类属于的集合:

In [6]:
# split the data into train, val, test according to the amount of data in each class
train_set_min_size = 17
val_set_min_size = 7
dict_with_split = {'train': [], 'val': [], 'test': []}
for class_idx, value in dict_with_cluster.items():
    data_count = len(value)
    if data_count >= train_set_min_size:
        dict_with_split['train'].append(class_idx)
    elif data_count >= val_set_min_size and data_count < train_set_min_size:
        dict_with_split['val'].append(class_idx)
    else:
        dict_with_split['test'].append(class_idx)
print(f"Train: {len(dict_with_split['train'])} classes, Val: {len(dict_with_split['val'])} classes, Test: {len(dict_with_split['test'])} classes")

# get the size of each split
dataset_size_dict = {
    "train": {key: len(dict_with_cluster[key]) for key in dict_with_split["train"]},
    "val": {key: len(dict_with_cluster[key]) for key in dict_with_split["val"]},
    "test": {key: len(dict_with_cluster[key]) for key in dict_with_split["test"]},
}
data_length = {name: np.sum(np.array(list(dataset_size_dict[name].values()))) for name in dict_with_split.keys()}
print("data length of each split", data_length)

Train: 186 classes, Val: 153 classes, Test: 808 classes
data length of each split {'train': np.int64(15810), 'val': np.int64(1574), 'test': np.int64(1761)}


到此为止完成了数据单元（单个聚类）的数据集划分，接下来要考虑如何使用数据单元来构建`meta-batch`，这涉及到:
- 任务构建逻辑，如何在所有聚类中选择`batch_size`个聚类构建一个任务
- 聚类中的样本采样逻辑，如何在一个聚类中选取样本构建支持集和查询集
这两点可以通过一个`TaskSampler`类来完成:

In [7]:
class TaskSampler:
    """
    Samples tasks based on specified rules.
    Can select classes uniformly or based on class size using softmax probability.
    Samples support/query items uniformly from within selected classes.
    """
    def __init__(self,
                 dataset_size_dict: dict,
                 train_num_classes_per_set: int,
                 val_num_classes_per_set: int,
                 train_num_support: int,
                 val_num_support: int,
                 train_num_query: int,
                 val_num_query: int,
                 sampling_rule: str = "uniform"): # Added sampling_rule parameter
        """
        Initializes the TaskSampler.

        Args:
            dataset_size_dict: Dictionary mapping split names ('train', 'val') to dictionaries
                               of {class_name: class_size}.
            train_num_classes_per_set: Number of classes per task in the training set.
            val_num_classes_per_set: Number of classes per task in the validation set.
            train_num_support: Number of support examples per class in the training set.
            val_num_support: Number of support examples per class in the validation set.
            train_num_query: Number of query examples per class in the training set.
            val_num_query: Number of query examples per class in the validation set.
            sampling_rule: Rule for sampling classes ('uniform' or 'softmax'). Defaults to 'uniform'.
        """
        self.dataset_size_dict = dataset_size_dict
        self.sampling_rule = sampling_rule # Store the rule, for now only uniform sampling is implemented
        
        # determine the number of classes to sample for each task
        self.num_classes_per_set = {"train": train_num_classes_per_set, "val": val_num_classes_per_set}
        self.num_support = {"train": train_num_support, "val": val_num_support}
        self.num_query = {"train": train_num_query, "val": val_num_query}

    def sample_task(self, dataset_name: str, seed: int) -> dict:
        """
        Samples a task based on the configured sampling rule.

        Args:
            dataset_name: 'train' or 'val'.
            seed: Random seed for reproducibility.

        Returns:
            Dictionary with sampled class names and their support/query indices.
            Example: {'class_A': {'support_indices': [...], 'query_indices': [...]}, ...}
        """
        rng = np.random.RandomState(seed)
        num_classes = self.num_classes_per_set[dataset_name]
        num_support = self.num_support[dataset_name]
        num_query = self.num_query[dataset_name]
        num_samples_per_class = num_support + num_query
        available_classes_dict = self.dataset_size_dict[dataset_name]
        available_class_names = list(available_classes_dict.keys())
        if len(available_class_names) < num_classes:
            raise ValueError(
                f"Not enough classes in {dataset_name} split ({len(available_class_names)}) "
                f"to sample {num_classes} classes."
            )
        
        # select classes based on the sampling rule
        assert self.sampling_rule == "uniform"
        selected_classes = rng.choice(
            available_class_names,
            size=num_classes,
            replace=False,
        )
        
        task_info = {}
        for class_name in selected_classes:
            class_size = available_classes_dict[class_name] 
            if class_size < num_samples_per_class:
                 raise ValueError(
                    f"Class '{class_name}' in '{dataset_name}' has only {class_size} samples, "
                    f"but {num_samples_per_class} are required for support+query."
                )
            # sample indices from 0 to class_size - 1
            selected_indices = rng.choice(class_size, size=num_samples_per_class, replace=False)
            task_info[class_name] = {
                'support_indices': selected_indices[:num_support],
                'query_indices': selected_indices[num_support : num_support + num_query]
            }
        return task_info

task_sampler = TaskSampler(
    dataset_size_dict=dataset_size_dict,
    train_num_classes_per_set=5,
    val_num_classes_per_set=16,
    train_num_support=5,
    val_num_support=3,
    train_num_query=5,
    val_num_query=3,
    sampling_rule="uniform"
)

这是一个非常简单的采样实现（也是`MAML`论文中的实现），基于这个`TaskSampler`尝试采样一批任务:

In [8]:
train_support_set_data = [];train_query_set_data = [];val_support_set_data = [];val_query_set_data = []
sampled_train_task_info = task_sampler.sample_task(dataset_name='train', seed=1227)
sampled_val_task_info = task_sampler.sample_task(dataset_name='val', seed=1227)

# sample the data for train set
for sampled_train_class_name in sampled_train_task_info.keys():
    train_class_data_list = dict_with_cluster[sampled_train_class_name]
    support_indices = sampled_train_task_info[sampled_train_class_name]['support_indices']
    query_indices = sampled_train_task_info[sampled_train_class_name]['query_indices']
    per_class_support_data = [train_class_data_list[i] for i in support_indices]
    train_support_set_data.append(per_class_support_data)
    per_class_query_data = [train_class_data_list[i] for i in query_indices]
    train_query_set_data.append(per_class_query_data)

# do the same for val set
for sampled_val_class_name in sampled_val_task_info.keys():
    val_class_data_list = dict_with_cluster[sampled_val_class_name]
    support_indices = sampled_val_task_info[sampled_val_class_name]['support_indices']
    query_indices = sampled_val_task_info[sampled_val_class_name]['query_indices']
    per_class_support_data = [val_class_data_list[i] for i in support_indices]
    val_support_set_data.append(per_class_support_data)
    per_class_query_data = [val_class_data_list[i] for i in query_indices]
    val_query_set_data.append(per_class_query_data)

# thus a task is sampled, data is organized as follows:
# (5,5,16,16), refer to the number of classes sampled 
len(train_support_set_data),len(train_query_set_data),len(val_support_set_data),len(val_query_set_data)
# (5,5,3,3), refer to the number of support and query samples sampled for each class
len(train_support_set_data[0]),len(train_query_set_data[0]),len(val_support_set_data[0]),len(val_query_set_data[0])

(5, 5, 3, 3)

目前采样得到的任务组织比较混乱，考虑用字典使任务变得更加清晰，可以定义一个`transform`函数实现这一点，同时，该函数可以用于定义`__getitem__`方法，即获得一个数据单元，在此我们通过复制任务来得到一个`meta-batch`方便后续演示:

In [9]:
def transform(data):
    """Transforms the raw data list from get_set into separate lists."""
    pdb_ids = [];labels = [];prots = [];ligs = []
    for class_data_list in data:
        for per_pdb_data_dict in class_data_list:
            pdb_ids.append(per_pdb_data_dict['pdb_id'])
            labels.append(float(per_pdb_data_dict['label']))
            prots.append(per_pdb_data_dict['prot'])
            ligs.append(per_pdb_data_dict['lig'])
    return {
        'pdb_ids': pdb_ids,
        'labels': labels,
        'prots': prots,
        'ligs': ligs,
    }

# (25,25,48,48), 25 = 5*5, 48 = 16*3
# train_support_set_data['pdb_ids'].__len__(),train_query_set_data['pdb_ids'].__len__(),val_support_set_data['pdb_ids'].__len__(),val_query_set_data['pdb_ids'].__len__()
train_support_set_data = transform(train_support_set_data);train_query_set_data = transform(train_query_set_data)
val_support_set_data = transform(val_support_set_data);val_query_set_data = transform(val_query_set_data)
train_task = (train_support_set_data,train_query_set_data,1227); val_task = (val_support_set_data,val_query_set_data,1227)

# simply repeat the task for 4 times to get a batch
train_batch = [train_task for _ in range(4)];val_batch = [val_task for _ in range(4)]

为了让数据格式更加工整，考虑写一个自定义的`task_collate_fn`来实现这一点，最后整合好的`meta-batch`是一个长度为`batch_size`的列表，列表的每个元素是一个字典:

In [10]:
def task_collate_fn(batch):
    """
    Collate function for creating task batches with graph data.
    Processes and collates support and query sets into batches.
    Input `batch`: A list of tuples, where each tuple is (support_set_dict, query_set_dict, seed)
                    returned by FewShotLearningDatasetParallel.__getitem__.
    """
    all_tasks = []
    
    for task in batch:
        # Unpack the task tuple
        support_set_dict, query_set_dict, seed = task

        # Process Support Set
        support_ligs = support_set_dict['ligs'] # List of ligand graph objects
        support_prots = support_set_dict['prots'] # List of protein graph objects
        support_labels = support_set_dict['labels']
        support_pdb_ids = support_set_dict['pdb_ids']
        
        query_ligs = query_set_dict['ligs'] # List of ligand graph objects
        query_prots = query_set_dict['prots'] # List of protein graph objects
        query_labels = query_set_dict['labels']
        query_pdb_ids = query_set_dict['pdb_ids']
        
        task_data = {
            "support": {
                "prots": Batch.from_data_list(support_prots),
                "ligs": Batch.from_data_list(support_ligs),
                "labels": torch.tensor(support_labels, dtype=torch.float).reshape(-1),
                "pdb_ids": support_pdb_ids,
            },
            "query": {
                "prots": Batch.from_data_list(query_prots),
                "ligs": Batch.from_data_list(query_ligs),
                "labels": torch.tensor(query_labels, dtype=torch.float).reshape(-1),
                "pdb_ids": query_pdb_ids,
            },
            "seed": seed,
        }
        all_tasks.append(task_data)
    return all_tasks

# successfully get a meta-batch for training and validation
train_batch = task_collate_fn(train_batch);val_batch = task_collate_fn(val_batch)

## 2. 基础模型的定义

该部分定义了神经网络的具体架构，实现与传统机器学习并无不同，后续考虑收录一些基于`MAML`设计模型的技巧在里面，目前相关的工作有:

- [MetaMolGen: A Neural Graph Motif Generation Model for De Novo Molecular Design](https://arxiv.org/abs/2504.15587)
- [Meta-MGNN: Few-Shot Graph Learning for Molecular Property Prediction](https://arxiv.org/abs/2102.07916)
  
该部分工作持续收录中~

此处我们就用预先写好的`./models/gatedgcn.py`中定义好的权重以及参数`./models/gatedgcn.pt`进行模型初始化，并且简单测试一下基础模型是否能够工作:

In [11]:
# load a pretrained meta model
state_dict = torch.load('./models/gatedgcn.pt',weights_only=False,map_location='cpu')
meta_model_state_dict = state_dict['network'] # meta_model_state_dict, with a 'regressor.' prefix

# strip the prefix, ensuring the new key is not empty
model_state_dict = {
    new_key: v 
    for k, v in meta_model_state_dict.items() 
    if k.startswith('regressor.') and (new_key := k[len('regressor.'):])
}
base_model = GenScore_GGCN()
base_model.load_state_dict(model_state_dict)
base_model

GenScore_GGCN(
  (ligand_model): GatedGCN(
    (node_encoder): Linear(in_features=41, out_features=128, bias=True)
    (edge_encoder): Linear(in_features=10, out_features=128, bias=True)
    (gt_block): ModuleList(
      (0-5): 6 x GatedGCNLayer()
    )
  )
  (target_model): GatedGCN(
    (node_encoder): Linear(in_features=41, out_features=128, bias=True)
    (edge_encoder): Linear(in_features=5, out_features=128, bias=True)
    (gt_block): ModuleList(
      (0-5): 6 x GatedGCNLayer()
    )
  )
  (MLP): Sequential(
    (0): Linear(in_features=256, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ELU(alpha=1.0)
    (3): Dropout(p=0.15, inplace=False)
  )
  (z_pi): Linear(in_features=128, out_features=10, bias=True)
  (z_sigma): Linear(in_features=128, out_features=10, bias=True)
  (z_mu): Linear(in_features=128, out_features=10, bias=True)
  (atom_types): Linear(in_features=128, out_features=17, bias=True)
  

In [12]:
# Get sample data
support_train_prot_eg = train_batch[0]['support']['prots']
support_train_lig_eg = train_batch[0]['support']['ligs']

# Test network forward pass
pi, sigma, mu, dist, atom_types, bond_types, C_batch = base_model.net_forward(
    support_train_lig_eg, support_train_prot_eg
)

print(f"Predicted Gaussian Mixture Weights: {pi.shape}\n"
      f"Predicted Gaussian Mixture Standard Deviations: {sigma.shape}\n"
      f"Predicted Gaussian Mixture Means: {mu.shape}\n"
      f"Precalculated Closest Distance: {dist.shape}\n"
      f"Predicted Atom Types: {atom_types.shape}\n"
      f"Predicted Bond Types: {bond_types.shape}\n"
      f"Batch Index: {C_batch.shape}")

# Test full model forward pass
total_loss, mdn_loss, affi_loss, atom_loss, bond_loss, y, batch = base_model.forward(
    train_batch[0]['support']
)

print(f"Total Loss: {total_loss}\n"
      f"MDN Loss: {mdn_loss}\n"
      f"Affinity Pearson Correlation: {affi_loss}\n"
      f"Atom Loss: {atom_loss}\n"
      f"Bond Loss: {bond_loss}\n"
      f"Predicted Scores: {y}\n"
      f"Batch Index: {batch}")

# end
print('All Test Passed Successfully!')


Predicted Gaussian Mixture Weights: torch.Size([64953, 10])
Predicted Gaussian Mixture Standard Deviations: torch.Size([64953, 10])
Predicted Gaussian Mixture Means: torch.Size([64953, 10])
Precalculated Closest Distance: torch.Size([64953, 1])
Predicted Atom Types: torch.Size([763, 17])
Predicted Bond Types: torch.Size([1638, 4])
Batch Index: torch.Size([64953])
Total Loss: 0.9430026737227443
MDN Loss: 0.9428443908691406
Affinity Pearson Correlation: 0.6915146708488464
Atom Loss: 0.002030132105574012
Bond Loss: 0.15622907876968384
Predicted Scores: tensor([210.0679, 219.5144, 210.9504, 133.8276, 220.4948, 106.6906, 161.8000,
         51.3203, 165.5737, 172.4626,  81.7168,  54.9609,  70.7099,  89.0071,
        204.1539, 119.4976, 173.1818, 150.1528, 118.9367, 168.6175,  57.8180,
        121.2643, 184.5635,  77.7820, 120.7989], dtype=torch.float64)
Batch Index: tensor([ 0,  0,  0,  ..., 24, 24, 24])
All Test Passed Successfully!


Excellent! 接下来就要上元学习的核心逻辑了，MetaScore采用的元学习方法是基于优化的`MAML`及其变体`MAML++`（后续考虑继续做变体，比如`Task Based Attention Mechanism`等），其介绍可以详细看同目录下的`MetaLearning.md`，下面主要强调其工程上的实现（主要基于`higher`库，比较方便，所有已经写好的并非用于元学习的模型通过`higher`的帮助均可以轻松改造为元模型）。

## 3. 元学习框架设计

为了将元学习框架套用到`GenScore_GGCN`模型上，我们可以额外定义一个类`MAMLRegressor`.在具体的训练实现中,我们使用了`higher`库以简化开发流程.

该设计基于`MAML++`,同时可以通过修改配置逐步退化为传统的`MAML`框架,有一定灵活性.`MAML++`相比于`MAML`的主要改进点有:

- `MSL`机制引入.一般而言`MAML`在内部存在多步循环的情况下只考虑通过最后一步循环得到的损失对元模型求导,而`MAML++`中则是通过对多步循环中的每一步的损失进行加权求和,然后再对元模型进行求导.权重根据内循环适应的步数获取.
- `learnable inner-loop optimizer`.为模型的每个参数维护一个学习率,这个学习率是可以通过外循环的`meta-update`进行同步更新的,旨在加速模型收敛.
- `first-order to second-order`.在一般的`MAML`中,训练流程计算了二阶导会增加内存开销和时间开销,且模型训练容易产生剧烈震荡;而其近似`FOMAML`训练所得结果并不如`MAML`优秀,为此`MAML++`提出,可以在训练的初期先采用一阶近似稳定模型的行为,并在损失基本稳定后开启二阶导训练增强模型的性能.

In [13]:
def set_torch_seed(seed):
    """
    Sets the pytorch seeds for current experiment run
    :param seed: The seed (int)
    :return: A random number generator to use
    """
    rng = np.random.RandomState(seed=seed)
    torch_seed = rng.randint(0, 999999)
    torch.manual_seed(seed=torch_seed)

    return rng

class MAMLRegressor(nn.Module):
    def __init__(self):
        """
        Initializes a MAML few shot learning system
        :param device: The device to use to use the model on.
        :param args: A namedtuple of arguments specifying various hyperparameters.
        :param logs_filepath: Path to the logs directory for storing csv files
        """
        super(MAMLRegressor, self).__init__()

        # base configs
        self.device = torch.device('cpu')
        self.batch_size = 4
        self.current_epoch = 0
        self.logs_filepath = './tmp'
        self.learnable_inner_opt_params = True
        self.rng = set_torch_seed(seed=1227)
        self.regressor = GenScore_GGCN()
        self.training_per_iter = 1
        self.evaluation_per_iter = 1
        self.multi_step_loss_num = 10
        self.enable_inner_loop_optimizable_bn_params = True
        self.first_order_to_second_order_epoch = 100
        self.use_multi_step_loss_optimization = True
        self.second_order = True
        
        # use SGD as the inner loop optimizer according to the original paper
        inner_opt_class = hydra.utils.get_class('torch.optim.SGD')
        kwargs = {'lr':0.01}

        # learnable inner loop optimizer parameters
        if self.learnable_inner_opt_params:
            param_groups = [
                {'params': p, 'lr': 0.01} for p in self.regressor.parameters()
            ]
            self.inner_opt = inner_opt_class(param_groups, **kwargs)
            t = higher.optim.get_trainable_opt_params(self.inner_opt)
            self.lrs = nn.ParameterList(map(
                nn.Parameter,
                t['lr']
            ))
        else:
            params = self.regressor.parameters()
            self.inner_opt = inner_opt_class(params, **kwargs)

        # print the parameters of the model, which should include inner loop optimizer parameters
        print("Outer Loop parameters")
        param_shapes = []
        for name, param in self.named_parameters():
            if param.requires_grad:
                print(name, param.shape, param.device, param.requires_grad)
                param_shapes.append(param.shape)
        print(f'n_params: {sum(map(np.prod, param_shapes))}')

        # set Adam as outer loop optimizer
        self.optimizer = optim.Adam(self.trainable_parameters(), lr=0.001, amsgrad=False)
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer=self.optimizer, 
            T_max=100, # 100 epochs
            eta_min=1.0e-5 # minimum outer loop learning rate
        )

    def get_per_step_loss_importance_vector(self):
        """
        Generates a tensor of dimensionality (num_inner_loop_steps) indicating the importance of each step's target
        loss towards the optimization loss.
        :return: A tensor to be used to compute the weighted average of the loss, useful for
        the MSL (Multi Step Loss) mechanism.
        """
        loss_weights = np.ones(shape=(self.training_per_iter)) * (1.0 / self.training_per_iter)
        decay_rate = 1.0 / self.training_per_iter / self.multi_step_loss_num
        min_value_for_non_final_losses = 0.03 / self.training_per_iter
        for i in range(len(loss_weights) - 1):
            curr_value = np.maximum(loss_weights[i] - (self.current_epoch * decay_rate), min_value_for_non_final_losses)
            loss_weights[i] = curr_value
        curr_value = np.minimum(
            loss_weights[-1] + (self.current_epoch * (self.training_per_iter - 1) * decay_rate),
            1.0 - ((self.training_per_iter - 1) * min_value_for_non_final_losses))
        loss_weights[-1] = curr_value
        loss_weights = torch.Tensor(loss_weights).to(device=self.device)
        return loss_weights

    def get_inner_loop_parameter_dict(self, params):
        """
        Returns a dictionary with the parameters to use for inner loop updates.
        :param params: A dictionary of the network's parameters.
        :return: A dictionary of the parameters to use for the inner loop optimization process.
        """
        param_dict = dict()
        for name, param in params:
            #print(name, param.shape, param.device, param.requires_grad)
            if param.requires_grad:
                if self.enable_inner_loop_optimizable_bn_params:
                    param_dict[name] = param.to(device=self.device)
                else:
                    if "norm_layer" not in name:
                        param_dict[name] = param.to(device=self.device)

        return param_dict

    def forward(self, data_batch, epoch, 
                use_second_order, use_multi_step_loss_optimization, 
                num_steps, training_phase, dist_threshold = 7.0):
        # zero the gradients
        self.regressor.zero_grad()

        # initialize the lists to store the losses and predictions
        total_losses = [];mdn_losses = []
        atom_losses = [];bond_losses = [];affi_coeffs = []
        per_task_query_preds = [];per_task_query_labels = []
        
        for task_id, task in enumerate(data_batch):
            task_losses = [];task_mdn_losses = [];task_affi_coeffs = []
            task_atom_losses = [];task_bond_losses = []
            
            # We only need the predictions after the final inner loop step
            final_query_preds_for_task = None 
            
            # get the support and query tasks
            support_task = task["support"];query_task = task["query"]
            per_task_query_labels.append(query_task["labels"])

            # get the per step loss importance vector
            per_step_loss_importance_vectors = self.get_per_step_loss_importance_vector()

            # use higher to track the gradients
            with higher.innerloop_ctx(
                self.regressor, self.inner_opt, 
                copy_initial_weights=False,
                track_higher_grads=training_phase,
            ) as (fnet, diffopt):
                # a trick, but it is obviously wrong, because the inner loop optimizer is not the same as the outer loop optimizer
                # for p in self.regressor.parameters():
                #    self.inner_opt.state[p] = copy.deepcopy(self.optimizer.state[p])
                for num_step in range(num_steps):
                    (support_loss, support_mdn_loss, support_affi_coeff, 
                     support_atom_loss, support_bond_loss, support_preds, 
                     support_batch) = fnet(task=support_task,dist_threshold=dist_threshold) # run the inner loop  
                    
                    # higher provide a differentiable optimizer
                    # we can use dict to override the parameters of the inner loop optimizer
                    if self.learnable_inner_opt_params:
                        diffopt.step(support_loss, override={'lr': self.lrs})
                    else:
                        diffopt.step(support_loss)
                    
                    # use the multi step loss optimization
                    if use_multi_step_loss_optimization and training_phase and \
                            epoch < self.multi_step_loss_num:
                        (query_loss, query_mdn_loss, query_affi_coeff, 
                         query_atom_loss, query_bond_loss, query_preds, 
                         query_batch) = fnet(query_task,dist_threshold=dist_threshold)
                        
                        task_losses.append(per_step_loss_importance_vectors[num_step] * query_loss)
                        task_mdn_losses.append(per_step_loss_importance_vectors[num_step] * query_mdn_loss)
                        task_affi_coeffs.append(per_step_loss_importance_vectors[num_step] * query_affi_coeff)
                        task_atom_losses.append(per_step_loss_importance_vectors[num_step] * query_atom_loss)
                        task_bond_losses.append(per_step_loss_importance_vectors[num_step] * query_bond_loss)
                        
                        # store the prediction only from the last step
                        if num_step == (num_steps - 1):
                            final_query_preds_for_task = query_preds

                    # if not use the multi step loss optimization
                    # use the last step's prediction and loss
                    else:
                        if num_step == (self.training_per_iter - 1):
                            (query_loss, query_mdn_loss, query_affi_coeff, 
                             query_atom_loss, query_bond_loss, query_preds, 
                             query_batch) = fnet(query_task,dist_threshold=dist_threshold)

                            task_losses.append(query_loss)
                            task_mdn_losses.append(query_mdn_loss)
                            task_affi_coeffs.append(query_affi_coeff)
                            task_atom_losses.append(query_atom_loss)
                            task_bond_losses.append(query_bond_loss)
                            
                            # store the prediction from the last step
                            final_query_preds_for_task = query_preds

            # sum the losses of all tasks
            total_losses.append(torch.sum(torch.stack(task_losses)))
            mdn_losses.append(torch.sum(torch.stack(task_mdn_losses)))
            atom_losses.append(torch.sum(torch.stack(task_atom_losses)))
            bond_losses.append(torch.sum(torch.stack(task_bond_losses)))
            affi_coeffs.append(torch.sum(torch.stack(task_affi_coeffs)))
            
            # append the final prediction for this task
            per_task_query_preds.append(final_query_preds_for_task) 

        # concatenate the final predictions from all tasks
        all_query_preds = torch.cat(per_task_query_preds, dim=0) 
        all_query_labels = torch.cat(per_task_query_labels, dim=0)
        all_query_labels = all_query_labels.to(device=self.device)

        # now shapes should match: [batch_size * num_query]
        total_affi_coeff = torch.corrcoef(torch.stack([all_query_preds, all_query_labels]))[1, 0]

        losses = {
            "total_loss": torch.mean(torch.stack(total_losses)),
            "mdn_loss": torch.mean(torch.stack(mdn_losses)),
            "affi_coeffs": torch.mean(torch.stack(affi_coeffs)),
            "atom_loss": torch.mean(torch.stack(atom_losses)),
            "bond_loss": torch.mean(torch.stack(bond_losses)),
            "total_affi_coeff": total_affi_coeff,
        }
        
        # return the losses and the predictions
        return losses, all_query_preds.cpu().numpy()

    def trainable_parameters(self):
        """
        Returns an iterator over the trainable parameters of the model.
        """
        for param in self.parameters():
            if param.requires_grad:
                yield param

    def train_forward_prop(self, data_batch, epoch):
        """
        Runs an outer loop forward prop using the meta-model and base-model.
        :param data_batch: A data batch containing the support set and target set input, output pairs.
        :param epoch: The training epoch's index
        :return: A dictionary of losses for the current step.
        """
        losses, all_query_preds = self.forward(
            data_batch=data_batch, epoch=epoch,
            use_second_order=self.second_order and
            epoch > self.first_order_to_second_order_epoch,
            use_multi_step_loss_optimization=self.use_multi_step_loss_optimization,
            num_steps=self.training_per_iter,
            training_phase=True)
        return losses, all_query_preds

    def evaluation_forward_prop(self, data_batch, epoch):
        """
        Runs an outer loop evaluation forward prop using the meta-model and base-model.
        :param data_batch: A data batch containing the support set and target set input, output pairs.
        :param epoch: The training epoch's index
        :return: A dictionary of losses for the current step.
        """
        losses, all_query_preds = self.forward(
            data_batch=data_batch, epoch=epoch, use_second_order=False,
            use_multi_step_loss_optimization=self.use_multi_step_loss_optimization,
            num_steps=self.evaluation_per_iter,
            training_phase=False)
        return losses, all_query_preds

    def meta_update(self, loss):
        """
        Applies an outer loop update on the meta-parameters of the model.
        :param loss: The current loss.
        """
        self.optimizer.zero_grad()
        loss.backward()
        for name, param in self.regressor.named_parameters():
            if param.requires_grad and param.grad is not None:
                param.grad.data.clamp_(-10, 10)  # not sure if this is necessary, more experiments are needed
        self.optimizer.step()

        # set the minimum learning rate for the inner loop optimizer
        if self.learnable_inner_opt_params:
            for lr in self.lrs:
                lr.data[lr < 1e-4] = 1e-4

    def run_train_iter(self, data_batch, epoch):
        """
        Runs an outer loop update step on the meta-model's parameters.
        :param data_batch: input data batch containing the support set and target set input, output pairs
        :param epoch: the index of the current epoch
        :return: The losses of the ran iteration.
        """
        epoch = int(epoch)
        self.scheduler.step(epoch=epoch)
        if self.current_epoch != epoch:
            self.current_epoch = epoch

        # set the model to training mode
        self.train()

        # run the forward prop
        losses, all_query_preds = self.train_forward_prop(data_batch=data_batch, epoch=epoch)

        # update the meta-model
        self.meta_update(loss=losses['total_loss'])

        # get the learning rate for the outer loop optimizer
        losses['learning_rate'] = self.scheduler.get_lr()[0]

        # zero the gradients
        self.optimizer.zero_grad()
        self.zero_grad()
        
        # return the losses and the predictions
        return losses, all_query_preds

    def run_validation_iter(self, data_batch):
        """
        Runs an outer loop evaluation step on the meta-model's parameters.
        :param data_batch: input data batch containing the support set and target set input, output pairs
        :param epoch: the index of the current epoch
        :return: The losses of the ran iteration.
        """
        # still need to set the model to training mode
        # to track the gradients of the base model
        self.train()

        # run the forward prop
        losses, all_query_preds = self.evaluation_forward_prop(data_batch=data_batch, epoch=self.current_epoch)

        # return the losses and the predictions
        return losses, all_query_preds

    def save_model(self, model_save_dir, state):
        """
        Save the network parameter state and experiment state dictionary.
        :param model_save_dir: The directory to store the state at.
        :param state: The state containing the experiment state and the network. It's in the form of a dictionary
        object.
        """
        state['network'] = self.state_dict()
        torch.save(state, f=model_save_dir)

    def load_model(self, model_save_dir):
        """
        Load checkpoint and return the state dictionary containing the network state params and experiment state.
        :param model_save_dir: The directory from which to load the files.
        :param model_name: The model_name to be loaded from the direcotry.
        :param model_idx: The index of the model (i.e. epoch number or 'latest' for the latest saved model of the current
        experiment)
        :return: A dictionary containing the experiment state and the saved model parameters.
        """
        filepath = os.path.join(model_save_dir)
        state = torch.load(filepath, map_location=self.device, weights_only=False)
        state_dict_loaded = state['network']
        self.load_state_dict(state_dict=state_dict_loaded)
        return state

## 4. 元学习训练演示

现在我们使用定义好的`MAMLRegressor`和准备好的`train_batch`、`val_batch`来演示元学习的训练和验证过程：


In [14]:
# check the data batch format again
print("Check the data batch format:")
print(f"Train batch size: {len(train_batch)}")
print(f"Val batch size: {len(val_batch)}")
print(f"First task in train batch support set size: {train_batch[0]['support']['labels'].shape}")
print(f"First task in train batch query set size: {train_batch[0]['query']['labels'].shape}")
print(f"First task in val batch support set size: {val_batch[0]['support']['labels'].shape}")
print(f"First task in val batch query set size: {val_batch[0]['query']['labels'].shape}")

# initialize the MAMLRegressor
print("\nInitializing MAMLRegressor...")
maml_model = MAMLRegressor()
maml_model.load_model('./models/gatedgcn.pt')

Check the data batch format:
Train batch size: 4
Val batch size: 4
First task in train batch support set size: torch.Size([25])
First task in train batch query set size: torch.Size([25])
First task in val batch support set size: torch.Size([48])
First task in val batch query set size: torch.Size([48])

Initializing MAMLRegressor...
Outer Loop parameters
regressor.ligand_model.node_encoder.weight torch.Size([128, 41]) cpu True
regressor.ligand_model.node_encoder.bias torch.Size([128]) cpu True
regressor.ligand_model.edge_encoder.weight torch.Size([128, 10]) cpu True
regressor.ligand_model.edge_encoder.bias torch.Size([128]) cpu True
regressor.ligand_model.gt_block.0.A.weight torch.Size([128, 128]) cpu True
regressor.ligand_model.gt_block.0.A.bias torch.Size([128]) cpu True
regressor.ligand_model.gt_block.0.B.weight torch.Size([128, 128]) cpu True
regressor.ligand_model.gt_block.0.B.bias torch.Size([128]) cpu True
regressor.ligand_model.gt_block.0.C.weight torch.Size([128, 128]) cpu True

{'best_loss': 0.0,
 'best_val_iter': 0,
 'current_iter': 43800,
 'best_epoch': 0,
 'train_total_loss_mean': np.float64(0.8793659770725962),
 'train_total_loss_std': np.float64(0.0511439905202655),
 'train_mdn_loss_mean': np.float64(0.8793552911281586),
 'train_mdn_loss_std': np.float64(0.051144483024852305),
 'train_affi_coeffs_mean': np.float64(0.30650401670485733),
 'train_affi_coeffs_std': np.float64(0.15846373011194897),
 'train_atom_loss_mean': np.float64(0.003970644685905427),
 'train_atom_loss_std': np.float64(0.002066610637330492),
 'train_bond_loss_mean': np.float64(0.10284830048680306),
 'train_bond_loss_std': np.float64(0.02092453590579454),
 'train_total_affi_coeff_mean': np.float64(0.29317947251102905),
 'train_total_affi_coeff_std': np.float64(0.17060049695686635),
 'train_learning_rate_mean': np.float64(0.00043141474416372774),
 'train_learning_rate_std': np.float64(1.6263032587282567e-19),
 'val_total_loss_mean': np.float64(1.013706842350887),
 'val_total_loss_std': np.

In [15]:
# perform a training iteration
print("=" * 60)
print("Training...")
print("=" * 60)

# train an epoch
epoch = 0
train_losses, train_preds = maml_model.run_train_iter(train_batch, epoch)

print("Training results:")
print(f"Total Loss: {train_losses['total_loss']:.4f}")
print(f"MDN Loss: {train_losses['mdn_loss']:.4f}")
print(f"Affinity Coefficient: {train_losses['affi_coeffs']:.4f}")
print(f"Atom Loss: {train_losses['atom_loss']:.4f}")
print(f"Bond Loss: {train_losses['bond_loss']:.4f}")
print(f"Total Affinity Coefficient: {train_losses['total_affi_coeff']:.4f}")
print(f"Learning Rate: {train_losses['learning_rate']:.6f}")
print(f"Predicted values shape: {train_preds.shape}")

Training...
Training results:
Total Loss: 0.9161
MDN Loss: 0.9160
Affinity Coefficient: 0.6065
Atom Loss: 0.0022
Bond Loss: 0.1107
Total Affinity Coefficient: 0.6062
Learning Rate: 0.001000
Predicted values shape: (100,)


In [16]:
# perform validation iteration
print("=" * 60)
print("Evaluating...")
print("=" * 60)

# validation
val_losses, val_preds = maml_model.run_validation_iter(val_batch)

print("Validation results:")
print(f"Total Loss: {val_losses['total_loss']:.4f}")
print(f"MDN Loss: {val_losses['mdn_loss']:.4f}")
print(f"Affinity Coefficient: {val_losses['affi_coeffs']:.4f}")
print(f"Atom Loss: {val_losses['atom_loss']:.4f}")
print(f"Bond Loss: {val_losses['bond_loss']:.4f}")
print(f"Total Affinity Coefficient: {val_losses['total_affi_coeff']:.4f}")
print(f"Predicted values shape: {val_preds.shape}")


Evaluating...
Validation results:
Total Loss: 1.2348
MDN Loss: 1.2346
Affinity Coefficient: 0.5213
Atom Loss: 0.0017
Bond Loss: 0.1535
Total Affinity Coefficient: 0.5213
Predicted values shape: (192,)


In [17]:
# demonstrate multiple epochs of training loop
print("=" * 60)
print("Demonstrate multiple epochs of training loop...")
print("=" * 60)

num_epochs = 3
train_history = []
val_history = []

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch + 1}/{num_epochs}")
    print("-" * 40)
    
    # training
    train_losses, train_preds = maml_model.run_train_iter(train_batch, epoch)
    train_history.append(train_losses)
    
    # validation
    val_losses, val_preds = maml_model.run_validation_iter(val_batch)
    val_history.append(val_losses)
    
    print(f"Train - Total Loss: {train_losses['total_loss']:.4f}, "
          f"Affinity Coeff: {train_losses['total_affi_coeff']:.4f}, "
          f"LR: {train_losses['learning_rate']:.6f}")
    print(f"Val   - Total Loss: {val_losses['total_loss']:.4f}, "
          f"Affinity Coeff: {val_losses['total_affi_coeff']:.4f}")

print("\nTraining completed!")


Demonstrate multiple epochs of training loop...

Epoch 1/3
----------------------------------------
Train - Total Loss: 0.8793, Affinity Coeff: 0.5860, LR: 0.001000
Val   - Total Loss: 1.2373, Affinity Coeff: 0.5264

Epoch 2/3
----------------------------------------
Train - Total Loss: 0.8457, Affinity Coeff: 0.5949, LR: 0.001000
Val   - Total Loss: 1.2400, Affinity Coeff: 0.5334

Epoch 3/3
----------------------------------------
Train - Total Loss: 0.8192, Affinity Coeff: 0.5962, LR: 0.000998
Val   - Total Loss: 1.2412, Affinity Coeff: 0.5317

Training completed!


In [18]:
# analyze the training process and results
print("=" * 60)
print("Training process analysis")
print("=" * 60)

# extract the key metrics
train_total_losses = [h['total_loss'].item() for h in train_history]
val_total_losses = [h['total_loss'].item() for h in val_history]
train_affi_coeffs = [h['total_affi_coeff'].item() for h in train_history]
val_affi_coeffs = [h['total_affi_coeff'].item() for h in val_history]

print("Training loss changes:")
for i, loss in enumerate(train_total_losses):
    print(f"  Epoch {i+1}: {loss:.4f}")

print("\nValidation loss changes:")
for i, loss in enumerate(val_total_losses):
    print(f"  Epoch {i+1}: {loss:.4f}")

# check the learnable inner-loop learning rates
flag = 0
if maml_model.learnable_inner_opt_params:
    print(f"\nLearnable inner-loop learning rates:")
    for i, lr in enumerate(maml_model.lrs):
        print(f"  Parameter group {i}: {lr.data.item():.6f}")
        flag += 1
        if flag>=3:
            break

print(f"\nCurrent outer-loop learning rate: {maml_model.scheduler.get_lr()[0]:.6f}")
print(f"Total number of model parameters: {sum(p.numel() for p in maml_model.parameters() if p.requires_grad)}")


Training process analysis
Training loss changes:
  Epoch 1: 0.8793
  Epoch 2: 0.8457
  Epoch 3: 0.8192

Validation loss changes:
  Epoch 1: 1.2373
  Epoch 2: 1.2400
  Epoch 3: 1.2412

Learnable inner-loop learning rates:
  Parameter group 0: 0.003903
  Parameter group 1: 5.531588
  Parameter group 2: 3.829666

Current outer-loop learning rate: 0.000998
Total number of model parameters: 1050225


### 总结

1. **数据准备**: 使用`TaskSampler`采样任务，通过`transform`和`task_collate_fn`处理数据格式
2. **模型初始化**: 创建`MAMLRegressor`实例并加载预训练权重
3. **训练过程**: 使用`run_train_iter`执行训练，使用`run_validation_iter`执行验证
4. **结果分析**: 监控损失变化、亲和力系数、学习率等关键指标
