In [1]:
"""
Our code is based on GraphENS:
https://github.com/JoonHyung-Park/GraphENS
"""

import os.path as osp
import random
import torch
import torch.nn.functional as F
from nets import *
from data_utils import *
from args import parse_args
from models import *
from losses import *
from sklearn.metrics import balanced_accuracy_score, f1_score
import statistics
import numpy as np

import warnings

warnings.filterwarnings("ignore")

from tqdm import tqdm
import dataloader

from utils import *

# Set GPU
device = "cuda" if torch.cuda.is_available() else "cpu"

import sys

sys.argv = [""]
args = parse_args()

if device == "cuda":
    torch.cuda.set_device(2)
    device_id = torch.cuda.current_device()
    print(f"Now using GPU #{device_id}:\n{torch.cuda.get_device_name(device_id)}")

args

Now using GPU #2:
Tesla V100-SXM2-32GB


Namespace(dataset='cora', imb_ratio=10, net='GCN', n_layer=2, feat_dim=256, loss_type='bs', tam=False, reweight=False, pc_softmax=False, ens=False, renode=False, keep_prob=0.01, pred_temp=2, loss_name='re-weight', factor_focal=2.0, factor_cb=0.9999, rn_base=0.5, rn_max=1.5, tam_alpha=2.5, tam_beta=0.5, temp_phi=1.2, warmup=5)

In [2]:
%%time
import time
from torch_geometric.utils import to_undirected

class GraphBalanceAugmenter():

    MODE_SPACE = ['dummy', 'uncertainty', 'topology', 'discrepancy']
    
    def __init__(self, x, edge_index, y_train, train_mask, device, 
                 mode:str='uncertainty', risk_pow:int=None, risk_mul:float=None):
        
        # parameter check
        assert mode in self.MODE_SPACE
        assert risk_pow is None or risk_pow >=0
        assert risk_mul is None or (risk_mul <= 1 and risk_mul >=0)
        self.mode = mode
        self.risk_pow = risk_pow
        self.risk_mul = risk_mul

        # initialization
        x, edge_index, y_train, train_mask = \
            x.to(device), edge_index.to(device), y_train.to(device), train_mask.to(device)
        self.n_node = x.shape[0]
        self.n_edge = edge_index.shape[1]
        self.classes = y_train.unique()
        self.y_virtual = y_train.unique()
        self.n_class = y_train.unique().shape[0]
        self.y_train = y_train
        self.train_mask = train_mask
        self.adj = index_to_adj(x, edge_index)
        self.class_num_list = self.y_train.bincount()
        self.class_weights = self.class_num_list / self.class_num_list.max()
        self.empty_edge_index = torch.zeros(2, 0, dtype=torch.long, device=device)
        self.device = device

    @staticmethod
    def edge_sampling(edge_index, edge_sampling_proba):
        assert edge_sampling_proba.min() >=0 and edge_sampling_proba.max() <= 1
        edge_sample_mask = (torch.rand_like(edge_sampling_proba) < edge_sampling_proba)
        return edge_index[:, edge_sample_mask]
    
    @staticmethod
    def get_group_mean(values, labels, classes):
        new_values = torch.zeros_like(values)
        for i in classes:
            mask = labels == i
            new_values[mask] = values[mask].mean()
        return new_values

    @staticmethod
    def get_virtual_node_features(x, y_pred, classes):
        return torch.stack([x[y_pred == label].mean(axis=0) for label in classes])

    @staticmethod
    def get_connectivity_distribution(y_pred, adj, n_class, n_node):
        # get connectivity label distribution
        y_pred_mat = y_pred.mul(adj)
        y_pred_mat[~adj.bool()] = n_class
        y_neighbor_distr = torch.zeros(
            n_class+1, n_node, dtype=torch.int, device=y_pred_mat.device
        ).scatter_add_(0, y_pred_mat.T, torch.ones(
            n_node, n_node, dtype=torch.int, device=y_pred_mat.device
        ))[:n_cls].T.float()
        # row-wise normalization
        y_neighbor_distr /= y_neighbor_distr.sum(axis=1).reshape(-1, 1)
        y_neighbor_distr = y_neighbor_distr.nan_to_num(0)
        return y_neighbor_distr

    def adapt_labels_and_train_mask(self, y:torch.Tensor, train_mask:torch.Tensor):
        if self.mode == 'dummy':
            return y, train_mask
        new_y = torch.concat([y, self.y_virtual])
        new_train_mask = torch.concat([train_mask, torch.ones_like(self.y_virtual).bool()])
        return new_y, new_train_mask

    def augment(self, model, x, edge_index):

        train_mask = self.train_mask
        # do nothing if mode is 'dummy'
        if self.mode == 'dummy':
            return (x, edge_index, {'time': 0, 'node_ratio': 1, 'edge_ratio': 1})

        # initialization
        start_time = time.time()
        y_pred_proba = predict_proba_tensor(model, x, edge_index)
        y_pred = y_pred_proba.argmax(axis=1)
        y_pred[train_mask] = self.y_train
        y_neighbor_distr = self.get_connectivity_distribution(y_pred, self.adj, self.n_class, self.n_node)

        # compute node_risk and virtual link probability
        node_risk = self.get_node_risk(y_pred_proba, y_pred)
        node_similarities = self.get_node_similarity_to_candidate_classes(y_pred_proba, y_neighbor_distr)
        virtual_link_proba = self.get_virual_link_proba(node_similarities, y_pred)
        # assign link probability w.r.t node risk
        virtual_link_proba *= node_risk.reshape(-1, 1)

        # sample virtual edge_index w.r.t given probability
        virtual_adj = virtual_link_proba.T.to_sparse().coalesce()
        edge_index_candidates, edge_sampling_proba = virtual_adj.indices(), virtual_adj.values()
        virtual_edge_index = self.edge_sampling(edge_index_candidates, edge_sampling_proba)
        virtual_edge_index[0] += self.n_node    # adjust index to match original node index
        virtual_edge_index = to_undirected(virtual_edge_index)

        # compute virtual node features
        x_virtual = self.get_virtual_node_features(x, y_pred, self.classes)
        
        # concatenate results
        used_time = time.time() - start_time
        x_aug = torch.concat([x, x_virtual])
        edge_index_aug = torch.concat([edge_index, virtual_edge_index], axis=1)
        info = {
            'time': used_time,
            'node_ratio': x_aug.shape[0] / x.shape[0],
            'edge_ratio': edge_index_aug.shape[1] / edge_index.shape[1],
        }
        return x_aug, edge_index_aug, info
    
    def get_node_risk(self, y_pred_proba, y_pred):
        # compute node uncertainty
        node_uncertainty = 1 - y_pred_proba.max(axis=1).values
        # compute class-aware relative uncertainty
        node_unc_class_mean = self.get_group_mean(node_uncertainty, y_pred, self.classes)
        node_risk = (node_uncertainty - node_unc_class_mean).clip(min=0)
        # lower the risk of minority class nodes
        node_risk *= self.class_weights[y_pred]
        # rescale node risk by given hyper-parameter
        # node_risk /= node_risk.max()
        if self.risk_pow:
            node_risk = node_risk.pow(self.risk_pow)
        if self.risk_mul:
            node_risk *= self.risk_mul
        return node_risk

    def get_node_similarity_to_candidate_classes(self, y_pred_proba, y_neighbor_distr):
        mode = self.mode
        if mode == 'uncertainty':
            node_similarities = y_pred_proba
        elif mode == 'topology':
            node_similarities = y_neighbor_distr
        elif mode == 'discrepancy':
            node_similarities = y_neighbor_distr - y_pred_proba
        else: raise NotImplementedError
        return node_similarities
        
    def get_virual_link_proba(self, node_similarities, y_pred):
        # set similarity to current predicted class as 0
        node_similarities *= (1 - F.one_hot(y_pred, num_classes=self.n_class))
        node_similarities = node_similarities.clip(min=0)
        # row-wise normalize
        node_similarities /= node_similarities.sum(axis=1).reshape(-1, 1)
        virtual_link_proba = node_similarities.nan_to_num(0)
        return virtual_link_proba


# gba = GraphBalanceAugmenter(
#     x=data.x, edge_index=data.edge_index, y_train=data.y[data_train_mask], 
#     train_mask=data_train_mask, device=device, 
#     mode='uncertainty', risk_pow=None, risk_mul=None,
# )
# x_aug, edge_index_aug, info = gba.augment(model, data.x, data.edge_index)
# y_aug, train_mask_aug = gba.adapt_labels_and_train_mask(data.y, data_train_mask)
# info

CPU times: user 30 µs, sys: 0 ns, total: 30 µs
Wall time: 34.6 µs


In [3]:
import copy
from torch_geometric.utils import to_dense_adj, dense_to_sparse, mask_to_index


## For GraphENS ##
def backward_hook(module, grad_input, grad_output):
    global saliency
    saliency = grad_input[0].data


def tensor_hook(grad):
    global saliency
    saliency = grad.data


def train():
    global data, class_weight, gba, runtime_info
    global class_num_list, idx_info, prev_out, aggregator, renode_loss, tail_classes
    global data_train_mask, data_val_mask, data_test_mask
    global model, optimizer, criterion, scheduler, epoch, neighbor_dist_list

    if args.gba:
        x_aug, edge_index_aug, info = gba.augment(model, data.x, data.edge_index)
        y_aug, train_mask_aug = gba.adapt_labels_and_train_mask(data.y, data_train_mask)
        runtime_info.append([info["time"], info["node_ratio"], info["edge_ratio"]])
        if args.debug and epoch % args.num_epochs == 0:
            print(
                "Epoch: {:d} | aug_time {} ms | node_ratio {:.3%} | edge_ratio {:.3%}".format(
                    epoch, info["time"] * 1000, info["node_ratio"], info["edge_ratio"]
                )
            )
    else:
        x_aug, y_aug = data.x.clone(), data.y.clone()
        edge_index_aug = data.edge_index.clone()
        train_mask_aug = data_train_mask.clone()

    model.train()
    optimizer.zero_grad()

    if args.ens:
        # Hook saliency map of input features
        model.conv1.temp_weight.register_backward_hook(backward_hook)
        # Sampling source and destination nodes
        sampling_src_idx, sampling_dst_idx = sampling_idx_individual_dst(
            class_num_list, idx_info, device
        )
        beta = torch.distributions.beta.Beta(2, 2)
        lam = beta.sample((len(sampling_src_idx),)).unsqueeze(1)
        ori_saliency = saliency[: x_aug.shape[0]] if (saliency != None) else None
        # Augment nodes
        if epoch > args.warmup:
            with torch.no_grad():
                prev_out = aggregator(prev_out, edge_index_aug)
                prev_out = F.softmax(prev_out / args.pred_temp, dim=1).detach().clone()
            new_edge_index, dist_kl = neighbor_sampling(
                x_aug.size(0),
                edge_index_aug,
                sampling_src_idx,
                sampling_dst_idx,
                neighbor_dist_list,
                prev_out,
            )
            new_x = saliency_mixup(
                x_aug,
                sampling_src_idx,
                sampling_dst_idx,
                lam,
                ori_saliency,
                dist_kl=dist_kl,
                keep_prob=args.keep_prob,
            )
        else:
            new_edge_index = duplicate_neighbor(
                x_aug.size(0), data.edge_index, sampling_src_idx
            )
            dist_kl, ori_saliency = None, None
            new_x = saliency_mixup(
                x_aug,
                sampling_src_idx,
                sampling_dst_idx,
                lam,
                ori_saliency,
                dist_kl=dist_kl,
            )
        new_x.requires_grad = True
        # Get predictions
        output = model(new_x, new_edge_index)
        prev_out = (output[: x_aug.size(0)]).detach().clone()  # logit propagation
        ## Train_mask modification ##
        add_num = output.shape[0] - train_mask_aug.shape[0]
        new_train_mask = torch.ones(add_num, dtype=torch.bool, device=x_aug.device)
        new_train_mask = torch.cat((train_mask_aug, new_train_mask), dim=0)
        ## Label modification ##
        _new_y = y_aug[sampling_src_idx].clone()
        new_y = torch.cat((y_aug[train_mask_aug], _new_y), dim=0)
        ## Compute Loss ##
        loss = criterion(output[new_train_mask], new_y)
        if args.debug and epoch % args.num_epochs == 0:
            print(
                f"GraphENS node_ratio {new_x.shape[0]/x_aug.shape[0]:.2%} edge_ratio {new_edge_index.shape[1]/edge_index_aug.shape[1]:.2%}"
            )

    elif args.renode:
        ## ReNode ##
        if args.gba and data.rn_weight.shape[0] < train_mask_aug.shape[0]:
            data.rn_weight = torch.cat(
                [data.rn_weight, torch.ones(n_cls, device=device)]
            )
        output = model(x_aug, edge_index_aug)
        sup_logits = output[train_mask_aug]
        cls_loss = renode_loss.compute(sup_logits, y_aug[train_mask_aug].to(device))
        loss = torch.sum(
            cls_loss * data.rn_weight[train_mask_aug].to(device)
        ) / cls_loss.size(0)

    elif args.graphsmote:
        ## GraphSMOTE ##
        num_nodes_aug = x_aug.shape[0]
        embed = model.get_embed(x_aug, edge_index_aug)
        idx_train = mask_to_index(train_mask_aug)
        adj = to_dense_adj(edge_index_aug, max_num_nodes=num_nodes_aug)[0]
        embed_new, labels_new, idx_train_new, adj_up = recon_upsample(
            embed,
            y_aug,
            idx_train,
            adj=adj.detach(),
            portion=0,
            tail_classes=tail_classes,
        )
        generated_G = gs_decoder(embed_new)
        loss_rec = graphsmote.adj_mse_loss(
            generated_G[:num_nodes_aug, :][:, :num_nodes_aug], adj.detach()
        )
        adj_new = copy.deepcopy(generated_G.detach())
        threshold = 0.5
        adj_new[adj_new < threshold] = 0.0
        adj_new[adj_new >= threshold] = 1.0
        adj_new = torch.mul(adj_up.to(adj_new.device), adj_new)
        adj_new[:num_nodes_aug, :][:, :num_nodes_aug] = adj.detach()
        adj_new, _ = dense_to_sparse(adj_new)
        output = model.embed_to_pred(embed_new, adj_new)
        loss = (
            criterion(output[idx_train_new], labels_new[idx_train_new])
            + loss_rec * 1e-6
        )
        if args.debug and epoch % args.num_epochs == 0:
            print(
                f"GraphSMOTE node_ratio {embed_new.shape[0]/x_aug.shape[0]:.2%} edge_ratio {adj_new.shape[1]/edge_index_aug.shape[1]:.2%}"
            )

    elif args.smote:
        ## SMOTE ##
        adj, features, labels = edge_index_aug, x_aug, y_aug
        idx_train = mask_to_index(train_mask_aug)
        adj_new, features_new, labels_new, idx_train_new = src_smote(
            adj, features, labels, idx_train, portion=0, tail_classes=tail_classes
        )
        output = model(features_new, adj_new.indices())
        loss = criterion(output[idx_train_new], labels_new[idx_train_new])
        if args.debug and epoch % args.num_epochs == 0:
            print(
                f"SMOTE node_ratio {features_new.shape[0]/x_aug.shape[0]:.2%} edge_ratio {adj_new.indices().shape[1]/edge_index_aug.shape[1]:.2%}"
            )

    elif args.resample:
        ## Resampling ##
        adj, features, labels = edge_index_aug, x_aug, y_aug
        idx_train = mask_to_index(train_mask_aug)
        adj_new, features_new, labels_new, idx_train_new = src_upsample(
            adj, features, labels, idx_train, portion=0, tail_classes=tail_classes
        )
        output = model(features_new, adj_new.indices())
        loss = criterion(output[idx_train_new], labels_new[idx_train_new])
        if args.debug and epoch % args.num_epochs == 0:
            print(
                f"Oversample node_ratio {features_new.shape[0]/x_aug.shape[0]:.2%} edge_ratio {adj_new.indices().shape[1]/edge_index_aug.shape[1]:.2%}"
            )

    elif args.reweight:
        ## Reweight ##
        output = model(x_aug, edge_index_aug)
        loss = criterion(
            output[train_mask_aug], y_aug[train_mask_aug], weight=class_weight
        )

    else:
        ## Vanilla ##
        output = model(x_aug, edge_index_aug)
        loss = criterion(output[train_mask_aug], y_aug[train_mask_aug])

    loss.backward()

    if args.graphsmote:
        gs_decoder_optimizer.zero_grad()
        gs_decoder_optimizer.step()

    with torch.no_grad():
        model.eval()
        output = model(data.x, data.edge_index)
        val_loss = F.cross_entropy(output[data_val_mask], data.y[data_val_mask])

    optimizer.step()
    scheduler.step(val_loss)

    torch.cuda.empty_cache()


@torch.no_grad()
def test():
    model.eval()
    logits = model(data.x, data.edge_index)
    accs, baccs, f1s = [], [], []

    if args.pc_softmax:
        logits = pc_softmax(logits, class_num_list)

    for i, mask in enumerate([data_train_mask, data_val_mask, data_test_mask]):
        pred = logits[mask].max(1)[1]
        y_pred = pred.cpu().numpy()
        y_true = data.y[mask].cpu().numpy()
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        bacc = balanced_accuracy_score(y_true, y_pred)
        f1 = f1_score(y_true, y_pred, average="macro")

        accs.append(acc)
        baccs.append(bacc)
        f1s.append(f1)

    return accs, baccs, f1s

In [4]:
def get_semi_train_val_test_split(data, n_train_num_per_class=20, test_ratio=0.5):
    from torch_geometric.utils import mask_to_index, index_to_mask

    n_nodes = data.y.shape[0]
    train_index, val_index, test_index = [], [], []
    for label in data.y.unique():
        cls_index = mask_to_index(data.y == label)
        n_cls_nodes = len(cls_index)
        n_train_nodes = n_train_num_per_class
        n_val_nodes = int((n_cls_nodes - n_train_nodes) * (1 - test_ratio))
        n_test_nodes = int((n_cls_nodes - n_train_nodes) * test_ratio)
        perm_index = torch.randperm(len(cls_index))
        cls_train_index = cls_index[perm_index[:n_train_nodes]]
        cls_val_index = cls_index[
            perm_index[n_train_nodes : n_train_nodes + n_val_nodes]
        ]
        cls_test_index = cls_index[perm_index[n_train_nodes + n_val_nodes :]]
        train_index.append(cls_train_index)
        val_index.append(cls_val_index)
        test_index.append(cls_test_index)

    device = data.y.device
    train_mask = index_to_mask(torch.concat(train_index), size=n_nodes).to(device)
    val_mask = index_to_mask(torch.concat(val_index), size=n_nodes).to(device)
    test_mask = index_to_mask(torch.concat(test_index), size=n_nodes).to(device)
    return train_mask, val_mask, test_mask

In [5]:
dataset_space = ["Cora"]
# dataset_space = ['Cora', 'CiteSeer', 'PubMed', 'CS', 'Physics']

# model_space = ['GCN', 'GAT', 'SAGE']
model_space = ["GCN"]
baseline_space = [
    "vanilla",
    "reweight",
    "resample",
    "renode",
    "smote",
    "graphsmote",
    "graphens",
]
# baseline_space = ['graphens', 'graphsmote', 'renode', 'smote', 'resample', 'reweight', 'vanilla']
# baseline_space = ['vanilla']

setting_space = ["dummy", "uncertainty", "topology"]
# setting_space = ["uncertainty", "topology"]
# setting_space = ['uncertainty']
save_path = "./results"

sys.argv = [""]
args = parse_args()

args.disable_tqdm = False
args.debug = False
# args.debug = True

args.imb_ratio = 10
# args.imb_ratio = 10
args.loss_type = "ce"
args.tam = False
args.pc_softmax = False
args.net = "GCN"
args.feat_dim = 256
args.n_layer = 3
args.num_epochs = 100
args.repeatition = 2
args.save_results = True
# args.save_results = False

if args.save_results:
    file_name = f"IR({args.imb_ratio})-rep({args.repeatition})-epoch({args.num_epochs})-gnn{str(model_space)}-data{str(dataset_space)}-bsl{str(baseline_space)}.csv"
    args.res_path = f"{save_path}/{file_name}"
    print(f"Saving to {args.res_path} ...")

args.gba = False
args.ens = False
args.reweight = False
args.renode = False
args.graphsmote = False
args.smote = False
args.resample = False

# args.ens = True
# args.reweight = True
args.gba = {
    "risk_pow": None,
    # 'risk_pow': 2,
    "risk_mul": None,
    # 'risk_mul': .5,
    # 'mode': 'dummy',
    # 'mode': 'uncertainty',
    # 'mode': 'topology',
    # 'mode': 'discrepancy',
}

all_results = []


for dataset_name in dataset_space:
    args.dataset = dataset_name

    torch.cuda.empty_cache()

    for baseline_name in baseline_space:
        if baseline_name == "vanilla":
            (
                args.reweight,
                args.renode,
                args.resample,
                args.smote,
                args.graphsmote,
                args.ens,
            ) = (False, False, False, False, False, False)
        elif baseline_name == "reweight":
            (
                args.reweight,
                args.renode,
                args.resample,
                args.smote,
                args.graphsmote,
                args.ens,
            ) = (True, False, False, False, False, False)
        elif baseline_name == "renode":
            (
                args.reweight,
                args.renode,
                args.resample,
                args.smote,
                args.graphsmote,
                args.ens,
            ) = (False, True, False, False, False, False)
        elif baseline_name == "resample":
            (
                args.reweight,
                args.renode,
                args.resample,
                args.smote,
                args.graphsmote,
                args.ens,
            ) = (False, False, True, False, False, False)
        elif baseline_name == "smote":
            (
                args.reweight,
                args.renode,
                args.resample,
                args.smote,
                args.graphsmote,
                args.ens,
            ) = (False, False, False, True, False, False)
        elif baseline_name == "graphsmote":
            (
                args.reweight,
                args.renode,
                args.resample,
                args.smote,
                args.graphsmote,
                args.ens,
            ) = (False, False, False, False, True, False)
        elif baseline_name == "graphens":
            (
                args.reweight,
                args.renode,
                args.resample,
                args.smote,
                args.graphsmote,
                args.ens,
            ) = (False, False, False, False, False, True)
        else:
            raise RuntimeError

        for model_name in model_space:
            args.net = model_name

            torch.cuda.empty_cache()

            for gba_setting in setting_space:
                args.gba["mode"] = gba_setting

                runtime_info = []

                ## Log for Experiment Setting ##
                setting_log = "Runs: {}, Dataset: {}, ratio: {}, net: {}, n_layer: {}, feat_dim: {}, baseline: {}, gba: {}".format(
                    args.repeatition,
                    args.dataset,
                    str(args.imb_ratio),
                    args.net,
                    str(args.n_layer),
                    str(args.feat_dim),
                    baseline_name,
                    gba_setting,
                )

                dataset = args.dataset
                path = osp.join("data", dataset)
                dataset = get_dataset(dataset, path, split_type="public")
                data = dataset[0]
                n_cls = data.y.max().item() + 1
                data = data.to(device)
                # print (f'Loaded dataset: {args.dataset} - ImbRatio {args.imb_ratio}')
                # print (args)

                torch.cuda.empty_cache()

                repeatition = args.repeatition
                seed = 100
                (
                    avg_val_acc_f1,
                    avg_test_acc,
                    avg_test_bacc,
                    avg_test_f1,
                    avg_test_disparity,
                ) = ([], [], [], [], [])
                avg_run_time, avg_node_ratio, avg_edge_ratio = [], [], []
                for r in range(repeatition):

                    runtime_info = []

                    ## Fix seed ##
                    torch.cuda.empty_cache()
                    seed += 1
                    torch.manual_seed(seed)
                    torch.cuda.manual_seed(seed)
                    torch.backends.cudnn.deterministic = True
                    torch.backends.cudnn.benchmark = False
                    random.seed(seed)
                    np.random.seed(seed)

                    if args.dataset in ["Cora", "CiteSeer", "PubMed"]:
                        data_train_mask, data_val_mask, data_test_mask = (
                            data.train_mask.clone(),
                            data.val_mask.clone(),
                            data.test_mask.clone(),
                        )
                    elif args.dataset in ["CS", "Physics"]:
                        data_train_mask, data_val_mask, data_test_mask = (
                            get_semi_train_val_test_split(data, 20)
                        )
                    else:
                        data_train_mask, data_val_mask, data_test_mask = (
                            get_graphens_train_val_test_split(data, 20)
                        )
                        # data_train_mask, data_val_mask, data_test_mask = get_train_val_test_split(data, 20)
                        # data_train_mask, data_val_mask, data_test_mask = get_iid_train_val_test_split(data, train_ratio=0.6, test_ratio=0.2)

                    ## Data statistic ##
                    stats = data.y[data_train_mask]
                    n_data = []
                    for i in range(n_cls):
                        data_num = (stats == i).sum()
                        n_data.append(int(data_num.item()))
                    idx_info = get_idx_info(data.y, n_cls, data_train_mask)
                    class_num_list = n_data

                    # for artificial imbalanced setting: only the last imb_class_num classes are imbalanced
                    imb_class_num = n_cls // 2
                    new_class_num_list = []
                    tail_classes = []
                    max_num = np.max(class_num_list[: n_cls - imb_class_num])
                    for i in range(n_cls):
                        if (
                            args.imb_ratio > 1 and i > n_cls - 1 - imb_class_num
                        ):  # only imbalance the last classes
                            tail_classes.append(i)
                            new_class_num_list.append(
                                min(
                                    int(max_num * (1.0 / args.imb_ratio)),
                                    class_num_list[i],
                                )
                            )
                        else:
                            new_class_num_list.append(class_num_list[i])
                    class_num_list = new_class_num_list

                    if args.imb_ratio > 1:
                        data_train_mask, idx_info = split_semi_dataset(
                            len(data.x),
                            n_data,
                            n_cls,
                            class_num_list,
                            idx_info,
                            data.x.device,
                        )

                    if args.ens:
                        neighbor_dist_list = get_ins_neighbor_dist(
                            data.y.size(0), data.edge_index, data_train_mask, device
                        )
                    else:
                        neighbor_dist_list = None

                    if args.ens:  # for getting saliency
                        from ens_nets import *
                    else:
                        from nets import *

                    if args.renode:
                        ## ReNode method ##
                        ## hyperparam ##
                        pagerank_prob = 0.85

                        # calculating the Personalized PageRank Matrix
                        pr_prob = 1 - pagerank_prob
                        A = index2dense(data.edge_index, data.num_nodes)
                        A_hat = A.to(device) + torch.eye(A.size(0)).to(
                            device
                        )  # add self-loop
                        D = torch.diag(torch.sum(A_hat, 1))
                        D = D.inverse().sqrt()
                        A_hat = torch.mm(torch.mm(D, A_hat), D)
                        data.Pi = pr_prob * (
                            (
                                torch.eye(A.size(0)).to(device) - (1 - pr_prob) * A_hat
                            ).inverse()
                        )
                        data.Pi = data.Pi.cpu()

                        # calculating the ReNode Weight
                        gpr_matrix = []  # the class-level influence distribution
                        data.num_classes = n_cls
                        for iter_c in range(data.num_classes):
                            # iter_Pi = data.Pi[torch.tensor(target_data.train_node[iter_c]).long()]
                            iter_Pi = data.Pi[
                                idx_info[iter_c].long()
                            ]  # check! is it same with above line?
                            iter_gpr = torch.mean(iter_Pi, dim=0).squeeze()
                            gpr_matrix.append(iter_gpr)

                        temp_gpr = torch.stack(gpr_matrix, dim=0)
                        temp_gpr = temp_gpr.transpose(0, 1)
                        data.gpr = temp_gpr
                        data.rn_weight = get_renode_weight(
                            data, data_train_mask, args.rn_base, args.rn_max
                        )  # ReNode Weight
                        renode_loss = IMB_LOSS(
                            args.loss_name,
                            data,
                            idx_info,
                            args.factor_focal,
                            args.factor_cb,
                        )

                    if args.graphsmote:
                        gs_decoder = GSMOTEDecoder(
                            nembed=args.feat_dim, dropout=0.1
                        ).to(device)
                        gs_decoder_optimizer = torch.optim.Adam(
                            gs_decoder.parameters(), lr=0.001, weight_decay=5e-4
                        )

                    if args.gba:
                        gba = GraphBalanceAugmenter(
                            x=data.x,
                            edge_index=data.edge_index,
                            y_train=data.y[data_train_mask],
                            train_mask=data_train_mask,
                            device=device,
                            mode=args.gba["mode"],
                            risk_pow=args.gba["risk_pow"],
                            risk_mul=args.gba["risk_mul"],
                        )

                    ## Re-weight method ##
                    class_weight = get_weight(args.reweight, class_num_list).to(device)

                    ## Model Selection ##
                    if args.net == "GCN":
                        model = create_gcn(
                            nfeat=dataset.num_features,
                            nhid=args.feat_dim,
                            nclass=n_cls,
                            dropout=0.5,
                            nlayer=args.n_layer,
                        )
                    elif args.net == "GAT":
                        model = create_gat(
                            nfeat=dataset.num_features,
                            nhid=args.feat_dim,
                            nclass=n_cls,
                            dropout=0.5,
                            nlayer=args.n_layer,
                        )
                    elif args.net == "SAGE":
                        model = create_sage(
                            nfeat=dataset.num_features,
                            nhid=args.feat_dim,
                            nclass=n_cls,
                            dropout=0.5,
                            nlayer=args.n_layer,
                        )
                    else:
                        raise NotImplementedError("Not Implemented Architecture!")

                    ## Criterion Selection ##
                    if args.loss_type == "ce":  # CE
                        criterion = CrossEntropy()
                    elif args.loss_type == "bs":
                        criterion = BalancedSoftmax(class_num_list)
                    else:
                        raise NotImplementedError("Not Implemented Loss!")

                    model = model.to(device)
                    criterion = criterion.to(device)

                    # Set optimizer
                    optimizer = torch.optim.Adam(
                        [
                            dict(params=model.reg_params, weight_decay=5e-4),
                            dict(params=model.non_reg_params, weight_decay=0),
                        ],
                        lr=0.01,
                    )
                    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                        optimizer, mode="min", factor=0.5, patience=100, verbose=False
                    )

                    # Train models
                    best_val_acc_f1 = 0
                    saliency, prev_out = None, None
                    aggregator = MeanAggregation()

                    all_stats_list, all_maskes_list = [], []
                    for epoch in tqdm(
                        range(1, args.num_epochs + 1), disable=args.disable_tqdm
                    ):

                        train()
                        accs, baccs, f1s = test()
                        train_acc, val_acc, tmp_test_acc = accs
                        train_f1, val_f1, tmp_test_f1 = f1s
                        val_acc_f1 = (val_acc + val_f1) / 2.0
                        if val_acc_f1 > best_val_acc_f1:
                            best_val_acc_f1 = val_acc_f1
                            test_acc = accs[2]
                            test_bacc = baccs[2]
                            test_f1 = f1s[2]
                            best_epoch = epoch
                            y_pred = predict(model, data.x, data.edge_index)
                            test_disparity = get_class_wise_accuracy(
                                data.y, y_pred, data_test_mask
                            ).std()

                    avg_val_acc_f1.append(best_val_acc_f1)
                    avg_test_acc.append(test_acc)
                    avg_test_bacc.append(test_bacc)
                    avg_test_f1.append(test_f1)
                    avg_test_disparity.append(test_disparity)
                    if args.debug:
                        print(
                            f"Best epoch {best_epoch} Val Acc F1: {best_val_acc_f1:.4f}, "
                            f"Test Disp {test_disparity:.4f}, Acc: {test_acc:.4f}, "
                            f"BAcc: {test_bacc:.4f}, F1: {test_f1:.4f}"
                        )

                    avg_runtime_stats = pd.DataFrame(
                        runtime_info, columns=["runtime", "node_ratio", "edge_ratio"]
                    ).mean()
                    avg_run_time.append(avg_runtime_stats["runtime"])
                    avg_node_ratio.append(avg_runtime_stats["node_ratio"])
                    avg_edge_ratio.append(avg_runtime_stats["edge_ratio"])

                    torch.cuda.empty_cache()

                if args.repeatition > 1:
                    ## Calculate statistics ##
                    acc_CI = statistics.stdev(avg_test_acc) / (repeatition ** (1 / 2))
                    bacc_CI = statistics.stdev(avg_test_bacc) / (repeatition ** (1 / 2))
                    f1_CI = statistics.stdev(avg_test_f1) / (repeatition ** (1 / 2))
                    disp_CI = statistics.stdev(avg_test_disparity) / (
                        repeatition ** (1 / 2)
                    )
                    avg_acc = statistics.mean(avg_test_acc)
                    avg_bacc = statistics.mean(avg_test_bacc)
                    avg_f1 = statistics.mean(avg_test_f1)
                    avg_disp = statistics.mean(avg_test_disparity)
                    avg_val_acc_f1 = statistics.mean(avg_val_acc_f1)
                    avg_run_time = statistics.mean(avg_run_time)
                    avg_node_ratio = statistics.mean(avg_node_ratio)
                    avg_edge_ratio = statistics.mean(avg_edge_ratio)

                    avg_log = "Val Acc F1: {:.4f}, Test Disp {:.4f} +- {:.4f} Acc: {:.4f} +- {:.4f}, BAcc: {:.4f} +- {:.4f}, F1: {:.4f} +- {:.4f}"
                    avg_log = avg_log.format(
                        avg_val_acc_f1,
                        avg_disp,
                        disp_CI,
                        avg_acc,
                        acc_CI,
                        avg_bacc,
                        bacc_CI,
                        avg_f1,
                        f1_CI,
                    )
                    avg_runtim_log = f"node_ratio: {avg_node_ratio-1:.3%} | edge_ratio: {avg_edge_ratio-1:.3%} | Runtime: {avg_run_time:.4f} s"
                    log = "{}\n{}\n{}".format(setting_log, avg_runtim_log, avg_log)
                    print(log + "\n")

                    run_results = [
                        avg_val_acc_f1,
                        avg_acc,
                        acc_CI,
                        avg_bacc,
                        bacc_CI,
                        avg_f1,
                        f1_CI,
                        avg_disp,
                        disp_CI,
                        model_name,
                        dataset_name,
                        baseline_name,
                        gba_setting,
                        avg_run_time,
                        avg_node_ratio,
                        avg_edge_ratio,
                    ]
                    all_results.append(run_results)

                    if args.save_results:
                        columns = [
                            "avg_val_acc_f1",
                            "acc",
                            "acc_std",
                            "bacc",
                            "bacc_std",
                            "f1",
                            "f1_std",
                            "disparity",
                            "disparity_std",
                            "model",
                            "dataset",
                            "baseline",
                            "setting",
                            "runtime",
                            "node_ratio",
                            "edge_ratio",
                        ]
                        df_all_results = pd.DataFrame(all_results, columns=columns)
                        df_all_results.to_csv(args.res_path, index=None)

Saving to ./results/IR(10)-rep(2)-epoch(100)-gnn['GCN']-data['Cora']-bsl['vanilla', 'reweight', 'resample', 'renode', 'smote', 'graphsmote', 'graphens'].csv ...


100%|██████████| 100/100 [00:02<00:00, 48.57it/s]
100%|██████████| 100/100 [00:01<00:00, 67.20it/s]


Runs: 2, Dataset: Cora, ratio: 10, net: GCN, n_layer: 3, feat_dim: 256, baseline: vanilla, gba: dummy
node_ratio: 0.000% | edge_ratio: 0.000% | Runtime: 0.0000 s
Val Acc F1: 0.6380, Test Disp 0.2797 +- 0.0139 Acc: 0.6830 +- 0.0060, BAcc: 0.6228 +- 0.0113, F1: 0.6160 +- 0.0079



100%|██████████| 100/100 [00:02<00:00, 48.31it/s]
100%|██████████| 100/100 [00:01<00:00, 51.04it/s]


Runs: 2, Dataset: Cora, ratio: 10, net: GCN, n_layer: 3, feat_dim: 256, baseline: vanilla, gba: uncertainty
node_ratio: 0.258% | edge_ratio: 3.539% | Runtime: 0.0050 s
Val Acc F1: 0.6288, Test Disp 0.2421 +- 0.0309 Acc: 0.6755 +- 0.0135, BAcc: 0.6282 +- 0.0073, F1: 0.6103 +- 0.0170



100%|██████████| 100/100 [00:02<00:00, 48.98it/s]
100%|██████████| 100/100 [00:02<00:00, 49.36it/s]


Runs: 2, Dataset: Cora, ratio: 10, net: GCN, n_layer: 3, feat_dim: 256, baseline: vanilla, gba: topology
node_ratio: 0.258% | edge_ratio: 1.534% | Runtime: 0.0051 s
Val Acc F1: 0.6725, Test Disp 0.2365 +- 0.0119 Acc: 0.7180 +- 0.0190, BAcc: 0.6708 +- 0.0214, F1: 0.6584 +- 0.0259



100%|██████████| 100/100 [00:01<00:00, 65.51it/s]
100%|██████████| 100/100 [00:01<00:00, 67.53it/s]


Runs: 2, Dataset: Cora, ratio: 10, net: GCN, n_layer: 3, feat_dim: 256, baseline: reweight, gba: dummy
node_ratio: 0.000% | edge_ratio: 0.000% | Runtime: 0.0000 s
Val Acc F1: 0.6799, Test Disp 0.2418 +- 0.0002 Acc: 0.7270 +- 0.0140, BAcc: 0.6720 +- 0.0054, F1: 0.6790 +- 0.0052



100%|██████████| 100/100 [00:01<00:00, 50.57it/s]
100%|██████████| 100/100 [00:01<00:00, 50.56it/s]


Runs: 2, Dataset: Cora, ratio: 10, net: GCN, n_layer: 3, feat_dim: 256, baseline: reweight, gba: uncertainty
node_ratio: 0.258% | edge_ratio: 1.900% | Runtime: 0.0049 s
Val Acc F1: 0.6934, Test Disp 0.2046 +- 0.0024 Acc: 0.7200 +- 0.0190, BAcc: 0.6873 +- 0.0001, F1: 0.6727 +- 0.0111



100%|██████████| 100/100 [00:01<00:00, 50.70it/s]
100%|██████████| 100/100 [00:02<00:00, 49.72it/s]


Runs: 2, Dataset: Cora, ratio: 10, net: GCN, n_layer: 3, feat_dim: 256, baseline: reweight, gba: topology
node_ratio: 0.258% | edge_ratio: 1.115% | Runtime: 0.0049 s
Val Acc F1: 0.7162, Test Disp 0.2189 +- 0.0021 Acc: 0.7205 +- 0.0015, BAcc: 0.7112 +- 0.0017, F1: 0.6842 +- 0.0013



100%|██████████| 100/100 [00:01<00:00, 55.34it/s]
100%|██████████| 100/100 [00:01<00:00, 54.97it/s]


Runs: 2, Dataset: Cora, ratio: 10, net: GCN, n_layer: 3, feat_dim: 256, baseline: resample, gba: dummy
node_ratio: 0.000% | edge_ratio: 0.000% | Runtime: 0.0000 s
Val Acc F1: 0.6156, Test Disp 0.2944 +- 0.0276 Acc: 0.6700 +- 0.0160, BAcc: 0.6077 +- 0.0210, F1: 0.5985 +- 0.0352



100%|██████████| 100/100 [00:02<00:00, 45.21it/s]
100%|██████████| 100/100 [00:02<00:00, 44.70it/s]


Runs: 2, Dataset: Cora, ratio: 10, net: GCN, n_layer: 3, feat_dim: 256, baseline: resample, gba: uncertainty
node_ratio: 0.258% | edge_ratio: 2.383% | Runtime: 0.0047 s
Val Acc F1: 0.7327, Test Disp 0.1590 +- 0.0490 Acc: 0.7580 +- 0.0100, BAcc: 0.7360 +- 0.0124, F1: 0.7208 +- 0.0071



100%|██████████| 100/100 [00:02<00:00, 44.98it/s]
100%|██████████| 100/100 [00:02<00:00, 44.60it/s]


Runs: 2, Dataset: Cora, ratio: 10, net: GCN, n_layer: 3, feat_dim: 256, baseline: resample, gba: topology
node_ratio: 0.258% | edge_ratio: 1.190% | Runtime: 0.0048 s
Val Acc F1: 0.7247, Test Disp 0.2174 +- 0.0018 Acc: 0.7755 +- 0.0025, BAcc: 0.7262 +- 0.0108, F1: 0.7250 +- 0.0038



100%|██████████| 100/100 [00:01<00:00, 65.71it/s]
100%|██████████| 100/100 [00:01<00:00, 64.93it/s]


Runs: 2, Dataset: Cora, ratio: 10, net: GCN, n_layer: 3, feat_dim: 256, baseline: renode, gba: dummy
node_ratio: 0.000% | edge_ratio: 0.000% | Runtime: 0.0000 s
Val Acc F1: 0.6474, Test Disp 0.2633 +- 0.0087 Acc: 0.7020 +- 0.0160, BAcc: 0.6421 +- 0.0291, F1: 0.6435 +- 0.0254



100%|██████████| 100/100 [00:02<00:00, 48.82it/s]
100%|██████████| 100/100 [00:01<00:00, 50.09it/s]


Runs: 2, Dataset: Cora, ratio: 10, net: GCN, n_layer: 3, feat_dim: 256, baseline: renode, gba: uncertainty
node_ratio: 0.258% | edge_ratio: 1.942% | Runtime: 0.0049 s
Val Acc F1: 0.7043, Test Disp 0.1832 +- 0.0316 Acc: 0.7085 +- 0.0155, BAcc: 0.7023 +- 0.0125, F1: 0.6827 +- 0.0151



100%|██████████| 100/100 [00:01<00:00, 50.44it/s]
100%|██████████| 100/100 [00:01<00:00, 50.23it/s]


Runs: 2, Dataset: Cora, ratio: 10, net: GCN, n_layer: 3, feat_dim: 256, baseline: renode, gba: topology
node_ratio: 0.258% | edge_ratio: 1.090% | Runtime: 0.0049 s
Val Acc F1: 0.7225, Test Disp 0.1886 +- 0.0310 Acc: 0.7350 +- 0.0160, BAcc: 0.7065 +- 0.0183, F1: 0.6997 +- 0.0226



100%|██████████| 100/100 [00:02<00:00, 46.42it/s]
100%|██████████| 100/100 [00:02<00:00, 45.95it/s]


Runs: 2, Dataset: Cora, ratio: 10, net: GCN, n_layer: 3, feat_dim: 256, baseline: smote, gba: dummy
node_ratio: 0.000% | edge_ratio: 0.000% | Runtime: 0.0000 s
Val Acc F1: 0.6106, Test Disp 0.3037 +- 0.0025 Acc: 0.6720 +- 0.0030, BAcc: 0.6024 +- 0.0030, F1: 0.5925 +- 0.0032



100%|██████████| 100/100 [00:02<00:00, 40.20it/s]
100%|██████████| 100/100 [00:02<00:00, 40.45it/s]


Runs: 2, Dataset: Cora, ratio: 10, net: GCN, n_layer: 3, feat_dim: 256, baseline: smote, gba: uncertainty
node_ratio: 0.258% | edge_ratio: 2.425% | Runtime: 0.0047 s
Val Acc F1: 0.7325, Test Disp 0.1882 +- 0.0041 Acc: 0.7545 +- 0.0055, BAcc: 0.7261 +- 0.0047, F1: 0.7130 +- 0.0021



100%|██████████| 100/100 [00:02<00:00, 40.16it/s]
100%|██████████| 100/100 [00:02<00:00, 40.26it/s]


Runs: 2, Dataset: Cora, ratio: 10, net: GCN, n_layer: 3, feat_dim: 256, baseline: smote, gba: topology
node_ratio: 0.258% | edge_ratio: 1.228% | Runtime: 0.0046 s
Val Acc F1: 0.7356, Test Disp 0.2175 +- 0.0022 Acc: 0.7670 +- 0.0030, BAcc: 0.7282 +- 0.0003, F1: 0.7140 +- 0.0042



100%|██████████| 100/100 [00:05<00:00, 19.34it/s]
100%|██████████| 100/100 [00:05<00:00, 18.99it/s]


Runs: 2, Dataset: Cora, ratio: 10, net: GCN, n_layer: 3, feat_dim: 256, baseline: graphsmote, gba: dummy
node_ratio: 0.000% | edge_ratio: 0.000% | Runtime: 0.0000 s
Val Acc F1: 0.6868, Test Disp 0.2151 +- 0.0024 Acc: 0.7450 +- 0.0160, BAcc: 0.6908 +- 0.0050, F1: 0.6942 +- 0.0127



100%|██████████| 100/100 [00:04<00:00, 22.36it/s]
100%|██████████| 100/100 [00:04<00:00, 22.52it/s]


Runs: 2, Dataset: Cora, ratio: 10, net: GCN, n_layer: 3, feat_dim: 256, baseline: graphsmote, gba: uncertainty
node_ratio: 0.258% | edge_ratio: 2.287% | Runtime: 0.0050 s
Val Acc F1: 0.6958, Test Disp 0.2167 +- 0.0058 Acc: 0.7290 +- 0.0160, BAcc: 0.6916 +- 0.0065, F1: 0.6820 +- 0.0117



100%|██████████| 100/100 [00:04<00:00, 22.50it/s]
100%|██████████| 100/100 [00:04<00:00, 22.70it/s]


Runs: 2, Dataset: Cora, ratio: 10, net: GCN, n_layer: 3, feat_dim: 256, baseline: graphsmote, gba: topology
node_ratio: 0.258% | edge_ratio: 1.251% | Runtime: 0.0049 s
Val Acc F1: 0.7250, Test Disp 0.2188 +- 0.0045 Acc: 0.7470 +- 0.0100, BAcc: 0.7069 +- 0.0045, F1: 0.6966 +- 0.0063



100%|██████████| 100/100 [00:01<00:00, 51.54it/s]
100%|██████████| 100/100 [00:01<00:00, 52.28it/s]


Runs: 2, Dataset: Cora, ratio: 10, net: GCN, n_layer: 3, feat_dim: 256, baseline: graphens, gba: dummy
node_ratio: 0.000% | edge_ratio: 0.000% | Runtime: 0.0000 s
Val Acc F1: 0.7257, Test Disp 0.2170 +- 0.0046 Acc: 0.7410 +- 0.0030, BAcc: 0.7017 +- 0.0102, F1: 0.6936 +- 0.0066



100%|██████████| 100/100 [00:02<00:00, 41.22it/s]
100%|██████████| 100/100 [00:02<00:00, 40.65it/s]


Runs: 2, Dataset: Cora, ratio: 10, net: GCN, n_layer: 3, feat_dim: 256, baseline: graphens, gba: uncertainty
node_ratio: 0.258% | edge_ratio: 2.422% | Runtime: 0.0055 s
Val Acc F1: 0.7290, Test Disp 0.2209 +- 0.0029 Acc: 0.7730 +- 0.0100, BAcc: 0.7199 +- 0.0080, F1: 0.7164 +- 0.0121



100%|██████████| 100/100 [00:02<00:00, 41.04it/s]
100%|██████████| 100/100 [00:02<00:00, 41.10it/s]

Runs: 2, Dataset: Cora, ratio: 10, net: GCN, n_layer: 3, feat_dim: 256, baseline: graphens, gba: topology
node_ratio: 0.258% | edge_ratio: 1.364% | Runtime: 0.0055 s
Val Acc F1: 0.7335, Test Disp 0.2195 +- 0.0020 Acc: 0.7765 +- 0.0115, BAcc: 0.7260 +- 0.0075, F1: 0.7219 +- 0.0124




