In [1]:
import sys
import os
import gc

os.chdir('..')
sys.path.append(os.getcwd())

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torch_geometric.data import Data
from torch_sparse import SparseTensor
import torch
from torch_geometric.nn import knn_graph
import typing

from datasetLoader import DATASETS, ClusteringDataset
from utils import config

In [3]:
def knn_same_class_rate(data:torch.Tensor, label:np.ndarray, k_max=1, distance:typing.Literal["eculidean", "cosine"]="eculidean"):
    """
    Calculate the rate of same class nodes in the k-nearest neighbors of each node
    """
    data = data.to("cuda" if torch.cuda.is_available() else "cpu")
    if distance == "eculidean":
        dist = torch.cdist(data, data, p=2)
    elif distance == "cosine":
        X_norm = data / data.norm(dim=1, keepdim=True)
        dist = 1 - torch.mm(X_norm, X_norm.t())
    elif distance == "NormCos":
        X_binary = torch.where(data > 0, 1, 0).float()
        X_norm = X_binary / X_binary.norm(dim=1, keepdim=True)
        dist = 1 - torch.mm(X_norm, X_norm.t())
    elif distance == "Heat":
        dist = 1 - torch.exp(-0.5 * torch.cdist(data, data, p=2) ** 2)
    elif distance == "Manhattan":
        dist = torch.cdist(data, data, p=1)
    else:
        raise ValueError("Invalid distance type")
    dist = dist.cpu().detach().numpy()
    same_class_rate_list = []
    inds = {}
    for k in range(1, k_max+1):
        inds[k] = []
        
    for i in range(dist.shape[0]):
        ind = np.argpartition(dist[i, :], k_max+1)[:k_max+1]
        for k in range(1, k_max+1):
            inds[k].append(ind[k])
            
    for k in range(1, k_max+1):
        inds_k = np.array(inds[k])
        same_class_rate = np.array([np.sum(label[inds_k[i]] == label[i]) for i in range(len(label))])
        same_class_rate = np.mean(same_class_rate)
        same_class_rate_list.append(same_class_rate)
    del dist, inds
    return same_class_rate_list

In [4]:
ignore_datasets = [
    'obgn_papers100M', # 50+G dataset too big
]
used_datasets = [
    # "Cora",
]
k_max = 5
dist_list = [
    "eculidean",
    "cosine",
    "NormCos",
    "Heat",
    "Manhattan",
]

In [None]:
dataset_dict = {"dataset_name": [], "n_labeled_samples": [], "n_unlabeled_samples": [], "n_features": [], "n_classes": []}
graph_dataset_dict = {"dataset_name": [], "n_nodes": [], "n_edges": [], "weighted": []}

knn_same_class_rate_dicts = {}
for dist in dist_list:
    knn_same_class_rate_dicts[dist] = {"dataset_name": []}
    for k in range(1, k_max+1):
        knn_same_class_rate_dicts[dist][f"knn_same_class_rate_{k}"] = []

for dataset_name, Dataset in DATASETS.items():
    if dataset_name in ignore_datasets:
        continue
    if len(used_datasets) > 0 and dataset_name not in used_datasets:
        continue
    cfg = config.init_by_path("./cfg/example.cfg")
    dataset:ClusteringDataset = Dataset(cfg, ["seq"])
    print(f"Processing {dataset.name}...")
    dataset_dict["dataset_name"].append(dataset.name)
    dataset_dict["n_labeled_samples"].append(dataset.label_length)
    dataset_dict["n_unlabeled_samples"].append(dataset.unlabel_length)
    dataset_dict["n_features"].append(dataset.input_dim)
    dataset_dict["n_classes"].append(dataset.num_classes)
    for dist in dist_list:
        knn_same_class_rate_dicts[dist]["dataset_name"].append(dataset.name)
    
    for dist in dist_list:
        print(f"Calculating knn_same_class_rate_{dist}...")
        try:
            same_class_rate_list = knn_same_class_rate(dataset.label_data, dataset.label, k_max, dist)
            gc.collect()
        except Exception as e:
            print(f"Calculating knn_same_class_rate_{dist} failed: {e}")
            same_class_rate_list = [np.nan for _ in range(k_max)]
        for k, same_class_rate in enumerate(same_class_rate_list):
            knn_same_class_rate_dicts[dist][f"knn_same_class_rate_{k+1}"].append(same_class_rate)
   
    if dataset._graph is not None:
        # this is a graph dataset
        print(f"Processing {dataset.name} as a graph dataset...")
        data:Data = dataset._graph
        graph_dataset_dict["dataset_name"].append(dataset.name)
        graph_dataset_dict["n_nodes"].append(data.num_nodes)
        graph_dataset_dict["n_edges"].append(data.num_edges)
        if data.edge_attr is not None or isinstance(data.edge_index, SparseTensor) and (data.edge_index.storage.value() is not None) and (data.edge_index.storage.value().unique().numel() > 1):
            graph_dataset_dict["weighted"].append(True)
        else:
            graph_dataset_dict["weighted"].append(False)


Processing MNIST_seq_resnet50...
Calculating knn_same_class_rate_eculidean...
Calculating knn_same_class_rate_cosine...
Calculating knn_same_class_rate_NormCos...
Calculating knn_same_class_rate_Heat...
Calculating knn_same_class_rate_Manhattan...
Processing FashionMNIST_seq_resnet50...
Calculating knn_same_class_rate_eculidean...
Calculating knn_same_class_rate_cosine...
Calculating knn_same_class_rate_NormCos...
Calculating knn_same_class_rate_Heat...
Calculating knn_same_class_rate_Manhattan...
Files already downloaded and verified
Files already downloaded and verified
Processing CIFAR10_seq_resnet50...
Calculating knn_same_class_rate_eculidean...
Calculating knn_same_class_rate_cosine...
Calculating knn_same_class_rate_NormCos...
Calculating knn_same_class_rate_Heat...
Calculating knn_same_class_rate_Manhattan...
Calculating knn_same_class_rate_Manhattan failed: CUDA error: invalid configuration argument
CUDA kernel errors might be asynchronously reported at some other API call, so

In [6]:
df_dataset = pd.DataFrame(dataset_dict)
df_graph_dataset = pd.DataFrame(graph_dataset_dict)
print("Dataset information:")
display(df_dataset)
print("Graph dataset information:")
display(df_graph_dataset)
for dist in knn_same_class_rate_dicts:
    df_knn_same_class_rate = pd.DataFrame(knn_same_class_rate_dicts[dist])
    print(f"Knn same class rate for distance {dist}:")
    display(df_knn_same_class_rate)

Dataset information:


Unnamed: 0,dataset_name,n_labeled_samples,n_unlabeled_samples,n_features,n_classes
0,MNIST_seq_resnet50,70000,0,2048,10
1,FashionMNIST_seq_resnet50,70000,0,2048,10
2,CIFAR10_seq_resnet50,60000,0,2048,10
3,CIFAR100_seq_resnet50,60000,0,2048,20
4,STL10_seq_hog_color,13000,100000,1488,10
5,USPS_seq_resnet50,9298,0,2048,10
6,Reuters10K,10000,0,2000,4
7,XYh5_scRNA_Baron_human,8569,0,17499,14
8,ACM,3025,0,1870,3
9,DBLP,4057,0,334,4


Graph dataset information:


Unnamed: 0,dataset_name,n_nodes,n_edges,weighted
0,ACM,3025,13128,False
1,DBLP,4057,3528,False
2,Cora,2708,5278,False
3,Citeseer,3327,4552,False
4,Pubmed,19717,44324,False
5,Wiki,2405,7679,False
6,BAT,131,1003,False
7,EAT,399,5993,False
8,UAT,1190,13599,False
9,Amazon_Computers,13752,245861,False


Knn same class rate for distance eculidean:


Unnamed: 0,dataset_name,knn_same_class_rate_1,knn_same_class_rate_2,knn_same_class_rate_3,knn_same_class_rate_4,knn_same_class_rate_5
0,MNIST_seq_resnet50,0.947714,0.470271,0.31069,0.230054,0.1822
1,FashionMNIST_seq_resnet50,0.890886,0.438607,0.287795,0.212729,0.167274
2,CIFAR10_seq_resnet50,0.854067,0.417817,0.271278,0.198658,0.155367
3,CIFAR100_seq_resnet50,0.771733,0.369808,0.236061,0.169712,0.13
4,STL10_seq_hog_color,0.477769,0.208654,0.121308,0.080596,0.054985
5,USPS_seq_resnet50,0.961174,0.476447,0.315803,0.234325,0.185653
6,Reuters10K,0.9462,0.4678,0.307833,0.226525,0.17788
7,XYh5_scRNA_Baron_human,0.337846,0.139339,0.074143,0.043675,0.028008
8,ACM,0.59438,0.279008,0.173223,0.119917,0.090579
9,DBLP,0.740695,0.339537,0.213212,0.148509,0.109933


Knn same class rate for distance cosine:


Unnamed: 0,dataset_name,knn_same_class_rate_1,knn_same_class_rate_2,knn_same_class_rate_3,knn_same_class_rate_4,knn_same_class_rate_5
0,MNIST_seq_resnet50,0.950214,0.4712,0.310967,0.231489,0.182843
1,FashionMNIST_seq_resnet50,0.893257,0.439843,0.288595,0.212789,0.167869
2,CIFAR10_seq_resnet50,0.865067,0.423558,0.275733,0.202154,0.15864
3,CIFAR100_seq_resnet50,0.787183,0.375775,0.242406,0.174412,0.134733
4,STL10_seq_hog_color,0.481308,0.208808,0.12359,0.081173,0.056046
5,USPS_seq_resnet50,0.963541,0.477845,0.317237,0.235104,0.186363
6,Reuters10K,0.9464,0.46775,0.3078,0.226425,0.17794
7,XYh5_scRNA_Baron_human,0.853542,0.419594,0.276928,0.210089,0.168608
8,ACM,0.820826,0.395702,0.255537,0.188595,0.146314
9,DBLP,0.787774,0.375154,0.239997,0.172726,0.133596


Knn same class rate for distance NormCos:


Unnamed: 0,dataset_name,knn_same_class_rate_1,knn_same_class_rate_2,knn_same_class_rate_3,knn_same_class_rate_4,knn_same_class_rate_5
0,MNIST_seq_resnet50,0.681171,0.321379,0.2006,0.141704,0.105903
1,FashionMNIST_seq_resnet50,0.5529,0.243443,0.145795,0.096725,0.068123
2,CIFAR10_seq_resnet50,0.42835,0.167,0.093972,0.057375,0.037963
3,CIFAR100_seq_resnet50,0.357433,0.129492,0.067117,0.03585,0.020137
4,STL10_seq_hog_color,0.420462,0.171154,0.095744,0.059365,0.038985
5,USPS_seq_resnet50,0.769951,0.364863,0.233563,0.165546,0.126178
6,Reuters10K,0.9511,0.47035,0.3116,0.2309,0.18182
7,XYh5_scRNA_Baron_human,0.913876,0.446435,0.295173,0.218549,0.173486
8,ACM,0.820826,0.395702,0.255537,0.188595,0.146314
9,DBLP,0.787774,0.375154,0.239997,0.172726,0.133596


Knn same class rate for distance Heat:


Unnamed: 0,dataset_name,knn_same_class_rate_1,knn_same_class_rate_2,knn_same_class_rate_3,knn_same_class_rate_4,knn_same_class_rate_5
0,MNIST_seq_resnet50,0.897614,0.434179,0.278843,0.201725,0.157414
1,FashionMNIST_seq_resnet50,0.469371,0.168214,0.048305,0.030132,0.021963
2,CIFAR10_seq_resnet50,0.408667,0.133025,0.061411,0.045733,0.036487
3,CIFAR100_seq_resnet50,0.37475,0.112442,0.045961,0.034021,0.027723
4,STL10_seq_hog_color,0.335769,0.122731,0.081513,0.060096,0.029446
5,USPS_seq_resnet50,0.760916,0.348193,0.229046,0.141509,0.103893
6,Reuters10K,0.5756,0.2608,0.176667,0.0838,0.07004
7,XYh5_scRNA_Baron_human,0.538453,0.222313,0.057455,0.059954,0.030855
8,ACM,0.511736,0.216529,0.145124,0.110992,0.063802
9,DBLP,0.730836,0.335716,0.206639,0.144195,0.105447


Knn same class rate for distance Manhattan:


Unnamed: 0,dataset_name,knn_same_class_rate_1,knn_same_class_rate_2,knn_same_class_rate_3,knn_same_class_rate_4,knn_same_class_rate_5
0,MNIST_seq_resnet50,0.904129,0.438736,0.282652,0.205289,0.160491
1,FashionMNIST_seq_resnet50,0.520929,0.202221,0.078081,0.052486,0.040063
2,CIFAR10_seq_resnet50,,,,,
3,CIFAR100_seq_resnet50,,,,,
4,STL10_seq_hog_color,0.498154,0.219654,0.128462,0.084731,0.060815
5,USPS_seq_resnet50,0.958916,0.476447,0.316233,0.235669,0.1869
6,Reuters10K,0.8668,0.4139,0.266433,0.192675,0.14824
7,XYh5_scRNA_Baron_human,0.471,0.203116,0.115844,0.076205,0.057346
8,ACM,0.59438,0.279008,0.173223,0.119917,0.090579
9,DBLP,0.740695,0.339537,0.213212,0.148509,0.109933
