In [4]:
import argparse
import os
import time
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.multiprocessing as mp

In [5]:
from graphsaint.sampler import SAINTNodeSampler, SAINTEdgeSampler, SAINTRandomWalkSampler
from graphsaint.config import CONFIG
from graphsaint.modules import GCNNet
from graphsaint.utils import Logger, evaluate, save_log_dir, load_data, calc_f1
import warnings

In [6]:
from Juyeong.aug import HLoss, Jensen_Shannon, generate_aug_graph

graphsaint > tran_sampling.py > main()

In [7]:
a = {
        'aggr': 'concat', 'arch': '1-0-1-0', 'dataset': 'ppi', 'dropout': 0, 'edge_budget': 4000, 'length': 2,
        'log_dir': 'none', 'lr': 0.005, 'decay': 0.0005, 'n_epochs': 50, 'n_hidden': 512, 'no_batch_norm': False, 'node_budget': 6000,
        'num_subg': 50, 'num_roots': 3000, 'sampler': 'node', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 0,
        'num_subg_sampler': 10000, 'batch_size_sampler': 200, 'num_workers': 8, 'full': True,
        'sigma_delta_e': 0.03, 'sigma_delta_v': 0.03, 'mu_e': 0.6, 'mu_v': 0.2, 'lam1_e': 1, 'lam1_v': 1, 'lam2_e': 0.0, 'lam2_v': 0.0,
        'a_e': 100, 'b_e': 1, 'a_v': 100, 'b_v': 1, 'kl': 2.0, 'h': 0.2, 'online': False, 'gpu': 0,'task': 'ppi_n'
}
multilabel =  True

from collections import namedtuple
A = namedtuple('a', a)
args = A(**a)

In [8]:
# load and preprocess dataset
data = load_data(args, multilabel)
g = data.g
train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask']
test_mask = g.ndata['test_mask']
labels = g.ndata['label']

train_nid = data.train_nid

in_feats = g.ndata['feat'].shape[1]
n_classes = data.num_classes
n_nodes = g.num_nodes()
n_edges = g.num_edges()

n_train_samples = train_mask.int().sum().item()
n_val_samples = val_mask.int().sum().item()
n_test_samples = test_mask.int().sum().item()

In [9]:
data

Dataset(num_classes=121, train_nid=array([   0,    1,    2, ..., 9713, 9714, 9715]), g=Graph(num_nodes=14755, num_edges=450540,
      ndata_schemes={'feat': Scheme(shape=(50,), dtype=torch.float32), 'label': Scheme(shape=(121,), dtype=torch.float32), 'train_mask': Scheme(shape=(), dtype=torch.bool), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool)}
      edata_schemes={}))

In [10]:
kwargs = {
        'dn': args.dataset, 'g': g, 'train_nid': train_nid, 'num_workers_sampler': args.num_workers_sampler,
        'num_subg_sampler': args.num_subg_sampler, 'batch_size_sampler': args.batch_size_sampler,
        'online': args.online, 'num_subg': args.num_subg, 'full': args.full
    }

In [11]:
if args.sampler == "node":
        saint_sampler = SAINTNodeSampler(args.node_budget, **kwargs)
elif args.sampler == "edge":
    saint_sampler = SAINTEdgeSampler(args.edge_budget, **kwargs)
elif args.sampler == "rw":
    saint_sampler = SAINTRandomWalkSampler(args.num_roots, args.length, **kwargs)
else:
    raise NotImplementedError
loader = DataLoader(saint_sampler, collate_fn=saint_sampler.__collate_fn__, batch_size=1,
                    shuffle=True, num_workers=args.num_workers, drop_last=False)

[3.1428573 2.9642859 3.        ... 3.        3.2807019 3.1612902]
[0.00028198 0.00228718 0.00027085 ... 0.00036113 0.00128654 0.00033201]
The number of subgraphs is:  200


In [12]:
# set device for dataset tensors
cpu_flag = False

if args.gpu < 0:
    cuda = False
else:
    cuda = True
    torch.cuda.set_device(args.gpu)
    val_mask = val_mask.cuda()
    test_mask = test_mask.cuda()
    if not cpu_flag:
        g = g.to('cuda:{}'.format(args.gpu))

print('labels shape:', g.ndata['label'].shape)
print("features shape:", g.ndata['feat'].shape)

model = GCNNet(
    in_dim=in_feats,
    hid_dim=args.n_hidden,
    out_dim=n_classes,
    arch=args.arch,
    dropout=args.dropout,
    batch_norm=not args.no_batch_norm,
    aggr=args.aggr
)

if cuda:
    model.cuda()

# logger and so on
log_dir = save_log_dir(args)
logger = Logger(os.path.join(log_dir, 'loggings'))
logger.write(args)

# use optimizer
optimizer = torch.optim.Adam(model.parameters(),
                                lr=args.lr)

# set train_nids to cuda tensor
if cuda:
    train_nid = torch.from_numpy(train_nid).cuda()
    print("GPU memory allocated before training(MB)",
            torch.cuda.memory_allocated(device=train_nid.device) / 1024 / 1024)
start_time = time.time()
best_f1 = -1

labels shape: torch.Size([14755, 121])
features shape: torch.Size([14755, 50])
a(aggr='concat', arch='1-0-1-0', dataset='ppi', dropout=0, edge_budget=4000, length=2, log_dir='none', lr=0.005, decay=0.0005, n_epochs=50, n_hidden=512, no_batch_norm=False, node_budget=6000, num_subg=50, num_roots=3000, sampler='node', use_val=True, val_every=1, num_workers_sampler=0, num_subg_sampler=10000, batch_size_sampler=200, num_workers=8, full=True, sigma_delta_e=0.03, sigma_delta_v=0.03, mu_e=0.6, mu_v=0.2, lam1_e=1, lam1_v=1, lam2_e=0.0, lam2_v=0.0, a_e=100, b_e=1, a_v=100, b_v=1, kl=2.0, h=0.2, online=False, gpu=0, task='ppi_n')
GPU memory allocated before training(MB) 26.75830078125


In [13]:
data = torch.Tensor([[1], [2], [3]]).to(torch.cuda.current_device())
d = DataLoader(data)

def _augment(g):
    return g*10

# subg_original = torch.empty(1).to(torch.cuda.current_device())
# subg_aug = torch.empty(1).to(torch.cuda.current_device())

for epoch in range(2):
    for j, subg in enumerate(d):
        if epoch == 0:
            subg_original = subg#.clone()
        else:
            # subg_original = torch.cat((subg_original, subg))
            subg_original = subg_aug#.clone()

        if j == 0:
            subg_aug = _augment(subg)      # simulate augmentation
        else:
            subg_aug = torch.cat((subg_aug, _augment(subg)))

        print("# epoch: ", epoch)
        print("# j: ", j)
        print("Subg original", subg_original)
        print("Subg aug", subg_aug)
        print("-------")

# epoch:  0
# j:  0
Subg original tensor([[1.]], device='cuda:0')
Subg aug tensor([[10.]], device='cuda:0')
-------
# epoch:  0
# j:  1
Subg original tensor([[2.]], device='cuda:0')
Subg aug tensor([[10.],
        [20.]], device='cuda:0')
-------
# epoch:  0
# j:  2
Subg original tensor([[3.]], device='cuda:0')
Subg aug tensor([[10.],
        [20.],
        [30.]], device='cuda:0')
-------
# epoch:  1
# j:  0
Subg original tensor([[10.],
        [20.],
        [30.]], device='cuda:0')
Subg aug tensor([[10.]], device='cuda:0')
-------
# epoch:  1
# j:  1
Subg original tensor([[10.]], device='cuda:0')
Subg aug tensor([[10.],
        [20.]], device='cuda:0')
-------
# epoch:  1
# j:  2
Subg original tensor([[10.],
        [20.]], device='cuda:0')
Subg aug tensor([[10.],
        [20.],
        [30.]], device='cuda:0')
-------


In [14]:
import dgl

u, v = torch.tensor([0, 0, 0, 1]), torch.tensor([1, 2, 3, 3])
u2, v2 = torch.tensor([1, 2, 3, 4]), torch.tensor([4, 5, 6, 7])
g = dgl.graph((u, v)).to(torch.cuda.current_device())
g2 = dgl.graph((u2, v2)).to(torch.cuda.current_device())

a = [g, g2]
b = a
a = []
print(b, a)

[Graph(num_nodes=4, num_edges=4,
      ndata_schemes={}
      edata_schemes={}), Graph(num_nodes=8, num_edges=4,
      ndata_schemes={}
      edata_schemes={})] []


In [15]:
# subg_t = [[] for _ in range(2)]
# subg_t = torch.empty((2, 1), dtype=torch.int32).to(torch.cuda.current_device())

h_loss_op = HLoss()
js_loss_op = Jensen_Shannon()

current_aug = []

for epoch in range(args.n_epochs):
    if epoch == 3:
        break ##########


    if epoch > 0:
        prev_aug = current_aug          

    for j, subg in enumerate(loader):
        if j == 2:
            break ###############


        if cuda:
            subg = subg.to(torch.cuda.current_device())
        # Augment Subgraph

        if epoch == 0:
            target = subg
        else:
            target = prev_aug[j]

        auged_subg, delta_G_e, delta_G_v, delta_G_e_aug, delta_G_v_aug \
            = generate_aug_graph(target, model,
                                    args.sigma_delta_e, args.sigma_delta_v, args.mu_e, args.mu_v,
                                    args.lam1_e, args.lam1_v, args.lam2_e, args.lam2_v,
                                    args.a_e, args.b_e, args.a_v, args.b_v)
        
        if j == 0:
            current_aug = []
        current_aug.append(auged_subg)#.clone()
            
    

  assert input.numel() == input.storage().size(), (


KeyboardInterrupt: 