# Generate more challenging train/test splits

Step1: draw N attributes (e.g., 10 used in our paper) for each sample, see how the trianing/validation/test dataset overlaps

Step2: prune the training set, deleting those have overlapped G

Step3: save the dataset

In [1]:
import torch.optim as optim
import torch
import argparse
import numpy as np
import random
import os
from torch.nn.functional import cosine_similarity
import matplotlib.pyplot as plt
import argparse
from ogb.graphproppred import PygGraphPropPredDataset
from torch_geometric.loader import DataLoader
import pandas as pd
from probing import *
from utils.general import *
from rdkit import Chem
from rdkit.Chem import Draw

def get_args_parser():
    # Training settings
    # ======= Usually default settings
    parser = argparse.ArgumentParser(description='GNN baselines on ogbgmol* data with Pytorch Geometrics')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--device', type=int, default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--drop_ratio', type=float, default=0.5,
                        help='dropout ratio (default: 0.5)')
    parser.add_argument('--batch_size', type=int, default=64,
                        help='input batch size for training (default: 64)')
    parser.add_argument('--batch_size_train', type=int, default=64,
                        help='input batch size for training (default: 64)')
    parser.add_argument('--num_workers', type=int, default=2,
                        help='number of workers (default: 0)')
    parser.add_argument('--dataset_name', type=str, default="ogbg-molhiv",
                        help='dataset name (default: ogbg-molhiv/moltox21/molpcba)')
    parser.add_argument('--feature', type=str, default="full",
                        help='full feature or simple feature')
    parser.add_argument('--bottle_type', type=str, default='std',
                        help='bottleneck type, can be std or sem')
    # ==== Model Structure ======
        # ----- Backbone
    parser.add_argument('--backbone_type', type=str, default='gcn',
                        help='backbone type, can be gcn, gin, gcn_virtual, gin_virtual')
    parser.add_argument('--emb_dim', type=int, default=300,
                        help='dimensionality of hidden units in GNNs (default: 300)')  
    parser.add_argument('--num_layer', type=int, default=5,
                        help='number of GNN message passing layers (default: 5)')
        # ---- SEM
    parser.add_argument('--L', type=int, default=30,
                        help='No. word in SEM')
    parser.add_argument('--V', type=int, default=10,
                        help='word size in SEM')
                        
        # ---- Head-type
    parser.add_argument('--head_type', type=str, default='linear',
                        help='Head type in interaction, linear or mlp')    
    return parser

def rnd_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        cudnn.benchmark = True

rnd_seed(10086)

args = get_args_parser()
args = args.parse_args(args=[])
#args = args.parse_args()
args.device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
FIG_DIR = 'E:\\P45_disentanglement\\figures\\'



## Prepare for the probing data

In [2]:
if args.dataset_name == 'ogbg-molhiv':
    smiles_path = 'E:\\P4_Graph\\dataset\\ogbg_molhiv\\mapping\\mol.csv.gz'
    args.batch_size = 4113
    args.batch_size_train = 32901
elif args.dataset_name == 'ogbg-molpcba':
    args.batch_size = 24000
    smiles_path = 'E:\\P4_Graph\\dataset\\ogbg_molpcba\\mapping\\mol.csv.gz'
elif args.dataset_name =='ogbg-moltox21':
    args.batch_size = 783
    smiles_path = 'E:\\P4_Graph\\dataset\\ogbg_moltox21\\mapping\\mol.csv.gz'

selected_prop = ['NumSaturatedRings', 'NumAromaticRings', 'NumAromaticCarbocycles', 'fr_aniline', 
                 'fr_ketone', 'fr_bicyclic', 'fr_methoxy', 'fr_para_hydroxylation', 'fr_pyridine', 'fr_benzene']

dataset = PygGraphPropPredDataset(name = args.dataset_name)
args.num_tasks = dataset.num_tasks
split_idx = dataset.get_idx_split()
train_loader = DataLoader(dataset[split_idx["train"]], batch_size=args.batch_size_train, shuffle=False, drop_last=True,
                        num_workers=args.num_workers)
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=args.batch_size, shuffle=False, drop_last=True,
                        num_workers=args.num_workers)
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=args.batch_size, shuffle=False, drop_last=True,
                        num_workers=args.num_workers)

train_smiles = pd.read_csv(smiles_path).iloc[split_idx['train']].smiles.values
train_smiles = train_smiles.tolist()
valid_smiles = pd.read_csv(smiles_path).iloc[split_idx['valid']].smiles.values
valid_smiles = valid_smiles.tolist()
test_smiles = pd.read_csv(smiles_path).iloc[split_idx['test']].smiles.values
test_smiles = test_smiles.tolist()

train_desc_names, train_properties = compute_properties(train_smiles)
valid_desc_names, valid_properties = compute_properties(valid_smiles)
test_desc_names, test_properties = compute_properties(test_smiles)

In [3]:
train_fp, valid_fp, test_fp = [], [], []
for prop in tqdm(selected_prop):
    train_fp.append(np.array(train_properties[prop].values[:args.batch_size_train]>0,dtype=int))
    valid_fp.append(np.array(valid_properties[prop].values[:args.batch_size]>0,dtype=int))
    test_fp.append(np.array(test_properties[prop].values[:args.batch_size]>0,dtype=int))
train_fp = np.array(train_fp).transpose()
valid_fp = np.array(valid_fp).transpose()
test_fp = np.array(test_fp).transpose()

100%|████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 3333.57it/s]


In [4]:
# ----- For each sample in validation set, sweep the training set to see how many samples share similar finger print
# -- This will generate a long list of the number of duplicates each training sample have
index_template = np.arange(0,args.batch_size,1)
dup_count_list, dup_index_list = [],[]
for i in tqdm(range(args.batch_size_train)):
    anchor = train_fp[i]
    tmp1 = valid_fp.dot(anchor)==(anchor.sum())
    tmp2 = valid_fp.dot(anchor)==valid_fp.sum(1)
    dup_mask = np.logical_and(tmp1,tmp2)
    dup_count = dup_mask.sum()
    dup_index = index_template[dup_mask]
    dup_count_list.append(dup_count)
    dup_index_list.append(dup_index)

100%|██████████████████████████████████████████████████████████████████████████| 32901/32901 [00:04<00:00, 6911.81it/s]


In [16]:
# ---- The dup_count_list stores how many similar samples are in the validation set
# --- To make the problem more challenging, we need to first argsort this list, and delete those numbers with the HIGHER overlapping FP's
# --- After that, there might be less similar samples in the training set
PRUNE_IDX = int(args.batch_size_train*0.95)   # How many data samples are pruned
train_index_template = np.arange(0,args.batch_size_train,1)
train_split_index = np.array(split_idx['train'])
dup_count_list = np.array(dup_count_list)
tmp_pd = pd.DataFrame(np.column_stack((train_index_template, dup_count_list,train_split_index)),
                      index=train_index_template, columns=['index','count','split_index'])
prune_pd = tmp_pd.sort_values(by='count',ascending=False)[PRUNE_IDX:]
prune_pd_sort = prune_pd.sort_values(by='index',ascending=True)

# ----- The index of the pruned dataset is saved in this npy file, during training, we use it to select samples
sel_index = prune_pd_sort.iloc[:,0]
sel_index = sel_index.values
save_path = 'E:\\P4_Graph\\dataset\\' + args.dataset_name + '_hard.npy'
np.save(save_path, sel_index)

In [17]:
sel_index.shape

(1646,)

In [18]:
prune_pd_sort['count'].max()

6