In [5]:
import os
import argparse
import random
import yaml
import logging
from functools import partial
import numpy as np
from collections import namedtuple, Counter



In [6]:
!pip install torch==2.3.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

import torch

import torch.utils.data

from torch.utils.data.sampler import SubsetRandomSampler

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


Looking in indexes: https://download.pytorch.org/whl/cu121


In [7]:
!pip install pyyaml tqdm tensorboardX scikit-learn ogb torchdata



In [8]:
!pip install  dgl -f https://data.dgl.ai/wheels/torch-2.3/cu121/repo.html
from dgl.dataloading import GraphDataLoader

from dgl.data import (
    load_data,
    TUDataset,
    CoraGraphDataset,
    CiteseerGraphDataset,
    PubmedGraphDataset
)

Looking in links: https://data.dgl.ai/wheels/torch-2.3/cu121/repo.html


In [9]:
from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling
import dgl


In [10]:
print("PyTorch GPU available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("PyTorch GPU device count:", torch.cuda.device_count())
    print("PyTorch GPU device name:", torch.cuda.get_device_name(0))

PyTorch GPU available: True
PyTorch GPU device count: 1
PyTorch GPU device name: Tesla T4


In [11]:
def load_best_configs(args, path):
    with open(path, "r") as f:
        configs = yaml.load(f, yaml.FullLoader)

    if args.dataset not in configs:
        logging.info("Best args not found")
        return args

    logging.info("Using best configs")
    configs = configs[args.dataset]

    for k, v in configs.items():
        if "lr" in k or "weight_decay" in k:
            v = float(v)
        setattr(args, k, v)
    print("------ Use best configs ------")
    return args

In [12]:
def load_graph_classification_dataset(dataset_name, deg4feat=False):
    dataset_name = dataset_name.upper()
    dataset = TUDataset(dataset_name)
    graph, _ = dataset[0]

    if "attr" not in graph.ndata:
        if "node_labels" in graph.ndata and not deg4feat:
            print("Use node label as node features")
            feature_dim = 0
            for g, _ in dataset:
                feature_dim = max(feature_dim, g.ndata["node_labels"].max().item())

            feature_dim += 1
            for g, l in dataset:
                node_label = g.ndata["node_labels"].view(-1)
                feat = F.one_hot(node_label, num_classes=feature_dim).float()
                g.ndata["attr"] = feat
        else:
            print("Using degree as node features")
            feature_dim = 0
            degrees = []
            for g, _ in dataset:
                feature_dim = max(feature_dim, g.in_degrees().max().item())
                degrees.extend(g.in_degrees().tolist())
            MAX_DEGREES = 400

            oversize = 0
            for d, n in Counter(degrees).items():
                if d > MAX_DEGREES:
                    oversize += n
            # print(f"N > {MAX_DEGREES}, #NUM: {oversize}, ratio: {oversize/sum(degrees):.8f}")
            feature_dim = min(feature_dim, MAX_DEGREES)

            feature_dim += 1
            for g, l in dataset:
                degrees = g.in_degrees()
                degrees[degrees > MAX_DEGREES] = MAX_DEGREES

                feat = F.one_hot(degrees, num_classes=feature_dim).float()
                g.ndata["attr"] = feat
    else:
        print("******** Use `attr` as node features ********")
        feature_dim = graph.ndata["attr"].shape[1]

    labels = torch.tensor([x[1] for x in dataset])

    num_classes = torch.max(labels).item() + 1
    dataset = [(g.remove_self_loop().add_self_loop(), y) for g, y in dataset]

    print(f"******** # Num Graphs: {len(dataset)}, # Num Feat: {feature_dim}, # Num Classes: {num_classes} ********")

    return dataset, (feature_dim, num_classes)

In [13]:
def collate_fn(batch):
    # graphs = [x[0].add_self_loop() for x in batch]
    graphs = [x[0] for x in batch]
    labels = [x[1] for x in batch]
    batch_g = dgl.batch(graphs)
    labels = torch.cat(labels, dim=0)
    return batch_g, labels

In [14]:
def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.determinstic = True

In [15]:
from tensorboardX import SummaryWriter


In [16]:
class TBLogger(object):
    def __init__(self, log_path="./logging_data", name="run"):
        super(TBLogger, self).__init__()

        if not os.path.exists(log_path):
            os.makedirs(log_path, exist_ok=True)

        self.last_step = 0
        self.log_path = log_path
        raw_name = os.path.join(log_path, name)
        name = raw_name
        for i in range(1000):
            name = raw_name + str(f"_{i}")
            if not os.path.exists(name):
                break
        self.writer = SummaryWriter(logdir=name)

    def note(self, metrics, step=None):
        if step is None:
            step = self.last_step
        for key, value in metrics.items():
            self.writer.add_scalar(key, value, step)
        self.last_step = step

    def finish(self):
        self.writer.close()

In [17]:
from typing import Optional


In [18]:
def create_activation(name):
    if name == "relu":
        return nn.ReLU()
    elif name == "gelu":
        return nn.GELU()
    elif name == "prelu":
        return nn.PReLU()
    elif name is None:
        return nn.Identity()
    elif name == "elu":
        return nn.ELU()
    else:
        raise NotImplementedError(f"{name} is not implemented.")

In [19]:
class GATConv(nn.Module):
    def __init__(self,
                 in_feats,
                 out_feats,
                 num_heads,
                 feat_drop=0.,
                 attn_drop=0.,
                 negative_slope=0.2,
                 residual=False,
                 activation=None,
                 allow_zero_in_degree=False,
                 bias=True,
                 norm=None,
                 concat_out=True):
        super(GATConv, self).__init__()
        self._num_heads = num_heads
        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
        self._out_feats = out_feats
        self._allow_zero_in_degree = allow_zero_in_degree
        self._concat_out = concat_out

        if isinstance(in_feats, tuple):
            self.fc_src = nn.Linear(
                self._in_src_feats, out_feats * num_heads, bias=False)
            self.fc_dst = nn.Linear(
                self._in_dst_feats, out_feats * num_heads, bias=False)
        else:
            self.fc = nn.Linear(
                self._in_src_feats, out_feats * num_heads, bias=False)
        self.attn_l = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_feats)))
        self.attn_r = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_feats)))
        self.feat_drop = nn.Dropout(feat_drop)
        self.attn_drop = nn.Dropout(attn_drop)
        self.leaky_relu = nn.LeakyReLU(negative_slope)
        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(size=(num_heads * out_feats,)))
        else:
            self.register_buffer('bias', None)
        if residual:
            if self._in_dst_feats != out_feats * num_heads:
                self.res_fc = nn.Linear(
                    self._in_dst_feats, num_heads * out_feats, bias=False)
            else:
                self.res_fc = nn.Identity()
        else:
            self.register_buffer('res_fc', None)
        self.reset_parameters()
        self.activation = activation
        # if norm is not None:
        #     self.norm = norm(num_heads * out_feats)
        # else:
        #     self.norm = None

        self.norm = norm
        if norm is not None:
            self.norm = norm(num_heads * out_feats)

    def reset_parameters(self):
        """

        Description
        -----------
        Reinitialize learnable parameters.

        Note
        ----
        The fc weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.
        The attention weights are using xavier initialization method.
        """
        gain = nn.init.calculate_gain('relu')
        if hasattr(self, 'fc'):
            nn.init.xavier_normal_(self.fc.weight, gain=gain)
        else:
            nn.init.xavier_normal_(self.fc_src.weight, gain=gain)
            nn.init.xavier_normal_(self.fc_dst.weight, gain=gain)
        nn.init.xavier_normal_(self.attn_l, gain=gain)
        nn.init.xavier_normal_(self.attn_r, gain=gain)
        if self.bias is not None:
            nn.init.constant_(self.bias, 0)
        if isinstance(self.res_fc, nn.Linear):
            nn.init.xavier_normal_(self.res_fc.weight, gain=gain)

    def set_allow_zero_in_degree(self, set_value):
        self._allow_zero_in_degree = set_value

    def forward(self, graph, feat, get_attention=False):
        with graph.local_scope():
            if not self._allow_zero_in_degree:
                if (graph.in_degrees() == 0).any():
                    raise RuntimeError('There are 0-in-degree nodes in the graph, '
                                   'output for those nodes will be invalid. '
                                   'This is harmful for some applications, '
                                   'causing silent performance regression. '
                                   'Adding self-loop on the input graph by '
                                   'calling `g = dgl.add_self_loop(g)` will resolve '
                                   'the issue. Setting ``allow_zero_in_degree`` '
                                   'to be `True` when constructing this module will '
                                   'suppress the check and let the code run.')

            if isinstance(feat, tuple):
                src_prefix_shape = feat[0].shape[:-1]
                dst_prefix_shape = feat[1].shape[:-1]
                h_src = self.feat_drop(feat[0])
                h_dst = self.feat_drop(feat[1])
                if not hasattr(self, 'fc_src'):
                    feat_src = self.fc(h_src).view(
                        *src_prefix_shape, self._num_heads, self._out_feats)
                    feat_dst = self.fc(h_dst).view(
                        *dst_prefix_shape, self._num_heads, self._out_feats)
                else:
                    feat_src = self.fc_src(h_src).view(
                        *src_prefix_shape, self._num_heads, self._out_feats)
                    feat_dst = self.fc_dst(h_dst).view(
                        *dst_prefix_shape, self._num_heads, self._out_feats)
            else:
                src_prefix_shape = dst_prefix_shape = feat.shape[:-1]
                h_src = h_dst = self.feat_drop(feat)
                feat_src = feat_dst = self.fc(h_src).view(
                    *src_prefix_shape, self._num_heads, self._out_feats)
                if graph.is_block:
                    feat_dst = feat_src[:graph.number_of_dst_nodes()]
                    h_dst = h_dst[:graph.number_of_dst_nodes()]
                    dst_prefix_shape = (graph.number_of_dst_nodes(),) + dst_prefix_shape[1:]
            # NOTE: GAT paper uses "first concatenation then linear projection"
            # to compute attention scores, while ours is "first projection then
            # addition", the two approaches are mathematically equivalent:
            # We decompose the weight vector a mentioned in the paper into
            # [a_l || a_r], then
            # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j
            # Our implementation is much efficient because we do not need to
            # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
            # addition could be optimized with DGL's built-in function u_add_v,
            # which further speeds up computation and saves memory footprint.
            el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
            er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
            graph.srcdata.update({'ft': feat_src, 'el': el})
            graph.dstdata.update({'er': er})
            # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
            graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
            e = self.leaky_relu(graph.edata.pop('e'))
            # e[e == 0] = -1e3
            # e = graph.edata.pop('e')
            # compute softmax
            graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
            # message passing
            graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
                             fn.sum('m', 'ft'))
            rst = graph.dstdata['ft']

            # bias
            if self.bias is not None:
                rst = rst + self.bias.view(
                    *((1,) * len(dst_prefix_shape)), self._num_heads, self._out_feats)

            # residual
            if self.res_fc is not None:
                # Use -1 rather than self._num_heads to handle broadcasting
                resval = self.res_fc(h_dst).view(*dst_prefix_shape, -1, self._out_feats)
                rst = rst + resval

            if self._concat_out:
                rst = rst.flatten(1)
            else:
                rst = torch.mean(rst, dim=1)

            if self.norm is not None:
                rst = self.norm(rst)

            # activation
            if self.activation:
                rst = self.activation(rst)

            if get_attention:
                return rst, graph.edata['a']
            else:
                return rst

In [20]:

class GAT(nn.Module):
    def __init__(self,
                 in_dim,
                 num_hidden,
                 out_dim,
                 num_layers,
                 nhead,
                 nhead_out,
                 activation,
                 feat_drop,
                 attn_drop,
                 negative_slope,
                 residual,
                 norm,
                 concat_out=False,
                 encoding=False
                 ):
        super(GAT, self).__init__()
        self.out_dim = out_dim
        self.num_heads = nhead
        self.num_layers = num_layers
        self.gat_layers = nn.ModuleList()
        self.activation = activation
        self.concat_out = concat_out

        last_activation = create_activation(activation) if encoding else None
        last_residual = (encoding and residual)
        last_norm = norm if encoding else None

        if num_layers == 1:
            self.gat_layers.append(GATConv(
                in_dim, out_dim, nhead_out,
                feat_drop, attn_drop, negative_slope, last_residual, norm=last_norm, concat_out=concat_out))
        else:
            # input projection (no residual)
            self.gat_layers.append(GATConv(
                in_dim, num_hidden, nhead,
                feat_drop, attn_drop, negative_slope, residual, create_activation(activation), norm=norm, concat_out=concat_out))
            # hidden layers
            for l in range(1, num_layers - 1):
                # due to multi-head, the in_dim = num_hidden * num_heads
                self.gat_layers.append(GATConv(
                    num_hidden * nhead, num_hidden, nhead,
                    feat_drop, attn_drop, negative_slope, residual, create_activation(activation), norm=norm, concat_out=concat_out))
            # output projection
            self.gat_layers.append(GATConv(
                num_hidden * nhead, out_dim, nhead_out,
                feat_drop, attn_drop, negative_slope, last_residual, activation=last_activation, norm=last_norm, concat_out=concat_out))

        # if norm is not None:
        #     self.norms = nn.ModuleList([
        #         norm(num_hidden * nhead)
        #         for _ in range(num_layers - 1)
        #     ])
        #     if self.concat_out:
        #         self.norms.append(norm(num_hidden * nhead))
        # else:
        #     self.norms = None

        self.head = nn.Identity()

    # def forward(self, g, inputs):
    #     h = inputs
    #     for l in range(self.num_layers):
    #         h = self.gat_layers[l](g, h)
    #         if l != self.num_layers - 1:
    #             h = h.flatten(1)
    #             if self.norms is not None:
    #                 h = self.norms[l](h)
    #     # output projection
    #     if self.concat_out:
    #         out = h.flatten(1)
    #         if self.norms is not None:
    #             out = self.norms[-1](out)
    #     else:
    #         out = h.mean(1)
    #     return self.head(out)

    def forward(self, g, inputs, return_hidden=False):
        h = inputs
        hidden_list = []
        for l in range(self.num_layers):
            h = self.gat_layers[l](g, h)
            hidden_list.append(h)
            # h = h.flatten(1)
        # output projection
        if return_hidden:
            return self.head(h), hidden_list
        else:
            return self.head(h)

    def reset_classifier(self, num_classes):
        self.head = nn.Linear(self.num_heads * self.out_dim, num_classes)

In [21]:
class NormLayer(nn.Module):
    def __init__(self, hidden_dim, norm_type):
        super().__init__()
        if norm_type == "batchnorm":
            self.norm = nn.BatchNorm1d(hidden_dim)
        elif norm_type == "layernorm":
            self.norm = nn.LayerNorm(hidden_dim)
        elif norm_type == "graphnorm":
            self.norm = norm_type
            self.weight = nn.Parameter(torch.ones(hidden_dim))
            self.bias = nn.Parameter(torch.zeros(hidden_dim))

            self.mean_scale = nn.Parameter(torch.ones(hidden_dim))
        else:
            raise NotImplementedError

    def forward(self, graph, x):
        tensor = x
        if self.norm is not None and type(self.norm) != str:
            return self.norm(tensor)
        elif self.norm is None:
            return tensor

        batch_list = graph.batch_num_nodes
        batch_size = len(batch_list)
        batch_list = torch.Tensor(batch_list).long().to(tensor.device)
        batch_index = torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list)
        batch_index = batch_index.view((-1,) + (1,) * (tensor.dim() - 1)).expand_as(tensor)
        mean = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
        mean = mean.scatter_add_(0, batch_index, tensor)
        mean = (mean.T / batch_list).T
        mean = mean.repeat_interleave(batch_list, dim=0)

        sub = tensor - mean * self.mean_scale

        std = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
        std = std.scatter_add_(0, batch_index, sub.pow(2))
        std = ((std.T / batch_list).T + 1e-6).sqrt()
        std = std.repeat_interleave(batch_list, dim=0)
        return self.weight * sub / std + self.bias

In [22]:
def create_norm(name):
    if name == "layernorm":
        return nn.LayerNorm
    elif name == "batchnorm":
        return nn.BatchNorm1d
    elif name == "graphnorm":
        return partial(NormLayer, norm_type="groupnorm")
    else:
        return nn.Identity


In [23]:
class DotGatConv(nn.Module):
    def __init__(self,
                 in_feats,
                 out_feats,
                 num_heads,
                 feat_drop,
                 attn_drop,
                 residual,
                 activation=None,
                 norm=None,
                 concat_out=False,
                 allow_zero_in_degree=False):
        super(DotGatConv, self).__init__()
        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
        self._out_feats = out_feats
        self._allow_zero_in_degree = allow_zero_in_degree
        self._num_heads = num_heads
        self._concat_out = concat_out

        self.feat_drop = nn.Dropout(feat_drop)
        self.attn_drop = nn.Dropout(attn_drop) if attn_drop > 0 else nn.Identity()
        self.activation = activation

        if isinstance(in_feats, tuple):
            self.fc_src = nn.Linear(self._in_src_feats, self._out_feats*self._num_heads, bias=False)
            self.fc_dst = nn.Linear(self._in_dst_feats, self._out_feats*self._num_heads, bias=False)
        else:
            self.fc = nn.Linear(self._in_src_feats, self._out_feats*self._num_heads, bias=False)

        if residual:
            if self._in_dst_feats != out_feats * num_heads:
                self.res_fc = nn.Linear(
                    self._in_dst_feats, num_heads * out_feats, bias=False)
            else:
                self.res_fc = nn.Identity()
        else:
            self.register_buffer('res_fc', None)

        self.norm = norm
        if norm is not None:
            self.norm = norm(num_heads * out_feats)

    def forward(self, graph, feat, get_attention=False):
        r"""

        Description
        -----------
        Apply dot product version of self attention in GCN.

        Parameters
        ----------
        graph: DGLGraph or bi_partities graph
            The graph
        feat: torch.Tensor or pair of torch.Tensor
            If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})` where
            :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
            If a pair of torch.Tensor is given, the pair must contain two tensors of shape
            :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
        get_attention : bool, optional
            Whether to return the attention values. Default to False.

        Returns
        -------
        torch.Tensor
            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` is size
            of output feature.
        torch.Tensor, optional
            The attention values of shape :math:`(E, 1)`, where :math:`E` is the number of
            edges. This is returned only when :attr:`get_attention` is ``True``.

        Raises
        ------
        DGLError
            If there are 0-in-degree nodes in the input graph, it will raise DGLError
            since no message will be passed to those nodes. This will cause invalid output.
            The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.
        """

        graph = graph.local_var()

        if not self._allow_zero_in_degree:
            if (graph.in_degrees() == 0).any():
                raise ValueError('There are 0-in-degree nodes in the graph, '
                               'output for those nodes will be invalid. '
                               'This is harmful for some applications, '
                               'causing silent performance regression. '
                               'Adding self-loop on the input graph by '
                               'calling `g = dgl.add_self_loop(g)` will resolve '
                               'the issue. Setting ``allow_zero_in_degree`` '
                               'to be `True` when constructing this module will '
                               'suppress the check and let the code run.')

        # check if feat is a tuple
        if isinstance(feat, tuple):
            h_src = feat[0]
            h_dst = feat[1]
            feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats)
            feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats)
            print("!! tuple input in DotGAT !!")
        else:
            feat = self.feat_drop(feat)
            h_src = feat
            feat_src = feat_dst = self.fc(h_src).view(-1, self._num_heads, self._out_feats)
            if graph.is_block:
                feat_dst = feat_src[:graph.number_of_dst_nodes()]

        # Assign features to nodes
        graph.srcdata.update({'ft': feat_src})
        graph.dstdata.update({'ft': feat_dst})

        # Step 1. dot product
        graph.apply_edges(fn.u_dot_v('ft', 'ft', 'a'))

        # Step 2. edge softmax to compute attention scores
        graph.edata['sa'] = edge_softmax(graph, graph.edata['a'] / self._out_feats**0.5)
        graph.edata["sa"] = self.attn_drop(graph.edata["sa"])
        # Step 3. Broadcast softmax value to each edge, and aggregate dst node
        graph.update_all(fn.u_mul_e('ft', 'sa', 'attn'), fn.sum('attn', 'agg_u'))

        # output results to the destination nodes
        rst = graph.dstdata['agg_u']

        if self.res_fc is not None:
            # Use -1 rather than self._num_heads to handle broadcasting
            batch_size = feat.shape[0]
            resval = self.res_fc(h_dst).view(batch_size, -1, self._out_feats)
            rst = rst + resval

        if self._concat_out:
            rst = rst.flatten(1)
        else:
            rst = torch.mean(rst, dim=1)

        if self.norm is not None:
            rst = self.norm(rst)

        # activation
        if self.activation:
            rst = self.activation(rst)

        if get_attention:
            return rst, graph.edata['sa']
        else:
            return rst

In [24]:
class DotGAT(nn.Module):
    def __init__(self,
                 in_dim,
                 num_hidden,
                 out_dim,
                 num_layers,
                 nhead,
                 nhead_out,
                 activation,
                 feat_drop,
                 attn_drop,
                 residual,
                 norm,
                 concat_out=False,
                 encoding=False
                 ):
        super(DotGAT, self).__init__()
        self.out_dim = out_dim
        self.num_heads = nhead
        self.num_layers = num_layers
        self.gat_layers = nn.ModuleList()
        self.activation = activation
        self.concat_out = concat_out

        last_activation = create_activation(activation) if encoding else None
        last_residual = (encoding and residual)
        last_norm = norm if encoding else None

        if num_layers == 1:
            self.gat_layers.append(DotGatConv(
                in_dim, out_dim, nhead_out,
                feat_drop, attn_drop, last_residual, norm=last_norm, concat_out=concat_out))
        else:
            # input projection (no residual)
            self.gat_layers.append(DotGatConv(
                in_dim, num_hidden, nhead,
                feat_drop, attn_drop, residual, create_activation(activation), norm=norm, concat_out=concat_out))
            # hidden layers
            for l in range(1, num_layers - 1):
                # due to multi-head, the in_dim = num_hidden * num_heads
                self.gat_layers.append(DotGatConv(
                    num_hidden * nhead, num_hidden, nhead,
                    feat_drop, attn_drop, residual, create_activation(activation), norm=norm, concat_out=concat_out))
            # output projection
            self.gat_layers.append(DotGatConv(
                num_hidden * nhead, out_dim, nhead_out,
                feat_drop, attn_drop, last_residual, activation=last_activation, norm=last_norm, concat_out=concat_out))

        self.head = nn.Identity()

    def forward(self, g, inputs, return_hidden=False):
        h = inputs
        hidden_list = []
        for l in range(self.num_layers):
            h = self.gat_layers[l](g, h)
            hidden_list.append(h)
            # h = h.flatten(1)
        # output projection
        if return_hidden:
            return self.head(h), hidden_list
        else:
            return self.head(h)

    def reset_classifier(self, num_classes):
        self.head = nn.Linear(self.num_heads * self.out_dim, num_classes)

In [25]:
class MLP(nn.Module):
    """MLP with linear output"""
    def __init__(self, num_layers, input_dim, hidden_dim, output_dim, activation="relu", norm="batchnorm"):
        super(MLP, self).__init__()
        self.linear_or_not = True  # default is linear model
        self.num_layers = num_layers
        self.output_dim = output_dim

        if num_layers < 1:
            raise ValueError("number of layers should be positive!")
        elif num_layers == 1:
            # Linear model
            self.linear = nn.Linear(input_dim, output_dim)
        else:
            # Multi-layer model
            self.linear_or_not = False
            self.linears = torch.nn.ModuleList()
            self.norms = torch.nn.ModuleList()
            self.activations = torch.nn.ModuleList()

            self.linears.append(nn.Linear(input_dim, hidden_dim))
            for layer in range(num_layers - 2):
                self.linears.append(nn.Linear(hidden_dim, hidden_dim))
            self.linears.append(nn.Linear(hidden_dim, output_dim))

            for layer in range(num_layers - 1):
                self.norms.append(create_norm(norm)(hidden_dim))
                self.activations.append(create_activation(activation))

    def forward(self, x):
        if self.linear_or_not:
            # If linear model
            return self.linear(x)
        else:
            # If MLP
            h = x
            for i in range(self.num_layers - 1):
                h = self.norms[i](self.linears[i](h))
                h = self.activations[i](h)
            return self.linears[-1](h)

In [26]:
class ApplyNodeFunc(nn.Module):
    """Update the node feature hv with MLP, BN and ReLU."""
    def __init__(self, mlp, norm="batchnorm", activation="relu"):
        super(ApplyNodeFunc, self).__init__()
        self.mlp = mlp
        norm_func = create_norm(norm)
        if norm_func is None:
            self.norm = nn.Identity()
        else:
            self.norm = norm_func(self.mlp.output_dim)
        self.act = create_activation(activation)

    def forward(self, h):
        h = self.mlp(h)
        h = self.norm(h)
        h = self.act(h)
        return h

In [27]:
import dgl.function as fn
from dgl.utils import expand_as_pair


In [28]:
class GINConv(nn.Module):
    def __init__(self,
                 in_dim,
                 out_dim,
                 apply_func,
                 aggregator_type="sum",
                 init_eps=0,
                 learn_eps=False,
                 residual=False,
                 ):
        super().__init__()
        self._in_feats = in_dim
        self._out_feats = out_dim
        self.apply_func = apply_func

        self._aggregator_type = aggregator_type
        if aggregator_type == 'sum':
            self._reducer = fn.sum
        elif aggregator_type == 'max':
            self._reducer = fn.max
        elif aggregator_type == 'mean':
            self._reducer = fn.mean
        else:
            raise KeyError('Aggregator type {} not recognized.'.format(aggregator_type))

        if learn_eps:
            self.eps = torch.nn.Parameter(torch.FloatTensor([init_eps]))
        else:
            self.register_buffer('eps', torch.FloatTensor([init_eps]))

        if residual:
            if self._in_feats != self._out_feats:
                self.res_fc = nn.Linear(
                    self._in_feats, self._out_feats, bias=False)
                print("! Linear Residual !")
            else:
                print("Identity Residual ")
                self.res_fc = nn.Identity()
        else:
            self.register_buffer('res_fc', None)

    def forward(self, graph, feat):
        with graph.local_scope():
            aggregate_fn = fn.copy_u('h', 'm')

            feat_src, feat_dst = expand_as_pair(feat, graph)
            graph.srcdata['h'] = feat_src
            graph.update_all(aggregate_fn, self._reducer('m', 'neigh'))
            rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh']
            if self.apply_func is not None:
                rst = self.apply_func(rst)

            if self.res_fc is not None:
                rst = rst + self.res_fc(feat_dst)

            return rst

In [29]:
class GIN(nn.Module):
    def __init__(self,
                 in_dim,
                 num_hidden,
                 out_dim,
                 num_layers,
                 dropout,
                 activation,
                 residual,
                 norm,
                 encoding=False,
                 learn_eps=False,
                 aggr="sum",
                 ):
        super(GIN, self).__init__()
        self.out_dim = out_dim
        self.num_layers = num_layers
        self.layers = nn.ModuleList()
        self.activation = activation
        self.dropout = dropout

        last_activation = create_activation(activation) if encoding else None
        last_residual = encoding and residual
        last_norm = norm if encoding else None

        if num_layers == 1:
            apply_func = MLP(2, in_dim, num_hidden, out_dim, activation=activation, norm=norm)
            if last_norm:
                apply_func = ApplyNodeFunc(apply_func, norm=norm, activation=activation)
            self.layers.append(GINConv(in_dim, out_dim, apply_func, init_eps=0, learn_eps=learn_eps, residual=last_residual))
        else:
            # input projection (no residual)
            self.layers.append(GINConv(
                in_dim,
                num_hidden,
                ApplyNodeFunc(MLP(2, in_dim, num_hidden, num_hidden, activation=activation, norm=norm), activation=activation, norm=norm),
                init_eps=0,
                learn_eps=learn_eps,
                residual=residual)
                )
            # hidden layers
            for l in range(1, num_layers - 1):
                # due to multi-head, the in_dim = num_hidden * num_heads
                self.layers.append(GINConv(
                    num_hidden, num_hidden,
                    ApplyNodeFunc(MLP(2, num_hidden, num_hidden, num_hidden, activation=activation, norm=norm), activation=activation, norm=norm),
                    init_eps=0,
                    learn_eps=learn_eps,
                    residual=residual)
                )
            # output projection
            apply_func = MLP(2, num_hidden, num_hidden, out_dim, activation=activation, norm=norm)
            if last_norm:
                apply_func = ApplyNodeFunc(apply_func, activation=activation, norm=norm)

            self.layers.append(GINConv(num_hidden, out_dim, apply_func, init_eps=0, learn_eps=learn_eps, residual=last_residual))

        self.head = nn.Identity()

    def forward(self, g, inputs, return_hidden=False):
        h = inputs
        hidden_list = []
        for l in range(self.num_layers):
            h = F.dropout(h, p=self.dropout, training=self.training)
            h = self.layers[l](g, h)
            hidden_list.append(h)
        # output projection
        if return_hidden:
            return self.head(h), hidden_list
        else:
            return self.head(h)

    def reset_classifier(self, num_classes):
        self.head = nn.Linear(self.out_dim, num_classes)

In [30]:
class GraphConv(nn.Module):
    def __init__(self,
                 in_dim,
                 out_dim,
                 norm=None,
                 activation=None,
                 residual=True,
                 ):
        super().__init__()
        self._in_feats = in_dim
        self._out_feats = out_dim

        self.fc = nn.Linear(in_dim, out_dim)

        if residual:
            if self._in_feats != self._out_feats:
                self.res_fc = nn.Linear(
                    self._in_feats, self._out_feats, bias=False)
                print("! Linear Residual !")
            else:
                print("Identity Residual ")
                self.res_fc = nn.Identity()
        else:
            self.register_buffer('res_fc', None)

        # if norm == "batchnorm":
        #     self.norm = nn.BatchNorm1d(out_dim)
        # elif norm == "layernorm":
        #     self.norm = nn.LayerNorm(out_dim)
        # else:
        #     self.norm = None

        self.norm = norm
        if norm is not None:
            self.norm = norm(out_dim)
        self._activation = activation

        self.reset_parameters()

    def reset_parameters(self):
        self.fc.reset_parameters()

    def forward(self, graph, feat):
        with graph.local_scope():
            aggregate_fn = fn.copy_u('h', 'm')
            # if edge_weight is not None:
            #     assert edge_weight.shape[0] == graph.number_of_edges()
            #     graph.edata['_edge_weight'] = edge_weight
            #     aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm')

            # (BarclayII) For RGCN on heterogeneous graphs we need to support GCN on bipartite.
            feat_src, feat_dst = expand_as_pair(feat, graph)
            # if self._norm in ['left', 'both']:
            degs = graph.out_degrees().float().clamp(min=1)
            norm = torch.pow(degs, -0.5)
            shp = norm.shape + (1,) * (feat_src.dim() - 1)
            norm = torch.reshape(norm, shp)
            feat_src = feat_src * norm

            # if self._in_feats > self._out_feats:
            #     # mult W first to reduce the feature size for aggregation.
            #     # if weight is not None:
            #         # feat_src = th.matmul(feat_src, weight)
            #     graph.srcdata['h'] = feat_src
            #     graph.update_all(aggregate_fn, fn.sum(msg='m', out='h'))
            #     rst = graph.dstdata['h']
            # else:
            # aggregate first then mult W
            graph.srcdata['h'] = feat_src
            graph.update_all(aggregate_fn, fn.sum(msg='m', out='h'))
            rst = graph.dstdata['h']

            rst = self.fc(rst)

            # if self._norm in ['right', 'both']:
            degs = graph.in_degrees().float().clamp(min=1)
            norm = torch.pow(degs, -0.5)
            shp = norm.shape + (1,) * (feat_dst.dim() - 1)
            norm = torch.reshape(norm, shp)
            rst = rst * norm

            if self.res_fc is not None:
                rst = rst + self.res_fc(feat_dst)

            if self.norm is not None:
                rst = self.norm(rst)

            if self._activation is not None:
                rst = self._activation(rst)

            return rst

In [31]:
class GCN(nn.Module):
    def __init__(self,
                 in_dim,
                 num_hidden,
                 out_dim,
                 num_layers,
                 dropout,
                 activation,
                 residual,
                 norm,
                 encoding=False
                 ):
        super(GCN, self).__init__()
        self.out_dim = out_dim
        self.num_layers = num_layers
        self.gcn_layers = nn.ModuleList()
        self.activation = activation
        self.dropout = dropout

        last_activation = create_activation(activation) if encoding else None
        last_residual = encoding and residual
        last_norm = norm if encoding else None

        if num_layers == 1:
            self.gcn_layers.append(GraphConv(
                in_dim, out_dim, residual=last_residual, norm=last_norm, activation=last_activation))
        else:
            # input projection (no residual)
            self.gcn_layers.append(GraphConv(
                in_dim, num_hidden, residual=residual, norm=norm, activation=create_activation(activation)))
            # hidden layers
            for l in range(1, num_layers - 1):
                # due to multi-head, the in_dim = num_hidden * num_heads
                self.gcn_layers.append(GraphConv(
                    num_hidden, num_hidden, residual=residual, norm=norm, activation=create_activation(activation)))
            # output projection
            self.gcn_layers.append(GraphConv(
                num_hidden, out_dim, residual=last_residual, activation=last_activation, norm=last_norm))

        # if norm is not None:
        #     self.norms = nn.ModuleList([
        #         norm(num_hidden)
        #         for _ in range(num_layers - 1)
        #     ])
        #     if not encoding:
        #         self.norms.append(norm(out_dim))
        # else:
        #     self.norms = None
        self.norms = None
        self.head = nn.Identity()

    def forward(self, g, inputs, return_hidden=False):
        h = inputs
        hidden_list = []
        for l in range(self.num_layers):
            h = F.dropout(h, p=self.dropout, training=self.training)
            h = self.gcn_layers[l](g, h)
            if self.norms is not None and l != self.num_layers - 1:
                h = self.norms[l](h)
            hidden_list.append(h)
        # output projection
        if self.norms is not None and len(self.norms) == self.num_layers:
            h = self.norms[-1](h)
        if return_hidden:
            return self.head(h), hidden_list
        else:
            return self.head(h)

    def reset_classifier(self, num_classes):
        self.head = nn.Linear(self.out_dim, num_classes)

In [32]:
def setup_module(m_type, enc_dec, in_dim, num_hidden, out_dim, num_layers, dropout, activation, residual, norm, nhead, nhead_out, attn_drop, negative_slope=0.2, concat_out=True) -> nn.Module:
    if m_type == "gat":
        mod = GAT(
            in_dim=in_dim,
            num_hidden=num_hidden,
            out_dim=out_dim,
            num_layers=num_layers,
            nhead=nhead,
            nhead_out=nhead_out,
            concat_out=concat_out,
            activation=activation,
            feat_drop=dropout,
            attn_drop=attn_drop,
            negative_slope=negative_slope,
            residual=residual,
            norm=create_norm(norm),
            encoding=(enc_dec == "encoding"),
        )
    elif m_type == "dotgat":
        mod = DotGAT(
            in_dim=in_dim,
            num_hidden=num_hidden,
            out_dim=out_dim,
            num_layers=num_layers,
            nhead=nhead,
            nhead_out=nhead_out,
            concat_out=concat_out,
            activation=activation,
            feat_drop=dropout,
            attn_drop=attn_drop,
            residual=residual,
            norm=create_norm(norm),
            encoding=(enc_dec == "encoding"),
        )
    elif m_type == "gin":
        mod = GIN(
            in_dim=in_dim,
            num_hidden=num_hidden,
            out_dim=out_dim,
            num_layers=num_layers,
            dropout=dropout,
            activation=activation,
            residual=residual,
            norm=norm,
            encoding=(enc_dec == "encoding"),
        )
    elif m_type == "gcn":
        mod = GCN(
            in_dim=in_dim,
            num_hidden=num_hidden,
            out_dim=out_dim,
            num_layers=num_layers,
            dropout=dropout,
            activation=activation,
            residual=residual,
            norm=create_norm(norm),
            encoding=(enc_dec == "encoding")
        )
    elif m_type == "mlp":
        # * just for decoder
        mod = nn.Sequential(
            nn.Linear(in_dim, num_hidden),
            nn.PReLU(),
            nn.Dropout(0.2),
            nn.Linear(num_hidden, out_dim)
        )
    elif m_type == "linear":
        mod = nn.Linear(in_dim, out_dim)
    else:
        raise NotImplementedError

    return mod

In [33]:
def sce_loss(x, y, alpha=3):
    x = F.normalize(x, p=2, dim=-1)
    y = F.normalize(y, p=2, dim=-1)

    # loss =  - (x * y).sum(dim=-1)
    # loss = (x_h - y_h).norm(dim=1).pow(alpha)

    loss = (1 - (x * y).sum(dim=-1)).pow_(alpha)

    loss = loss.mean()
    return loss

In [34]:
def mask_edge(graph, mask_prob):
    E = graph.num_edges()

    mask_rates = torch.FloatTensor(np.ones(E) * mask_prob)
    masks = torch.bernoulli(1 - mask_rates)
    mask_idx = masks.nonzero().squeeze(1)
    return mask_idx

In [35]:
def drop_edge(graph, drop_rate, return_edges=False):
    if drop_rate <= 0:
        return graph

    n_node = graph.num_nodes()
    edge_mask = mask_edge(graph, drop_rate)
    src = graph.edges()[0]
    dst = graph.edges()[1]

    nsrc = src[edge_mask]
    ndst = dst[edge_mask]

    ng = dgl.graph((nsrc, ndst), num_nodes=n_node)
    ng = ng.add_self_loop()

    dsrc = src[~edge_mask]
    ddst = dst[~edge_mask]

    if return_edges:
        return ng, (dsrc, ddst)
    return ng

In [36]:
from itertools import chain


In [37]:
class PreModel(nn.Module):
    def __init__(
            self,
            in_dim: int,
            num_hidden: int,
            num_layers: int,
            nhead: int,
            nhead_out: int,
            activation: str,
            feat_drop: float,
            attn_drop: float,
            negative_slope: float,
            residual: bool,
            norm: Optional[str],
            mask_rate: float = 0.3,
            encoder_type: str = "gat",
            decoder_type: str = "gat",
            loss_fn: str = "sce",
            drop_edge_rate: float = 0.0,
            replace_rate: float = 0.1,
            alpha_l: float = 2,
            concat_hidden: bool = False,
         ):
        super(PreModel, self).__init__()
        self._mask_rate = mask_rate

        self._encoder_type = encoder_type
        self._decoder_type = decoder_type
        self._drop_edge_rate = drop_edge_rate
        self._output_hidden_size = num_hidden
        self._concat_hidden = concat_hidden

        self._replace_rate = replace_rate
        self._mask_token_rate = 1 - self._replace_rate

        assert num_hidden % nhead == 0
        assert num_hidden % nhead_out == 0
        if encoder_type in ("gat", "dotgat"):
            enc_num_hidden = num_hidden // nhead
            enc_nhead = nhead
        else:
            enc_num_hidden = num_hidden
            enc_nhead = 1

        dec_in_dim = num_hidden
        dec_num_hidden = num_hidden // nhead_out if decoder_type in ("gat", "dotgat") else num_hidden

        # build encoder
        self.encoder = setup_module(
            m_type=encoder_type,
            enc_dec="encoding",
            in_dim=in_dim,
            num_hidden=enc_num_hidden,
            out_dim=enc_num_hidden,
            num_layers=num_layers,
            nhead=enc_nhead,
            nhead_out=enc_nhead,
            concat_out=True,
            activation=activation,
            dropout=feat_drop,
            attn_drop=attn_drop,
            negative_slope=negative_slope,
            residual=residual,
            norm=norm,
        )

        # build decoder for attribute prediction
        self.decoder = setup_module(
            m_type=decoder_type,
            enc_dec="decoding",
            in_dim=dec_in_dim,
            num_hidden=dec_num_hidden,
            out_dim=in_dim,
            num_layers=1,
            nhead=nhead,
            nhead_out=nhead_out,
            activation=activation,
            dropout=feat_drop,
            attn_drop=attn_drop,
            negative_slope=negative_slope,
            residual=residual,
            norm=norm,
            concat_out=True,
        )

        self.enc_mask_token = nn.Parameter(torch.zeros(1, in_dim))
        if concat_hidden:
            self.encoder_to_decoder = nn.Linear(dec_in_dim * num_layers, dec_in_dim, bias=False)
        else:
            self.encoder_to_decoder = nn.Linear(dec_in_dim, dec_in_dim, bias=False)

        # * setup loss function
        self.criterion = self.setup_loss_fn(loss_fn, alpha_l)

    @property
    def output_hidden_dim(self):
        return self._output_hidden_size

    def setup_loss_fn(self, loss_fn, alpha_l):
        if loss_fn == "mse":
            criterion = nn.MSELoss()
        elif loss_fn == "sce":
            criterion = partial(sce_loss, alpha=alpha_l)
        else:
            raise NotImplementedError
        return criterion

    def encoding_mask_noise(self, g, x, mask_rate=0.3):
        num_nodes = g.num_nodes()
        perm = torch.randperm(num_nodes, device=x.device)
        num_mask_nodes = int(mask_rate * num_nodes)

        # random masking
        num_mask_nodes = int(mask_rate * num_nodes)
        mask_nodes = perm[: num_mask_nodes]
        keep_nodes = perm[num_mask_nodes: ]

        if self._replace_rate > 0:
            num_noise_nodes = int(self._replace_rate * num_mask_nodes)
            perm_mask = torch.randperm(num_mask_nodes, device=x.device)
            token_nodes = mask_nodes[perm_mask[: int(self._mask_token_rate * num_mask_nodes)]]
            noise_nodes = mask_nodes[perm_mask[-int(self._replace_rate * num_mask_nodes):]]
            noise_to_be_chosen = torch.randperm(num_nodes, device=x.device)[:num_noise_nodes]

            out_x = x.clone()
            out_x[token_nodes] = 0.0
            out_x[noise_nodes] = x[noise_to_be_chosen]
        else:
            out_x = x.clone()
            token_nodes = mask_nodes
            out_x[mask_nodes] = 0.0

        out_x[token_nodes] += self.enc_mask_token
        use_g = g.clone()

        return use_g, out_x, (mask_nodes, keep_nodes)

    def forward(self, g, x):
        # ---- attribute reconstruction ----
        loss = self.mask_attr_prediction(g, x)
        loss_item = {"loss": loss.item()}
        return loss, loss_item

    def mask_attr_prediction(self, g, x):
        pre_use_g, use_x, (mask_nodes, keep_nodes) = self.encoding_mask_noise(g, x, self._mask_rate)

        if self._drop_edge_rate > 0:
            use_g, masked_edges = drop_edge(pre_use_g, self._drop_edge_rate, return_edges=True)
        else:
            use_g = pre_use_g

        enc_rep, all_hidden = self.encoder(use_g, use_x, return_hidden=True)
        if self._concat_hidden:
            enc_rep = torch.cat(all_hidden, dim=1)

        # ---- attribute reconstruction ----
        rep = self.encoder_to_decoder(enc_rep)

        if self._decoder_type not in ("mlp", "linear"):
            # * remask, re-mask
            rep[mask_nodes] = 0

        if self._decoder_type in ("mlp", "liear") :
            recon = self.decoder(rep)
        else:
            recon = self.decoder(pre_use_g, rep)

        x_init = x[mask_nodes]
        x_rec = recon[mask_nodes]

        loss = self.criterion(x_rec, x_init)
        return loss

    def embed(self, g, x):
        rep = self.encoder(g, x)
        return rep

    @property
    def enc_params(self):
        return self.encoder.parameters()

    @property
    def dec_params(self):
        return chain(*[self.encoder_to_decoder.parameters(), self.decoder.parameters()])

In [38]:


def build_model(args):
    num_heads = args.num_heads
    num_out_heads = args.num_out_heads
    num_hidden = args.num_hidden
    num_layers = args.num_layers
    residual = args.residual
    attn_drop = args.attn_drop
    in_drop = args.in_drop
    norm = args.norm
    negative_slope = args.negative_slope
    encoder_type = args.encoder
    decoder_type = args.decoder
    mask_rate = args.mask_rate
    drop_edge_rate = args.drop_edge_rate
    replace_rate = args.replace_rate


    activation = args.activation
    loss_fn = args.loss_fn
    alpha_l = args.alpha_l
    concat_hidden = args.concat_hidden
    num_features = args.num_features


    model = PreModel(
        in_dim=num_features,
        num_hidden=num_hidden,
        num_layers=num_layers,
        nhead=num_heads,
        nhead_out=num_out_heads,
        activation=activation,
        feat_drop=in_drop,
        attn_drop=attn_drop,
        negative_slope=negative_slope,
        residual=residual,
        encoder_type=encoder_type,
        decoder_type=decoder_type,
        mask_rate=mask_rate,
        norm=norm,
        loss_fn=loss_fn,
        drop_edge_rate=drop_edge_rate,
        replace_rate=replace_rate,
        alpha_l=alpha_l,
        concat_hidden=concat_hidden,
    )
    return model


In [39]:
def create_optimizer(opt, model, lr, weight_decay, get_num_layer=None, get_layer_scale=None):
    opt_lower = opt.lower()

    parameters = model.parameters()
    opt_args = dict(lr=lr, weight_decay=weight_decay)

    opt_split = opt_lower.split("_")
    opt_lower = opt_split[-1]
    if opt_lower == "adam":
        optimizer = optim.Adam(parameters, **opt_args)
    elif opt_lower == "adamw":
        optimizer = optim.AdamW(parameters, **opt_args)
    elif opt_lower == "adadelta":
        optimizer = optim.Adadelta(parameters, **opt_args)
    elif opt_lower == "radam":
        optimizer = optim.RAdam(parameters, **opt_args)
    elif opt_lower == "sgd":
        opt_args["momentum"] = 0.9
        return optim.SGD(parameters, **opt_args)
    else:
        assert False and "Invalid optimizer"

    return optimizer

In [40]:
from tqdm import tqdm
def get_current_lr(optimizer):
    return optimizer.state_dict()["param_groups"][0]["lr"]

In [41]:
def pretrain(model, pooler, dataloaders, optimizer, max_epoch, device, scheduler, num_classes, lr_f, weight_decay_f, max_epoch_f, linear_prob=True, logger=None):
    train_loader, eval_loader = dataloaders

    epoch_iter = tqdm(range(max_epoch))
    for epoch in epoch_iter:
        model.train()
        loss_list = []
        for batch in train_loader:
            batch_g, _ = batch
            batch_g = batch_g.to(device)

            feat = batch_g.ndata["attr"]
            model.train()
            loss, loss_dict = model(batch_g, feat)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_list.append(loss.item())
            if logger is not None:
                loss_dict["lr"] = get_current_lr(optimizer)
                logger.note(loss_dict, step=epoch)
        if scheduler is not None:
            scheduler.step()
        epoch_iter.set_description(f"Epoch {epoch} | train_loss: {np.mean(loss_list):.4f}")

    return model

In [42]:
from sklearn.model_selection import StratifiedKFold, GridSearchCV
from sklearn.svm import SVC
from sklearn.metrics import f1_score

In [43]:
def evaluate_graph_embeddings_using_svm(embeddings, labels):
    result = []
    kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=0)

    for train_index, test_index in kf.split(embeddings, labels):
        x_train = embeddings[train_index]
        x_test = embeddings[test_index]
        y_train = labels[train_index]
        y_test = labels[test_index]
        params = {"C": [1e-3, 1e-2, 1e-1, 1, 10]}
        svc = SVC(random_state=42)
        clf = GridSearchCV(svc, params)
        clf.fit(x_train, y_train)

        preds = clf.predict(x_test)
        f1 = f1_score(y_test, preds, average="micro")
        result.append(f1)
    test_f1 = np.mean(result)
    test_std = np.std(result)

    return test_f1, test_std

In [44]:
def graph_classification_evaluation(model, pooler, dataloader, num_classes, lr_f, weight_decay_f, max_epoch_f, device, mute=False):
    model.eval()
    x_list = []
    y_list = []
    with torch.no_grad():
        for i, (batch_g, labels) in enumerate(dataloader):
            batch_g = batch_g.to(device)
            feat = batch_g.ndata["attr"]
            out = model.embed(batch_g, feat)
            out = pooler(batch_g, out)

            y_list.append(labels.numpy())
            x_list.append(out.cpu().numpy())
    x = np.concatenate(x_list, axis=0)
    y = np.concatenate(y_list, axis=0)
    test_f1, test_std = evaluate_graph_embeddings_using_svm(x, y)
    print(f"#Test_f1: {test_f1:.4f}±{test_std:.4f}")
    return test_f1

In [45]:
def main(args):
    device = args.device if args.device >= 0 else "cpu"
    seeds = args.seeds
    dataset_name = args.dataset
    max_epoch = args.max_epoch
    max_epoch_f = args.max_epoch_f
    num_hidden = args.num_hidden
    num_layers = args.num_layers
    encoder_type = args.encoder
    decoder_type = args.decoder
    replace_rate = args.replace_rate

    optim_type = args.optimizer
    loss_fn = args.loss_fn

    lr = args.lr
    weight_decay = args.weight_decay
    lr_f = args.lr_f
    weight_decay_f = args.weight_decay_f
    linear_prob = args.linear_prob
    load_model = args.load_model
    save_model = args.save_model
    logs = args.logging
    use_scheduler = args.scheduler
    pooling = args.pooling
    deg4feat = args.deg4feat
    batch_size = args.batch_size

    graphs, (num_features, num_classes) = load_graph_classification_dataset(dataset_name, deg4feat=deg4feat)
    args.num_features = num_features

    train_idx = torch.arange(len(graphs))
    train_sampler = SubsetRandomSampler(train_idx)

    train_loader = GraphDataLoader(graphs, sampler=train_sampler, collate_fn=collate_fn, batch_size=batch_size, pin_memory=True)
    eval_loader = GraphDataLoader(graphs, collate_fn=collate_fn, batch_size=batch_size, shuffle=False)

    if pooling == "mean":
        pooler = AvgPooling()
    elif pooling == "max":
        pooler = MaxPooling()
    elif pooling == "sum":
        pooler = SumPooling()
    else:
        raise NotImplementedError

    acc_list = []
    for i, seed in enumerate(seeds):
        print(f"####### Run {i} for seed {seed}")
        set_random_seed(seed)

        if logs:
            logger = TBLogger(name=f"{dataset_name}_loss_{loss_fn}_rpr_{replace_rate}_nh_{num_hidden}_nl_{num_layers}_lr_{lr}_mp_{max_epoch}_mpf_{max_epoch_f}_wd_{weight_decay}_wdf_{weight_decay_f}_{encoder_type}_{decoder_type}")
        else:
            logger = None

        model = build_model(args)
        model.to(device)
        optimizer = create_optimizer(optim_type, model, lr, weight_decay)

        if use_scheduler:
            logging.info("Use schedular")
            scheduler = lambda epoch :( 1 + np.cos((epoch) * np.pi / max_epoch) ) * 0.5
            # scheduler = lambda epoch: epoch / warmup_steps if epoch < warmup_steps \
                    # else ( 1 + np.cos((epoch - warmup_steps) * np.pi / (max_epoch - warmup_steps))) * 0.5
            scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler)
        else:
            scheduler = None

        if not load_model:
            model = pretrain(model, pooler, (train_loader, eval_loader), optimizer, max_epoch, device, scheduler, num_classes, lr_f, weight_decay_f, max_epoch_f, linear_prob,  logger)
            model = model.cpu()

        if load_model:
            logging.info("Loading Model ... ")
            model.load_state_dict(torch.load("checkpoint.pt"))
        if save_model:
            logging.info("Saveing Model ...")
            torch.save(model.state_dict(), "checkpoint.pt")

        model = model.to(device)
        model.eval()
        test_f1 = graph_classification_evaluation(model, pooler, eval_loader, num_classes, lr_f, weight_decay_f, max_epoch_f, device, mute=False)
        acc_list.append(test_f1)

    final_acc, final_acc_std = np.mean(acc_list), np.std(acc_list)
    print(f"# final_acc: {final_acc:.4f}±{final_acc_std:.4f}")

In [46]:
def build_args():
    parser = argparse.ArgumentParser(description="GAT")
    parser.add_argument("--seeds", type=int, nargs="+", default=[0])
    parser.add_argument("--dataset", type=str, default="cora")
    parser.add_argument("--device", type=int, default=-1)
    parser.add_argument("--max_epoch", type=int, default=200,
                        help="number of training epochs")
    parser.add_argument("--warmup_steps", type=int, default=-1)

    parser.add_argument("--num_heads", type=int, default=4,
                        help="number of hidden attention heads")
    parser.add_argument("--num_out_heads", type=int, default=1,
                        help="number of output attention heads")
    parser.add_argument("--num_layers", type=int, default=2,
                        help="number of hidden layers")
    parser.add_argument("--num_hidden", type=int, default=256,
                        help="number of hidden units")
    parser.add_argument("--residual", action="store_true", default=False,
                        help="use residual connection")
    parser.add_argument("--in_drop", type=float, default=.2,
                        help="input feature dropout")
    parser.add_argument("--attn_drop", type=float, default=.1,
                        help="attention dropout")
    parser.add_argument("--norm", type=str, default=None)
    parser.add_argument("--lr", type=float, default=0.005,
                        help="learning rate")
    parser.add_argument("--weight_decay", type=float, default=5e-4,
                        help="weight decay")
    parser.add_argument("--negative_slope", type=float, default=0.2,
                        help="the negative slope of leaky relu for GAT")
    parser.add_argument("--activation", type=str, default="prelu")
    parser.add_argument("--mask_rate", type=float, default=0.5)
    parser.add_argument("--drop_edge_rate", type=float, default=0.0)
    parser.add_argument("--replace_rate", type=float, default=0.0)

    parser.add_argument("--encoder", type=str, default="gat")
    parser.add_argument("--decoder", type=str, default="gat")
    parser.add_argument("--loss_fn", type=str, default="sce")
    parser.add_argument("--alpha_l", type=float, default=2, help="`pow`coefficient for `sce` loss")
    parser.add_argument("--optimizer", type=str, default="adam")

    parser.add_argument("--max_epoch_f", type=int, default=30)
    parser.add_argument("--lr_f", type=float, default=0.001, help="learning rate for evaluation")
    parser.add_argument("--weight_decay_f", type=float, default=0.0, help="weight decay for evaluation")
    parser.add_argument("--linear_prob", action="store_true", default=False)

    parser.add_argument("--load_model", action="store_true")
    parser.add_argument("--save_model", action="store_true")
    parser.add_argument("--use_cfg", action="store_true")
    parser.add_argument("--logging", action="store_true")
    parser.add_argument("--scheduler", action="store_true", default=False)
    parser.add_argument("--concat_hidden", action="store_true", default=False)

    # for graph classification
    parser.add_argument("--pooling", type=str, default="mean")
    parser.add_argument("--deg4feat", action="store_true", default=False, help="use node degree as input feature")
    parser.add_argument("--batch_size", type=int, default=32)

    args = parser.parse_args()
    return args

In [47]:
if __name__ == "__main__":
    # Simulate command-line arguments using argparse.Namespace
    args = argparse.Namespace(
        dataset='ENZYMES',
        encoder='gin',
        decoder='gin',
        seeds=[0],
        device=0,
        use_cfg=False,
        max_epoch=200, # Default value from build_args
        warmup_steps=-1, # Default value from build_args
        num_heads=4, # Default value from build_args
        num_out_heads=1, # Default value from build_args
        num_layers=2, # Default value from build_args
        num_hidden=512, # Default value from build_args
        residual=False, # Default value from build_args
        in_drop=0.2, # Default value from build_args
        attn_drop=0.1, # Default value from build_args
        norm=None, # Default value from build_args
        lr=0.0005, # Default value from build_args
        weight_decay=5e-4, # Default value from build_args
        negative_slope=0.2, # Default value from build_args
        activation="prelu", # Default value from build_args
        mask_rate=0.5, # Default value from build_args
        drop_edge_rate=0.0, # Default value from build_args
        replace_rate=0.0, # Default value from build_args
        loss_fn="sce", # Default value from build_args
        alpha_l=2, # Default value from build_args
        optimizer="adam", # Default value from build_args
        max_epoch_f=50, # Default value from build_args
        lr_f=0.001, # Default value from build_args
        weight_decay_f=0.0, # Default value from build_args
        linear_prob=False, # Default value from build_args
        load_model=False, # Default value from build_args
        save_model=True, # Default value from build_args
        logging=False, # Default value from build_args
        scheduler=False, # Default value from build_args
        concat_hidden=False, # Default value from build_args
        pooling="mean", # Default value from build_args
        deg4feat=False, # Default value from build_args
        batch_size=32, # Default value from build_args
    )

    if args.use_cfg:
         # This part will still load configs.yml based on args.dataset
        args = load_best_configs(args, "configs.yml")
    print(args)
    main(args)

Namespace(dataset='ENZYMES', encoder='gin', decoder='gin', seeds=[0], device=0, use_cfg=False, max_epoch=200, warmup_steps=-1, num_heads=4, num_out_heads=1, num_layers=2, num_hidden=512, residual=False, in_drop=0.2, attn_drop=0.1, norm=None, lr=0.0005, weight_decay=0.0005, negative_slope=0.2, activation='prelu', mask_rate=0.5, drop_edge_rate=0.0, replace_rate=0.0, loss_fn='sce', alpha_l=2, optimizer='adam', max_epoch_f=50, lr_f=0.001, weight_decay_f=0.0, linear_prob=False, load_model=False, save_model=True, logging=False, scheduler=False, concat_hidden=False, pooling='mean', deg4feat=False, batch_size=32)
Downloading /root/.dgl/ENZYMES.zip from https://www.chrsmrrs.com/graphkerneldatasets/ENZYMES.zip...


/root/.dgl/ENZYMES.zip:   0%|          | 0.00/537k [00:00<?, ?B/s]

Extracting file to /root/.dgl/ENZYMES_67bfdeff
Use node label as node features
******** # Num Graphs: 600, # Num Feat: 3, # Num Classes: 6 ********
####### Run 0 for seed 0


Epoch 199 | train_loss: 0.0850: 100%|██████████| 200/200 [00:35<00:00,  5.60it/s]


#Test_f1: 0.2200±0.0267
# final_acc: 0.2200±0.0000


In [48]:
def few_shot_finetune_and_evaluate(model, pooler, support_graphs, support_labels, query_graphs, query_labels, num_classes, lr_f, weight_decay_f, max_epoch_f, device, num_hidden):
    # Create a copy of the pre-trained model
    finetune_model = build_model(args) # Assuming args is accessible, or pass necessary args
    finetune_model.load_state_dict(model.state_dict())
    finetune_model.to(device)

    # Add a classification head for the current task's number of ways
    unique_support_labels = torch.unique(support_labels)
    num_ways = len(unique_support_labels)
    # Remove the incorrect call to reset_classifier on the encoder
    # Add a new linear classification layer to the finetune_model
    finetune_model.classifier_head = nn.Linear(num_hidden, num_ways).to(device)


    # Prepare data for fine-tuning
    support_batch = dgl.batch(support_graphs).to(device)
    support_feat = support_batch.ndata["attr"]
    support_labels = support_labels.to(device)

    query_batch = dgl.batch(query_graphs).to(device)
    query_feat = query_batch.ndata["attr"]
    query_labels = query_labels.to(device)

    # Define optimizer for fine-tuning (only on the classification head and maybe last layers)
    # Here we fine-tune the whole model and the new classification head
    optimizer = torch.optim.Adam(finetune_model.parameters(), lr=lr_f, weight_decay=weight_decay_f)
    criterion = nn.CrossEntropyLoss()

    finetune_model.train()
    for epoch in range(max_epoch_f):
        # Forward pass
        # Pass through the encoder to get node embeddings
        support_node_embeddings = finetune_model.encoder(support_batch, support_feat)
        # Apply pooling to get graph embeddings
        support_graph_embeddings = pooler(support_batch, support_node_embeddings)
        # Pass graph embeddings through the new classification head
        outputs = finetune_model.classifier_head(support_graph_embeddings)
        loss = criterion(outputs, support_labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Evaluate on the query set
    finetune_model.eval()
    with torch.no_grad():
        query_node_embeddings = finetune_model.encoder(query_batch, query_feat)
        query_graph_embeddings = pooler(query_batch, query_node_embeddings)
        query_outputs = finetune_model.classifier_head(query_graph_embeddings)
        _, preds = torch.max(query_outputs, 1)
        f1 = f1_score(query_labels.cpu().numpy(), preds.cpu().numpy(), average="micro")

    return f1

In [49]:
if __name__ == "__main__":
    # Simulate command-line arguments using argparse.Namespace
    args = argparse.Namespace(
        dataset='ENZYMES',
        encoder='gin',
        decoder='gin',
        seeds=[0],
        device=0, # Use CPU if GPU is not available
        use_cfg=False,
        max_epoch=200, # Default value from build_args
        warmup_steps=-1, # Default value from build_args
        num_heads=4, # Default value from build_args
        num_out_heads=1, # Default value from build_args
        num_layers=2, # Default value from build_args
        num_hidden=512, # Default value from build_args
        residual=False, # Default value from build_args
        in_drop=0.2, # Default value from build_args
        attn_drop=0.1, # Default value from build_args
        norm=None, # Default value from build_args
        lr=0.0005, # Default value from build_args
        weight_decay=5e-4, # Default value from build_args
        negative_slope=0.2, # Default value from build_args
        activation="prelu", # Default value from build_args
        mask_rate=0.5, # Default value from build_args
        drop_edge_rate=0.0, # Default value from build_args
        replace_rate=0.0, # Default value from build_args
        loss_fn="sce", # Default value from build_fn_args
        alpha_l=2, # Default value from build_args
        optimizer="adam", # Default value from build_args
        max_epoch_f=30, # Default value from build_args
        lr_f=0.001, # Default value from build_args
        weight_decay_f=0.0, # Default value from build_args
        linear_prob=False, # Default value from build_args
        load_model=True, # Load the pre-trained model
        save_model=False, # Set to False to avoid saving during few-shot eval
        logging=False, # Default value from build_args
        scheduler=False, # Default value from build_args
        concat_hidden=False, # Default value from build_args
        pooling="mean", # Default value from build_args
        deg4feat=False, # Default value from build_args
        batch_size=32, # Default value from build_args
        num_shots=5, # Few-shot parameter: number of support samples per class
        num_ways=5, # Few-shot parameter: number of classes per task
        num_queries=15, # Few-shot parameter: number of query samples per class
        num_tasks=100 # Few-shot parameter: number of few-shot tasks to evaluate on
    )

    if args.use_cfg:
         # This part will still load configs.yml based on args.dataset
        args = load_best_configs(args, "configs.yml")
    print(args)

    # Load the dataset
    graphs, (num_features, num_classes) = load_graph_classification_dataset(args.dataset, deg4feat=args.deg4feat)
    args.num_features = num_features
    args.num_classes = num_classes # Add num_classes to args

    # Build the model
    model = build_model(args)
    model.to(args.device)

    # Load the pre-trained model weights if load_model is True
    if args.load_model:
        try:
            model.load_state_dict(torch.load("checkpoint.pt"))
            print("Loaded pre-trained model weights from checkpoint.pt")
        except FileNotFoundError:
            print("checkpoint.pt not found. Please make sure the pre-trained model is saved.")
            # Optionally, you can exit or train the model from scratch here
            # return

    # Initialize the pooler
    if args.pooling == "mean":
        pooler = AvgPooling()
    elif args.pooling == "max":
        pooler = MaxPooling()
    elif args.pooling == "sum":
        pooler = SumPooling()
    else:
        raise NotImplementedError

    # Run few-shot evaluation with fine-tuning
    acc_list = []
    for task in tqdm(range(args.num_tasks), desc="Few-shot tasks"):
        # Randomly sample classes
        selected_classes = np.random.choice(num_classes, args.num_ways, replace=False)
        task_graphs_with_labels = [(graph, label) for graph, label in graphs if label.item() in selected_classes]
        task_labels_tensor = torch.tensor([label.item() for _, label in task_graphs_with_labels])
        task_graphs = [graph for graph, _ in task_graphs_with_labels]

        # Create mapping for selected classes to 0-num_ways-1
        class_map = {cls.item(): i for i, cls in enumerate(selected_classes)}
        mapped_labels = torch.tensor([class_map[label.item()] for _, label in task_graphs_with_labels])

        # Split into support and query sets
        support_indices = []
        query_indices = []
        skip_task = False
        for cls in selected_classes:
            # Get indices for the current class within the task_graphs_with_labels list
            cls_task_indices = [i for i, (_, label) in enumerate(task_graphs_with_labels) if label.item() == cls.item()]

            if len(cls_task_indices) < args.num_shots + args.num_queries:
                 # Skip task if not enough samples for a class
                print(f"Skipping task {task}: Not enough samples ({len(cls_task_indices)}) for class {cls}. Need {args.num_shots + args.num_queries}.")
                skip_task = True
                break

            # Shuffle indices for the current class and split
            shuffled_cls_indices = np.random.permutation(cls_task_indices)
            support_indices.extend(shuffled_cls_indices[:args.num_shots])
            query_indices.extend(shuffled_cls_indices[args.num_shots:args.num_shots + args.num_queries])

        if skip_task:
            continue

        if not support_indices or not query_indices:
             # Skip task if no support or query samples could be gathered (should be caught by the previous check, but as a safeguard)
            print(f"Skipping task {task}: No support or query samples gathered after splitting")
            continue

        # Ensure indices are within the bounds of task_graphs
        support_graphs = [task_graphs[i] for i in support_indices]
        support_labels = mapped_labels[support_indices]
        query_graphs = [task_graphs[i] for i in query_indices]
        query_labels = mapped_labels[query_indices]

        # Perform fine-tuning and evaluation
        f1 = few_shot_finetune_and_evaluate(
            model,
            pooler,
            support_graphs,
            support_labels,
            query_graphs,
            query_labels,
            num_classes,
            args.lr_f,
            args.weight_decay_f,
            args.max_epoch_f,
            args.device,
            args.num_hidden,
        )
        acc_list.append(f1)

    final_acc = np.mean(acc_list)
    final_acc_std = np.std(acc_list)
    print(f"# Few-shot Acc ({args.num_ways}-way {args.num_shots}-shot): {final_acc:.4f}±{final_acc_std:.4f}")

Namespace(dataset='ENZYMES', encoder='gin', decoder='gin', seeds=[0], device=0, use_cfg=False, max_epoch=200, warmup_steps=-1, num_heads=4, num_out_heads=1, num_layers=2, num_hidden=512, residual=False, in_drop=0.2, attn_drop=0.1, norm=None, lr=0.0005, weight_decay=0.0005, negative_slope=0.2, activation='prelu', mask_rate=0.5, drop_edge_rate=0.0, replace_rate=0.0, loss_fn='sce', alpha_l=2, optimizer='adam', max_epoch_f=30, lr_f=0.001, weight_decay_f=0.0, linear_prob=False, load_model=True, save_model=False, logging=False, scheduler=False, concat_hidden=False, pooling='mean', deg4feat=False, batch_size=32, num_shots=5, num_ways=5, num_queries=15, num_tasks=100)
Use node label as node features
******** # Num Graphs: 600, # Num Feat: 3, # Num Classes: 6 ********
Loaded pre-trained model weights from checkpoint.pt


Few-shot tasks: 100%|██████████| 100/100 [00:15<00:00,  6.29it/s]

# Few-shot Acc (5-way 5-shot): 0.2471±0.0574





In [50]:
if __name__ == "__main__":
    # Simulate command-line arguments using argparse.Namespace
    args = argparse.Namespace(
        dataset='ENZYMES',
        encoder='gin',
        decoder='gin',
        seeds=[0],
        device=0, # Use CPU if GPU is not available
        use_cfg=False,
        max_epoch=200, # Default value from build_args
        warmup_steps=-1, # Default value from build_args
        num_heads=4, # Default value from build_args
        num_out_heads=1, # Default value from build_args
        num_layers=2, # Default value from build_args
        num_hidden=512, # Default value from build_args
        residual=False, # Default value from build_args
        in_drop=0.2, # Default value from build_args
        attn_drop=0.1, # Default value from build_args
        norm=None, # Default value from build_args
        lr=0.0005, # Default value from build_args
        weight_decay=5e-4, # Default value from build_args
        negative_slope=0.2, # Default value from build_args
        activation="prelu", # Default value from build_args
        mask_rate=0.5, # Default value from build_args
        drop_edge_rate=0.0, # Default value from build_args
        replace_rate=0.0, # Default value from build_args
        loss_fn="sce", # Default value from build_fn_args
        alpha_l=2, # Default value from build_args
        optimizer="adam", # Default value from build_args
        max_epoch_f=30, # Default value from build_args
        lr_f=0.001, # Default value from build_args
        weight_decay_f=0.0, # Default value from build_args
        linear_prob=False, # Default value from build_args
        load_model=True, # Load the pre-trained model
        save_model=False, # Set to False to avoid saving during few-shot eval
        logging=False, # Default value from build_args
        scheduler=False, # Default value from build_args
        concat_hidden=False, # Default value from build_args
        pooling="mean", # Default value from build_args
        deg4feat=False, # Default value from build_args
        batch_size=32, # Default value from build_args
        num_shots=5, # Few-shot parameter: number of support samples per class
        num_ways=4, # Few-shot parameter: number of classes per task
        num_queries=15, # Few-shot parameter: number of query samples per class
        num_tasks=100 # Few-shot parameter: number of few-shot tasks to evaluate on
    )

    if args.use_cfg:
         # This part will still load configs.yml based on args.dataset
        args = load_best_configs(args, "configs.yml")
    print(args)

    # Load the dataset
    graphs, (num_features, num_classes) = load_graph_classification_dataset(args.dataset, deg4feat=args.deg4feat)
    args.num_features = num_features
    args.num_classes = num_classes # Add num_classes to args

    # Build the model
    model = build_model(args)
    model.to(args.device)

    # Load the pre-trained model weights if load_model is True
    if args.load_model:
        try:
            model.load_state_dict(torch.load("checkpoint.pt"))
            print("Loaded pre-trained model weights from checkpoint.pt")
        except FileNotFoundError:
            print("checkpoint.pt not found. Please make sure the pre-trained model is saved.")
            # Optionally, you can exit or train the model from scratch here
            # return

    # Initialize the pooler
    if args.pooling == "mean":
        pooler = AvgPooling()
    elif args.pooling == "max":
        pooler = MaxPooling()
    elif args.pooling == "sum":
        pooler = SumPooling()
    else:
        raise NotImplementedError

    # Run few-shot evaluation with fine-tuning
    acc_list = []
    for task in tqdm(range(args.num_tasks), desc="Few-shot tasks"):
        # Randomly sample classes
        selected_classes = np.random.choice(num_classes, args.num_ways, replace=False)
        task_graphs_with_labels = [(graph, label) for graph, label in graphs if label.item() in selected_classes]
        task_labels_tensor = torch.tensor([label.item() for _, label in task_graphs_with_labels])
        task_graphs = [graph for graph, _ in task_graphs_with_labels]

        # Create mapping for selected classes to 0-num_ways-1
        class_map = {cls.item(): i for i, cls in enumerate(selected_classes)}
        mapped_labels = torch.tensor([class_map[label.item()] for _, label in task_graphs_with_labels])

        # Split into support and query sets
        support_indices = []
        query_indices = []
        skip_task = False
        for cls in selected_classes:
            # Get indices for the current class within the task_graphs_with_labels list
            cls_task_indices = [i for i, (_, label) in enumerate(task_graphs_with_labels) if label.item() == cls.item()]

            if len(cls_task_indices) < args.num_shots + args.num_queries:
                 # Skip task if not enough samples for a class
                print(f"Skipping task {task}: Not enough samples ({len(cls_task_indices)}) for class {cls}. Need {args.num_shots + args.num_queries}.")
                skip_task = True
                break

            # Shuffle indices for the current class and split
            shuffled_cls_indices = np.random.permutation(cls_task_indices)
            support_indices.extend(shuffled_cls_indices[:args.num_shots])
            query_indices.extend(shuffled_cls_indices[args.num_shots:args.num_shots + args.num_queries])

        if skip_task:
            continue

        if not support_indices or not query_indices:
             # Skip task if no support or query samples could be gathered (should be caught by the previous check, but as a safeguard)
            print(f"Skipping task {task}: No support or query samples gathered after splitting")
            continue

        # Ensure indices are within the bounds of task_graphs
        support_graphs = [task_graphs[i] for i in support_indices]
        support_labels = mapped_labels[support_indices]
        query_graphs = [task_graphs[i] for i in query_indices]
        query_labels = mapped_labels[query_indices]

        # Perform fine-tuning and evaluation
        f1 = few_shot_finetune_and_evaluate(
            model,
            pooler,
            support_graphs,
            support_labels,
            query_graphs,
            query_labels,
            num_classes,
            args.lr_f,
            args.weight_decay_f,
            args.max_epoch_f,
            args.device,
            args.num_hidden,
        )
        acc_list.append(f1)

    final_acc = np.mean(acc_list)
    final_acc_std = np.std(acc_list)
    print(f"# Few-shot Acc ({args.num_ways}-way {args.num_shots}-shot): {final_acc:.4f}±{final_acc_std:.4f}")

Namespace(dataset='ENZYMES', encoder='gin', decoder='gin', seeds=[0], device=0, use_cfg=False, max_epoch=200, warmup_steps=-1, num_heads=4, num_out_heads=1, num_layers=2, num_hidden=512, residual=False, in_drop=0.2, attn_drop=0.1, norm=None, lr=0.0005, weight_decay=0.0005, negative_slope=0.2, activation='prelu', mask_rate=0.5, drop_edge_rate=0.0, replace_rate=0.0, loss_fn='sce', alpha_l=2, optimizer='adam', max_epoch_f=30, lr_f=0.001, weight_decay_f=0.0, linear_prob=False, load_model=True, save_model=False, logging=False, scheduler=False, concat_hidden=False, pooling='mean', deg4feat=False, batch_size=32, num_shots=5, num_ways=4, num_queries=15, num_tasks=100)
Use node label as node features
******** # Num Graphs: 600, # Num Feat: 3, # Num Classes: 6 ********
Loaded pre-trained model weights from checkpoint.pt


Few-shot tasks: 100%|██████████| 100/100 [00:15<00:00,  6.38it/s]

# Few-shot Acc (4-way 5-shot): 0.2978±0.0588





In [51]:
if __name__ == "__main__":
    # Simulate command-line arguments using argparse.Namespace
    args = argparse.Namespace(
        dataset='ENZYMES',
        encoder='gin',
        decoder='gin',
        seeds=[0],
        device=0, # Use CPU if GPU is not available
        use_cfg=False,
        max_epoch=200, # Default value from build_args
        warmup_steps=-1, # Default value from build_args
        num_heads=4, # Default value from build_args
        num_out_heads=1, # Default value from build_args
        num_layers=2, # Default value from build_args
        num_hidden=512, # Default value from build_args
        residual=False, # Default value from build_args
        in_drop=0.2, # Default value from build_args
        attn_drop=0.1, # Default value from build_args
        norm=None, # Default value from build_args
        lr=0.0005, # Default value from build_args
        weight_decay=5e-4, # Default value from build_args
        negative_slope=0.2, # Default value from build_args
        activation="prelu", # Default value from build_args
        mask_rate=0.5, # Default value from build_args
        drop_edge_rate=0.0, # Default value from build_args
        replace_rate=0.0, # Default value from build_args
        loss_fn="sce", # Default value from build_fn_args
        alpha_l=2, # Default value from build_args
        optimizer="adam", # Default value from build_args
        max_epoch_f=30, # Default value from build_args
        lr_f=0.001, # Default value from build_args
        weight_decay_f=0.0, # Default value from build_args
        linear_prob=False, # Default value from build_args
        load_model=True, # Load the pre-trained model
        save_model=False, # Set to False to avoid saving during few-shot eval
        logging=False, # Default value from build_args
        scheduler=False, # Default value from build_args
        concat_hidden=False, # Default value from build_args
        pooling="mean", # Default value from build_args
        deg4feat=False, # Default value from build_args
        batch_size=32, # Default value from build_args
        num_shots=10, # Few-shot parameter: number of support samples per class
        num_ways=5, # Few-shot parameter: number of classes per task
        num_queries=15, # Few-shot parameter: number of query samples per class
        num_tasks=100 # Few-shot parameter: number of few-shot tasks to evaluate on
    )

    if args.use_cfg:
         # This part will still load configs.yml based on args.dataset
        args = load_best_configs(args, "configs.yml")
    print(args)

    # Load the dataset
    graphs, (num_features, num_classes) = load_graph_classification_dataset(args.dataset, deg4feat=args.deg4feat)
    args.num_features = num_features
    args.num_classes = num_classes # Add num_classes to args

    # Build the model
    model = build_model(args)
    model.to(args.device)

    # Load the pre-trained model weights if load_model is True
    if args.load_model:
        try:
            model.load_state_dict(torch.load("checkpoint.pt"))
            print("Loaded pre-trained model weights from checkpoint.pt")
        except FileNotFoundError:
            print("checkpoint.pt not found. Please make sure the pre-trained model is saved.")
            # Optionally, you can exit or train the model from scratch here
            # return

    # Initialize the pooler
    if args.pooling == "mean":
        pooler = AvgPooling()
    elif args.pooling == "max":
        pooler = MaxPooling()
    elif args.pooling == "sum":
        pooler = SumPooling()
    else:
        raise NotImplementedError

    # Run few-shot evaluation with fine-tuning
    acc_list = []
    for task in tqdm(range(args.num_tasks), desc="Few-shot tasks"):
        # Randomly sample classes
        selected_classes = np.random.choice(num_classes, args.num_ways, replace=False)
        task_graphs_with_labels = [(graph, label) for graph, label in graphs if label.item() in selected_classes]
        task_labels_tensor = torch.tensor([label.item() for _, label in task_graphs_with_labels])
        task_graphs = [graph for graph, _ in task_graphs_with_labels]

        # Create mapping for selected classes to 0-num_ways-1
        class_map = {cls.item(): i for i, cls in enumerate(selected_classes)}
        mapped_labels = torch.tensor([class_map[label.item()] for _, label in task_graphs_with_labels])

        # Split into support and query sets
        support_indices = []
        query_indices = []
        skip_task = False
        for cls in selected_classes:
            # Get indices for the current class within the task_graphs_with_labels list
            cls_task_indices = [i for i, (_, label) in enumerate(task_graphs_with_labels) if label.item() == cls.item()]

            if len(cls_task_indices) < args.num_shots + args.num_queries:
                 # Skip task if not enough samples for a class
                print(f"Skipping task {task}: Not enough samples ({len(cls_task_indices)}) for class {cls}. Need {args.num_shots + args.num_queries}.")
                skip_task = True
                break

            # Shuffle indices for the current class and split
            shuffled_cls_indices = np.random.permutation(cls_task_indices)
            support_indices.extend(shuffled_cls_indices[:args.num_shots])
            query_indices.extend(shuffled_cls_indices[args.num_shots:args.num_shots + args.num_queries])

        if skip_task:
            continue

        if not support_indices or not query_indices:
             # Skip task if no support or query samples could be gathered (should be caught by the previous check, but as a safeguard)
            print(f"Skipping task {task}: No support or query samples gathered after splitting")
            continue

        # Ensure indices are within the bounds of task_graphs
        support_graphs = [task_graphs[i] for i in support_indices]
        support_labels = mapped_labels[support_indices]
        query_graphs = [task_graphs[i] for i in query_indices]
        query_labels = mapped_labels[query_indices]

        # Perform fine-tuning and evaluation
        f1 = few_shot_finetune_and_evaluate(
            model,
            pooler,
            support_graphs,
            support_labels,
            query_graphs,
            query_labels,
            num_classes,
            args.lr_f,
            args.weight_decay_f,
            args.max_epoch_f,
            args.device,
            args.num_hidden,
        )
        acc_list.append(f1)

    final_acc = np.mean(acc_list)
    final_acc_std = np.std(acc_list)
    print(f"# Few-shot Acc ({args.num_ways}-way {args.num_shots}-shot): {final_acc:.4f}±{final_acc_std:.4f}")

Namespace(dataset='ENZYMES', encoder='gin', decoder='gin', seeds=[0], device=0, use_cfg=False, max_epoch=200, warmup_steps=-1, num_heads=4, num_out_heads=1, num_layers=2, num_hidden=512, residual=False, in_drop=0.2, attn_drop=0.1, norm=None, lr=0.0005, weight_decay=0.0005, negative_slope=0.2, activation='prelu', mask_rate=0.5, drop_edge_rate=0.0, replace_rate=0.0, loss_fn='sce', alpha_l=2, optimizer='adam', max_epoch_f=30, lr_f=0.001, weight_decay_f=0.0, linear_prob=False, load_model=True, save_model=False, logging=False, scheduler=False, concat_hidden=False, pooling='mean', deg4feat=False, batch_size=32, num_shots=10, num_ways=5, num_queries=15, num_tasks=100)
Use node label as node features
******** # Num Graphs: 600, # Num Feat: 3, # Num Classes: 6 ********
Loaded pre-trained model weights from checkpoint.pt


Few-shot tasks: 100%|██████████| 100/100 [00:16<00:00,  6.02it/s]

# Few-shot Acc (5-way 10-shot): 0.2675±0.0557





In [52]:
if __name__ == "__main__":
    # Simulate command-line arguments using argparse.Namespace
    args = argparse.Namespace(
        dataset='ENZYMES',
        encoder='gin',
        decoder='gin',
        seeds=[0],
        device=0, # Use CPU if GPU is not available
        use_cfg=False,
        max_epoch=200, # Default value from build_args
        warmup_steps=-1, # Default value from build_args
        num_heads=4, # Default value from build_args
        num_out_heads=1, # Default value from build_args
        num_layers=2, # Default value from build_args
        num_hidden=512, # Default value from build_args
        residual=False, # Default value from build_args
        in_drop=0.2, # Default value from build_args
        attn_drop=0.1, # Default value from build_args
        norm=None, # Default value from build_args
        lr=0.0005, # Default value from build_args
        weight_decay=2e-4, # Default value from build_args
        negative_slope=0.2, # Default value from build_args
        activation="prelu", # Default value from build_args
        mask_rate=0.5, # Default value from build_args
        drop_edge_rate=0.0, # Default value from build_args
        replace_rate=0.0, # Default value from build_args
        loss_fn="sce", # Default value from build_fn_args
        alpha_l=2, # Default value from build_args
        optimizer="adam", # Default value from build_args
        max_epoch_f=30, # Default value from build_args
        lr_f=0.001, # Default value from build_args
        weight_decay_f=0.0, # Default value from build_args
        linear_prob=False, # Default value from build_args
        load_model=True, # Load the pre-trained model
        save_model=False, # Set to False to avoid saving during few-shot eval
        logging=False, # Default value from build_args
        scheduler=False, # Default value from build_args
        concat_hidden=False, # Default value from build_args
        pooling="mean", # Default value from build_args
        deg4feat=False, # Default value from build_args
        batch_size=32, # Default value from build_args
        num_shots=10, # Few-shot parameter: number of support samples per class
        num_ways=5, # Few-shot parameter: number of classes per task
        num_queries=15, # Few-shot parameter: number of query samples per class
        num_tasks=100 # Few-shot parameter: number of few-shot tasks to evaluate on
    )

    if args.use_cfg:
         # This part will still load configs.yml based on args.dataset
        args = load_best_configs(args, "configs.yml")
    print(args)

    # Load the dataset
    graphs, (num_features, num_classes) = load_graph_classification_dataset(args.dataset, deg4feat=args.deg4feat)
    args.num_features = num_features
    args.num_classes = num_classes # Add num_classes to args

    # Build the model
    model = build_model(args)
    model.to(args.device)

    # Load the pre-trained model weights if load_model is True
    if args.load_model:
        try:
            model.load_state_dict(torch.load("checkpoint.pt"))
            print("Loaded pre-trained model weights from checkpoint.pt")
        except FileNotFoundError:
            print("checkpoint.pt not found. Please make sure the pre-trained model is saved.")
            # Optionally, you can exit or train the model from scratch here
            # return

    # Initialize the pooler
    if args.pooling == "mean":
        pooler = AvgPooling()
    elif args.pooling == "max":
        pooler = MaxPooling()
    elif args.pooling == "sum":
        pooler = SumPooling()
    else:
        raise NotImplementedError

    # Run few-shot evaluation with fine-tuning
    acc_list = []
    for task in tqdm(range(args.num_tasks), desc="Few-shot tasks"):
        # Randomly sample classes
        selected_classes = np.random.choice(num_classes, args.num_ways, replace=False)
        task_graphs_with_labels = [(graph, label) for graph, label in graphs if label.item() in selected_classes]
        task_labels_tensor = torch.tensor([label.item() for _, label in task_graphs_with_labels])
        task_graphs = [graph for graph, _ in task_graphs_with_labels]

        # Create mapping for selected classes to 0-num_ways-1
        class_map = {cls.item(): i for i, cls in enumerate(selected_classes)}
        mapped_labels = torch.tensor([class_map[label.item()] for _, label in task_graphs_with_labels])

        # Split into support and query sets
        support_indices = []
        query_indices = []
        skip_task = False
        for cls in selected_classes:
            # Get indices for the current class within the task_graphs_with_labels list
            cls_task_indices = [i for i, (_, label) in enumerate(task_graphs_with_labels) if label.item() == cls.item()]

            if len(cls_task_indices) < args.num_shots + args.num_queries:
                 # Skip task if not enough samples for a class
                print(f"Skipping task {task}: Not enough samples ({len(cls_task_indices)}) for class {cls}. Need {args.num_shots + args.num_queries}.")
                skip_task = True
                break

            # Shuffle indices for the current class and split
            shuffled_cls_indices = np.random.permutation(cls_task_indices)
            support_indices.extend(shuffled_cls_indices[:args.num_shots])
            query_indices.extend(shuffled_cls_indices[args.num_shots:args.num_shots + args.num_queries])

        if skip_task:
            continue

        if not support_indices or not query_indices:
             # Skip task if no support or query samples could be gathered (should be caught by the previous check, but as a safeguard)
            print(f"Skipping task {task}: No support or query samples gathered after splitting")
            continue

        # Ensure indices are within the bounds of task_graphs
        support_graphs = [task_graphs[i] for i in support_indices]
        support_labels = mapped_labels[support_indices]
        query_graphs = [task_graphs[i] for i in query_indices]
        query_labels = mapped_labels[query_indices]

        # Perform fine-tuning and evaluation
        f1 = few_shot_finetune_and_evaluate(
            model,
            pooler,
            support_graphs,
            support_labels,
            query_graphs,
            query_labels,
            num_classes,
            args.lr_f,
            args.weight_decay_f,
            args.max_epoch_f,
            args.device,
            args.num_hidden,
        )
        acc_list.append(f1)

    final_acc = np.mean(acc_list)
    final_acc_std = np.std(acc_list)
    print(f"# Few-shot Acc ({args.num_ways}-way {args.num_shots}-shot): {final_acc:.4f}±{final_acc_std:.4f}")

Namespace(dataset='ENZYMES', encoder='gin', decoder='gin', seeds=[0], device=0, use_cfg=False, max_epoch=200, warmup_steps=-1, num_heads=4, num_out_heads=1, num_layers=2, num_hidden=512, residual=False, in_drop=0.2, attn_drop=0.1, norm=None, lr=0.0005, weight_decay=0.0002, negative_slope=0.2, activation='prelu', mask_rate=0.5, drop_edge_rate=0.0, replace_rate=0.0, loss_fn='sce', alpha_l=2, optimizer='adam', max_epoch_f=30, lr_f=0.001, weight_decay_f=0.0, linear_prob=False, load_model=True, save_model=False, logging=False, scheduler=False, concat_hidden=False, pooling='mean', deg4feat=False, batch_size=32, num_shots=10, num_ways=5, num_queries=15, num_tasks=100)
Use node label as node features
******** # Num Graphs: 600, # Num Feat: 3, # Num Classes: 6 ********
Loaded pre-trained model weights from checkpoint.pt


Few-shot tasks: 100%|██████████| 100/100 [00:16<00:00,  6.02it/s]

# Few-shot Acc (5-way 10-shot): 0.2728±0.0564





In [54]:
if __name__ == "__main__":
    # Simulate command-line arguments using argparse.Namespace
    args = argparse.Namespace(
        dataset='SYNTHETIC',
        encoder='gin',
        decoder='gin',
        seeds=[0],
        device=0,
        use_cfg=False,
        max_epoch=200, # Default value from build_args
        warmup_steps=-1, # Default value from build_args
        num_heads=4, # Default value from build_args
        num_out_heads=1, # Default value from build_args
        num_layers=2, # Default value from build_args
        num_hidden=512, # Default value from build_args
        residual=False, # Default value from build_args
        in_drop=0.2, # Default value from build_args
        attn_drop=0.1, # Default value from build_args
        norm=None, # Default value from build_args
        lr=0.0005, # Default value from build_args
        weight_decay=5e-4, # Default value from build_args
        negative_slope=0.2, # Default value from build_args
        activation="prelu", # Default value from build_args
        mask_rate=0.5, # Default value from build_args
        drop_edge_rate=0.0, # Default value from build_args
        replace_rate=0.0, # Default value from build_args
        loss_fn="sce", # Default value from build_args
        alpha_l=2, # Default value from build_args
        optimizer="adam", # Default value from build_args
        max_epoch_f=50, # Default value from build_args
        lr_f=0.001, # Default value from build_args
        weight_decay_f=0.0, # Default value from build_args
        linear_prob=False, # Default value from build_args
        load_model=False, # Default value from build_args
        save_model=True, # Default value from build_args
        logging=False, # Default value from build_args
        scheduler=False, # Default value from build_args
        concat_hidden=False, # Default value from build_args
        pooling="mean", # Default value from build_args
        deg4feat=False, # Default value from build_args
        batch_size=32, # Default value from build_args
    )

    if args.use_cfg:
         # This part will still load configs.yml based on args.dataset
        args = load_best_configs(args, "configs.yml")
    print(args)
    main(args)

Namespace(dataset='SYNTHETIC', encoder='gin', decoder='gin', seeds=[0], device=0, use_cfg=False, max_epoch=200, warmup_steps=-1, num_heads=4, num_out_heads=1, num_layers=2, num_hidden=512, residual=False, in_drop=0.2, attn_drop=0.1, norm=None, lr=0.0005, weight_decay=0.0005, negative_slope=0.2, activation='prelu', mask_rate=0.5, drop_edge_rate=0.0, replace_rate=0.0, loss_fn='sce', alpha_l=2, optimizer='adam', max_epoch_f=50, lr_f=0.001, weight_decay_f=0.0, linear_prob=False, load_model=False, save_model=True, logging=False, scheduler=False, concat_hidden=False, pooling='mean', deg4feat=False, batch_size=32)
Downloading /root/.dgl/SYNTHETIC.zip from https://www.chrsmrrs.com/graphkerneldatasets/SYNTHETIC.zip...


/root/.dgl/SYNTHETIC.zip:   0%|          | 0.00/437k [00:00<?, ?B/s]

Extracting file to /root/.dgl/SYNTHETIC_0be794dd
Use node label as node features
******** # Num Graphs: 300, # Num Feat: 8, # Num Classes: 2 ********
####### Run 0 for seed 0


Epoch 199 | train_loss: 0.1063: 100%|██████████| 200/200 [00:26<00:00,  7.49it/s]


#Test_f1: 0.5400±0.0327
# final_acc: 0.5400±0.0000


In [55]:
if __name__ == "__main__":
    # Simulate command-line arguments using argparse.Namespace
    args = argparse.Namespace(
        dataset='SYNTHETIC',
        encoder='gin',
        decoder='gin',
        seeds=[0],
        device=0, # Use CPU if GPU is not available
        use_cfg=False,
        max_epoch=200, # Default value from build_args
        warmup_steps=-1, # Default value from build_args
        num_heads=4, # Default value from build_args
        num_out_heads=1, # Default value from build_args
        num_layers=2, # Default value from build_args
        num_hidden=512, # Default value from build_args
        residual=False, # Default value from build_args
        in_drop=0.2, # Default value from build_args
        attn_drop=0.1, # Default value from build_args
        norm=None, # Default value from build_args
        lr=0.0005, # Default value from build_args
        weight_decay=5e-4, # Default value from build_args
        negative_slope=0.2, # Default value from build_args
        activation="prelu", # Default value from build_args
        mask_rate=0.5, # Default value from build_args
        drop_edge_rate=0.0, # Default value from build_args
        replace_rate=0.0, # Default value from build_args
        loss_fn="sce", # Default value from build_fn_args
        alpha_l=2, # Default value from build_args
        optimizer="adam", # Default value from build_args
        max_epoch_f=30, # Default value from build_args
        lr_f=0.001, # Default value from build_args
        weight_decay_f=0.0, # Default value from build_args
        linear_prob=False, # Default value from build_args
        load_model=True, # Load the pre-trained model
        save_model=False, # Set to False to avoid saving during few-shot eval
        logging=False, # Default value from build_args
        scheduler=False, # Default value from build_args
        concat_hidden=False, # Default value from build_args
        pooling="mean", # Default value from build_args
        deg4feat=False, # Default value from build_args
        batch_size=32, # Default value from build_args
        num_shots=5, # Few-shot parameter: number of support samples per class
        num_ways=2, # Few-shot parameter: number of classes per task
        num_queries=15, # Few-shot parameter: number of query samples per class
        num_tasks=100 # Few-shot parameter: number of few-shot tasks to evaluate on
    )

    if args.use_cfg:
         # This part will still load configs.yml based on args.dataset
        args = load_best_configs(args, "configs.yml")
    print(args)

    # Load the dataset
    graphs, (num_features, num_classes) = load_graph_classification_dataset(args.dataset, deg4feat=args.deg4feat)
    args.num_features = num_features
    args.num_classes = num_classes # Add num_classes to args

    # Build the model
    model = build_model(args)
    model.to(args.device)

    # Load the pre-trained model weights if load_model is True
    if args.load_model:
        try:
            model.load_state_dict(torch.load("checkpoint.pt"))
            print("Loaded pre-trained model weights from checkpoint.pt")
        except FileNotFoundError:
            print("checkpoint.pt not found. Please make sure the pre-trained model is saved.")
            # Optionally, you can exit or train the model from scratch here
            # return

    # Initialize the pooler
    if args.pooling == "mean":
        pooler = AvgPooling()
    elif args.pooling == "max":
        pooler = MaxPooling()
    elif args.pooling == "sum":
        pooler = SumPooling()
    else:
        raise NotImplementedError

    # Run few-shot evaluation with fine-tuning
    acc_list = []
    for task in tqdm(range(args.num_tasks), desc="Few-shot tasks"):
        # Randomly sample classes
        selected_classes = np.random.choice(num_classes, args.num_ways, replace=False)
        task_graphs_with_labels = [(graph, label) for graph, label in graphs if label.item() in selected_classes]
        task_labels_tensor = torch.tensor([label.item() for _, label in task_graphs_with_labels])
        task_graphs = [graph for graph, _ in task_graphs_with_labels]

        # Create mapping for selected classes to 0-num_ways-1
        class_map = {cls.item(): i for i, cls in enumerate(selected_classes)}
        mapped_labels = torch.tensor([class_map[label.item()] for _, label in task_graphs_with_labels])

        # Split into support and query sets
        support_indices = []
        query_indices = []
        skip_task = False
        for cls in selected_classes:
            # Get indices for the current class within the task_graphs_with_labels list
            cls_task_indices = [i for i, (_, label) in enumerate(task_graphs_with_labels) if label.item() == cls.item()]

            if len(cls_task_indices) < args.num_shots + args.num_queries:
                 # Skip task if not enough samples for a class
                print(f"Skipping task {task}: Not enough samples ({len(cls_task_indices)}) for class {cls}. Need {args.num_shots + args.num_queries}.")
                skip_task = True
                break

            # Shuffle indices for the current class and split
            shuffled_cls_indices = np.random.permutation(cls_task_indices)
            support_indices.extend(shuffled_cls_indices[:args.num_shots])
            query_indices.extend(shuffled_cls_indices[args.num_shots:args.num_shots + args.num_queries])

        if skip_task:
            continue

        if not support_indices or not query_indices:
             # Skip task if no support or query samples could be gathered (should be caught by the previous check, but as a safeguard)
            print(f"Skipping task {task}: No support or query samples gathered after splitting")
            continue

        # Ensure indices are within the bounds of task_graphs
        support_graphs = [task_graphs[i] for i in support_indices]
        support_labels = mapped_labels[support_indices]
        query_graphs = [task_graphs[i] for i in query_indices]
        query_labels = mapped_labels[query_indices]

        # Perform fine-tuning and evaluation
        f1 = few_shot_finetune_and_evaluate(
            model,
            pooler,
            support_graphs,
            support_labels,
            query_graphs,
            query_labels,
            num_classes,
            args.lr_f,
            args.weight_decay_f,
            args.max_epoch_f,
            args.device,
            args.num_hidden,
        )
        acc_list.append(f1)

    final_acc = np.mean(acc_list)
    final_acc_std = np.std(acc_list)
    print(f"# Few-shot Acc ({args.num_ways}-way {args.num_shots}-shot): {final_acc:.4f}±{final_acc_std:.4f}")

Namespace(dataset='SYNTHETIC', encoder='gin', decoder='gin', seeds=[0], device=0, use_cfg=False, max_epoch=200, warmup_steps=-1, num_heads=4, num_out_heads=1, num_layers=2, num_hidden=512, residual=False, in_drop=0.2, attn_drop=0.1, norm=None, lr=0.0005, weight_decay=0.0005, negative_slope=0.2, activation='prelu', mask_rate=0.5, drop_edge_rate=0.0, replace_rate=0.0, loss_fn='sce', alpha_l=2, optimizer='adam', max_epoch_f=30, lr_f=0.001, weight_decay_f=0.0, linear_prob=False, load_model=True, save_model=False, logging=False, scheduler=False, concat_hidden=False, pooling='mean', deg4feat=False, batch_size=32, num_shots=5, num_ways=2, num_queries=15, num_tasks=100)
Use node label as node features
******** # Num Graphs: 300, # Num Feat: 8, # Num Classes: 2 ********
Loaded pre-trained model weights from checkpoint.pt


Few-shot tasks: 100%|██████████| 100/100 [00:16<00:00,  5.89it/s]

# Few-shot Acc (2-way 5-shot): 0.5000±0.0000





In [56]:
if __name__ == "__main__":
    # Simulate command-line arguments using argparse.Namespace
    args = argparse.Namespace(
        dataset='SYNTHETIC',
        encoder='gin',
        decoder='gin',
        seeds=[0],
        device=0, # Use CPU if GPU is not available
        use_cfg=False,
        max_epoch=200, # Default value from build_args
        warmup_steps=-1, # Default value from build_args
        num_heads=4, # Default value from build_args
        num_out_heads=1, # Default value from build_args
        num_layers=2, # Default value from build_args
        num_hidden=512, # Default value from build_args
        residual=False, # Default value from build_args
        in_drop=0.2, # Default value from build_args
        attn_drop=0.1, # Default value from build_args
        norm=None, # Default value from build_args
        lr=0.0005, # Default value from build_args
        weight_decay=5e-4, # Default value from build_args
        negative_slope=0.2, # Default value from build_args
        activation="prelu", # Default value from build_args
        mask_rate=0.5, # Default value from build_args
        drop_edge_rate=0.0, # Default value from build_args
        replace_rate=0.0, # Default value from build_args
        loss_fn="sce", # Default value from build_fn_args
        alpha_l=2, # Default value from build_args
        optimizer="adam", # Default value from build_args
        max_epoch_f=30, # Default value from build_args
        lr_f=0.001, # Default value from build_args
        weight_decay_f=0.0, # Default value from build_args
        linear_prob=False, # Default value from build_args
        load_model=True, # Load the pre-trained model
        save_model=False, # Set to False to avoid saving during few-shot eval
        logging=False, # Default value from build_args
        scheduler=False, # Default value from build_args
        concat_hidden=False, # Default value from build_args
        pooling="mean", # Default value from build_args
        deg4feat=False, # Default value from build_args
        batch_size=32, # Default value from build_args
        num_shots=10, # Few-shot parameter: number of support samples per class
        num_ways=2, # Few-shot parameter: number of classes per task
        num_queries=15, # Few-shot parameter: number of query samples per class
        num_tasks=100 # Few-shot parameter: number of few-shot tasks to evaluate on
    )

    if args.use_cfg:
         # This part will still load configs.yml based on args.dataset
        args = load_best_configs(args, "configs.yml")
    print(args)

    # Load the dataset
    graphs, (num_features, num_classes) = load_graph_classification_dataset(args.dataset, deg4feat=args.deg4feat)
    args.num_features = num_features
    args.num_classes = num_classes # Add num_classes to args

    # Build the model
    model = build_model(args)
    model.to(args.device)

    # Load the pre-trained model weights if load_model is True
    if args.load_model:
        try:
            model.load_state_dict(torch.load("checkpoint.pt"))
            print("Loaded pre-trained model weights from checkpoint.pt")
        except FileNotFoundError:
            print("checkpoint.pt not found. Please make sure the pre-trained model is saved.")
            # Optionally, you can exit or train the model from scratch here
            # return

    # Initialize the pooler
    if args.pooling == "mean":
        pooler = AvgPooling()
    elif args.pooling == "max":
        pooler = MaxPooling()
    elif args.pooling == "sum":
        pooler = SumPooling()
    else:
        raise NotImplementedError

    # Run few-shot evaluation with fine-tuning
    acc_list = []
    for task in tqdm(range(args.num_tasks), desc="Few-shot tasks"):
        # Randomly sample classes
        selected_classes = np.random.choice(num_classes, args.num_ways, replace=False)
        task_graphs_with_labels = [(graph, label) for graph, label in graphs if label.item() in selected_classes]
        task_labels_tensor = torch.tensor([label.item() for _, label in task_graphs_with_labels])
        task_graphs = [graph for graph, _ in task_graphs_with_labels]

        # Create mapping for selected classes to 0-num_ways-1
        class_map = {cls.item(): i for i, cls in enumerate(selected_classes)}
        mapped_labels = torch.tensor([class_map[label.item()] for _, label in task_graphs_with_labels])

        # Split into support and query sets
        support_indices = []
        query_indices = []
        skip_task = False
        for cls in selected_classes:
            # Get indices for the current class within the task_graphs_with_labels list
            cls_task_indices = [i for i, (_, label) in enumerate(task_graphs_with_labels) if label.item() == cls.item()]

            if len(cls_task_indices) < args.num_shots + args.num_queries:
                 # Skip task if not enough samples for a class
                print(f"Skipping task {task}: Not enough samples ({len(cls_task_indices)}) for class {cls}. Need {args.num_shots + args.num_queries}.")
                skip_task = True
                break

            # Shuffle indices for the current class and split
            shuffled_cls_indices = np.random.permutation(cls_task_indices)
            support_indices.extend(shuffled_cls_indices[:args.num_shots])
            query_indices.extend(shuffled_cls_indices[args.num_shots:args.num_shots + args.num_queries])

        if skip_task:
            continue

        if not support_indices or not query_indices:
             # Skip task if no support or query samples could be gathered (should be caught by the previous check, but as a safeguard)
            print(f"Skipping task {task}: No support or query samples gathered after splitting")
            continue

        # Ensure indices are within the bounds of task_graphs
        support_graphs = [task_graphs[i] for i in support_indices]
        support_labels = mapped_labels[support_indices]
        query_graphs = [task_graphs[i] for i in query_indices]
        query_labels = mapped_labels[query_indices]

        # Perform fine-tuning and evaluation
        f1 = few_shot_finetune_and_evaluate(
            model,
            pooler,
            support_graphs,
            support_labels,
            query_graphs,
            query_labels,
            num_classes,
            args.lr_f,
            args.weight_decay_f,
            args.max_epoch_f,
            args.device,
            args.num_hidden,
        )
        acc_list.append(f1)

    final_acc = np.mean(acc_list)
    final_acc_std = np.std(acc_list)
    print(f"# Few-shot Acc ({args.num_ways}-way {args.num_shots}-shot): {final_acc:.4f}±{final_acc_std:.4f}")

Namespace(dataset='SYNTHETIC', encoder='gin', decoder='gin', seeds=[0], device=0, use_cfg=False, max_epoch=200, warmup_steps=-1, num_heads=4, num_out_heads=1, num_layers=2, num_hidden=512, residual=False, in_drop=0.2, attn_drop=0.1, norm=None, lr=0.0005, weight_decay=0.0005, negative_slope=0.2, activation='prelu', mask_rate=0.5, drop_edge_rate=0.0, replace_rate=0.0, loss_fn='sce', alpha_l=2, optimizer='adam', max_epoch_f=30, lr_f=0.001, weight_decay_f=0.0, linear_prob=False, load_model=True, save_model=False, logging=False, scheduler=False, concat_hidden=False, pooling='mean', deg4feat=False, batch_size=32, num_shots=10, num_ways=2, num_queries=15, num_tasks=100)
Use node label as node features
******** # Num Graphs: 300, # Num Feat: 8, # Num Classes: 2 ********
Loaded pre-trained model weights from checkpoint.pt


Few-shot tasks: 100%|██████████| 100/100 [00:16<00:00,  5.92it/s]

# Few-shot Acc (2-way 10-shot): 0.5000±0.0000





In [74]:
if __name__ == "__main__":
    # Simulate command-line arguments using argparse.Namespace
    args = argparse.Namespace(
        dataset='MSRC_9',
        encoder='gin',
        decoder='gin',
        seeds=[0],
        device=0,
        use_cfg=False,
        max_epoch=200, # Default value from build_args
        warmup_steps=-1, # Default value from build_args
        num_heads=4, # Default value from build_args
        num_out_heads=1, # Default value from build_args
        num_layers=2, # Default value from build_args
        num_hidden=512, # Default value from build_args
        residual=False, # Default value from build_args
        in_drop=0.2, # Default value from build_args
        attn_drop=0.1, # Default value from build_args
        norm=None, # Default value from build_args
        lr=0.0005, # Default value from build_args
        weight_decay=5e-4, # Default value from build_args
        negative_slope=0.2, # Default value from build_args
        activation="prelu", # Default value from build_args
        mask_rate=0.5, # Default value from build_args
        drop_edge_rate=0.0, # Default value from build_args
        replace_rate=0.0, # Default value from build_args
        loss_fn="sce", # Default value from build_args
        alpha_l=2, # Default value from build_args
        optimizer="adam", # Default value from build_args
        max_epoch_f=50, # Default value from build_args
        lr_f=0.001, # Default value from build_args
        weight_decay_f=0.0, # Default value from build_args
        linear_prob=False, # Default value from build_args
        load_model=False, # Default value from build_args
        save_model=True, # Default value from build_args
        logging=False, # Default value from build_args
        scheduler=False, # Default value from build_args
        concat_hidden=False, # Default value from build_args
        pooling="mean", # Default value from build_args
        deg4feat=False, # Default value from build_args
        batch_size=32, # Default value from build_args
    )

    if args.use_cfg:
         # This part will still load configs.yml based on args.dataset
        args = load_best_configs(args, "configs.yml")
    print(args)
    main(args)

Namespace(dataset='MSRC_9', encoder='gin', decoder='gin', seeds=[0], device=0, use_cfg=False, max_epoch=200, warmup_steps=-1, num_heads=4, num_out_heads=1, num_layers=2, num_hidden=512, residual=False, in_drop=0.2, attn_drop=0.1, norm=None, lr=0.0005, weight_decay=0.0005, negative_slope=0.2, activation='prelu', mask_rate=0.5, drop_edge_rate=0.0, replace_rate=0.0, loss_fn='sce', alpha_l=2, optimizer='adam', max_epoch_f=50, lr_f=0.001, weight_decay_f=0.0, linear_prob=False, load_model=False, save_model=True, logging=False, scheduler=False, concat_hidden=False, pooling='mean', deg4feat=False, batch_size=32)
Downloading /root/.dgl/MSRC_9.zip from https://www.chrsmrrs.com/graphkerneldatasets/MSRC_9.zip...


/root/.dgl/MSRC_9.zip:   0%|          | 0.00/98.5k [00:00<?, ?B/s]

Extracting file to /root/.dgl/MSRC_9_dd7c6f73
Use node label as node features
******** # Num Graphs: 221, # Num Feat: 10, # Num Classes: 8 ********
####### Run 0 for seed 0


Epoch 199 | train_loss: 0.1401: 100%|██████████| 200/200 [00:18<00:00, 10.85it/s]


#Test_f1: 0.9146±0.0636
# final_acc: 0.9146±0.0000


In [75]:
if __name__ == "__main__":
    # Simulate command-line arguments using argparse.Namespace
    args = argparse.Namespace(
        dataset='MSRC_9',
        encoder='gin',
        decoder='gin',
        seeds=[0],
        device=0, # Use CPU if GPU is not available
        use_cfg=False,
        max_epoch=200, # Default value from build_args
        warmup_steps=-1, # Default value from build_args
        num_heads=4, # Default value from build_args
        num_out_heads=1, # Default value from build_args
        num_layers=2, # Default value from build_args
        num_hidden=512, # Default value from build_args
        residual=False, # Default value from build_args
        in_drop=0.2, # Default value from build_args
        attn_drop=0.1, # Default value from build_args
        norm=None, # Default value from build_args
        lr=0.0005, # Default value from build_args
        weight_decay=5e-4, # Default value from build_args
        negative_slope=0.2, # Default value from build_args
        activation="prelu", # Default value from build_args
        mask_rate=0.5, # Default value from build_args
        drop_edge_rate=0.0, # Default value from build_args
        replace_rate=0.0, # Default value from build_args
        loss_fn="sce", # Default value from build_fn_args
        alpha_l=2, # Default value from build_args
        optimizer="adam", # Default value from build_args
        max_epoch_f=30, # Default value from build_args
        lr_f=0.001, # Default value from build_args
        weight_decay_f=0.0, # Default value from build_args
        linear_prob=False, # Default value from build_args
        load_model=True, # Load the pre-trained model
        save_model=False, # Set to False to avoid saving during few-shot eval
        logging=False, # Default value from build_args
        scheduler=False, # Default value from build_args
        concat_hidden=False, # Default value from build_args
        pooling="mean", # Default value from build_args
        deg4feat=False, # Default value from build_args
        batch_size=32, # Default value from build_args
        num_shots=5, # Few-shot parameter: number of support samples per class
        num_ways=3, # Few-shot parameter: number of classes per task
        num_queries=15, # Few-shot parameter: number of query samples per class
        num_tasks=100 # Few-shot parameter: number of few-shot tasks to evaluate on
    )

    if args.use_cfg:
         # This part will still load configs.yml based on args.dataset
        args = load_best_configs(args, "configs.yml")
    print(args)

    # Load the dataset
    graphs, (num_features, num_classes) = load_graph_classification_dataset(args.dataset, deg4feat=args.deg4feat)
    args.num_features = num_features
    args.num_classes = num_classes # Add num_classes to args

    # Build the model
    model = build_model(args)
    model.to(args.device)

    # Load the pre-trained model weights if load_model is True
    if args.load_model:
        try:
            model.load_state_dict(torch.load("checkpoint.pt"))
            print("Loaded pre-trained model weights from checkpoint.pt")
        except FileNotFoundError:
            print("checkpoint.pt not found. Please make sure the pre-trained model is saved.")
            # Optionally, you can exit or train the model from scratch here
            # return

    # Initialize the pooler
    if args.pooling == "mean":
        pooler = AvgPooling()
    elif args.pooling == "max":
        pooler = MaxPooling()
    elif args.pooling == "sum":
        pooler = SumPooling()
    else:
        raise NotImplementedError

    # Run few-shot evaluation with fine-tuning
    acc_list = []
    for task in tqdm(range(args.num_tasks), desc="Few-shot tasks"):
        # Randomly sample classes
        selected_classes = np.random.choice(num_classes, args.num_ways, replace=False)
        task_graphs_with_labels = [(graph, label) for graph, label in graphs if label.item() in selected_classes]
        task_labels_tensor = torch.tensor([label.item() for _, label in task_graphs_with_labels])
        task_graphs = [graph for graph, _ in task_graphs_with_labels]

        # Create mapping for selected classes to 0-num_ways-1
        class_map = {cls.item(): i for i, cls in enumerate(selected_classes)}
        mapped_labels = torch.tensor([class_map[label.item()] for _, label in task_graphs_with_labels])

        # Split into support and query sets
        support_indices = []
        query_indices = []
        skip_task = False
        for cls in selected_classes:
            # Get indices for the current class within the task_graphs_with_labels list
            cls_task_indices = [i for i, (_, label) in enumerate(task_graphs_with_labels) if label.item() == cls.item()]

            if len(cls_task_indices) < args.num_shots + args.num_queries:
                 # Skip task if not enough samples for a class
                print(f"Skipping task {task}: Not enough samples ({len(cls_task_indices)}) for class {cls}. Need {args.num_shots + args.num_queries}.")
                skip_task = True
                break

            # Shuffle indices for the current class and split
            shuffled_cls_indices = np.random.permutation(cls_task_indices)
            support_indices.extend(shuffled_cls_indices[:args.num_shots])
            query_indices.extend(shuffled_cls_indices[args.num_shots:args.num_shots + args.num_queries])

        if skip_task:
            continue

        if not support_indices or not query_indices:
             # Skip task if no support or query samples could be gathered (should be caught by the previous check, but as a safeguard)
            print(f"Skipping task {task}: No support or query samples gathered after splitting")
            continue

        # Ensure indices are within the bounds of task_graphs
        support_graphs = [task_graphs[i] for i in support_indices]
        support_labels = mapped_labels[support_indices]
        query_graphs = [task_graphs[i] for i in query_indices]
        query_labels = mapped_labels[query_indices]

        # Perform fine-tuning and evaluation
        f1 = few_shot_finetune_and_evaluate(
            model,
            pooler,
            support_graphs,
            support_labels,
            query_graphs,
            query_labels,
            num_classes,
            args.lr_f,
            args.weight_decay_f,
            args.max_epoch_f,
            args.device,
            args.num_hidden,
        )
        acc_list.append(f1)

    final_acc = np.mean(acc_list)
    final_acc_std = np.std(acc_list)
    print(f"# Few-shot Acc ({args.num_ways}-way {args.num_shots}-shot): {final_acc:.4f}±{final_acc_std:.4f}")

Namespace(dataset='MSRC_9', encoder='gin', decoder='gin', seeds=[0], device=0, use_cfg=False, max_epoch=200, warmup_steps=-1, num_heads=4, num_out_heads=1, num_layers=2, num_hidden=512, residual=False, in_drop=0.2, attn_drop=0.1, norm=None, lr=0.0005, weight_decay=0.0005, negative_slope=0.2, activation='prelu', mask_rate=0.5, drop_edge_rate=0.0, replace_rate=0.0, loss_fn='sce', alpha_l=2, optimizer='adam', max_epoch_f=30, lr_f=0.001, weight_decay_f=0.0, linear_prob=False, load_model=True, save_model=False, logging=False, scheduler=False, concat_hidden=False, pooling='mean', deg4feat=False, batch_size=32, num_shots=5, num_ways=3, num_queries=15, num_tasks=100)
Use node label as node features
******** # Num Graphs: 221, # Num Feat: 10, # Num Classes: 8 ********
Loaded pre-trained model weights from checkpoint.pt


Few-shot tasks:   4%|▍         | 4/100 [00:00<00:10,  8.98it/s]

Skipping task 2: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:   9%|▉         | 9/100 [00:00<00:06, 14.23it/s]

Skipping task 5: Not enough samples (19) for class 0. Need 20.
Skipping task 6: Not enough samples (19) for class 0. Need 20.
Skipping task 7: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  14%|█▍        | 14/100 [00:01<00:06, 12.81it/s]

Skipping task 11: Not enough samples (19) for class 0. Need 20.
Skipping task 12: Not enough samples (19) for class 0. Need 20.
Skipping task 14: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  20%|██        | 20/100 [00:01<00:04, 16.31it/s]

Skipping task 16: Not enough samples (19) for class 0. Need 20.
Skipping task 17: Not enough samples (19) for class 0. Need 20.
Skipping task 18: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  24%|██▍       | 24/100 [00:01<00:05, 14.67it/s]

Skipping task 21: Not enough samples (19) for class 0. Need 20.
Skipping task 22: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  28%|██▊       | 28/100 [00:02<00:05, 13.82it/s]

Skipping task 25: Not enough samples (19) for class 0. Need 20.
Skipping task 26: Not enough samples (19) for class 0. Need 20.
Skipping task 28: Not enough samples (19) for class 0. Need 20.
Skipping task 29: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  36%|███▌      | 36/100 [00:02<00:03, 20.19it/s]

Skipping task 31: Not enough samples (19) for class 0. Need 20.
Skipping task 32: Not enough samples (19) for class 0. Need 20.
Skipping task 33: Not enough samples (19) for class 0. Need 20.
Skipping task 34: Not enough samples (19) for class 0. Need 20.
Skipping task 36: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  43%|████▎     | 43/100 [00:03<00:04, 12.23it/s]

Skipping task 41: Not enough samples (19) for class 0. Need 20.
Skipping task 43: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  49%|████▉     | 49/100 [00:03<00:04, 12.36it/s]

Skipping task 46: Not enough samples (19) for class 0. Need 20.
Skipping task 48: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  51%|█████     | 51/100 [00:03<00:03, 12.51it/s]

Skipping task 50: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  55%|█████▌    | 55/100 [00:04<00:04, 10.55it/s]

Skipping task 54: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  62%|██████▏   | 62/100 [00:04<00:02, 12.83it/s]

Skipping task 58: Not enough samples (19) for class 0. Need 20.
Skipping task 59: Not enough samples (19) for class 0. Need 20.
Skipping task 60: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  64%|██████▍   | 64/100 [00:05<00:02, 12.89it/s]

Skipping task 63: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  68%|██████▊   | 68/100 [00:05<00:03, 10.66it/s]

Skipping task 66: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  70%|███████   | 70/100 [00:05<00:02, 11.14it/s]

Skipping task 69: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  75%|███████▌  | 75/100 [00:06<00:02, 11.72it/s]

Skipping task 72: Not enough samples (19) for class 0. Need 20.
Skipping task 73: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  83%|████████▎ | 83/100 [00:07<00:01,  9.12it/s]

Skipping task 81: Not enough samples (19) for class 0. Need 20.
Skipping task 83: Not enough samples (19) for class 0. Need 20.
Skipping task 84: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  91%|█████████ | 91/100 [00:07<00:00, 13.89it/s]

Skipping task 87: Not enough samples (19) for class 0. Need 20.
Skipping task 88: Not enough samples (19) for class 0. Need 20.
Skipping task 89: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  95%|█████████▌| 95/100 [00:08<00:00, 13.35it/s]

Skipping task 92: Not enough samples (19) for class 0. Need 20.
Skipping task 93: Not enough samples (19) for class 0. Need 20.
Skipping task 95: Not enough samples (19) for class 0. Need 20.
Skipping task 96: Not enough samples (19) for class 0. Need 20.


Few-shot tasks: 100%|██████████| 100/100 [00:08<00:00, 11.90it/s]

Skipping task 99: Not enough samples (19) for class 0. Need 20.
# Few-shot Acc (3-way 5-shot): 0.9700±0.0383





In [77]:
if __name__ == "__main__":
    # Simulate command-line arguments using argparse.Namespace
    args = argparse.Namespace(
        dataset='MSRC_9',
        encoder='gin',
        decoder='gin',
        seeds=[0],
        device=0, # Use CPU if GPU is not available
        use_cfg=False,
        max_epoch=200, # Default value from build_args
        warmup_steps=-1, # Default value from build_args
        num_heads=4, # Default value from build_args
        num_out_heads=1, # Default value from build_args
        num_layers=2, # Default value from build_args
        num_hidden=512, # Default value from build_args
        residual=False, # Default value from build_args
        in_drop=0.2, # Default value from build_args
        attn_drop=0.1, # Default value from build_args
        norm=None, # Default value from build_args
        lr=0.0005, # Default value from build_args
        weight_decay=5e-4, # Default value from build_args
        negative_slope=0.2, # Default value from build_args
        activation="prelu", # Default value from build_args
        mask_rate=0.5, # Default value from build_args
        drop_edge_rate=0.0, # Default value from build_args
        replace_rate=0.0, # Default value from build_args
        loss_fn="sce", # Default value from build_fn_args
        alpha_l=2, # Default value from build_args
        optimizer="adam", # Default value from build_args
        max_epoch_f=30, # Default value from build_args
        lr_f=0.001, # Default value from build_args
        weight_decay_f=0.0, # Default value from build_args
        linear_prob=False, # Default value from build_args
        load_model=True, # Load the pre-trained model
        save_model=False, # Set to False to avoid saving during few-shot eval
        logging=False, # Default value from build_args
        scheduler=False, # Default value from build_args
        concat_hidden=False, # Default value from build_args
        pooling="mean", # Default value from build_args
        deg4feat=False, # Default value from build_args
        batch_size=32, # Default value from build_args
        num_shots=5, # Few-shot parameter: number of support samples per class
        num_ways=5, # Few-shot parameter: number of classes per task
        num_queries=15, # Few-shot parameter: number of query samples per class
        num_tasks=100 # Few-shot parameter: number of few-shot tasks to evaluate on
    )

    if args.use_cfg:
         # This part will still load configs.yml based on args.dataset
        args = load_best_configs(args, "configs.yml")
    print(args)

    # Load the dataset
    graphs, (num_features, num_classes) = load_graph_classification_dataset(args.dataset, deg4feat=args.deg4feat)
    args.num_features = num_features
    args.num_classes = num_classes # Add num_classes to args

    # Build the model
    model = build_model(args)
    model.to(args.device)

    # Load the pre-trained model weights if load_model is True
    if args.load_model:
        try:
            model.load_state_dict(torch.load("checkpoint.pt"))
            print("Loaded pre-trained model weights from checkpoint.pt")
        except FileNotFoundError:
            print("checkpoint.pt not found. Please make sure the pre-trained model is saved.")
            # Optionally, you can exit or train the model from scratch here
            # return

    # Initialize the pooler
    if args.pooling == "mean":
        pooler = AvgPooling()
    elif args.pooling == "max":
        pooler = MaxPooling()
    elif args.pooling == "sum":
        pooler = SumPooling()
    else:
        raise NotImplementedError

    # Run few-shot evaluation with fine-tuning
    acc_list = []
    for task in tqdm(range(args.num_tasks), desc="Few-shot tasks"):
        # Randomly sample classes
        selected_classes = np.random.choice(num_classes, args.num_ways, replace=False)
        task_graphs_with_labels = [(graph, label) for graph, label in graphs if label.item() in selected_classes]
        task_labels_tensor = torch.tensor([label.item() for _, label in task_graphs_with_labels])
        task_graphs = [graph for graph, _ in task_graphs_with_labels]

        # Create mapping for selected classes to 0-num_ways-1
        class_map = {cls.item(): i for i, cls in enumerate(selected_classes)}
        mapped_labels = torch.tensor([class_map[label.item()] for _, label in task_graphs_with_labels])

        # Split into support and query sets
        support_indices = []
        query_indices = []
        skip_task = False
        for cls in selected_classes:
            # Get indices for the current class within the task_graphs_with_labels list
            cls_task_indices = [i for i, (_, label) in enumerate(task_graphs_with_labels) if label.item() == cls.item()]

            if len(cls_task_indices) < args.num_shots + args.num_queries:
                 # Skip task if not enough samples for a class
                print(f"Skipping task {task}: Not enough samples ({len(cls_task_indices)}) for class {cls}. Need {args.num_shots + args.num_queries}.")
                skip_task = True
                break

            # Shuffle indices for the current class and split
            shuffled_cls_indices = np.random.permutation(cls_task_indices)
            support_indices.extend(shuffled_cls_indices[:args.num_shots])
            query_indices.extend(shuffled_cls_indices[args.num_shots:args.num_shots + args.num_queries])

        if skip_task:
            continue

        if not support_indices or not query_indices:
             # Skip task if no support or query samples could be gathered (should be caught by the previous check, but as a safeguard)
            print(f"Skipping task {task}: No support or query samples gathered after splitting")
            continue

        # Ensure indices are within the bounds of task_graphs
        support_graphs = [task_graphs[i] for i in support_indices]
        support_labels = mapped_labels[support_indices]
        query_graphs = [task_graphs[i] for i in query_indices]
        query_labels = mapped_labels[query_indices]

        # Perform fine-tuning and evaluation
        f1 = few_shot_finetune_and_evaluate(
            model,
            pooler,
            support_graphs,
            support_labels,
            query_graphs,
            query_labels,
            num_classes,
            args.lr_f,
            args.weight_decay_f,
            args.max_epoch_f,
            args.device,
            args.num_hidden,
        )
        acc_list.append(f1)

    final_acc = np.mean(acc_list)
    final_acc_std = np.std(acc_list)
    print(f"# Few-shot Acc ({args.num_ways}-way {args.num_shots}-shot): {final_acc:.4f}±{final_acc_std:.4f}")

Namespace(dataset='MSRC_9', encoder='gin', decoder='gin', seeds=[0], device=0, use_cfg=False, max_epoch=200, warmup_steps=-1, num_heads=4, num_out_heads=1, num_layers=2, num_hidden=512, residual=False, in_drop=0.2, attn_drop=0.1, norm=None, lr=0.0005, weight_decay=0.0005, negative_slope=0.2, activation='prelu', mask_rate=0.5, drop_edge_rate=0.0, replace_rate=0.0, loss_fn='sce', alpha_l=2, optimizer='adam', max_epoch_f=30, lr_f=0.001, weight_decay_f=0.0, linear_prob=False, load_model=True, save_model=False, logging=False, scheduler=False, concat_hidden=False, pooling='mean', deg4feat=False, batch_size=32, num_shots=5, num_ways=5, num_queries=15, num_tasks=100)
Use node label as node features
******** # Num Graphs: 221, # Num Feat: 10, # Num Classes: 8 ********
Loaded pre-trained model weights from checkpoint.pt


Few-shot tasks:   0%|          | 0/100 [00:00<?, ?it/s]

Skipping task 0: Not enough samples (19) for class 0. Need 20.
Skipping task 1: Not enough samples (19) for class 0. Need 20.
Skipping task 2: Not enough samples (19) for class 0. Need 20.
Skipping task 3: Not enough samples (19) for class 0. Need 20.
Skipping task 4: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:   6%|▌         | 6/100 [00:00<00:03, 25.19it/s]

Skipping task 6: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  11%|█         | 11/100 [00:01<00:10,  8.41it/s]

Skipping task 11: Not enough samples (19) for class 0. Need 20.
Skipping task 12: Not enough samples (19) for class 0. Need 20.
Skipping task 13: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  17%|█▋        | 17/100 [00:01<00:07, 10.47it/s]

Skipping task 16: Not enough samples (19) for class 0. Need 20.
Skipping task 17: Not enough samples (19) for class 0. Need 20.
Skipping task 18: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  20%|██        | 20/100 [00:01<00:07, 11.42it/s]

Skipping task 20: Not enough samples (19) for class 0. Need 20.
Skipping task 21: Not enough samples (19) for class 0. Need 20.
Skipping task 22: Not enough samples (19) for class 0. Need 20.
Skipping task 23: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  27%|██▋       | 27/100 [00:02<00:05, 13.22it/s]

Skipping task 26: Not enough samples (19) for class 0. Need 20.
Skipping task 27: Not enough samples (19) for class 0. Need 20.
Skipping task 28: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  30%|███       | 30/100 [00:02<00:05, 13.41it/s]

Skipping task 30: Not enough samples (19) for class 0. Need 20.
Skipping task 31: Not enough samples (19) for class 0. Need 20.
Skipping task 32: Not enough samples (19) for class 0. Need 20.
Skipping task 33: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  37%|███▋      | 37/100 [00:02<00:04, 13.51it/s]

Skipping task 36: Not enough samples (19) for class 0. Need 20.
Skipping task 37: Not enough samples (19) for class 0. Need 20.
Skipping task 38: Not enough samples (19) for class 0. Need 20.
Skipping task 39: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  41%|████      | 41/100 [00:03<00:03, 14.77it/s]

Skipping task 41: Not enough samples (19) for class 0. Need 20.
Skipping task 42: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  44%|████▍     | 44/100 [00:03<00:04, 13.93it/s]

Skipping task 44: Not enough samples (19) for class 0. Need 20.
Skipping task 45: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  47%|████▋     | 47/100 [00:03<00:04, 13.00it/s]

Skipping task 47: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  51%|█████     | 51/100 [00:04<00:04,  9.89it/s]

Skipping task 50: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  57%|█████▋    | 57/100 [00:04<00:04,  9.41it/s]

Skipping task 53: Not enough samples (19) for class 0. Need 20.
Skipping task 54: Not enough samples (19) for class 0. Need 20.
Skipping task 55: Not enough samples (19) for class 0. Need 20.
Skipping task 57: Not enough samples (19) for class 0. Need 20.
Skipping task 58: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  60%|██████    | 60/100 [00:05<00:03, 10.48it/s]

Skipping task 60: Not enough samples (19) for class 0. Need 20.
Skipping task 61: Not enough samples (19) for class 0. Need 20.
Skipping task 62: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  64%|██████▍   | 64/100 [00:05<00:02, 12.34it/s]

Skipping task 64: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  70%|███████   | 70/100 [00:05<00:02, 11.36it/s]

Skipping task 67: Not enough samples (19) for class 0. Need 20.
Skipping task 68: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  74%|███████▍  | 74/100 [00:06<00:02, 12.05it/s]

Skipping task 71: Not enough samples (19) for class 0. Need 20.
Skipping task 73: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  81%|████████  | 81/100 [00:06<00:01, 17.69it/s]

Skipping task 75: Not enough samples (19) for class 0. Need 20.
Skipping task 76: Not enough samples (19) for class 0. Need 20.
Skipping task 77: Not enough samples (19) for class 0. Need 20.
Skipping task 78: Not enough samples (19) for class 0. Need 20.
Skipping task 79: Not enough samples (19) for class 0. Need 20.
Skipping task 81: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  88%|████████▊ | 88/100 [00:06<00:00, 20.53it/s]

Skipping task 83: Not enough samples (19) for class 0. Need 20.
Skipping task 84: Not enough samples (19) for class 0. Need 20.
Skipping task 85: Not enough samples (19) for class 0. Need 20.
Skipping task 86: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  91%|█████████ | 91/100 [00:07<00:00, 20.11it/s]

Skipping task 89: Not enough samples (19) for class 0. Need 20.
Skipping task 90: Not enough samples (19) for class 0. Need 20.
Skipping task 92: Not enough samples (19) for class 0. Need 20.


Few-shot tasks:  98%|█████████▊| 98/100 [00:07<00:00, 17.64it/s]

Skipping task 94: Not enough samples (19) for class 0. Need 20.
Skipping task 95: Not enough samples (19) for class 0. Need 20.
Skipping task 96: Not enough samples (19) for class 0. Need 20.


Few-shot tasks: 100%|██████████| 100/100 [00:07<00:00, 12.98it/s]

Skipping task 99: Not enough samples (19) for class 0. Need 20.
# Few-shot Acc (5-way 5-shot): 0.9441±0.0329





In [79]:
if __name__ == "__main__":
    # Simulate command-line arguments using argparse.Namespace
    args = argparse.Namespace(
        dataset='MSRC_9',
        encoder='gin',
        decoder='gin',
        seeds=[0],
        device=0, # Use CPU if GPU is not available
        use_cfg=False,
        max_epoch=200, # Default value from build_args
        warmup_steps=-1, # Default value from build_args
        num_heads=4, # Default value from build_args
        num_out_heads=1, # Default value from build_args
        num_layers=2, # Default value from build_args
        num_hidden=512, # Default value from build_args
        residual=False, # Default value from build_args
        in_drop=0.2, # Default value from build_args
        attn_drop=0.1, # Default value from build_args
        norm=None, # Default value from build_args
        lr=0.0005, # Default value from build_args
        weight_decay=5e-4, # Default value from build_args
        negative_slope=0.2, # Default value from build_args
        activation="prelu", # Default value from build_args
        mask_rate=0.5, # Default value from build_args
        drop_edge_rate=0.0, # Default value from build_args
        replace_rate=0.0, # Default value from build_args
        loss_fn="sce", # Default value from build_fn_args
        alpha_l=2, # Default value from build_args
        optimizer="adam", # Default value from build_args
        max_epoch_f=30, # Default value from build_args
        lr_f=0.001, # Default value from build_args
        weight_decay_f=0.0, # Default value from build_args
        linear_prob=False, # Default value from build_args
        load_model=True, # Load the pre-trained model
        save_model=False, # Set to False to avoid saving during few-shot eval
        logging=False, # Default value from build_args
        scheduler=False, # Default value from build_args
        concat_hidden=False, # Default value from build_args
        pooling="mean", # Default value from build_args
        deg4feat=False, # Default value from build_args
        batch_size=32, # Default value from build_args
        num_shots=10, # Few-shot parameter: number of support samples per class
        num_ways=3, # Few-shot parameter: number of classes per task
        num_queries=15, # Few-shot parameter: number of query samples per class
        num_tasks=100 # Few-shot parameter: number of few-shot tasks to evaluate on
    )

    if args.use_cfg:
         # This part will still load configs.yml based on args.dataset
        args = load_best_configs(args, "configs.yml")
    print(args)

    # Load the dataset
    graphs, (num_features, num_classes) = load_graph_classification_dataset(args.dataset, deg4feat=args.deg4feat)
    args.num_features = num_features
    args.num_classes = num_classes # Add num_classes to args

    # Build the model
    model = build_model(args)
    model.to(args.device)

    # Load the pre-trained model weights if load_model is True
    if args.load_model:
        try:
            model.load_state_dict(torch.load("checkpoint.pt"))
            print("Loaded pre-trained model weights from checkpoint.pt")
        except FileNotFoundError:
            print("checkpoint.pt not found. Please make sure the pre-trained model is saved.")
            # Optionally, you can exit or train the model from scratch here
            # return

    # Initialize the pooler
    if args.pooling == "mean":
        pooler = AvgPooling()
    elif args.pooling == "max":
        pooler = MaxPooling()
    elif args.pooling == "sum":
        pooler = SumPooling()
    else:
        raise NotImplementedError

    # Run few-shot evaluation with fine-tuning
    acc_list = []
    for task in tqdm(range(args.num_tasks), desc="Few-shot tasks"):
        # Randomly sample classes
        selected_classes = np.random.choice(num_classes, args.num_ways, replace=False)
        task_graphs_with_labels = [(graph, label) for graph, label in graphs if label.item() in selected_classes]
        task_labels_tensor = torch.tensor([label.item() for _, label in task_graphs_with_labels])
        task_graphs = [graph for graph, _ in task_graphs_with_labels]

        # Create mapping for selected classes to 0-num_ways-1
        class_map = {cls.item(): i for i, cls in enumerate(selected_classes)}
        mapped_labels = torch.tensor([class_map[label.item()] for _, label in task_graphs_with_labels])

        # Split into support and query sets
        support_indices = []
        query_indices = []
        skip_task = False
        for cls in selected_classes:
            # Get indices for the current class within the task_graphs_with_labels list
            cls_task_indices = [i for i, (_, label) in enumerate(task_graphs_with_labels) if label.item() == cls.item()]

            if len(cls_task_indices) < args.num_shots + args.num_queries:
                 # Skip task if not enough samples for a class
                print(f"Skipping task {task}: Not enough samples ({len(cls_task_indices)}) for class {cls}. Need {args.num_shots + args.num_queries}.")
                skip_task = True
                break

            # Shuffle indices for the current class and split
            shuffled_cls_indices = np.random.permutation(cls_task_indices)
            support_indices.extend(shuffled_cls_indices[:args.num_shots])
            query_indices.extend(shuffled_cls_indices[args.num_shots:args.num_shots + args.num_queries])

        if skip_task:
            continue

        if not support_indices or not query_indices:
             # Skip task if no support or query samples could be gathered (should be caught by the previous check, but as a safeguard)
            print(f"Skipping task {task}: No support or query samples gathered after splitting")
            continue

        # Ensure indices are within the bounds of task_graphs
        support_graphs = [task_graphs[i] for i in support_indices]
        support_labels = mapped_labels[support_indices]
        query_graphs = [task_graphs[i] for i in query_indices]
        query_labels = mapped_labels[query_indices]

        # Perform fine-tuning and evaluation
        f1 = few_shot_finetune_and_evaluate(
            model,
            pooler,
            support_graphs,
            support_labels,
            query_graphs,
            query_labels,
            num_classes,
            args.lr_f,
            args.weight_decay_f,
            args.max_epoch_f,
            args.device,
            args.num_hidden,
        )
        acc_list.append(f1)

    final_acc = np.mean(acc_list)
    final_acc_std = np.std(acc_list)
    print(f"# Few-shot Acc ({args.num_ways}-way {args.num_shots}-shot): {final_acc:.4f}±{final_acc_std:.4f}")

Namespace(dataset='MSRC_9', encoder='gin', decoder='gin', seeds=[0], device=0, use_cfg=False, max_epoch=200, warmup_steps=-1, num_heads=4, num_out_heads=1, num_layers=2, num_hidden=512, residual=False, in_drop=0.2, attn_drop=0.1, norm=None, lr=0.0005, weight_decay=0.0005, negative_slope=0.2, activation='prelu', mask_rate=0.5, drop_edge_rate=0.0, replace_rate=0.0, loss_fn='sce', alpha_l=2, optimizer='adam', max_epoch_f=30, lr_f=0.001, weight_decay_f=0.0, linear_prob=False, load_model=True, save_model=False, logging=False, scheduler=False, concat_hidden=False, pooling='mean', deg4feat=False, batch_size=32, num_shots=10, num_ways=3, num_queries=15, num_tasks=100)
Use node label as node features
******** # Num Graphs: 221, # Num Feat: 10, # Num Classes: 8 ********
Loaded pre-trained model weights from checkpoint.pt


Few-shot tasks:   0%|          | 0/100 [00:00<?, ?it/s]

Skipping task 0: Not enough samples (23) for class 4. Need 25.


Few-shot tasks:   2%|▏         | 2/100 [00:00<00:11,  8.33it/s]

Skipping task 2: Not enough samples (23) for class 4. Need 25.
Skipping task 3: Not enough samples (23) for class 4. Need 25.
Skipping task 4: Not enough samples (19) for class 0. Need 25.
Skipping task 5: Not enough samples (19) for class 0. Need 25.
Skipping task 6: Not enough samples (23) for class 4. Need 25.
Skipping task 7: Not enough samples (23) for class 4. Need 25.
Skipping task 8: Not enough samples (23) for class 4. Need 25.


Few-shot tasks:  10%|█         | 10/100 [00:00<00:03, 23.64it/s]

Skipping task 10: Not enough samples (19) for class 0. Need 25.
Skipping task 11: Not enough samples (23) for class 4. Need 25.
Skipping task 12: Not enough samples (19) for class 0. Need 25.


Few-shot tasks:  17%|█▋        | 17/100 [00:01<00:06, 13.46it/s]

Skipping task 16: Not enough samples (19) for class 0. Need 25.
Skipping task 17: Not enough samples (19) for class 0. Need 25.
Skipping task 18: Not enough samples (23) for class 4. Need 25.


Few-shot tasks:  20%|██        | 20/100 [00:01<00:05, 13.50it/s]

Skipping task 20: Not enough samples (23) for class 4. Need 25.
Skipping task 21: Not enough samples (19) for class 0. Need 25.
Skipping task 22: Not enough samples (19) for class 0. Need 25.
Skipping task 23: Not enough samples (19) for class 0. Need 25.


Few-shot tasks:  25%|██▌       | 25/100 [00:01<00:04, 16.09it/s]

Skipping task 25: Not enough samples (23) for class 4. Need 25.
Skipping task 26: Not enough samples (19) for class 0. Need 25.


Few-shot tasks:  30%|███       | 30/100 [00:01<00:05, 13.85it/s]

Skipping task 29: Not enough samples (19) for class 0. Need 25.


Few-shot tasks:  32%|███▏      | 32/100 [00:02<00:05, 12.50it/s]

Skipping task 31: Not enough samples (23) for class 4. Need 25.
Skipping task 32: Not enough samples (19) for class 0. Need 25.
Skipping task 33: Not enough samples (19) for class 0. Need 25.


Few-shot tasks:  38%|███▊      | 38/100 [00:02<00:04, 13.51it/s]

Skipping task 35: Not enough samples (23) for class 4. Need 25.
Skipping task 36: Not enough samples (19) for class 0. Need 25.
Skipping task 38: Not enough samples (23) for class 4. Need 25.
Skipping task 39: Not enough samples (19) for class 0. Need 25.
Skipping task 40: Not enough samples (19) for class 0. Need 25.
Skipping task 41: Not enough samples (19) for class 0. Need 25.
Skipping task 42: Not enough samples (23) for class 4. Need 25.


Few-shot tasks:  44%|████▍     | 44/100 [00:02<00:03, 17.79it/s]

Skipping task 44: Not enough samples (19) for class 0. Need 25.
Skipping task 45: Not enough samples (19) for class 0. Need 25.


Few-shot tasks:  49%|████▉     | 49/100 [00:03<00:03, 14.73it/s]

Skipping task 47: Not enough samples (23) for class 4. Need 25.
Skipping task 49: Not enough samples (23) for class 4. Need 25.


Few-shot tasks:  55%|█████▌    | 55/100 [00:04<00:04,  9.39it/s]

Skipping task 54: Not enough samples (23) for class 4. Need 25.
Skipping task 55: Not enough samples (23) for class 4. Need 25.
Skipping task 56: Not enough samples (23) for class 4. Need 25.
Skipping task 57: Not enough samples (19) for class 0. Need 25.
Skipping task 58: Not enough samples (19) for class 0. Need 25.


Few-shot tasks:  60%|██████    | 60/100 [00:04<00:03, 12.82it/s]

Skipping task 60: Not enough samples (23) for class 4. Need 25.
Skipping task 61: Not enough samples (23) for class 4. Need 25.
Skipping task 62: Not enough samples (23) for class 4. Need 25.
Skipping task 63: Not enough samples (19) for class 0. Need 25.
Skipping task 64: Not enough samples (19) for class 0. Need 25.
Skipping task 65: Not enough samples (23) for class 4. Need 25.


Few-shot tasks:  67%|██████▋   | 67/100 [00:04<00:01, 18.14it/s]

Skipping task 67: Not enough samples (19) for class 0. Need 25.
Skipping task 68: Not enough samples (19) for class 0. Need 25.
Skipping task 69: Not enough samples (19) for class 0. Need 25.


Few-shot tasks:  71%|███████   | 71/100 [00:04<00:01, 18.40it/s]

Skipping task 71: Not enough samples (23) for class 4. Need 25.


Few-shot tasks:  73%|███████▎  | 73/100 [00:04<00:01, 15.10it/s]

Skipping task 73: Not enough samples (23) for class 4. Need 25.


Few-shot tasks:  75%|███████▌  | 75/100 [00:05<00:01, 13.39it/s]

Skipping task 75: Not enough samples (23) for class 4. Need 25.
Skipping task 76: Not enough samples (23) for class 4. Need 25.
Skipping task 77: Not enough samples (23) for class 4. Need 25.


Few-shot tasks:  83%|████████▎ | 83/100 [00:05<00:01, 16.94it/s]

Skipping task 79: Not enough samples (23) for class 4. Need 25.
Skipping task 80: Not enough samples (23) for class 4. Need 25.
Skipping task 81: Not enough samples (23) for class 4. Need 25.
Skipping task 83: Not enough samples (23) for class 4. Need 25.
Skipping task 84: Not enough samples (19) for class 0. Need 25.
Skipping task 85: Not enough samples (23) for class 4. Need 25.
Skipping task 86: Not enough samples (23) for class 4. Need 25.
Skipping task 87: Not enough samples (19) for class 0. Need 25.
Skipping task 88: Not enough samples (19) for class 0. Need 25.


Few-shot tasks:  94%|█████████▍| 94/100 [00:05<00:00, 24.15it/s]

Skipping task 90: Not enough samples (23) for class 4. Need 25.
Skipping task 91: Not enough samples (23) for class 4. Need 25.
Skipping task 92: Not enough samples (19) for class 0. Need 25.


Few-shot tasks: 100%|██████████| 100/100 [00:06<00:00, 16.47it/s]

Skipping task 95: Not enough samples (19) for class 0. Need 25.
Skipping task 96: Not enough samples (19) for class 0. Need 25.
Skipping task 97: Not enough samples (23) for class 4. Need 25.
Skipping task 98: Not enough samples (19) for class 0. Need 25.
Skipping task 99: Not enough samples (19) for class 0. Need 25.
# Few-shot Acc (3-way 10-shot): 0.9755±0.0342





In [80]:
if __name__ == "__main__":
    # Simulate command-line arguments using argparse.Namespace
    args = argparse.Namespace(
        dataset='MSRC_9',
        encoder='gin',
        decoder='gin',
        seeds=[0],
        device=0, # Use CPU if GPU is not available
        use_cfg=False,
        max_epoch=200, # Default value from build_args
        warmup_steps=-1, # Default value from build_args
        num_heads=4, # Default value from build_args
        num_out_heads=1, # Default value from build_args
        num_layers=2, # Default value from build_args
        num_hidden=512, # Default value from build_args
        residual=False, # Default value from build_args
        in_drop=0.2, # Default value from build_args
        attn_drop=0.1, # Default value from build_args
        norm=None, # Default value from build_args
        lr=0.0005, # Default value from build_args
        weight_decay=5e-4, # Default value from build_args
        negative_slope=0.2, # Default value from build_args
        activation="prelu", # Default value from build_args
        mask_rate=0.5, # Default value from build_args
        drop_edge_rate=0.0, # Default value from build_args
        replace_rate=0.0, # Default value from build_args
        loss_fn="sce", # Default value from build_fn_args
        alpha_l=2, # Default value from build_args
        optimizer="adam", # Default value from build_args
        max_epoch_f=30, # Default value from build_args
        lr_f=0.001, # Default value from build_args
        weight_decay_f=0.0, # Default value from build_args
        linear_prob=False, # Default value from build_args
        load_model=True, # Load the pre-trained model
        save_model=False, # Set to False to avoid saving during few-shot eval
        logging=False, # Default value from build_args
        scheduler=False, # Default value from build_args
        concat_hidden=False, # Default value from build_args
        pooling="mean", # Default value from build_args
        deg4feat=False, # Default value from build_args
        batch_size=32, # Default value from build_args
        num_shots=10, # Few-shot parameter: number of support samples per class
        num_ways=5, # Few-shot parameter: number of classes per task
        num_queries=15, # Few-shot parameter: number of query samples per class
        num_tasks=100 # Few-shot parameter: number of few-shot tasks to evaluate on
    )

    if args.use_cfg:
         # This part will still load configs.yml based on args.dataset
        args = load_best_configs(args, "configs.yml")
    print(args)

    # Load the dataset
    graphs, (num_features, num_classes) = load_graph_classification_dataset(args.dataset, deg4feat=args.deg4feat)
    args.num_features = num_features
    args.num_classes = num_classes # Add num_classes to args

    # Build the model
    model = build_model(args)
    model.to(args.device)

    # Load the pre-trained model weights if load_model is True
    if args.load_model:
        try:
            model.load_state_dict(torch.load("checkpoint.pt"))
            print("Loaded pre-trained model weights from checkpoint.pt")
        except FileNotFoundError:
            print("checkpoint.pt not found. Please make sure the pre-trained model is saved.")
            # Optionally, you can exit or train the model from scratch here
            # return

    # Initialize the pooler
    if args.pooling == "mean":
        pooler = AvgPooling()
    elif args.pooling == "max":
        pooler = MaxPooling()
    elif args.pooling == "sum":
        pooler = SumPooling()
    else:
        raise NotImplementedError

    # Run few-shot evaluation with fine-tuning
    acc_list = []
    for task in tqdm(range(args.num_tasks), desc="Few-shot tasks"):
        # Randomly sample classes
        selected_classes = np.random.choice(num_classes, args.num_ways, replace=False)
        task_graphs_with_labels = [(graph, label) for graph, label in graphs if label.item() in selected_classes]
        task_labels_tensor = torch.tensor([label.item() for _, label in task_graphs_with_labels])
        task_graphs = [graph for graph, _ in task_graphs_with_labels]

        # Create mapping for selected classes to 0-num_ways-1
        class_map = {cls.item(): i for i, cls in enumerate(selected_classes)}
        mapped_labels = torch.tensor([class_map[label.item()] for _, label in task_graphs_with_labels])

        # Split into support and query sets
        support_indices = []
        query_indices = []
        skip_task = False
        for cls in selected_classes:
            # Get indices for the current class within the task_graphs_with_labels list
            cls_task_indices = [i for i, (_, label) in enumerate(task_graphs_with_labels) if label.item() == cls.item()]

            if len(cls_task_indices) < args.num_shots + args.num_queries:
                 # Skip task if not enough samples for a class
                print(f"Skipping task {task}: Not enough samples ({len(cls_task_indices)}) for class {cls}. Need {args.num_shots + args.num_queries}.")
                skip_task = True
                break

            # Shuffle indices for the current class and split
            shuffled_cls_indices = np.random.permutation(cls_task_indices)
            support_indices.extend(shuffled_cls_indices[:args.num_shots])
            query_indices.extend(shuffled_cls_indices[args.num_shots:args.num_shots + args.num_queries])

        if skip_task:
            continue

        if not support_indices or not query_indices:
             # Skip task if no support or query samples could be gathered (should be caught by the previous check, but as a safeguard)
            print(f"Skipping task {task}: No support or query samples gathered after splitting")
            continue

        # Ensure indices are within the bounds of task_graphs
        support_graphs = [task_graphs[i] for i in support_indices]
        support_labels = mapped_labels[support_indices]
        query_graphs = [task_graphs[i] for i in query_indices]
        query_labels = mapped_labels[query_indices]

        # Perform fine-tuning and evaluation
        f1 = few_shot_finetune_and_evaluate(
            model,
            pooler,
            support_graphs,
            support_labels,
            query_graphs,
            query_labels,
            num_classes,
            args.lr_f,
            args.weight_decay_f,
            args.max_epoch_f,
            args.device,
            args.num_hidden,
        )
        acc_list.append(f1)

    final_acc = np.mean(acc_list)
    final_acc_std = np.std(acc_list)
    print(f"# Few-shot Acc ({args.num_ways}-way {args.num_shots}-shot): {final_acc:.4f}±{final_acc_std:.4f}")

Namespace(dataset='MSRC_9', encoder='gin', decoder='gin', seeds=[0], device=0, use_cfg=False, max_epoch=200, warmup_steps=-1, num_heads=4, num_out_heads=1, num_layers=2, num_hidden=512, residual=False, in_drop=0.2, attn_drop=0.1, norm=None, lr=0.0005, weight_decay=0.0005, negative_slope=0.2, activation='prelu', mask_rate=0.5, drop_edge_rate=0.0, replace_rate=0.0, loss_fn='sce', alpha_l=2, optimizer='adam', max_epoch_f=30, lr_f=0.001, weight_decay_f=0.0, linear_prob=False, load_model=True, save_model=False, logging=False, scheduler=False, concat_hidden=False, pooling='mean', deg4feat=False, batch_size=32, num_shots=10, num_ways=5, num_queries=15, num_tasks=100)
Use node label as node features
******** # Num Graphs: 221, # Num Feat: 10, # Num Classes: 8 ********
Loaded pre-trained model weights from checkpoint.pt


Few-shot tasks:  11%|█         | 11/100 [00:00<00:00, 108.84it/s]

Skipping task 0: Not enough samples (19) for class 0. Need 25.
Skipping task 1: Not enough samples (23) for class 4. Need 25.
Skipping task 2: Not enough samples (19) for class 0. Need 25.
Skipping task 3: Not enough samples (23) for class 4. Need 25.
Skipping task 4: Not enough samples (19) for class 0. Need 25.
Skipping task 5: Not enough samples (19) for class 0. Need 25.
Skipping task 6: Not enough samples (23) for class 4. Need 25.
Skipping task 7: Not enough samples (19) for class 0. Need 25.
Skipping task 8: Not enough samples (23) for class 4. Need 25.
Skipping task 9: Not enough samples (19) for class 0. Need 25.
Skipping task 10: Not enough samples (23) for class 4. Need 25.
Skipping task 11: Not enough samples (23) for class 4. Need 25.
Skipping task 12: Not enough samples (23) for class 4. Need 25.
Skipping task 13: Not enough samples (23) for class 4. Need 25.
Skipping task 15: Not enough samples (23) for class 4. Need 25.
Skipping task 16: Not enough samples (19) for clas

Few-shot tasks:  22%|██▏       | 22/100 [00:01<00:06, 12.88it/s] 

Skipping task 22: Not enough samples (19) for class 0. Need 25.


Few-shot tasks:  27%|██▋       | 27/100 [00:01<00:05, 13.51it/s]

Skipping task 24: Not enough samples (19) for class 0. Need 25.
Skipping task 25: Not enough samples (19) for class 0. Need 25.
Skipping task 26: Not enough samples (19) for class 0. Need 25.
Skipping task 27: Not enough samples (19) for class 0. Need 25.
Skipping task 28: Not enough samples (19) for class 0. Need 25.
Skipping task 29: Not enough samples (23) for class 4. Need 25.
Skipping task 30: Not enough samples (19) for class 0. Need 25.
Skipping task 31: Not enough samples (19) for class 0. Need 25.
Skipping task 32: Not enough samples (19) for class 0. Need 25.


Few-shot tasks:  34%|███▍      | 34/100 [00:02<00:03, 16.61it/s]

Skipping task 34: Not enough samples (19) for class 0. Need 25.
Skipping task 35: Not enough samples (23) for class 4. Need 25.


Few-shot tasks:  38%|███▊      | 38/100 [00:02<00:03, 16.88it/s]

Skipping task 37: Not enough samples (19) for class 0. Need 25.
Skipping task 38: Not enough samples (19) for class 0. Need 25.
Skipping task 39: Not enough samples (23) for class 4. Need 25.


Few-shot tasks:  41%|████      | 41/100 [00:02<00:03, 16.03it/s]

Skipping task 41: Not enough samples (19) for class 0. Need 25.
Skipping task 42: Not enough samples (23) for class 4. Need 25.
Skipping task 43: Not enough samples (19) for class 0. Need 25.
Skipping task 44: Not enough samples (23) for class 4. Need 25.
Skipping task 45: Not enough samples (19) for class 0. Need 25.
Skipping task 46: Not enough samples (19) for class 0. Need 25.
Skipping task 47: Not enough samples (19) for class 0. Need 25.
Skipping task 48: Not enough samples (19) for class 0. Need 25.
Skipping task 49: Not enough samples (23) for class 4. Need 25.
Skipping task 50: Not enough samples (19) for class 0. Need 25.
Skipping task 51: Not enough samples (19) for class 0. Need 25.
Skipping task 52: Not enough samples (23) for class 4. Need 25.
Skipping task 53: Not enough samples (19) for class 0. Need 25.


Few-shot tasks:  55%|█████▌    | 55/100 [00:02<00:01, 26.88it/s]

Skipping task 55: Not enough samples (19) for class 0. Need 25.
Skipping task 56: Not enough samples (23) for class 4. Need 25.
Skipping task 57: Not enough samples (23) for class 4. Need 25.
Skipping task 58: Not enough samples (19) for class 0. Need 25.
Skipping task 59: Not enough samples (19) for class 0. Need 25.
Skipping task 60: Not enough samples (23) for class 4. Need 25.
Skipping task 61: Not enough samples (19) for class 0. Need 25.
Skipping task 62: Not enough samples (23) for class 4. Need 25.
Skipping task 63: Not enough samples (19) for class 0. Need 25.
Skipping task 64: Not enough samples (23) for class 4. Need 25.
Skipping task 65: Not enough samples (19) for class 0. Need 25.
Skipping task 66: Not enough samples (19) for class 0. Need 25.
Skipping task 67: Not enough samples (23) for class 4. Need 25.
Skipping task 68: Not enough samples (19) for class 0. Need 25.
Skipping task 69: Not enough samples (19) for class 0. Need 25.
Skipping task 70: Not enough samples (23

Few-shot tasks:  72%|███████▏  | 72/100 [00:03<00:00, 37.79it/s]

Skipping task 72: Not enough samples (23) for class 4. Need 25.
Skipping task 73: Not enough samples (19) for class 0. Need 25.
Skipping task 74: Not enough samples (23) for class 4. Need 25.
Skipping task 75: Not enough samples (23) for class 4. Need 25.
Skipping task 76: Not enough samples (19) for class 0. Need 25.
Skipping task 77: Not enough samples (23) for class 4. Need 25.
Skipping task 78: Not enough samples (19) for class 0. Need 25.
Skipping task 79: Not enough samples (23) for class 4. Need 25.
Skipping task 80: Not enough samples (19) for class 0. Need 25.
Skipping task 81: Not enough samples (23) for class 4. Need 25.
Skipping task 82: Not enough samples (23) for class 4. Need 25.
Skipping task 83: Not enough samples (23) for class 4. Need 25.


Few-shot tasks:  85%|████████▌ | 85/100 [00:03<00:00, 41.74it/s]

Skipping task 86: Not enough samples (23) for class 4. Need 25.
Skipping task 87: Not enough samples (23) for class 4. Need 25.


Few-shot tasks: 100%|██████████| 100/100 [00:03<00:00, 26.69it/s]

Skipping task 89: Not enough samples (19) for class 0. Need 25.
Skipping task 90: Not enough samples (23) for class 4. Need 25.
Skipping task 91: Not enough samples (23) for class 4. Need 25.
Skipping task 92: Not enough samples (23) for class 4. Need 25.
Skipping task 93: Not enough samples (23) for class 4. Need 25.
Skipping task 94: Not enough samples (19) for class 0. Need 25.
Skipping task 95: Not enough samples (19) for class 0. Need 25.
Skipping task 96: Not enough samples (19) for class 0. Need 25.
Skipping task 97: Not enough samples (23) for class 4. Need 25.
Skipping task 98: Not enough samples (23) for class 4. Need 25.
Skipping task 99: Not enough samples (23) for class 4. Need 25.
# Few-shot Acc (5-way 10-shot): 0.9622±0.0271





In [81]:
if __name__ == "__main__":
    # Simulate command-line arguments using argparse.Namespace
    args = argparse.Namespace(
        dataset='DD',
        encoder='gin',
        decoder='gin',
        seeds=[0],
        device=0,
        use_cfg=False,
        max_epoch=200, # Default value from build_args
        warmup_steps=-1, # Default value from build_args
        num_heads=4, # Default value from build_args
        num_out_heads=1, # Default value from build_args
        num_layers=2, # Default value from build_args
        num_hidden=512, # Default value from build_args
        residual=False, # Default value from build_args
        in_drop=0.2, # Default value from build_args
        attn_drop=0.1, # Default value from build_args
        norm=None, # Default value from build_args
        lr=0.0005, # Default value from build_args
        weight_decay=5e-4, # Default value from build_args
        negative_slope=0.2, # Default value from build_args
        activation="prelu", # Default value from build_args
        mask_rate=0.5, # Default value from build_args
        drop_edge_rate=0.0, # Default value from build_args
        replace_rate=0.0, # Default value from build_args
        loss_fn="sce", # Default value from build_args
        alpha_l=2, # Default value from build_args
        optimizer="adam", # Default value from build_args
        max_epoch_f=50, # Default value from build_args
        lr_f=0.001, # Default value from build_args
        weight_decay_f=0.0, # Default value from build_args
        linear_prob=False, # Default value from build_args
        load_model=False, # Default value from build_args
        save_model=True, # Default value from build_args
        logging=False, # Default value from build_args
        scheduler=False, # Default value from build_args
        concat_hidden=False, # Default value from build_args
        pooling="mean", # Default value from build_args
        deg4feat=False, # Default value from build_args
        batch_size=32, # Default value from build_args
    )

    if args.use_cfg:
         # This part will still load configs.yml based on args.dataset
        args = load_best_configs(args, "configs.yml")
    print(args)
    main(args)

Namespace(dataset='DD', encoder='gin', decoder='gin', seeds=[0], device=0, use_cfg=False, max_epoch=200, warmup_steps=-1, num_heads=4, num_out_heads=1, num_layers=2, num_hidden=512, residual=False, in_drop=0.2, attn_drop=0.1, norm=None, lr=0.0005, weight_decay=0.0005, negative_slope=0.2, activation='prelu', mask_rate=0.5, drop_edge_rate=0.0, replace_rate=0.0, loss_fn='sce', alpha_l=2, optimizer='adam', max_epoch_f=50, lr_f=0.001, weight_decay_f=0.0, linear_prob=False, load_model=False, save_model=True, logging=False, scheduler=False, concat_hidden=False, pooling='mean', deg4feat=False, batch_size=32)
Downloading /root/.dgl/DD.zip from https://www.chrsmrrs.com/graphkerneldatasets/DD.zip...


/root/.dgl/DD.zip:   0%|          | 0.00/4.98M [00:00<?, ?B/s]

Extracting file to /root/.dgl/DD_a6864905
Use node label as node features
******** # Num Graphs: 1178, # Num Feat: 89, # Num Classes: 2 ********
####### Run 0 for seed 0


Epoch 199 | train_loss: 0.5731: 100%|██████████| 200/200 [04:01<00:00,  1.21s/it]


#Test_f1: 0.6800±0.0223
# final_acc: 0.6800±0.0000


In [83]:
if __name__ == "__main__":
    # Simulate command-line arguments using argparse.Namespace
    args = argparse.Namespace(
        dataset='DD',
        encoder='gin',
        decoder='gin',
        seeds=[0],
        device=0, # Use CPU if GPU is not available
        use_cfg=False,
        max_epoch=200, # Default value from build_args
        warmup_steps=-1, # Default value from build_args
        num_heads=4, # Default value from build_args
        num_out_heads=1, # Default value from build_args
        num_layers=2, # Default value from build_args
        num_hidden=512, # Default value from build_args
        residual=False, # Default value from build_args
        in_drop=0.2, # Default value from build_args
        attn_drop=0.1, # Default value from build_args
        norm=None, # Default value from build_args
        lr=0.0005, # Default value from build_args
        weight_decay=5e-4, # Default value from build_args
        negative_slope=0.2, # Default value from build_args
        activation="prelu", # Default value from build_args
        mask_rate=0.5, # Default value from build_args
        drop_edge_rate=0.0, # Default value from build_args
        replace_rate=0.0, # Default value from build_args
        loss_fn="sce", # Default value from build_fn_args
        alpha_l=2, # Default value from build_args
        optimizer="adam", # Default value from build_args
        max_epoch_f=30, # Default value from build_args
        lr_f=0.001, # Default value from build_args
        weight_decay_f=0.0, # Default value from build_args
        linear_prob=False, # Default value from build_args
        load_model=True, # Load the pre-trained model
        save_model=False, # Set to False to avoid saving during few-shot eval
        logging=False, # Default value from build_args
        scheduler=False, # Default value from build_args
        concat_hidden=False, # Default value from build_args
        pooling="mean", # Default value from build_args
        deg4feat=False, # Default value from build_args
        batch_size=32, # Default value from build_args
        num_shots=5, # Few-shot parameter: number of support samples per class
        num_ways=2, # Few-shot parameter: number of classes per task
        num_queries=15, # Few-shot parameter: number of query samples per class
        num_tasks=100 # Few-shot parameter: number of few-shot tasks to evaluate on
    )

    if args.use_cfg:
         # This part will still load configs.yml based on args.dataset
        args = load_best_configs(args, "configs.yml")
    print(args)

    # Load the dataset
    graphs, (num_features, num_classes) = load_graph_classification_dataset(args.dataset, deg4feat=args.deg4feat)
    args.num_features = num_features
    args.num_classes = num_classes # Add num_classes to args

    # Build the model
    model = build_model(args)
    model.to(args.device)

    # Load the pre-trained model weights if load_model is True
    if args.load_model:
        try:
            model.load_state_dict(torch.load("checkpoint.pt"))
            print("Loaded pre-trained model weights from checkpoint.pt")
        except FileNotFoundError:
            print("checkpoint.pt not found. Please make sure the pre-trained model is saved.")
            # Optionally, you can exit or train the model from scratch here
            # return

    # Initialize the pooler
    if args.pooling == "mean":
        pooler = AvgPooling()
    elif args.pooling == "max":
        pooler = MaxPooling()
    elif args.pooling == "sum":
        pooler = SumPooling()
    else:
        raise NotImplementedError

    # Run few-shot evaluation with fine-tuning
    acc_list = []
    for task in tqdm(range(args.num_tasks), desc="Few-shot tasks"):
        # Randomly sample classes
        selected_classes = np.random.choice(num_classes, args.num_ways, replace=False)
        task_graphs_with_labels = [(graph, label) for graph, label in graphs if label.item() in selected_classes]
        task_labels_tensor = torch.tensor([label.item() for _, label in task_graphs_with_labels])
        task_graphs = [graph for graph, _ in task_graphs_with_labels]

        # Create mapping for selected classes to 0-num_ways-1
        class_map = {cls.item(): i for i, cls in enumerate(selected_classes)}
        mapped_labels = torch.tensor([class_map[label.item()] for _, label in task_graphs_with_labels])

        # Split into support and query sets
        support_indices = []
        query_indices = []
        skip_task = False
        for cls in selected_classes:
            # Get indices for the current class within the task_graphs_with_labels list
            cls_task_indices = [i for i, (_, label) in enumerate(task_graphs_with_labels) if label.item() == cls.item()]

            if len(cls_task_indices) < args.num_shots + args.num_queries:
                 # Skip task if not enough samples for a class
                print(f"Skipping task {task}: Not enough samples ({len(cls_task_indices)}) for class {cls}. Need {args.num_shots + args.num_queries}.")
                skip_task = True
                break

            # Shuffle indices for the current class and split
            shuffled_cls_indices = np.random.permutation(cls_task_indices)
            support_indices.extend(shuffled_cls_indices[:args.num_shots])
            query_indices.extend(shuffled_cls_indices[args.num_shots:args.num_shots + args.num_queries])

        if skip_task:
            continue

        if not support_indices or not query_indices:
             # Skip task if no support or query samples could be gathered (should be caught by the previous check, but as a safeguard)
            print(f"Skipping task {task}: No support or query samples gathered after splitting")
            continue

        # Ensure indices are within the bounds of task_graphs
        support_graphs = [task_graphs[i] for i in support_indices]
        support_labels = mapped_labels[support_indices]
        query_graphs = [task_graphs[i] for i in query_indices]
        query_labels = mapped_labels[query_indices]

        # Perform fine-tuning and evaluation
        f1 = few_shot_finetune_and_evaluate(
            model,
            pooler,
            support_graphs,
            support_labels,
            query_graphs,
            query_labels,
            num_classes,
            args.lr_f,
            args.weight_decay_f,
            args.max_epoch_f,
            args.device,
            args.num_hidden,
        )
        acc_list.append(f1)

    final_acc = np.mean(acc_list)
    final_acc_std = np.std(acc_list)
    print(f"# Few-shot Acc ({args.num_ways}-way {args.num_shots}-shot): {final_acc:.4f}±{final_acc_std:.4f}")

Namespace(dataset='DD', encoder='gin', decoder='gin', seeds=[0], device=0, use_cfg=False, max_epoch=200, warmup_steps=-1, num_heads=4, num_out_heads=1, num_layers=2, num_hidden=512, residual=False, in_drop=0.2, attn_drop=0.1, norm=None, lr=0.0005, weight_decay=0.0005, negative_slope=0.2, activation='prelu', mask_rate=0.5, drop_edge_rate=0.0, replace_rate=0.0, loss_fn='sce', alpha_l=2, optimizer='adam', max_epoch_f=30, lr_f=0.001, weight_decay_f=0.0, linear_prob=False, load_model=True, save_model=False, logging=False, scheduler=False, concat_hidden=False, pooling='mean', deg4feat=False, batch_size=32, num_shots=5, num_ways=2, num_queries=15, num_tasks=100)
Use node label as node features
******** # Num Graphs: 1178, # Num Feat: 89, # Num Classes: 2 ********
Loaded pre-trained model weights from checkpoint.pt


Few-shot tasks: 100%|██████████| 100/100 [00:21<00:00,  4.71it/s]

# Few-shot Acc (2-way 5-shot): 0.5633±0.0853





In [85]:
if __name__ == "__main__":
    # Simulate command-line arguments using argparse.Namespace
    args = argparse.Namespace(
        dataset='DD',
        encoder='gin',
        decoder='gin',
        seeds=[0],
        device=0, # Use CPU if GPU is not available
        use_cfg=False,
        max_epoch=200, # Default value from build_args
        warmup_steps=-1, # Default value from build_args
        num_heads=4, # Default value from build_args
        num_out_heads=1, # Default value from build_args
        num_layers=2, # Default value from build_args
        num_hidden=512, # Default value from build_args
        residual=False, # Default value from build_args
        in_drop=0.2, # Default value from build_args
        attn_drop=0.1, # Default value from build_args
        norm=None, # Default value from build_args
        lr=0.0005, # Default value from build_args
        weight_decay=5e-4, # Default value from build_args
        negative_slope=0.2, # Default value from build_args
        activation="prelu", # Default value from build_args
        mask_rate=0.5, # Default value from build_args
        drop_edge_rate=0.0, # Default value from build_args
        replace_rate=0.0, # Default value from build_args
        loss_fn="sce", # Default value from build_fn_args
        alpha_l=2, # Default value from build_args
        optimizer="adam", # Default value from build_args
        max_epoch_f=30, # Default value from build_args
        lr_f=0.001, # Default value from build_args
        weight_decay_f=0.0, # Default value from build_args
        linear_prob=False, # Default value from build_args
        load_model=True, # Load the pre-trained model
        save_model=False, # Set to False to avoid saving during few-shot eval
        logging=False, # Default value from build_args
        scheduler=False, # Default value from build_args
        concat_hidden=False, # Default value from build_args
        pooling="mean", # Default value from build_args
        deg4feat=False, # Default value from build_args
        batch_size=32, # Default value from build_args
        num_shots=10, # Few-shot parameter: number of support samples per class
        num_ways=2, # Few-shot parameter: number of classes per task
        num_queries=15, # Few-shot parameter: number of query samples per class
        num_tasks=100 # Few-shot parameter: number of few-shot tasks to evaluate on
    )

    if args.use_cfg:
         # This part will still load configs.yml based on args.dataset
        args = load_best_configs(args, "configs.yml")
    print(args)

    # Load the dataset
    graphs, (num_features, num_classes) = load_graph_classification_dataset(args.dataset, deg4feat=args.deg4feat)
    args.num_features = num_features
    args.num_classes = num_classes # Add num_classes to args

    # Build the model
    model = build_model(args)
    model.to(args.device)

    # Load the pre-trained model weights if load_model is True
    if args.load_model:
        try:
            model.load_state_dict(torch.load("checkpoint.pt"))
            print("Loaded pre-trained model weights from checkpoint.pt")
        except FileNotFoundError:
            print("checkpoint.pt not found. Please make sure the pre-trained model is saved.")
            # Optionally, you can exit or train the model from scratch here
            # return

    # Initialize the pooler
    if args.pooling == "mean":
        pooler = AvgPooling()
    elif args.pooling == "max":
        pooler = MaxPooling()
    elif args.pooling == "sum":
        pooler = SumPooling()
    else:
        raise NotImplementedError

    # Run few-shot evaluation with fine-tuning
    acc_list = []
    for task in tqdm(range(args.num_tasks), desc="Few-shot tasks"):
        # Randomly sample classes
        selected_classes = np.random.choice(num_classes, args.num_ways, replace=False)
        task_graphs_with_labels = [(graph, label) for graph, label in graphs if label.item() in selected_classes]
        task_labels_tensor = torch.tensor([label.item() for _, label in task_graphs_with_labels])
        task_graphs = [graph for graph, _ in task_graphs_with_labels]

        # Create mapping for selected classes to 0-num_ways-1
        class_map = {cls.item(): i for i, cls in enumerate(selected_classes)}
        mapped_labels = torch.tensor([class_map[label.item()] for _, label in task_graphs_with_labels])

        # Split into support and query sets
        support_indices = []
        query_indices = []
        skip_task = False
        for cls in selected_classes:
            # Get indices for the current class within the task_graphs_with_labels list
            cls_task_indices = [i for i, (_, label) in enumerate(task_graphs_with_labels) if label.item() == cls.item()]

            if len(cls_task_indices) < args.num_shots + args.num_queries:
                 # Skip task if not enough samples for a class
                print(f"Skipping task {task}: Not enough samples ({len(cls_task_indices)}) for class {cls}. Need {args.num_shots + args.num_queries}.")
                skip_task = True
                break

            # Shuffle indices for the current class and split
            shuffled_cls_indices = np.random.permutation(cls_task_indices)
            support_indices.extend(shuffled_cls_indices[:args.num_shots])
            query_indices.extend(shuffled_cls_indices[args.num_shots:args.num_shots + args.num_queries])

        if skip_task:
            continue

        if not support_indices or not query_indices:
             # Skip task if no support or query samples could be gathered (should be caught by the previous check, but as a safeguard)
            print(f"Skipping task {task}: No support or query samples gathered after splitting")
            continue

        # Ensure indices are within the bounds of task_graphs
        support_graphs = [task_graphs[i] for i in support_indices]
        support_labels = mapped_labels[support_indices]
        query_graphs = [task_graphs[i] for i in query_indices]
        query_labels = mapped_labels[query_indices]

        # Perform fine-tuning and evaluation
        f1 = few_shot_finetune_and_evaluate(
            model,
            pooler,
            support_graphs,
            support_labels,
            query_graphs,
            query_labels,
            num_classes,
            args.lr_f,
            args.weight_decay_f,
            args.max_epoch_f,
            args.device,
            args.num_hidden,
        )
        acc_list.append(f1)

    final_acc = np.mean(acc_list)
    final_acc_std = np.std(acc_list)
    print(f"# Few-shot Acc ({args.num_ways}-way {args.num_shots}-shot): {final_acc:.4f}±{final_acc_std:.4f}")

Namespace(dataset='DD', encoder='gin', decoder='gin', seeds=[0], device=0, use_cfg=False, max_epoch=200, warmup_steps=-1, num_heads=4, num_out_heads=1, num_layers=2, num_hidden=512, residual=False, in_drop=0.2, attn_drop=0.1, norm=None, lr=0.0005, weight_decay=0.0005, negative_slope=0.2, activation='prelu', mask_rate=0.5, drop_edge_rate=0.0, replace_rate=0.0, loss_fn='sce', alpha_l=2, optimizer='adam', max_epoch_f=30, lr_f=0.001, weight_decay_f=0.0, linear_prob=False, load_model=True, save_model=False, logging=False, scheduler=False, concat_hidden=False, pooling='mean', deg4feat=False, batch_size=32, num_shots=10, num_ways=2, num_queries=15, num_tasks=100)
Use node label as node features
******** # Num Graphs: 1178, # Num Feat: 89, # Num Classes: 2 ********
Loaded pre-trained model weights from checkpoint.pt


Few-shot tasks: 100%|██████████| 100/100 [00:33<00:00,  2.96it/s]

# Few-shot Acc (2-way 10-shot): 0.5930±0.0885





In [86]:
if __name__ == "__main__":
    # Simulate command-line arguments using argparse.Namespace
    args = argparse.Namespace(
        dataset='PROTEINS',
        encoder='gin',
        decoder='gin',
        seeds=[0],
        device=0,
        use_cfg=False,
        max_epoch=200, # Default value from build_args
        warmup_steps=-1, # Default value from build_args
        num_heads=4, # Default value from build_args
        num_out_heads=1, # Default value from build_args
        num_layers=2, # Default value from build_args
        num_hidden=512, # Default value from build_args
        residual=False, # Default value from build_args
        in_drop=0.2, # Default value from build_args
        attn_drop=0.1, # Default value from build_args
        norm=None, # Default value from build_args
        lr=0.0005, # Default value from build_args
        weight_decay=5e-4, # Default value from build_args
        negative_slope=0.2, # Default value from build_args
        activation="prelu", # Default value from build_args
        mask_rate=0.5, # Default value from build_args
        drop_edge_rate=0.0, # Default value from build_args
        replace_rate=0.0, # Default value from build_args
        loss_fn="sce", # Default value from build_args
        alpha_l=2, # Default value from build_args
        optimizer="adam", # Default value from build_args
        max_epoch_f=50, # Default value from build_args
        lr_f=0.001, # Default value from build_args
        weight_decay_f=0.0, # Default value from build_args
        linear_prob=False, # Default value from build_args
        load_model=False, # Default value from build_args
        save_model=True, # Default value from build_args
        logging=False, # Default value from build_args
        scheduler=False, # Default value from build_args
        concat_hidden=False, # Default value from build_args
        pooling="mean", # Default value from build_args
        deg4feat=False, # Default value from build_args
        batch_size=32, # Default value from build_args
    )

    if args.use_cfg:
         # This part will still load configs.yml based on args.dataset
        args = load_best_configs(args, "configs.yml")
    print(args)
    main(args)

Namespace(dataset='PROTEINS', encoder='gin', decoder='gin', seeds=[0], device=0, use_cfg=False, max_epoch=200, warmup_steps=-1, num_heads=4, num_out_heads=1, num_layers=2, num_hidden=512, residual=False, in_drop=0.2, attn_drop=0.1, norm=None, lr=0.0005, weight_decay=0.0005, negative_slope=0.2, activation='prelu', mask_rate=0.5, drop_edge_rate=0.0, replace_rate=0.0, loss_fn='sce', alpha_l=2, optimizer='adam', max_epoch_f=50, lr_f=0.001, weight_decay_f=0.0, linear_prob=False, load_model=False, save_model=True, logging=False, scheduler=False, concat_hidden=False, pooling='mean', deg4feat=False, batch_size=32)
Use node label as node features
******** # Num Graphs: 1113, # Num Feat: 3, # Num Classes: 2 ********
####### Run 0 for seed 0


Epoch 199 | train_loss: 0.0880: 100%|██████████| 200/200 [01:30<00:00,  2.20it/s]


#Test_f1: 0.5930±0.0098
# final_acc: 0.5930±0.0000


In [87]:
if __name__ == "__main__":
    # Simulate command-line arguments using argparse.Namespace
    args = argparse.Namespace(
        dataset='PROTEINS',
        encoder='gin',
        decoder='gin',
        seeds=[0],
        device=0, # Use CPU if GPU is not available
        use_cfg=False,
        max_epoch=200, # Default value from build_args
        warmup_steps=-1, # Default value from build_args
        num_heads=4, # Default value from build_args
        num_out_heads=1, # Default value from build_args
        num_layers=2, # Default value from build_args
        num_hidden=512, # Default value from build_args
        residual=False, # Default value from build_args
        in_drop=0.2, # Default value from build_args
        attn_drop=0.1, # Default value from build_args
        norm=None, # Default value from build_args
        lr=0.0005, # Default value from build_args
        weight_decay=5e-4, # Default value from build_args
        negative_slope=0.2, # Default value from build_args
        activation="prelu", # Default value from build_args
        mask_rate=0.5, # Default value from build_args
        drop_edge_rate=0.0, # Default value from build_args
        replace_rate=0.0, # Default value from build_args
        loss_fn="sce", # Default value from build_fn_args
        alpha_l=2, # Default value from build_args
        optimizer="adam", # Default value from build_args
        max_epoch_f=30, # Default value from build_args
        lr_f=0.001, # Default value from build_args
        weight_decay_f=0.0, # Default value from build_args
        linear_prob=False, # Default value from build_args
        load_model=True, # Load the pre-trained model
        save_model=False, # Set to False to avoid saving during few-shot eval
        logging=False, # Default value from build_args
        scheduler=False, # Default value from build_args
        concat_hidden=False, # Default value from build_args
        pooling="mean", # Default value from build_args
        deg4feat=False, # Default value from build_args
        batch_size=32, # Default value from build_args
        num_shots=5, # Few-shot parameter: number of support samples per class
        num_ways=2, # Few-shot parameter: number of classes per task
        num_queries=15, # Few-shot parameter: number of query samples per class
        num_tasks=100 # Few-shot parameter: number of few-shot tasks to evaluate on
    )

    if args.use_cfg:
         # This part will still load configs.yml based on args.dataset
        args = load_best_configs(args, "configs.yml")
    print(args)

    # Load the dataset
    graphs, (num_features, num_classes) = load_graph_classification_dataset(args.dataset, deg4feat=args.deg4feat)
    args.num_features = num_features
    args.num_classes = num_classes # Add num_classes to args

    # Build the model
    model = build_model(args)
    model.to(args.device)

    # Load the pre-trained model weights if load_model is True
    if args.load_model:
        try:
            model.load_state_dict(torch.load("checkpoint.pt"))
            print("Loaded pre-trained model weights from checkpoint.pt")
        except FileNotFoundError:
            print("checkpoint.pt not found. Please make sure the pre-trained model is saved.")
            # Optionally, you can exit or train the model from scratch here
            # return

    # Initialize the pooler
    if args.pooling == "mean":
        pooler = AvgPooling()
    elif args.pooling == "max":
        pooler = MaxPooling()
    elif args.pooling == "sum":
        pooler = SumPooling()
    else:
        raise NotImplementedError

    # Run few-shot evaluation with fine-tuning
    acc_list = []
    for task in tqdm(range(args.num_tasks), desc="Few-shot tasks"):
        # Randomly sample classes
        selected_classes = np.random.choice(num_classes, args.num_ways, replace=False)
        task_graphs_with_labels = [(graph, label) for graph, label in graphs if label.item() in selected_classes]
        task_labels_tensor = torch.tensor([label.item() for _, label in task_graphs_with_labels])
        task_graphs = [graph for graph, _ in task_graphs_with_labels]

        # Create mapping for selected classes to 0-num_ways-1
        class_map = {cls.item(): i for i, cls in enumerate(selected_classes)}
        mapped_labels = torch.tensor([class_map[label.item()] for _, label in task_graphs_with_labels])

        # Split into support and query sets
        support_indices = []
        query_indices = []
        skip_task = False
        for cls in selected_classes:
            # Get indices for the current class within the task_graphs_with_labels list
            cls_task_indices = [i for i, (_, label) in enumerate(task_graphs_with_labels) if label.item() == cls.item()]

            if len(cls_task_indices) < args.num_shots + args.num_queries:
                 # Skip task if not enough samples for a class
                print(f"Skipping task {task}: Not enough samples ({len(cls_task_indices)}) for class {cls}. Need {args.num_shots + args.num_queries}.")
                skip_task = True
                break

            # Shuffle indices for the current class and split
            shuffled_cls_indices = np.random.permutation(cls_task_indices)
            support_indices.extend(shuffled_cls_indices[:args.num_shots])
            query_indices.extend(shuffled_cls_indices[args.num_shots:args.num_shots + args.num_queries])

        if skip_task:
            continue

        if not support_indices or not query_indices:
             # Skip task if no support or query samples could be gathered (should be caught by the previous check, but as a safeguard)
            print(f"Skipping task {task}: No support or query samples gathered after splitting")
            continue

        # Ensure indices are within the bounds of task_graphs
        support_graphs = [task_graphs[i] for i in support_indices]
        support_labels = mapped_labels[support_indices]
        query_graphs = [task_graphs[i] for i in query_indices]
        query_labels = mapped_labels[query_indices]

        # Perform fine-tuning and evaluation
        f1 = few_shot_finetune_and_evaluate(
            model,
            pooler,
            support_graphs,
            support_labels,
            query_graphs,
            query_labels,
            num_classes,
            args.lr_f,
            args.weight_decay_f,
            args.max_epoch_f,
            args.device,
            args.num_hidden,
        )
        acc_list.append(f1)

    final_acc = np.mean(acc_list)
    final_acc_std = np.std(acc_list)
    print(f"# Few-shot Acc ({args.num_ways}-way {args.num_shots}-shot): {final_acc:.4f}±{final_acc_std:.4f}")

Namespace(dataset='PROTEINS', encoder='gin', decoder='gin', seeds=[0], device=0, use_cfg=False, max_epoch=200, warmup_steps=-1, num_heads=4, num_out_heads=1, num_layers=2, num_hidden=512, residual=False, in_drop=0.2, attn_drop=0.1, norm=None, lr=0.0005, weight_decay=0.0005, negative_slope=0.2, activation='prelu', mask_rate=0.5, drop_edge_rate=0.0, replace_rate=0.0, loss_fn='sce', alpha_l=2, optimizer='adam', max_epoch_f=30, lr_f=0.001, weight_decay_f=0.0, linear_prob=False, load_model=True, save_model=False, logging=False, scheduler=False, concat_hidden=False, pooling='mean', deg4feat=False, batch_size=32, num_shots=5, num_ways=2, num_queries=15, num_tasks=100)
Use node label as node features
******** # Num Graphs: 1113, # Num Feat: 3, # Num Classes: 2 ********
Loaded pre-trained model weights from checkpoint.pt


Few-shot tasks: 100%|██████████| 100/100 [00:16<00:00,  6.06it/s]

# Few-shot Acc (2-way 5-shot): 0.5560±0.1029





In [88]:
if __name__ == "__main__":
    # Simulate command-line arguments using argparse.Namespace
    args = argparse.Namespace(
        dataset='PROTEINS',
        encoder='gin',
        decoder='gin',
        seeds=[0],
        device=0, # Use CPU if GPU is not available
        use_cfg=False,
        max_epoch=200, # Default value from build_args
        warmup_steps=-1, # Default value from build_args
        num_heads=4, # Default value from build_args
        num_out_heads=1, # Default value from build_args
        num_layers=2, # Default value from build_args
        num_hidden=512, # Default value from build_args
        residual=False, # Default value from build_args
        in_drop=0.2, # Default value from build_args
        attn_drop=0.1, # Default value from build_args
        norm=None, # Default value from build_args
        lr=0.0005, # Default value from build_args
        weight_decay=5e-4, # Default value from build_args
        negative_slope=0.2, # Default value from build_args
        activation="prelu", # Default value from build_args
        mask_rate=0.5, # Default value from build_args
        drop_edge_rate=0.0, # Default value from build_args
        replace_rate=0.0, # Default value from build_args
        loss_fn="sce", # Default value from build_fn_args
        alpha_l=2, # Default value from build_args
        optimizer="adam", # Default value from build_args
        max_epoch_f=30, # Default value from build_args
        lr_f=0.001, # Default value from build_args
        weight_decay_f=0.0, # Default value from build_args
        linear_prob=False, # Default value from build_args
        load_model=True, # Load the pre-trained model
        save_model=False, # Set to False to avoid saving during few-shot eval
        logging=False, # Default value from build_args
        scheduler=False, # Default value from build_args
        concat_hidden=False, # Default value from build_args
        pooling="mean", # Default value from build_args
        deg4feat=False, # Default value from build_args
        batch_size=32, # Default value from build_args
        num_shots=10, # Few-shot parameter: number of support samples per class
        num_ways=2, # Few-shot parameter: number of classes per task
        num_queries=15, # Few-shot parameter: number of query samples per class
        num_tasks=100 # Few-shot parameter: number of few-shot tasks to evaluate on
    )

    if args.use_cfg:
         # This part will still load configs.yml based on args.dataset
        args = load_best_configs(args, "configs.yml")
    print(args)

    # Load the dataset
    graphs, (num_features, num_classes) = load_graph_classification_dataset(args.dataset, deg4feat=args.deg4feat)
    args.num_features = num_features
    args.num_classes = num_classes # Add num_classes to args

    # Build the model
    model = build_model(args)
    model.to(args.device)

    # Load the pre-trained model weights if load_model is True
    if args.load_model:
        try:
            model.load_state_dict(torch.load("checkpoint.pt"))
            print("Loaded pre-trained model weights from checkpoint.pt")
        except FileNotFoundError:
            print("checkpoint.pt not found. Please make sure the pre-trained model is saved.")
            # Optionally, you can exit or train the model from scratch here
            # return

    # Initialize the pooler
    if args.pooling == "mean":
        pooler = AvgPooling()
    elif args.pooling == "max":
        pooler = MaxPooling()
    elif args.pooling == "sum":
        pooler = SumPooling()
    else:
        raise NotImplementedError

    # Run few-shot evaluation with fine-tuning
    acc_list = []
    for task in tqdm(range(args.num_tasks), desc="Few-shot tasks"):
        # Randomly sample classes
        selected_classes = np.random.choice(num_classes, args.num_ways, replace=False)
        task_graphs_with_labels = [(graph, label) for graph, label in graphs if label.item() in selected_classes]
        task_labels_tensor = torch.tensor([label.item() for _, label in task_graphs_with_labels])
        task_graphs = [graph for graph, _ in task_graphs_with_labels]

        # Create mapping for selected classes to 0-num_ways-1
        class_map = {cls.item(): i for i, cls in enumerate(selected_classes)}
        mapped_labels = torch.tensor([class_map[label.item()] for _, label in task_graphs_with_labels])

        # Split into support and query sets
        support_indices = []
        query_indices = []
        skip_task = False
        for cls in selected_classes:
            # Get indices for the current class within the task_graphs_with_labels list
            cls_task_indices = [i for i, (_, label) in enumerate(task_graphs_with_labels) if label.item() == cls.item()]

            if len(cls_task_indices) < args.num_shots + args.num_queries:
                 # Skip task if not enough samples for a class
                print(f"Skipping task {task}: Not enough samples ({len(cls_task_indices)}) for class {cls}. Need {args.num_shots + args.num_queries}.")
                skip_task = True
                break

            # Shuffle indices for the current class and split
            shuffled_cls_indices = np.random.permutation(cls_task_indices)
            support_indices.extend(shuffled_cls_indices[:args.num_shots])
            query_indices.extend(shuffled_cls_indices[args.num_shots:args.num_shots + args.num_queries])

        if skip_task:
            continue

        if not support_indices or not query_indices:
             # Skip task if no support or query samples could be gathered (should be caught by the previous check, but as a safeguard)
            print(f"Skipping task {task}: No support or query samples gathered after splitting")
            continue

        # Ensure indices are within the bounds of task_graphs
        support_graphs = [task_graphs[i] for i in support_indices]
        support_labels = mapped_labels[support_indices]
        query_graphs = [task_graphs[i] for i in query_indices]
        query_labels = mapped_labels[query_indices]

        # Perform fine-tuning and evaluation
        f1 = few_shot_finetune_and_evaluate(
            model,
            pooler,
            support_graphs,
            support_labels,
            query_graphs,
            query_labels,
            num_classes,
            args.lr_f,
            args.weight_decay_f,
            args.max_epoch_f,
            args.device,
            args.num_hidden,
        )
        acc_list.append(f1)

    final_acc = np.mean(acc_list)
    final_acc_std = np.std(acc_list)
    print(f"# Few-shot Acc ({args.num_ways}-way {args.num_shots}-shot): {final_acc:.4f}±{final_acc_std:.4f}")

Namespace(dataset='PROTEINS', encoder='gin', decoder='gin', seeds=[0], device=0, use_cfg=False, max_epoch=200, warmup_steps=-1, num_heads=4, num_out_heads=1, num_layers=2, num_hidden=512, residual=False, in_drop=0.2, attn_drop=0.1, norm=None, lr=0.0005, weight_decay=0.0005, negative_slope=0.2, activation='prelu', mask_rate=0.5, drop_edge_rate=0.0, replace_rate=0.0, loss_fn='sce', alpha_l=2, optimizer='adam', max_epoch_f=30, lr_f=0.001, weight_decay_f=0.0, linear_prob=False, load_model=True, save_model=False, logging=False, scheduler=False, concat_hidden=False, pooling='mean', deg4feat=False, batch_size=32, num_shots=10, num_ways=2, num_queries=15, num_tasks=100)
Use node label as node features
******** # Num Graphs: 1113, # Num Feat: 3, # Num Classes: 2 ********
Loaded pre-trained model weights from checkpoint.pt


Few-shot tasks: 100%|██████████| 100/100 [00:20<00:00,  4.98it/s]

# Few-shot Acc (2-way 10-shot): 0.5773±0.0960





In [89]:
if __name__ == "__main__":
    # Simulate command-line arguments using argparse.Namespace
    args = argparse.Namespace(
        dataset='OHSU',
        encoder='gin',
        decoder='gin',
        seeds=[0],
        device=0,
        use_cfg=False,
        max_epoch=200, # Default value from build_args
        warmup_steps=-1, # Default value from build_args
        num_heads=4, # Default value from build_args
        num_out_heads=1, # Default value from build_args
        num_layers=2, # Default value from build_args
        num_hidden=512, # Default value from build_args
        residual=False, # Default value from build_args
        in_drop=0.2, # Default value from build_args
        attn_drop=0.1, # Default value from build_args
        norm=None, # Default value from build_args
        lr=0.0005, # Default value from build_args
        weight_decay=5e-4, # Default value from build_args
        negative_slope=0.2, # Default value from build_args
        activation="prelu", # Default value from build_args
        mask_rate=0.5, # Default value from build_args
        drop_edge_rate=0.0, # Default value from build_args
        replace_rate=0.0, # Default value from build_args
        loss_fn="sce", # Default value from build_args
        alpha_l=2, # Default value from build_args
        optimizer="adam", # Default value from build_args
        max_epoch_f=50, # Default value from build_args
        lr_f=0.001, # Default value from build_args
        weight_decay_f=0.0, # Default value from build_args
        linear_prob=False, # Default value from build_args
        load_model=False, # Default value from build_args
        save_model=True, # Default value from build_args
        logging=False, # Default value from build_args
        scheduler=False, # Default value from build_args
        concat_hidden=False, # Default value from build_args
        pooling="mean", # Default value from build_args
        deg4feat=False, # Default value from build_args
        batch_size=32, # Default value from build_args
    )

    if args.use_cfg:
         # This part will still load configs.yml based on args.dataset
        args = load_best_configs(args, "configs.yml")
    print(args)
    main(args)

Namespace(dataset='OHSU', encoder='gin', decoder='gin', seeds=[0], device=0, use_cfg=False, max_epoch=200, warmup_steps=-1, num_heads=4, num_out_heads=1, num_layers=2, num_hidden=512, residual=False, in_drop=0.2, attn_drop=0.1, norm=None, lr=0.0005, weight_decay=0.0005, negative_slope=0.2, activation='prelu', mask_rate=0.5, drop_edge_rate=0.0, replace_rate=0.0, loss_fn='sce', alpha_l=2, optimizer='adam', max_epoch_f=50, lr_f=0.001, weight_decay_f=0.0, linear_prob=False, load_model=False, save_model=True, logging=False, scheduler=False, concat_hidden=False, pooling='mean', deg4feat=False, batch_size=32)
Downloading /root/.dgl/OHSU.zip from https://www.chrsmrrs.com/graphkerneldatasets/OHSU.zip...


/root/.dgl/OHSU.zip:   0%|          | 0.00/79.6k [00:00<?, ?B/s]

Extracting file to /root/.dgl/OHSU_f15c2a92
Use node label as node features
******** # Num Graphs: 79, # Num Feat: 190, # Num Classes: 2 ********
####### Run 0 for seed 0


Epoch 199 | train_loss: 0.6005: 100%|██████████| 200/200 [00:13<00:00, 15.38it/s]


#Test_f1: 0.5768±0.2143
# final_acc: 0.5768±0.0000


In [90]:
if __name__ == "__main__":
    # Simulate command-line arguments using argparse.Namespace
    args = argparse.Namespace(
        dataset='OHSU',
        encoder='gin',
        decoder='gin',
        seeds=[0],
        device=0, # Use CPU if GPU is not available
        use_cfg=False,
        max_epoch=200, # Default value from build_args
        warmup_steps=-1, # Default value from build_args
        num_heads=4, # Default value from build_args
        num_out_heads=1, # Default value from build_args
        num_layers=2, # Default value from build_args
        num_hidden=512, # Default value from build_args
        residual=False, # Default value from build_args
        in_drop=0.2, # Default value from build_args
        attn_drop=0.1, # Default value from build_args
        norm=None, # Default value from build_args
        lr=0.0005, # Default value from build_args
        weight_decay=5e-4, # Default value from build_args
        negative_slope=0.2, # Default value from build_args
        activation="prelu", # Default value from build_args
        mask_rate=0.5, # Default value from build_args
        drop_edge_rate=0.0, # Default value from build_args
        replace_rate=0.0, # Default value from build_args
        loss_fn="sce", # Default value from build_fn_args
        alpha_l=2, # Default value from build_args
        optimizer="adam", # Default value from build_args
        max_epoch_f=30, # Default value from build_args
        lr_f=0.001, # Default value from build_args
        weight_decay_f=0.0, # Default value from build_args
        linear_prob=False, # Default value from build_args
        load_model=True, # Load the pre-trained model
        save_model=False, # Set to False to avoid saving during few-shot eval
        logging=False, # Default value from build_args
        scheduler=False, # Default value from build_args
        concat_hidden=False, # Default value from build_args
        pooling="mean", # Default value from build_args
        deg4feat=False, # Default value from build_args
        batch_size=32, # Default value from build_args
        num_shots=5, # Few-shot parameter: number of support samples per class
        num_ways=2, # Few-shot parameter: number of classes per task
        num_queries=15, # Few-shot parameter: number of query samples per class
        num_tasks=100 # Few-shot parameter: number of few-shot tasks to evaluate on
    )

    if args.use_cfg:
         # This part will still load configs.yml based on args.dataset
        args = load_best_configs(args, "configs.yml")
    print(args)

    # Load the dataset
    graphs, (num_features, num_classes) = load_graph_classification_dataset(args.dataset, deg4feat=args.deg4feat)
    args.num_features = num_features
    args.num_classes = num_classes # Add num_classes to args

    # Build the model
    model = build_model(args)
    model.to(args.device)

    # Load the pre-trained model weights if load_model is True
    if args.load_model:
        try:
            model.load_state_dict(torch.load("checkpoint.pt"))
            print("Loaded pre-trained model weights from checkpoint.pt")
        except FileNotFoundError:
            print("checkpoint.pt not found. Please make sure the pre-trained model is saved.")
            # Optionally, you can exit or train the model from scratch here
            # return

    # Initialize the pooler
    if args.pooling == "mean":
        pooler = AvgPooling()
    elif args.pooling == "max":
        pooler = MaxPooling()
    elif args.pooling == "sum":
        pooler = SumPooling()
    else:
        raise NotImplementedError

    # Run few-shot evaluation with fine-tuning
    acc_list = []
    for task in tqdm(range(args.num_tasks), desc="Few-shot tasks"):
        # Randomly sample classes
        selected_classes = np.random.choice(num_classes, args.num_ways, replace=False)
        task_graphs_with_labels = [(graph, label) for graph, label in graphs if label.item() in selected_classes]
        task_labels_tensor = torch.tensor([label.item() for _, label in task_graphs_with_labels])
        task_graphs = [graph for graph, _ in task_graphs_with_labels]

        # Create mapping for selected classes to 0-num_ways-1
        class_map = {cls.item(): i for i, cls in enumerate(selected_classes)}
        mapped_labels = torch.tensor([class_map[label.item()] for _, label in task_graphs_with_labels])

        # Split into support and query sets
        support_indices = []
        query_indices = []
        skip_task = False
        for cls in selected_classes:
            # Get indices for the current class within the task_graphs_with_labels list
            cls_task_indices = [i for i, (_, label) in enumerate(task_graphs_with_labels) if label.item() == cls.item()]

            if len(cls_task_indices) < args.num_shots + args.num_queries:
                 # Skip task if not enough samples for a class
                print(f"Skipping task {task}: Not enough samples ({len(cls_task_indices)}) for class {cls}. Need {args.num_shots + args.num_queries}.")
                skip_task = True
                break

            # Shuffle indices for the current class and split
            shuffled_cls_indices = np.random.permutation(cls_task_indices)
            support_indices.extend(shuffled_cls_indices[:args.num_shots])
            query_indices.extend(shuffled_cls_indices[args.num_shots:args.num_shots + args.num_queries])

        if skip_task:
            continue

        if not support_indices or not query_indices:
             # Skip task if no support or query samples could be gathered (should be caught by the previous check, but as a safeguard)
            print(f"Skipping task {task}: No support or query samples gathered after splitting")
            continue

        # Ensure indices are within the bounds of task_graphs
        support_graphs = [task_graphs[i] for i in support_indices]
        support_labels = mapped_labels[support_indices]
        query_graphs = [task_graphs[i] for i in query_indices]
        query_labels = mapped_labels[query_indices]

        # Perform fine-tuning and evaluation
        f1 = few_shot_finetune_and_evaluate(
            model,
            pooler,
            support_graphs,
            support_labels,
            query_graphs,
            query_labels,
            num_classes,
            args.lr_f,
            args.weight_decay_f,
            args.max_epoch_f,
            args.device,
            args.num_hidden,
        )
        acc_list.append(f1)

    final_acc = np.mean(acc_list)
    final_acc_std = np.std(acc_list)
    print(f"# Few-shot Acc ({args.num_ways}-way {args.num_shots}-shot): {final_acc:.4f}±{final_acc_std:.4f}")

Namespace(dataset='OHSU', encoder='gin', decoder='gin', seeds=[0], device=0, use_cfg=False, max_epoch=200, warmup_steps=-1, num_heads=4, num_out_heads=1, num_layers=2, num_hidden=512, residual=False, in_drop=0.2, attn_drop=0.1, norm=None, lr=0.0005, weight_decay=0.0005, negative_slope=0.2, activation='prelu', mask_rate=0.5, drop_edge_rate=0.0, replace_rate=0.0, loss_fn='sce', alpha_l=2, optimizer='adam', max_epoch_f=30, lr_f=0.001, weight_decay_f=0.0, linear_prob=False, load_model=True, save_model=False, logging=False, scheduler=False, concat_hidden=False, pooling='mean', deg4feat=False, batch_size=32, num_shots=5, num_ways=2, num_queries=15, num_tasks=100)
Use node label as node features
******** # Num Graphs: 79, # Num Feat: 190, # Num Classes: 2 ********
Loaded pre-trained model weights from checkpoint.pt


Few-shot tasks: 100%|██████████| 100/100 [00:16<00:00,  6.18it/s]

# Few-shot Acc (2-way 5-shot): 0.5027±0.0839





In [91]:
if __name__ == "__main__":
    # Simulate command-line arguments using argparse.Namespace
    args = argparse.Namespace(
        dataset='OHSU',
        encoder='gin',
        decoder='gin',
        seeds=[0],
        device=0, # Use CPU if GPU is not available
        use_cfg=False,
        max_epoch=200, # Default value from build_args
        warmup_steps=-1, # Default value from build_args
        num_heads=4, # Default value from build_args
        num_out_heads=1, # Default value from build_args
        num_layers=2, # Default value from build_args
        num_hidden=512, # Default value from build_args
        residual=False, # Default value from build_args
        in_drop=0.2, # Default value from build_args
        attn_drop=0.1, # Default value from build_args
        norm=None, # Default value from build_args
        lr=0.0005, # Default value from build_args
        weight_decay=5e-4, # Default value from build_args
        negative_slope=0.2, # Default value from build_args
        activation="prelu", # Default value from build_args
        mask_rate=0.5, # Default value from build_args
        drop_edge_rate=0.0, # Default value from build_args
        replace_rate=0.0, # Default value from build_args
        loss_fn="sce", # Default value from build_fn_args
        alpha_l=2, # Default value from build_args
        optimizer="adam", # Default value from build_args
        max_epoch_f=30, # Default value from build_args
        lr_f=0.001, # Default value from build_args
        weight_decay_f=0.0, # Default value from build_args
        linear_prob=False, # Default value from build_args
        load_model=True, # Load the pre-trained model
        save_model=False, # Set to False to avoid saving during few-shot eval
        logging=False, # Default value from build_args
        scheduler=False, # Default value from build_args
        concat_hidden=False, # Default value from build_args
        pooling="mean", # Default value from build_args
        deg4feat=False, # Default value from build_args
        batch_size=32, # Default value from build_args
        num_shots=10, # Few-shot parameter: number of support samples per class
        num_ways=2, # Few-shot parameter: number of classes per task
        num_queries=15, # Few-shot parameter: number of query samples per class
        num_tasks=100 # Few-shot parameter: number of few-shot tasks to evaluate on
    )

    if args.use_cfg:
         # This part will still load configs.yml based on args.dataset
        args = load_best_configs(args, "configs.yml")
    print(args)

    # Load the dataset
    graphs, (num_features, num_classes) = load_graph_classification_dataset(args.dataset, deg4feat=args.deg4feat)
    args.num_features = num_features
    args.num_classes = num_classes # Add num_classes to args

    # Build the model
    model = build_model(args)
    model.to(args.device)

    # Load the pre-trained model weights if load_model is True
    if args.load_model:
        try:
            model.load_state_dict(torch.load("checkpoint.pt"))
            print("Loaded pre-trained model weights from checkpoint.pt")
        except FileNotFoundError:
            print("checkpoint.pt not found. Please make sure the pre-trained model is saved.")
            # Optionally, you can exit or train the model from scratch here
            # return

    # Initialize the pooler
    if args.pooling == "mean":
        pooler = AvgPooling()
    elif args.pooling == "max":
        pooler = MaxPooling()
    elif args.pooling == "sum":
        pooler = SumPooling()
    else:
        raise NotImplementedError

    # Run few-shot evaluation with fine-tuning
    acc_list = []
    for task in tqdm(range(args.num_tasks), desc="Few-shot tasks"):
        # Randomly sample classes
        selected_classes = np.random.choice(num_classes, args.num_ways, replace=False)
        task_graphs_with_labels = [(graph, label) for graph, label in graphs if label.item() in selected_classes]
        task_labels_tensor = torch.tensor([label.item() for _, label in task_graphs_with_labels])
        task_graphs = [graph for graph, _ in task_graphs_with_labels]

        # Create mapping for selected classes to 0-num_ways-1
        class_map = {cls.item(): i for i, cls in enumerate(selected_classes)}
        mapped_labels = torch.tensor([class_map[label.item()] for _, label in task_graphs_with_labels])

        # Split into support and query sets
        support_indices = []
        query_indices = []
        skip_task = False
        for cls in selected_classes:
            # Get indices for the current class within the task_graphs_with_labels list
            cls_task_indices = [i for i, (_, label) in enumerate(task_graphs_with_labels) if label.item() == cls.item()]

            if len(cls_task_indices) < args.num_shots + args.num_queries:
                 # Skip task if not enough samples for a class
                print(f"Skipping task {task}: Not enough samples ({len(cls_task_indices)}) for class {cls}. Need {args.num_shots + args.num_queries}.")
                skip_task = True
                break

            # Shuffle indices for the current class and split
            shuffled_cls_indices = np.random.permutation(cls_task_indices)
            support_indices.extend(shuffled_cls_indices[:args.num_shots])
            query_indices.extend(shuffled_cls_indices[args.num_shots:args.num_shots + args.num_queries])

        if skip_task:
            continue

        if not support_indices or not query_indices:
             # Skip task if no support or query samples could be gathered (should be caught by the previous check, but as a safeguard)
            print(f"Skipping task {task}: No support or query samples gathered after splitting")
            continue

        # Ensure indices are within the bounds of task_graphs
        support_graphs = [task_graphs[i] for i in support_indices]
        support_labels = mapped_labels[support_indices]
        query_graphs = [task_graphs[i] for i in query_indices]
        query_labels = mapped_labels[query_indices]

        # Perform fine-tuning and evaluation
        f1 = few_shot_finetune_and_evaluate(
            model,
            pooler,
            support_graphs,
            support_labels,
            query_graphs,
            query_labels,
            num_classes,
            args.lr_f,
            args.weight_decay_f,
            args.max_epoch_f,
            args.device,
            args.num_hidden,
        )
        acc_list.append(f1)

    final_acc = np.mean(acc_list)
    final_acc_std = np.std(acc_list)
    print(f"# Few-shot Acc ({args.num_ways}-way {args.num_shots}-shot): {final_acc:.4f}±{final_acc_std:.4f}")

Namespace(dataset='OHSU', encoder='gin', decoder='gin', seeds=[0], device=0, use_cfg=False, max_epoch=200, warmup_steps=-1, num_heads=4, num_out_heads=1, num_layers=2, num_hidden=512, residual=False, in_drop=0.2, attn_drop=0.1, norm=None, lr=0.0005, weight_decay=0.0005, negative_slope=0.2, activation='prelu', mask_rate=0.5, drop_edge_rate=0.0, replace_rate=0.0, loss_fn='sce', alpha_l=2, optimizer='adam', max_epoch_f=30, lr_f=0.001, weight_decay_f=0.0, linear_prob=False, load_model=True, save_model=False, logging=False, scheduler=False, concat_hidden=False, pooling='mean', deg4feat=False, batch_size=32, num_shots=10, num_ways=2, num_queries=15, num_tasks=100)
Use node label as node features
******** # Num Graphs: 79, # Num Feat: 190, # Num Classes: 2 ********
Loaded pre-trained model weights from checkpoint.pt


Few-shot tasks: 100%|██████████| 100/100 [00:20<00:00,  4.90it/s]

# Few-shot Acc (2-way 10-shot): 0.5100±0.0723





In [94]:
if __name__ == "__main__":
    # Simulate command-line arguments using argparse.Namespace
    args = argparse.Namespace(
        dataset='MSRC_21',
        encoder='gin',
        decoder='gin',
        seeds=[0],
        device=0,
        use_cfg=False,
        max_epoch=200, # Default value from build_args
        warmup_steps=-1, # Default value from build_args
        num_heads=4, # Default value from build_args
        num_out_heads=1, # Default value from build_args
        num_layers=2, # Default value from build_args
        num_hidden=512, # Default value from build_args
        residual=False, # Default value from build_args
        in_drop=0.2, # Default value from build_args
        attn_drop=0.1, # Default value from build_args
        norm=None, # Default value from build_args
        lr=0.0005, # Default value from build_args
        weight_decay=5e-4, # Default value from build_args
        negative_slope=0.2, # Default value from build_args
        activation="prelu", # Default value from build_args
        mask_rate=0.5, # Default value from build_args
        drop_edge_rate=0.0, # Default value from build_args
        replace_rate=0.0, # Default value from build_args
        loss_fn="sce", # Default value from build_args
        alpha_l=2, # Default value from build_args
        optimizer="adam", # Default value from build_args
        max_epoch_f=50, # Default value from build_args
        lr_f=0.001, # Default value from build_args
        weight_decay_f=0.0, # Default value from build_args
        linear_prob=False, # Default value from build_args
        load_model=False, # Default value from build_args
        save_model=True, # Default value from build_args
        logging=False, # Default value from build_args
        scheduler=False, # Default value from build_args
        concat_hidden=False, # Default value from build_args
        pooling="mean", # Default value from build_args
        deg4feat=False, # Default value from build_args
        batch_size=32, # Default value from build_args
    )

    if args.use_cfg:
         # This part will still load configs.yml based on args.dataset
        args = load_best_configs(args, "configs.yml")
    print(args)
    main(args)

Namespace(dataset='MSRC_21', encoder='gin', decoder='gin', seeds=[0], device=0, use_cfg=False, max_epoch=200, warmup_steps=-1, num_heads=4, num_out_heads=1, num_layers=2, num_hidden=512, residual=False, in_drop=0.2, attn_drop=0.1, norm=None, lr=0.0005, weight_decay=0.0005, negative_slope=0.2, activation='prelu', mask_rate=0.5, drop_edge_rate=0.0, replace_rate=0.0, loss_fn='sce', alpha_l=2, optimizer='adam', max_epoch_f=50, lr_f=0.001, weight_decay_f=0.0, linear_prob=False, load_model=False, save_model=True, logging=False, scheduler=False, concat_hidden=False, pooling='mean', deg4feat=False, batch_size=32)
Downloading /root/.dgl/MSRC_21.zip from https://www.chrsmrrs.com/graphkerneldatasets/MSRC_21.zip...


/root/.dgl/MSRC_21.zip:   0%|          | 0.00/517k [00:00<?, ?B/s]

Extracting file to /root/.dgl/MSRC_21_1fc24118
Use node label as node features
******** # Num Graphs: 563, # Num Feat: 24, # Num Classes: 20 ********
####### Run 0 for seed 0


Epoch 199 | train_loss: 0.1174: 100%|██████████| 200/200 [00:46<00:00,  4.33it/s]


#Test_f1: 0.9129±0.0335
# final_acc: 0.9129±0.0000


In [96]:
if __name__ == "__main__":
    # Simulate command-line arguments using argparse.Namespace
    args = argparse.Namespace(
        dataset='MSRC_21',
        encoder='gin',
        decoder='gin',
        seeds=[0],
        device=0, # Use CPU if GPU is not available
        use_cfg=False,
        max_epoch=200, # Default value from build_args
        warmup_steps=-1, # Default value from build_args
        num_heads=4, # Default value from build_args
        num_out_heads=1, # Default value from build_args
        num_layers=2, # Default value from build_args
        num_hidden=512, # Default value from build_args
        residual=False, # Default value from build_args
        in_drop=0.2, # Default value from build_args
        attn_drop=0.1, # Default value from build_args
        norm=None, # Default value from build_args
        lr=0.0005, # Default value from build_args
        weight_decay=5e-4, # Default value from build_args
        negative_slope=0.2, # Default value from build_args
        activation="prelu", # Default value from build_args
        mask_rate=0.5, # Default value from build_args
        drop_edge_rate=0.0, # Default value from build_args
        replace_rate=0.0, # Default value from build_args
        loss_fn="sce", # Default value from build_fn_args
        alpha_l=2, # Default value from build_args
        optimizer="adam", # Default value from build_args
        max_epoch_f=30, # Default value from build_args
        lr_f=0.001, # Default value from build_args
        weight_decay_f=0.0, # Default value from build_args
        linear_prob=False, # Default value from build_args
        load_model=True, # Load the pre-trained model
        save_model=False, # Set to False to avoid saving during few-shot eval
        logging=False, # Default value from build_args
        scheduler=False, # Default value from build_args
        concat_hidden=False, # Default value from build_args
        pooling="mean", # Default value from build_args
        deg4feat=False, # Default value from build_args
        batch_size=32, # Default value from build_args
        num_shots=5, # Few-shot parameter: number of support samples per class
        num_ways=3, # Few-shot parameter: number of classes per task
        num_queries=15, # Few-shot parameter: number of query samples per class
        num_tasks=100 # Few-shot parameter: number of few-shot tasks to evaluate on
    )

    if args.use_cfg:
         # This part will still load configs.yml based on args.dataset
        args = load_best_configs(args, "configs.yml")
    print(args)

    # Load the dataset
    graphs, (num_features, num_classes) = load_graph_classification_dataset(args.dataset, deg4feat=args.deg4feat)
    args.num_features = num_features
    args.num_classes = num_classes # Add num_classes to args

    # Build the model
    model = build_model(args)
    model.to(args.device)

    # Load the pre-trained model weights if load_model is True
    if args.load_model:
        try:
            model.load_state_dict(torch.load("checkpoint.pt"))
            print("Loaded pre-trained model weights from checkpoint.pt")
        except FileNotFoundError:
            print("checkpoint.pt not found. Please make sure the pre-trained model is saved.")
            # Optionally, you can exit or train the model from scratch here
            # return

    # Initialize the pooler
    if args.pooling == "mean":
        pooler = AvgPooling()
    elif args.pooling == "max":
        pooler = MaxPooling()
    elif args.pooling == "sum":
        pooler = SumPooling()
    else:
        raise NotImplementedError

    # Run few-shot evaluation with fine-tuning
    acc_list = []
    for task in tqdm(range(args.num_tasks), desc="Few-shot tasks"):
        # Randomly sample classes
        selected_classes = np.random.choice(num_classes, args.num_ways, replace=False)
        task_graphs_with_labels = [(graph, label) for graph, label in graphs if label.item() in selected_classes]
        task_labels_tensor = torch.tensor([label.item() for _, label in task_graphs_with_labels])
        task_graphs = [graph for graph, _ in task_graphs_with_labels]

        # Create mapping for selected classes to 0-num_ways-1
        class_map = {cls.item(): i for i, cls in enumerate(selected_classes)}
        mapped_labels = torch.tensor([class_map[label.item()] for _, label in task_graphs_with_labels])

        # Split into support and query sets
        support_indices = []
        query_indices = []
        skip_task = False
        for cls in selected_classes:
            # Get indices for the current class within the task_graphs_with_labels list
            cls_task_indices = [i for i, (_, label) in enumerate(task_graphs_with_labels) if label.item() == cls.item()]

            if len(cls_task_indices) < args.num_shots + args.num_queries:
                 # Skip task if not enough samples for a class
                print(f"Skipping task {task}: Not enough samples ({len(cls_task_indices)}) for class {cls}. Need {args.num_shots + args.num_queries}.")
                skip_task = True
                break

            # Shuffle indices for the current class and split
            shuffled_cls_indices = np.random.permutation(cls_task_indices)
            support_indices.extend(shuffled_cls_indices[:args.num_shots])
            query_indices.extend(shuffled_cls_indices[args.num_shots:args.num_shots + args.num_queries])

        if skip_task:
            continue

        if not support_indices or not query_indices:
             # Skip task if no support or query samples could be gathered (should be caught by the previous check, but as a safeguard)
            print(f"Skipping task {task}: No support or query samples gathered after splitting")
            continue

        # Ensure indices are within the bounds of task_graphs
        support_graphs = [task_graphs[i] for i in support_indices]
        support_labels = mapped_labels[support_indices]
        query_graphs = [task_graphs[i] for i in query_indices]
        query_labels = mapped_labels[query_indices]

        # Perform fine-tuning and evaluation
        f1 = few_shot_finetune_and_evaluate(
            model,
            pooler,
            support_graphs,
            support_labels,
            query_graphs,
            query_labels,
            num_classes,
            args.lr_f,
            args.weight_decay_f,
            args.max_epoch_f,
            args.device,
            args.num_hidden,
        )
        acc_list.append(f1)

    final_acc = np.mean(acc_list)
    final_acc_std = np.std(acc_list)
    print(f"# Few-shot Acc ({args.num_ways}-way {args.num_shots}-shot): {final_acc:.4f}±{final_acc_std:.4f}")

Namespace(dataset='MSRC_21', encoder='gin', decoder='gin', seeds=[0], device=0, use_cfg=False, max_epoch=200, warmup_steps=-1, num_heads=4, num_out_heads=1, num_layers=2, num_hidden=512, residual=False, in_drop=0.2, attn_drop=0.1, norm=None, lr=0.0005, weight_decay=0.0005, negative_slope=0.2, activation='prelu', mask_rate=0.5, drop_edge_rate=0.0, replace_rate=0.0, loss_fn='sce', alpha_l=2, optimizer='adam', max_epoch_f=30, lr_f=0.001, weight_decay_f=0.0, linear_prob=False, load_model=True, save_model=False, logging=False, scheduler=False, concat_hidden=False, pooling='mean', deg4feat=False, batch_size=32, num_shots=5, num_ways=3, num_queries=15, num_tasks=100)
Use node label as node features
******** # Num Graphs: 563, # Num Feat: 24, # Num Classes: 20 ********
Loaded pre-trained model weights from checkpoint.pt


Few-shot tasks:  17%|█▋        | 17/100 [00:03<00:14,  5.58it/s]

Skipping task 15: Not enough samples (10) for class 19. Need 20.


Few-shot tasks:  21%|██        | 21/100 [00:04<00:16,  4.74it/s]

Skipping task 21: Not enough samples (10) for class 19. Need 20.


Few-shot tasks:  26%|██▌       | 26/100 [00:06<00:21,  3.45it/s]

Skipping task 26: Not enough samples (10) for class 19. Need 20.


Few-shot tasks:  28%|██▊       | 28/100 [00:06<00:20,  3.57it/s]

Skipping task 28: Not enough samples (10) for class 19. Need 20.


Few-shot tasks:  36%|███▌      | 36/100 [00:09<00:23,  2.67it/s]

Skipping task 36: Not enough samples (10) for class 19. Need 20.


Few-shot tasks:  44%|████▍     | 44/100 [00:12<00:17,  3.23it/s]

Skipping task 44: Not enough samples (10) for class 19. Need 20.


Few-shot tasks:  57%|█████▋    | 57/100 [00:16<00:16,  2.61it/s]

Skipping task 57: Not enough samples (10) for class 19. Need 20.


Few-shot tasks:  64%|██████▍   | 64/100 [00:18<00:08,  4.31it/s]

Skipping task 64: Not enough samples (10) for class 19. Need 20.


Few-shot tasks:  66%|██████▌   | 66/100 [00:18<00:05,  5.69it/s]

Skipping task 66: Not enough samples (10) for class 19. Need 20.
Skipping task 67: Not enough samples (10) for class 19. Need 20.


Few-shot tasks:  71%|███████   | 71/100 [00:19<00:04,  5.97it/s]

Skipping task 71: Not enough samples (10) for class 19. Need 20.


Few-shot tasks:  73%|███████▎  | 73/100 [00:19<00:04,  6.43it/s]

Skipping task 73: Not enough samples (10) for class 19. Need 20.


Few-shot tasks:  75%|███████▌  | 75/100 [00:19<00:04,  5.83it/s]

Skipping task 75: Not enough samples (10) for class 19. Need 20.


Few-shot tasks:  90%|█████████ | 90/100 [00:23<00:02,  4.60it/s]

Skipping task 90: Not enough samples (10) for class 19. Need 20.


Few-shot tasks: 100%|██████████| 100/100 [00:25<00:00,  3.92it/s]

# Few-shot Acc (3-way 5-shot): 0.9282±0.0828





In [97]:
if __name__ == "__main__":
    # Simulate command-line arguments using argparse.Namespace
    args = argparse.Namespace(
        dataset='MSRC_21',
        encoder='gin',
        decoder='gin',
        seeds=[0],
        device=0, # Use CPU if GPU is not available
        use_cfg=False,
        max_epoch=200, # Default value from build_args
        warmup_steps=-1, # Default value from build_args
        num_heads=4, # Default value from build_args
        num_out_heads=1, # Default value from build_args
        num_layers=2, # Default value from build_args
        num_hidden=512, # Default value from build_args
        residual=False, # Default value from build_args
        in_drop=0.2, # Default value from build_args
        attn_drop=0.1, # Default value from build_args
        norm=None, # Default value from build_args
        lr=0.0005, # Default value from build_args
        weight_decay=5e-4, # Default value from build_args
        negative_slope=0.2, # Default value from build_args
        activation="prelu", # Default value from build_args
        mask_rate=0.5, # Default value from build_args
        drop_edge_rate=0.0, # Default value from build_args
        replace_rate=0.0, # Default value from build_args
        loss_fn="sce", # Default value from build_fn_args
        alpha_l=2, # Default value from build_args
        optimizer="adam", # Default value from build_args
        max_epoch_f=30, # Default value from build_args
        lr_f=0.001, # Default value from build_args
        weight_decay_f=0.0, # Default value from build_args
        linear_prob=False, # Default value from build_args
        load_model=True, # Load the pre-trained model
        save_model=False, # Set to False to avoid saving during few-shot eval
        logging=False, # Default value from build_args
        scheduler=False, # Default value from build_args
        concat_hidden=False, # Default value from build_args
        pooling="mean", # Default value from build_args
        deg4feat=False, # Default value from build_args
        batch_size=32, # Default value from build_args
        num_shots=5, # Few-shot parameter: number of support samples per class
        num_ways=4, # Few-shot parameter: number of classes per task
        num_queries=15, # Few-shot parameter: number of query samples per class
        num_tasks=100 # Few-shot parameter: number of few-shot tasks to evaluate on
    )

    if args.use_cfg:
         # This part will still load configs.yml based on args.dataset
        args = load_best_configs(args, "configs.yml")
    print(args)

    # Load the dataset
    graphs, (num_features, num_classes) = load_graph_classification_dataset(args.dataset, deg4feat=args.deg4feat)
    args.num_features = num_features
    args.num_classes = num_classes # Add num_classes to args

    # Build the model
    model = build_model(args)
    model.to(args.device)

    # Load the pre-trained model weights if load_model is True
    if args.load_model:
        try:
            model.load_state_dict(torch.load("checkpoint.pt"))
            print("Loaded pre-trained model weights from checkpoint.pt")
        except FileNotFoundError:
            print("checkpoint.pt not found. Please make sure the pre-trained model is saved.")
            # Optionally, you can exit or train the model from scratch here
            # return

    # Initialize the pooler
    if args.pooling == "mean":
        pooler = AvgPooling()
    elif args.pooling == "max":
        pooler = MaxPooling()
    elif args.pooling == "sum":
        pooler = SumPooling()
    else:
        raise NotImplementedError

    # Run few-shot evaluation with fine-tuning
    acc_list = []
    for task in tqdm(range(args.num_tasks), desc="Few-shot tasks"):
        # Randomly sample classes
        selected_classes = np.random.choice(num_classes, args.num_ways, replace=False)
        task_graphs_with_labels = [(graph, label) for graph, label in graphs if label.item() in selected_classes]
        task_labels_tensor = torch.tensor([label.item() for _, label in task_graphs_with_labels])
        task_graphs = [graph for graph, _ in task_graphs_with_labels]

        # Create mapping for selected classes to 0-num_ways-1
        class_map = {cls.item(): i for i, cls in enumerate(selected_classes)}
        mapped_labels = torch.tensor([class_map[label.item()] for _, label in task_graphs_with_labels])

        # Split into support and query sets
        support_indices = []
        query_indices = []
        skip_task = False
        for cls in selected_classes:
            # Get indices for the current class within the task_graphs_with_labels list
            cls_task_indices = [i for i, (_, label) in enumerate(task_graphs_with_labels) if label.item() == cls.item()]

            if len(cls_task_indices) < args.num_shots + args.num_queries:
                 # Skip task if not enough samples for a class
                print(f"Skipping task {task}: Not enough samples ({len(cls_task_indices)}) for class {cls}. Need {args.num_shots + args.num_queries}.")
                skip_task = True
                break

            # Shuffle indices for the current class and split
            shuffled_cls_indices = np.random.permutation(cls_task_indices)
            support_indices.extend(shuffled_cls_indices[:args.num_shots])
            query_indices.extend(shuffled_cls_indices[args.num_shots:args.num_shots + args.num_queries])

        if skip_task:
            continue

        if not support_indices or not query_indices:
             # Skip task if no support or query samples could be gathered (should be caught by the previous check, but as a safeguard)
            print(f"Skipping task {task}: No support or query samples gathered after splitting")
            continue

        # Ensure indices are within the bounds of task_graphs
        support_graphs = [task_graphs[i] for i in support_indices]
        support_labels = mapped_labels[support_indices]
        query_graphs = [task_graphs[i] for i in query_indices]
        query_labels = mapped_labels[query_indices]

        # Perform fine-tuning and evaluation
        f1 = few_shot_finetune_and_evaluate(
            model,
            pooler,
            support_graphs,
            support_labels,
            query_graphs,
            query_labels,
            num_classes,
            args.lr_f,
            args.weight_decay_f,
            args.max_epoch_f,
            args.device,
            args.num_hidden,
        )
        acc_list.append(f1)

    final_acc = np.mean(acc_list)
    final_acc_std = np.std(acc_list)
    print(f"# Few-shot Acc ({args.num_ways}-way {args.num_shots}-shot): {final_acc:.4f}±{final_acc_std:.4f}")

Namespace(dataset='MSRC_21', encoder='gin', decoder='gin', seeds=[0], device=0, use_cfg=False, max_epoch=200, warmup_steps=-1, num_heads=4, num_out_heads=1, num_layers=2, num_hidden=512, residual=False, in_drop=0.2, attn_drop=0.1, norm=None, lr=0.0005, weight_decay=0.0005, negative_slope=0.2, activation='prelu', mask_rate=0.5, drop_edge_rate=0.0, replace_rate=0.0, loss_fn='sce', alpha_l=2, optimizer='adam', max_epoch_f=30, lr_f=0.001, weight_decay_f=0.0, linear_prob=False, load_model=True, save_model=False, logging=False, scheduler=False, concat_hidden=False, pooling='mean', deg4feat=False, batch_size=32, num_shots=5, num_ways=4, num_queries=15, num_tasks=100)
Use node label as node features
******** # Num Graphs: 563, # Num Feat: 24, # Num Classes: 20 ********
Loaded pre-trained model weights from checkpoint.pt


Few-shot tasks:   4%|▍         | 4/100 [00:01<00:23,  4.16it/s]

Skipping task 4: Not enough samples (10) for class 19. Need 20.


Few-shot tasks:   8%|▊         | 8/100 [00:01<00:18,  5.06it/s]

Skipping task 8: Not enough samples (10) for class 19. Need 20.
Skipping task 9: Not enough samples (10) for class 19. Need 20.


Few-shot tasks:  13%|█▎        | 13/100 [00:02<00:14,  6.18it/s]

Skipping task 13: Not enough samples (10) for class 19. Need 20.
Skipping task 14: Not enough samples (10) for class 19. Need 20.


Few-shot tasks:  17%|█▋        | 17/100 [00:02<00:11,  7.26it/s]

Skipping task 17: Not enough samples (10) for class 19. Need 20.


Few-shot tasks:  22%|██▏       | 22/100 [00:03<00:13,  5.81it/s]

Skipping task 22: Not enough samples (10) for class 19. Need 20.


Few-shot tasks:  25%|██▌       | 25/100 [00:04<00:11,  6.26it/s]

Skipping task 25: Not enough samples (10) for class 19. Need 20.


Few-shot tasks:  33%|███▎      | 33/100 [00:05<00:13,  5.02it/s]

Skipping task 33: Not enough samples (10) for class 19. Need 20.


Few-shot tasks:  35%|███▌      | 35/100 [00:05<00:10,  6.08it/s]

Skipping task 35: Not enough samples (10) for class 19. Need 20.
Skipping task 36: Not enough samples (10) for class 19. Need 20.


Few-shot tasks:  38%|███▊      | 38/100 [00:06<00:07,  8.21it/s]

Skipping task 38: Not enough samples (10) for class 19. Need 20.


Few-shot tasks:  42%|████▏     | 42/100 [00:06<00:10,  5.50it/s]

Skipping task 42: Not enough samples (10) for class 19. Need 20.


Few-shot tasks:  46%|████▌     | 46/100 [00:07<00:14,  3.78it/s]

Skipping task 46: Not enough samples (10) for class 19. Need 20.
Skipping task 47: Not enough samples (10) for class 19. Need 20.


Few-shot tasks:  49%|████▉     | 49/100 [00:08<00:09,  5.30it/s]

Skipping task 49: Not enough samples (10) for class 19. Need 20.


Few-shot tasks:  66%|██████▌   | 66/100 [00:11<00:07,  4.67it/s]

Skipping task 66: Not enough samples (10) for class 19. Need 20.


Few-shot tasks:  68%|██████▊   | 68/100 [00:12<00:05,  5.89it/s]

Skipping task 68: Not enough samples (10) for class 19. Need 20.


Few-shot tasks:  73%|███████▎  | 73/100 [00:13<00:07,  3.58it/s]

Skipping task 73: Not enough samples (10) for class 19. Need 20.


Few-shot tasks:  90%|█████████ | 90/100 [00:16<00:01,  7.36it/s]

Skipping task 88: Not enough samples (10) for class 19. Need 20.


Few-shot tasks:  98%|█████████▊| 98/100 [00:17<00:00, 11.76it/s]

Skipping task 94: Not enough samples (10) for class 19. Need 20.
Skipping task 95: Not enough samples (10) for class 19. Need 20.
Skipping task 96: Not enough samples (10) for class 19. Need 20.


Few-shot tasks: 100%|██████████| 100/100 [00:17<00:00,  5.64it/s]

# Few-shot Acc (4-way 5-shot): 0.9175±0.0601





In [98]:
if __name__ == "__main__":
    # Simulate command-line arguments using argparse.Namespace
    args = argparse.Namespace(
        dataset='MSRC_21',
        encoder='gin',
        decoder='gin',
        seeds=[0],
        device=0, # Use CPU if GPU is not available
        use_cfg=False,
        max_epoch=200, # Default value from build_args
        warmup_steps=-1, # Default value from build_args
        num_heads=4, # Default value from build_args
        num_out_heads=1, # Default value from build_args
        num_layers=2, # Default value from build_args
        num_hidden=512, # Default value from build_args
        residual=False, # Default value from build_args
        in_drop=0.2, # Default value from build_args
        attn_drop=0.1, # Default value from build_args
        norm=None, # Default value from build_args
        lr=0.0005, # Default value from build_args
        weight_decay=5e-4, # Default value from build_args
        negative_slope=0.2, # Default value from build_args
        activation="prelu", # Default value from build_args
        mask_rate=0.5, # Default value from build_args
        drop_edge_rate=0.0, # Default value from build_args
        replace_rate=0.0, # Default value from build_args
        loss_fn="sce", # Default value from build_fn_args
        alpha_l=2, # Default value from build_args
        optimizer="adam", # Default value from build_args
        max_epoch_f=30, # Default value from build_args
        lr_f=0.001, # Default value from build_args
        weight_decay_f=0.0, # Default value from build_args
        linear_prob=False, # Default value from build_args
        load_model=True, # Load the pre-trained model
        save_model=False, # Set to False to avoid saving during few-shot eval
        logging=False, # Default value from build_args
        scheduler=False, # Default value from build_args
        concat_hidden=False, # Default value from build_args
        pooling="mean", # Default value from build_args
        deg4feat=False, # Default value from build_args
        batch_size=32, # Default value from build_args
        num_shots=10, # Few-shot parameter: number of support samples per class
        num_ways=3, # Few-shot parameter: number of classes per task
        num_queries=15, # Few-shot parameter: number of query samples per class
        num_tasks=100 # Few-shot parameter: number of few-shot tasks to evaluate on
    )

    if args.use_cfg:
         # This part will still load configs.yml based on args.dataset
        args = load_best_configs(args, "configs.yml")
    print(args)

    # Load the dataset
    graphs, (num_features, num_classes) = load_graph_classification_dataset(args.dataset, deg4feat=args.deg4feat)
    args.num_features = num_features
    args.num_classes = num_classes # Add num_classes to args

    # Build the model
    model = build_model(args)
    model.to(args.device)

    # Load the pre-trained model weights if load_model is True
    if args.load_model:
        try:
            model.load_state_dict(torch.load("checkpoint.pt"))
            print("Loaded pre-trained model weights from checkpoint.pt")
        except FileNotFoundError:
            print("checkpoint.pt not found. Please make sure the pre-trained model is saved.")
            # Optionally, you can exit or train the model from scratch here
            # return

    # Initialize the pooler
    if args.pooling == "mean":
        pooler = AvgPooling()
    elif args.pooling == "max":
        pooler = MaxPooling()
    elif args.pooling == "sum":
        pooler = SumPooling()
    else:
        raise NotImplementedError

    # Run few-shot evaluation with fine-tuning
    acc_list = []
    for task in tqdm(range(args.num_tasks), desc="Few-shot tasks"):
        # Randomly sample classes
        selected_classes = np.random.choice(num_classes, args.num_ways, replace=False)
        task_graphs_with_labels = [(graph, label) for graph, label in graphs if label.item() in selected_classes]
        task_labels_tensor = torch.tensor([label.item() for _, label in task_graphs_with_labels])
        task_graphs = [graph for graph, _ in task_graphs_with_labels]

        # Create mapping for selected classes to 0-num_ways-1
        class_map = {cls.item(): i for i, cls in enumerate(selected_classes)}
        mapped_labels = torch.tensor([class_map[label.item()] for _, label in task_graphs_with_labels])

        # Split into support and query sets
        support_indices = []
        query_indices = []
        skip_task = False
        for cls in selected_classes:
            # Get indices for the current class within the task_graphs_with_labels list
            cls_task_indices = [i for i, (_, label) in enumerate(task_graphs_with_labels) if label.item() == cls.item()]

            if len(cls_task_indices) < args.num_shots + args.num_queries:
                 # Skip task if not enough samples for a class
                print(f"Skipping task {task}: Not enough samples ({len(cls_task_indices)}) for class {cls}. Need {args.num_shots + args.num_queries}.")
                skip_task = True
                break

            # Shuffle indices for the current class and split
            shuffled_cls_indices = np.random.permutation(cls_task_indices)
            support_indices.extend(shuffled_cls_indices[:args.num_shots])
            query_indices.extend(shuffled_cls_indices[args.num_shots:args.num_shots + args.num_queries])

        if skip_task:
            continue

        if not support_indices or not query_indices:
             # Skip task if no support or query samples could be gathered (should be caught by the previous check, but as a safeguard)
            print(f"Skipping task {task}: No support or query samples gathered after splitting")
            continue

        # Ensure indices are within the bounds of task_graphs
        support_graphs = [task_graphs[i] for i in support_indices]
        support_labels = mapped_labels[support_indices]
        query_graphs = [task_graphs[i] for i in query_indices]
        query_labels = mapped_labels[query_indices]

        # Perform fine-tuning and evaluation
        f1 = few_shot_finetune_and_evaluate(
            model,
            pooler,
            support_graphs,
            support_labels,
            query_graphs,
            query_labels,
            num_classes,
            args.lr_f,
            args.weight_decay_f,
            args.max_epoch_f,
            args.device,
            args.num_hidden,
        )
        acc_list.append(f1)

    final_acc = np.mean(acc_list)
    final_acc_std = np.std(acc_list)
    print(f"# Few-shot Acc ({args.num_ways}-way {args.num_shots}-shot): {final_acc:.4f}±{final_acc_std:.4f}")

Namespace(dataset='MSRC_21', encoder='gin', decoder='gin', seeds=[0], device=0, use_cfg=False, max_epoch=200, warmup_steps=-1, num_heads=4, num_out_heads=1, num_layers=2, num_hidden=512, residual=False, in_drop=0.2, attn_drop=0.1, norm=None, lr=0.0005, weight_decay=0.0005, negative_slope=0.2, activation='prelu', mask_rate=0.5, drop_edge_rate=0.0, replace_rate=0.0, loss_fn='sce', alpha_l=2, optimizer='adam', max_epoch_f=30, lr_f=0.001, weight_decay_f=0.0, linear_prob=False, load_model=True, save_model=False, logging=False, scheduler=False, concat_hidden=False, pooling='mean', deg4feat=False, batch_size=32, num_shots=10, num_ways=3, num_queries=15, num_tasks=100)
Use node label as node features
******** # Num Graphs: 563, # Num Feat: 24, # Num Classes: 20 ********
Loaded pre-trained model weights from checkpoint.pt


Few-shot tasks:   6%|▌         | 6/100 [00:02<00:29,  3.22it/s]

Skipping task 6: Not enough samples (24) for class 17. Need 25.


Few-shot tasks:   9%|▉         | 9/100 [00:02<00:25,  3.57it/s]

Skipping task 9: Not enough samples (10) for class 19. Need 25.


Few-shot tasks:  14%|█▍        | 14/100 [00:04<00:28,  3.01it/s]

Skipping task 14: Not enough samples (24) for class 17. Need 25.


Few-shot tasks:  17%|█▋        | 17/100 [00:05<00:32,  2.57it/s]

Skipping task 17: Not enough samples (10) for class 19. Need 25.


Few-shot tasks:  19%|█▉        | 19/100 [00:06<00:32,  2.52it/s]

Skipping task 19: Not enough samples (24) for class 17. Need 25.
Skipping task 20: Not enough samples (24) for class 14. Need 25.
Skipping task 21: Not enough samples (24) for class 14. Need 25.
Skipping task 22: Not enough samples (10) for class 19. Need 25.


Few-shot tasks:  29%|██▉       | 29/100 [00:08<00:22,  3.13it/s]

Skipping task 29: Not enough samples (24) for class 14. Need 25.
Skipping task 30: Not enough samples (10) for class 19. Need 25.
Skipping task 31: Not enough samples (10) for class 19. Need 25.


Few-shot tasks:  34%|███▍      | 34/100 [00:09<00:16,  3.97it/s]

Skipping task 34: Not enough samples (24) for class 17. Need 25.


Few-shot tasks:  37%|███▋      | 37/100 [00:10<00:21,  2.88it/s]

Skipping task 37: Not enough samples (24) for class 17. Need 25.
Skipping task 38: Not enough samples (24) for class 14. Need 25.
Skipping task 39: Not enough samples (24) for class 14. Need 25.


Few-shot tasks:  42%|████▏     | 42/100 [00:11<00:11,  5.01it/s]

Skipping task 42: Not enough samples (10) for class 19. Need 25.


Few-shot tasks:  44%|████▍     | 44/100 [00:11<00:09,  5.88it/s]

Skipping task 44: Not enough samples (24) for class 17. Need 25.


Few-shot tasks:  52%|█████▏    | 52/100 [00:12<00:09,  5.06it/s]

Skipping task 52: Not enough samples (10) for class 19. Need 25.
Skipping task 53: Not enough samples (24) for class 14. Need 25.


Few-shot tasks:  56%|█████▌    | 56/100 [00:13<00:06,  6.75it/s]

Skipping task 56: Not enough samples (24) for class 14. Need 25.


Few-shot tasks:  58%|█████▊    | 58/100 [00:13<00:05,  7.43it/s]

Skipping task 58: Not enough samples (10) for class 19. Need 25.


Few-shot tasks:  61%|██████    | 61/100 [00:14<00:05,  6.85it/s]

Skipping task 61: Not enough samples (24) for class 17. Need 25.


Few-shot tasks:  63%|██████▎   | 63/100 [00:14<00:04,  7.50it/s]

Skipping task 63: Not enough samples (24) for class 14. Need 25.


Few-shot tasks:  66%|██████▌   | 66/100 [00:14<00:05,  6.59it/s]

Skipping task 66: Not enough samples (24) for class 14. Need 25.
Skipping task 67: Not enough samples (24) for class 14. Need 25.


Few-shot tasks:  70%|███████   | 70/100 [00:15<00:04,  7.49it/s]

Skipping task 70: Not enough samples (10) for class 19. Need 25.


Few-shot tasks:  75%|███████▌  | 75/100 [00:16<00:04,  5.81it/s]

Skipping task 75: Not enough samples (24) for class 17. Need 25.
Skipping task 76: Not enough samples (24) for class 14. Need 25.


Few-shot tasks:  82%|████████▏ | 82/100 [00:17<00:04,  3.66it/s]

Skipping task 82: Not enough samples (10) for class 19. Need 25.


Few-shot tasks:  84%|████████▍ | 84/100 [00:18<00:03,  4.05it/s]

Skipping task 84: Not enough samples (10) for class 19. Need 25.


Few-shot tasks:  88%|████████▊ | 88/100 [00:19<00:03,  3.94it/s]

Skipping task 88: Not enough samples (24) for class 14. Need 25.
Skipping task 89: Not enough samples (24) for class 17. Need 25.


Few-shot tasks:  96%|█████████▌| 96/100 [00:20<00:00,  4.78it/s]

Skipping task 96: Not enough samples (24) for class 17. Need 25.
Skipping task 97: Not enough samples (24) for class 14. Need 25.
Skipping task 98: Not enough samples (24) for class 14. Need 25.


Few-shot tasks: 100%|██████████| 100/100 [00:20<00:00,  4.85it/s]

# Few-shot Acc (3-way 10-shot): 0.9723±0.0326





In [99]:
if __name__ == "__main__":
    # Simulate command-line arguments using argparse.Namespace
    args = argparse.Namespace(
        dataset='MSRC_21',
        encoder='gin',
        decoder='gin',
        seeds=[0],
        device=0, # Use CPU if GPU is not available
        use_cfg=False,
        max_epoch=200, # Default value from build_args
        warmup_steps=-1, # Default value from build_args
        num_heads=4, # Default value from build_args
        num_out_heads=1, # Default value from build_args
        num_layers=2, # Default value from build_args
        num_hidden=512, # Default value from build_args
        residual=False, # Default value from build_args
        in_drop=0.2, # Default value from build_args
        attn_drop=0.1, # Default value from build_args
        norm=None, # Default value from build_args
        lr=0.0005, # Default value from build_args
        weight_decay=5e-4, # Default value from build_args
        negative_slope=0.2, # Default value from build_args
        activation="prelu", # Default value from build_args
        mask_rate=0.5, # Default value from build_args
        drop_edge_rate=0.0, # Default value from build_args
        replace_rate=0.0, # Default value from build_args
        loss_fn="sce", # Default value from build_fn_args
        alpha_l=2, # Default value from build_args
        optimizer="adam", # Default value from build_args
        max_epoch_f=30, # Default value from build_args
        lr_f=0.001, # Default value from build_args
        weight_decay_f=0.0, # Default value from build_args
        linear_prob=False, # Default value from build_args
        load_model=True, # Load the pre-trained model
        save_model=False, # Set to False to avoid saving during few-shot eval
        logging=False, # Default value from build_args
        scheduler=False, # Default value from build_args
        concat_hidden=False, # Default value from build_args
        pooling="mean", # Default value from build_args
        deg4feat=False, # Default value from build_args
        batch_size=32, # Default value from build_args
        num_shots=10, # Few-shot parameter: number of support samples per class
        num_ways=4, # Few-shot parameter: number of classes per task
        num_queries=15, # Few-shot parameter: number of query samples per class
        num_tasks=100 # Few-shot parameter: number of few-shot tasks to evaluate on
    )

    if args.use_cfg:
         # This part will still load configs.yml based on args.dataset
        args = load_best_configs(args, "configs.yml")
    print(args)

    # Load the dataset
    graphs, (num_features, num_classes) = load_graph_classification_dataset(args.dataset, deg4feat=args.deg4feat)
    args.num_features = num_features
    args.num_classes = num_classes # Add num_classes to args

    # Build the model
    model = build_model(args)
    model.to(args.device)

    # Load the pre-trained model weights if load_model is True
    if args.load_model:
        try:
            model.load_state_dict(torch.load("checkpoint.pt"))
            print("Loaded pre-trained model weights from checkpoint.pt")
        except FileNotFoundError:
            print("checkpoint.pt not found. Please make sure the pre-trained model is saved.")
            # Optionally, you can exit or train the model from scratch here
            # return

    # Initialize the pooler
    if args.pooling == "mean":
        pooler = AvgPooling()
    elif args.pooling == "max":
        pooler = MaxPooling()
    elif args.pooling == "sum":
        pooler = SumPooling()
    else:
        raise NotImplementedError

    # Run few-shot evaluation with fine-tuning
    acc_list = []
    for task in tqdm(range(args.num_tasks), desc="Few-shot tasks"):
        # Randomly sample classes
        selected_classes = np.random.choice(num_classes, args.num_ways, replace=False)
        task_graphs_with_labels = [(graph, label) for graph, label in graphs if label.item() in selected_classes]
        task_labels_tensor = torch.tensor([label.item() for _, label in task_graphs_with_labels])
        task_graphs = [graph for graph, _ in task_graphs_with_labels]

        # Create mapping for selected classes to 0-num_ways-1
        class_map = {cls.item(): i for i, cls in enumerate(selected_classes)}
        mapped_labels = torch.tensor([class_map[label.item()] for _, label in task_graphs_with_labels])

        # Split into support and query sets
        support_indices = []
        query_indices = []
        skip_task = False
        for cls in selected_classes:
            # Get indices for the current class within the task_graphs_with_labels list
            cls_task_indices = [i for i, (_, label) in enumerate(task_graphs_with_labels) if label.item() == cls.item()]

            if len(cls_task_indices) < args.num_shots + args.num_queries:
                 # Skip task if not enough samples for a class
                print(f"Skipping task {task}: Not enough samples ({len(cls_task_indices)}) for class {cls}. Need {args.num_shots + args.num_queries}.")
                skip_task = True
                break

            # Shuffle indices for the current class and split
            shuffled_cls_indices = np.random.permutation(cls_task_indices)
            support_indices.extend(shuffled_cls_indices[:args.num_shots])
            query_indices.extend(shuffled_cls_indices[args.num_shots:args.num_shots + args.num_queries])

        if skip_task:
            continue

        if not support_indices or not query_indices:
             # Skip task if no support or query samples could be gathered (should be caught by the previous check, but as a safeguard)
            print(f"Skipping task {task}: No support or query samples gathered after splitting")
            continue

        # Ensure indices are within the bounds of task_graphs
        support_graphs = [task_graphs[i] for i in support_indices]
        support_labels = mapped_labels[support_indices]
        query_graphs = [task_graphs[i] for i in query_indices]
        query_labels = mapped_labels[query_indices]

        # Perform fine-tuning and evaluation
        f1 = few_shot_finetune_and_evaluate(
            model,
            pooler,
            support_graphs,
            support_labels,
            query_graphs,
            query_labels,
            num_classes,
            args.lr_f,
            args.weight_decay_f,
            args.max_epoch_f,
            args.device,
            args.num_hidden,
        )
        acc_list.append(f1)

    final_acc = np.mean(acc_list)
    final_acc_std = np.std(acc_list)
    print(f"# Few-shot Acc ({args.num_ways}-way {args.num_shots}-shot): {final_acc:.4f}±{final_acc_std:.4f}")

Namespace(dataset='MSRC_21', encoder='gin', decoder='gin', seeds=[0], device=0, use_cfg=False, max_epoch=200, warmup_steps=-1, num_heads=4, num_out_heads=1, num_layers=2, num_hidden=512, residual=False, in_drop=0.2, attn_drop=0.1, norm=None, lr=0.0005, weight_decay=0.0005, negative_slope=0.2, activation='prelu', mask_rate=0.5, drop_edge_rate=0.0, replace_rate=0.0, loss_fn='sce', alpha_l=2, optimizer='adam', max_epoch_f=30, lr_f=0.001, weight_decay_f=0.0, linear_prob=False, load_model=True, save_model=False, logging=False, scheduler=False, concat_hidden=False, pooling='mean', deg4feat=False, batch_size=32, num_shots=10, num_ways=4, num_queries=15, num_tasks=100)
Use node label as node features
******** # Num Graphs: 563, # Num Feat: 24, # Num Classes: 20 ********
Loaded pre-trained model weights from checkpoint.pt


Few-shot tasks:   1%|          | 1/100 [00:00<00:30,  3.27it/s]

Skipping task 1: Not enough samples (10) for class 19. Need 25.
Skipping task 2: Not enough samples (10) for class 19. Need 25.


Few-shot tasks:   6%|▌         | 6/100 [00:00<00:16,  5.77it/s]

Skipping task 6: Not enough samples (10) for class 19. Need 25.


Few-shot tasks:   9%|▉         | 9/100 [00:01<00:18,  4.98it/s]

Skipping task 9: Not enough samples (24) for class 14. Need 25.
Skipping task 10: Not enough samples (24) for class 14. Need 25.


Few-shot tasks:  12%|█▏        | 12/100 [00:02<00:17,  5.02it/s]

Skipping task 12: Not enough samples (24) for class 14. Need 25.


Few-shot tasks:  15%|█▌        | 15/100 [00:02<00:15,  5.45it/s]

Skipping task 15: Not enough samples (10) for class 19. Need 25.
Skipping task 16: Not enough samples (24) for class 17. Need 25.
Skipping task 17: Not enough samples (24) for class 14. Need 25.


Few-shot tasks:  19%|█▉        | 19/100 [00:02<00:10,  8.07it/s]

Skipping task 19: Not enough samples (24) for class 17. Need 25.
Skipping task 20: Not enough samples (24) for class 14. Need 25.
Skipping task 21: Not enough samples (10) for class 19. Need 25.
Skipping task 22: Not enough samples (24) for class 17. Need 25.


Few-shot tasks:  24%|██▍       | 24/100 [00:03<00:06, 11.14it/s]

Skipping task 24: Not enough samples (10) for class 19. Need 25.
Skipping task 25: Not enough samples (24) for class 14. Need 25.


Few-shot tasks:  29%|██▉       | 29/100 [00:03<00:06, 10.39it/s]

Skipping task 28: Not enough samples (10) for class 19. Need 25.


Few-shot tasks:  31%|███       | 31/100 [00:03<00:06,  9.90it/s]

Skipping task 30: Not enough samples (24) for class 17. Need 25.


Few-shot tasks:  33%|███▎      | 33/100 [00:04<00:08,  7.60it/s]

Skipping task 33: Not enough samples (10) for class 19. Need 25.
Skipping task 34: Not enough samples (10) for class 19. Need 25.
Skipping task 35: Not enough samples (24) for class 17. Need 25.


Few-shot tasks:  39%|███▉      | 39/100 [00:05<00:07,  7.65it/s]

Skipping task 39: Not enough samples (10) for class 19. Need 25.
Skipping task 40: Not enough samples (10) for class 19. Need 25.
Skipping task 41: Not enough samples (10) for class 19. Need 25.
Skipping task 42: Not enough samples (24) for class 17. Need 25.
Skipping task 43: Not enough samples (24) for class 14. Need 25.
Skipping task 44: Not enough samples (10) for class 19. Need 25.


Few-shot tasks:  46%|████▌     | 46/100 [00:05<00:04, 12.62it/s]

Skipping task 46: Not enough samples (24) for class 14. Need 25.
Skipping task 47: Not enough samples (24) for class 17. Need 25.
Skipping task 48: Not enough samples (24) for class 14. Need 25.
Skipping task 49: Not enough samples (10) for class 19. Need 25.
Skipping task 50: Not enough samples (24) for class 17. Need 25.
Skipping task 51: Not enough samples (24) for class 17. Need 25.
Skipping task 52: Not enough samples (24) for class 17. Need 25.
Skipping task 53: Not enough samples (24) for class 14. Need 25.
Skipping task 54: Not enough samples (24) for class 14. Need 25.


Few-shot tasks:  56%|█████▌    | 56/100 [00:05<00:02, 19.44it/s]

Skipping task 56: Not enough samples (24) for class 17. Need 25.


Few-shot tasks:  61%|██████    | 61/100 [00:06<00:02, 13.09it/s]

Skipping task 60: Not enough samples (24) for class 14. Need 25.
Skipping task 61: Not enough samples (10) for class 19. Need 25.
Skipping task 62: Not enough samples (24) for class 17. Need 25.
Skipping task 63: Not enough samples (10) for class 19. Need 25.


Few-shot tasks:  67%|██████▋   | 67/100 [00:06<00:02, 12.60it/s]

Skipping task 66: Not enough samples (24) for class 17. Need 25.
Skipping task 67: Not enough samples (24) for class 14. Need 25.
Skipping task 68: Not enough samples (10) for class 19. Need 25.


Few-shot tasks:  70%|███████   | 70/100 [00:06<00:02, 12.64it/s]

Skipping task 70: Not enough samples (24) for class 14. Need 25.


Few-shot tasks:  74%|███████▍  | 74/100 [00:07<00:02, 10.67it/s]

Skipping task 73: Not enough samples (10) for class 19. Need 25.


Few-shot tasks:  76%|███████▌  | 76/100 [00:07<00:02, 10.12it/s]

Skipping task 75: Not enough samples (24) for class 14. Need 25.
Skipping task 76: Not enough samples (10) for class 19. Need 25.
Skipping task 77: Not enough samples (24) for class 17. Need 25.


Few-shot tasks:  81%|████████  | 81/100 [00:08<00:02,  8.16it/s]

Skipping task 81: Not enough samples (10) for class 19. Need 25.


Few-shot tasks:  83%|████████▎ | 83/100 [00:08<00:02,  8.00it/s]

Skipping task 83: Not enough samples (10) for class 19. Need 25.
Skipping task 84: Not enough samples (10) for class 19. Need 25.


Few-shot tasks:  88%|████████▊ | 88/100 [00:09<00:02,  4.75it/s]

Skipping task 88: Not enough samples (24) for class 17. Need 25.


Few-shot tasks:  91%|█████████ | 91/100 [00:10<00:01,  4.78it/s]

Skipping task 91: Not enough samples (10) for class 19. Need 25.
Skipping task 92: Not enough samples (24) for class 14. Need 25.


Few-shot tasks:  96%|█████████▌| 96/100 [00:12<00:01,  3.12it/s]

Skipping task 96: Not enough samples (24) for class 17. Need 25.


Few-shot tasks:  98%|█████████▊| 98/100 [00:12<00:00,  3.67it/s]

Skipping task 98: Not enough samples (24) for class 17. Need 25.


Few-shot tasks: 100%|██████████| 100/100 [00:12<00:00,  7.90it/s]

# Few-shot Acc (4-way 10-shot): 0.9614±0.0372



