In [1]:
import torch.optim as optim
import torch
import argparse
import numpy as np
import random
import os
from torch.nn.functional import cosine_similarity
import matplotlib.pyplot as plt
import argparse
from ogb.graphproppred import PygGraphPropPredDataset
from torch_geometric.loader import DataLoader
import pandas as pd
from probing import *
from utils.general import *

def get_args_parser():
    # Training settings
    # ======= Usually default settings
    parser = argparse.ArgumentParser(description='GNN baselines on ogbgmol* data with Pytorch Geometrics')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--device', type=int, default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--drop_ratio', type=float, default=0.5,
                        help='dropout ratio (default: 0.5)')
    parser.add_argument('--batch_size', type=int, default=64,
                        help='input batch size for training (default: 64)')
    parser.add_argument('--num_workers', type=int, default=2,
                        help='number of workers (default: 0)')
    parser.add_argument('--dataset_name', type=str, default="ogbg-molhiv",
                        help='dataset name (default: ogbg-molhiv/moltox21/molpcba)')
    parser.add_argument('--feature', type=str, default="full",
                        help='full feature or simple feature')
    parser.add_argument('--bottle_type', type=str, default='std',
                        help='bottleneck type, can be std or sem')
    # ==== Model Structure ======
        # ----- Backbone
    parser.add_argument('--backbone_type', type=str, default='gcn',
                        help='backbone type, can be gcn, gin, gcn_virtual, gin_virtual')
    parser.add_argument('--emb_dim', type=int, default=300,
                        help='dimensionality of hidden units in GNNs (default: 300)')  
    parser.add_argument('--num_layer', type=int, default=5,
                        help='number of GNN message passing layers (default: 5)')
        # ---- SEM
    parser.add_argument('--L', type=int, default=30,
                        help='No. word in SEM')
    parser.add_argument('--V', type=int, default=10,
                        help='word size in SEM')
                        
        # ---- Head-type
    parser.add_argument('--head_type', type=str, default='linear',
                        help='Head type in interaction, linear or mlp')    
    return parser


args = get_args_parser()
args = args.parse_args(args=[])
#args = args.parse_args()
args.device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")

## Prepare for the probing data

In [2]:
if args.dataset_name == 'ogbg-molhiv':
    smiles_path = 'E:\\P4_Graph\\dataset\\ogbg_molhiv\\mapping\\mol.csv.gz'
    args.batch_size = 4113
elif args.dataset_name == 'ogbg-molpcba':
    smiles_path = 'E:\\P4_Graph\\dataset\\ogbg_molpcba\\mapping\\mol.csv.gz'

selected_prop = ['NumSaturatedRings', 'NumAromaticRings', 'NumAromaticCarbocycles', 'fr_aniline', 'fr_ketone_Topliss', 
                 'fr_ketone', 'fr_bicyclic', 'fr_methoxy', 'fr_para_hydroxylation', 'fr_pyridine', 'fr_benzene']

dataset = PygGraphPropPredDataset(name = args.dataset_name)
args.num_tasks = dataset.num_tasks
split_idx = dataset.get_idx_split()
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=args.batch_size, shuffle=False, drop_last=False,
                        num_workers=args.num_workers)
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=args.batch_size, shuffle=False, drop_last=False,
                        num_workers=args.num_workers)

valid_smiles = pd.read_csv(smiles_path).iloc[split_idx['valid']].smiles.values
valid_smiles = valid_smiles.tolist()
test_smiles = pd.read_csv(smiles_path).iloc[split_idx['test']].smiles.values
test_smiles = test_smiles.tolist()

valid_desc_names, valid_properties = compute_properties(valid_smiles)
test_desc_names, test_properties = compute_properties(test_smiles)

## Prepare the model

In [3]:
# ====== Generate features
def get_features(args, model,loader):
    with torch.no_grad():
        for step, batch in enumerate(loader):
            batch = batch.to(args.device)
            msg, hid = model.task_forward(batch)
        msg = msg.reshape(msg.shape[0],-1)
        return msg.cpu()
    
def model_probing(args, model, run_random=False):
# ====== Generate representations
    embd_valid = get_features(args,model,valid_loader)
    embd_test = get_features(args,model,test_loader)

    prop = selected_prop[0]
    performs, rnd_performs = [], []

    for prop in tqdm(selected_prop):
        x_train = embd_valid
        y_train = valid_properties[prop].values.copy()
        #y_train[y_train>1] = 1 # binarize
        x_test = embd_test
        y_test = test_properties[prop].values.copy()
        #y_test[y_test>1] = 1 # binarize

        performs.append(linear_probing_regression(embedding_train=x_train, y_train=y_train, embeding_test=x_test, 
                                           y_test=y_test, seed=args.seed, scale=True))
        if run_random:
            np.random.shuffle(y_train)
            np.random.shuffle(y_test)
            rnd_performs.append(linear_probing_regression(embedding_train=x_train, y_train=y_train, embeding_test=x_test, 
                                               y_test=y_test, seed=args.seed, scale=True))
        else:
            rnd_performs = performs
    probing_performance = np.array(performs)
    random_performance = np.array(rnd_performs)
    #perf = pd.DataFrame(np.column_stack((probing_performance, random_performance)),
    #                index=selected_prop,
    #                columns=['MAE', 'AUC', 'rnd_MAE', 'rnd_AUC'])
    return probing_performance, random_performance

In [4]:
RESULT_PATH = 'E:\\P4_Graph\\results\\gcn_std\\ogbg-molhiv'
ckp_seed = os.path.join(RESULT_PATH, 'model_seed.pt')
ckp_baseline = os.path.join(RESULT_PATH, 'model_gen_00.pt')

# ====== Init a model, load parameters
args.bottle_type = 'std'
model_seed = get_init_net(args)
model_seed.load_state_dict(torch.load(ckp_seed),strict=True)
perf_seed, rnd = model_probing(args, model_seed, run_random=True)


model_base = get_init_net(args)
model_base.load_state_dict(torch.load(ckp_baseline),strict=True)
perf_base, _ = model_probing(args, model_base, run_random=False)

100%|██████████████████████████████████████████████████████████████████████████████████| 11/11 [00:03<00:00,  2.98it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 11/11 [00:01<00:00,  5.87it/s]


In [None]:
ckp_seed = os.path.join(RESULT_PATH, 'model_seed.pt')
ckp_baseline = os.path.join(RESULT_PATH, 'model_gen_00.pt')

In [7]:
perf = pd.DataFrame(np.column_stack((perf_base, perf_seed ,rnd)),
                      index=selected_prop,
                       columns=['baseline', 'initial', 'random'])

In [8]:
perf

Unnamed: 0,baseline,initial,random
NumSaturatedRings,0.742346,0.875128,1.101936
NumAromaticRings,0.898742,1.078103,1.452289
NumAromaticCarbocycles,0.793294,1.054372,1.226484
fr_aniline,0.541981,0.626742,0.625375
fr_ketone_Topliss,0.37445,0.397991,0.391622
fr_ketone,0.466354,0.531511,0.513402
fr_bicyclic,1.415504,1.637321,1.703217
fr_methoxy,0.435077,0.522924,0.50186
fr_para_hydroxylation,0.479379,0.536718,0.554054
fr_pyridine,0.30777,0.355498,0.332152
