Visualize ClassIL on Yelp

In [25]:
import scipy.io
import pandas as pd
import os
import scipy
import torch
import scipy.io
from torch_geometric.data import Data
import numpy as np
from torch_geometric.datasets import Yelp
from utils import sparse_mx_to_torch_sparse_tensor
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from collections import defaultdict
import itertools

In [26]:
# load data
print('Loading dataset Yelp...')
dataset = Yelp(root='../../tmp/Yelp')
data = dataset[0]
labels = data.y
features = data.x
edge_index = data.edge_index
print("dataset loaded")

Loading dataset Yelp...
dataset loaded


In [27]:
# groups assignment
n_cls = labels.shape[1]
cls_order = shuffle(range(n_cls))
n_cls_per_t = 5
groups_idx = [tuple(cls_order[i:i+n_cls_per_t]) for i in range(0, n_cls-1, n_cls_per_t)]
#print("groups index", groups_idx)

clses = torch.transpose(labels, 0, 1)
cls_asgn = {}
for i, label in enumerate(clses):
    cls_asgn[i] = torch.nonzero(label).flatten()

groups = {}
for g in groups_idx:
    groups[g] = []
for g in groups_idx:
    assert g in groups.keys()
    for c in g:
        groups[g].extend(cls_asgn[c].tolist())

        print("number of nodes in this class",len(groups[g]))
    # note only once the nodes belong to multiple classes
    groups[g] = list(set(groups[g]))
    groups[g].sort()
    print("total number of nodes after removing duplicats of node-ids because of multi-labelness", len(groups[g]))
    print("###################")

# splits is a nested dictionary
# key: groups, value: (key:train, val, test; value:node_ids)
# each time step, nodes have no more than classes seen so far

# split the train-val-test within each class
splits = {}
for g in list(groups.keys()):
    split = {}
    # get all the nodes in this group
    node_ids = groups[g]

    # split the nodes into train-val-test: 60-10-30
    ids_train, ids_val_test = train_test_split(node_ids, test_size=0.4, random_state=42)
    ids_val, ids_test = train_test_split(ids_val_test, test_size=0.75, random_state=41)
    # write in the dictionary
    split["train"] = ids_train
    split["val"] = ids_val
    split["test"] = ids_test
    splits[g] = split

#print(splits)

num_nodes = labels.shape[0]

G = Data(x=features,
         edge_index=edge_index,
         y=labels)

G.n_id = torch.arange(num_nodes)
G.splits = splits
G.groups = groups

number of nodes in this class 58722
number of nodes in this class 71095
number of nodes in this class 84641
number of nodes in this class 223504
number of nodes in this class 235923
total number of nodes after removing duplicats of node-ids because of multi-labelness 191243
###################
number of nodes in this class 50405
number of nodes in this class 141439
number of nodes in this class 154457
number of nodes in this class 193520
number of nodes in this class 253743
total number of nodes after removing duplicats of node-ids because of multi-labelness 190810
###################
number of nodes in this class 192696
number of nodes in this class 241575
number of nodes in this class 271709
number of nodes in this class 399202
number of nodes in this class 439038
total number of nodes after removing duplicats of node-ids because of multi-labelness 296751
###################
number of nodes in this class 16177
number of nodes in this class 43398
number of nodes in this class 54279
nu

In [28]:
# statistics
print("Groups of classes for each time step: ", G.groups.keys())
print("##########################################################################")
print("size of each group: ", [len(G.groups[g]) for g in G.groups.keys()])
print("##########################################################################")
print("total number of nodes in Yelp: ", G.x.shape[0])

Groups of classes for each time step:  dict_keys([(63, 62, 98, 25, 92), (52, 90, 79, 29, 54), (82, 42, 72, 53, 27), (49, 80, 96, 88, 95), (91, 7, 48, 41, 67), (9, 61, 85, 5, 32), (99, 83, 8, 76, 31), (6, 23, 55, 94, 12), (66, 43, 36, 50, 75), (60, 74, 17, 77, 57), (11, 19, 10, 37, 71), (21, 47, 86, 44, 78), (0, 13, 2, 40, 35), (89, 70, 24, 14, 81), (3, 69, 38, 16, 45), (65, 56, 93, 97, 39), (1, 59, 15, 28, 20), (30, 18, 51, 58, 26), (87, 84, 4, 46, 68), (64, 33, 73, 34, 22)])
##########################################################################
size of each group:  [191243, 190810, 296751, 70153, 173691, 268051, 312211, 155037, 234709, 154564, 272426, 199989, 346114, 339963, 218760, 212694, 557721, 139394, 115270, 320350]
##########################################################################
total number of nodes in Yelp:  716847


In [29]:
# splits
print(G.splits.keys())

dict_keys([(63, 62, 98, 25, 92), (52, 90, 79, 29, 54), (82, 42, 72, 53, 27), (49, 80, 96, 88, 95), (91, 7, 48, 41, 67), (9, 61, 85, 5, 32), (99, 83, 8, 76, 31), (6, 23, 55, 94, 12), (66, 43, 36, 50, 75), (60, 74, 17, 77, 57), (11, 19, 10, 37, 71), (21, 47, 86, 44, 78), (0, 13, 2, 40, 35), (89, 70, 24, 14, 81), (3, 69, 38, 16, 45), (65, 56, 93, 97, 39), (1, 59, 15, 28, 20), (30, 18, 51, 58, 26), (87, 84, 4, 46, 68), (64, 33, 73, 34, 22)])


In [36]:
g1 = G.splits[list(G.splits.keys())[0]]
train_id_g1 = g1["train"]
val_id_g1 = g1["val"]
test_id_g1 = g1["test"]

g2 = G.splits[list(G.splits.keys())[1]]
train_id_g2 = g2["train"]
val_id_g2 = g2["val"]
test_id_g2 = g2["test"]
g1_id = g1["train"] + g1["val"] + g1["test"]
g2_id = g2["train"] + g2["val"] + g2["test"]

print("number of duplicats in the class group1 and group2: ",
      len(g1_id)+len(g2_id)-len(list(set(g1_id + g2_id))))

number of duplicats in the class group1 and group2:  78047


In [42]:
# sample some common nodes from group1 and group3
sample_in_g1 = g1_id[:1000]
target = [s for s in sample_in_g1 if s in g2_id]
print("number sampled of common nodes: ", len(target))

number sampled of common nodes:  406


In [49]:
# take one node
id = target[0]
print("take node", id, "as an example")
if id in train_id_g1:
    print("it is in the training set of group 1")
if id in val_id_g1:
    print("it is in the val set of group 1")
if id in test_id_g1:
    print("it is in the test set of group 1")

if id in train_id_g2:
    print("and it is in the training set of group 2")
if id in val_id_g2:
    print("and it is in the val set of group 2")
if id in test_id_g2:
    print("and it is in the test set of group 2")

take node 372363 as an example
it is in the training set of group 1
and it is in the test set of group 2


In [54]:
from torch_geometric.utils import mask
from torch_geometric.utils import subgraph
def map_edge_index(node_ids, edge_index_complete):
    # input the indices of the nodes of the subgraph in the graph,
    # transform the edge_index into the subgraph index
    num_edge = edge_index_complete.shape[1]
    map_book = {x.item(): i for i, x in enumerate(node_ids)}
    edge_index_mapped = map(lambda node: map_book[node], np.asarray(edge_index_complete.flatten()))
    edge_index = torch.Tensor(list(edge_index_mapped))
    edge_index = torch.reshape(edge_index, (2, num_edge)).long()

    return edge_index


def map_split(node_ids, split):
    # map the split ids into the subgraph
    map_book = {x.item(): i for i, x in enumerate(node_ids)}

    mapped_split = {}
    for key in split.keys():
        mapped_split[key] = list(map(lambda node: map_book[node], np.asarray(split[key])))

    return mapped_split


def prepare_sub_graph(G, key, Cross_Task_Message_Passing=False):
    # task incremental
    # prepare subgraph for one task
    # target classes for each task, note for catastrophic forgetting evaluation
    target_classes = list(key)
    print(target_classes)
    #print(G.groups[key])
    # sorted nodes ids in the group
    node_ids_g = torch.Tensor(G.groups[key]).int().long()

    # also include the edges that connect the nodes in the subgraph with those that are not in the subgraph
    node_mask = mask.index_to_mask(node_ids_g, size=G.num_nodes)

    # allow nodes from other task to pass information to the nodes for this task
    if Cross_Task_Message_Passing:
        # or operation of two boolean lists
        edge_mask = node_mask[G.edge_index[0]] + node_mask[G.edge_index[1]]
        edge_index_g = G.edge_index[:, edge_mask]
        # all nodes in the subgraph(including target nodes and their neighbors) in the original graph
        node_ids_g_all = torch.unique(edge_index_g.flatten()).long()
        # index of target nodes in the graph
        target_ids_g = [i for i, n in enumerate(node_ids_g_all) if n in node_ids_g]

        # !!!!!!!!!!!!!convert target id into the subgraph !!!!!!!!!!!!!!!!!!!
        # evaluate only on the target nodes!!!!!!!!!!!!!!!!!!!!!!!!!!!
        #target_ids_sub =
        # get the edge_index in the subgraph
        edge_index_sub = map_edge_index(node_ids_g_all, edge_index_g)

    # only nodes of this task in the subgraph
    else:
        # edge index in the original graph
        edge_index_g, _ = subgraph(node_ids_g, G.edge_index, None)
        # all neighbors are in the subgraph already
        node_ids_g_all = node_ids_g
        # all nodes are target nodes
        target_ids_g = node_ids_g_all
        # node ids in the subgraph
        target_ids_sub = np.arange(node_ids_g_all.shape[0])
        # edge index in the subgraph
        edge_index_sub = map_edge_index(node_ids_g_all, edge_index_g)

    features = G.x[node_ids_g_all]
    labels = G.y[node_ids_g_all]
    # map the ids to subgraph
    split = map_split(node_ids_g_all, G.splits[key])
    # number of nodes in the subgraph
    num_nodes = node_ids_g_all.shape[0]

    sub_g = Data(x=features,
                 edge_index=edge_index_sub,
                 y=labels)

    # node id in the subgraph
    sub_g.n_id_sub = torch.arange(num_nodes)
    # node id in the original graph
    sub_g.n_id_original = node_ids_g_all
    sub_g.split = split
    sub_g.target_classes = target_classes
    # target ids in the sub graph
    sub_g.target_ids_sub = target_ids_sub
    # target ids in the original graph
    sub_g.taget_ids_g = target_ids_g

    return sub_g

In [66]:
sub_g1 = prepare_sub_graph(G, list(G.groups.keys())[0], Cross_Task_Message_Passing=False)
sub_g2 = prepare_sub_graph(G, list(G.groups.keys())[1], Cross_Task_Message_Passing=False)

[63, 62, 98, 25, 92]
[52, 90, 79, 29, 54]


In [67]:
target_classes_groups = []
target_classes_groups.append(sub_g1.target_classes)
target_classes_groups.append(sub_g1.target_classes)

sub_g1.y = sub_g1.y[:, target_classes_groups[0]]
print("true label for first group")
print(sub_g1.y[:3])
cls_seen = torch.flatten(torch.tensor(target_classes_groups))
sub_g2.y = sub_g2.y[:, cls_seen]
print("true label for second group")
print(sub_g2.y[:3])
print("###########################")
print("meaning for the shared node, the example node id", id)
# ids of the target nodes in the subgraph1
index1 = G.groups[list(G.groups.keys())[0]].index(id)
# ids of the target nodes in the subgraph2
index2 = G.groups[list(G.groups.keys())[1]].index(id)
label1 = sub_g1.y[index1]
label2 = sub_g2.y[index2]
print("true label in t0: ", label1)
print("true label in t1: ",label2)


true label for first group
tensor([[0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0.],
        [1., 0., 0., 1., 0.]])
true label for second group
tensor([[1., 0., 0., 1., 0., 1., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 1., 0.],
        [0., 0., 1., 1., 0., 0., 0., 1., 1., 0.]])
###########################
meaning for the shared node, the example node id 372363
true label in t0:  tensor([0., 0., 0., 1., 0.])
true label in t1:  tensor([0., 0., 0., 1., 0., 0., 0., 0., 1., 0.])
