In [None]:
from time import time
import logging
import os
import os.path as osp
import numpy as np
import time

import torch
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
from torch_geometric.utils import degree
from torch.autograd import Variable

import random
from torch.optim.lr_scheduler import StepLR


from utils import stat_graph, split_class_graphs, align_graphs
from utils import two_graphons_mixup, universal_svd
from graphon_estimator import universal_svd
from models import GIN,GCN
from tensorboardX import SummaryWriter

import argparse
logdir='/home/sayan/g-mixup/logs'
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(levelname)s: - %(message)s', datefmt='%Y-%m-%d')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tensorboard_writer = SummaryWriter(log_dir=logdir)



def prepare_dataset_x(dataset):
    if dataset[0].x is None:
        max_degree = 0
        degs = []
        for data in dataset:
            degs += [degree(data.edge_index[0], dtype=torch.long)]
            max_degree = max( max_degree, degs[-1].max().item() )
            data.num_nodes = int( torch.max(data.edge_index) ) + 1

        if max_degree < 2000:
            # dataset.transform = T.OneHotDegree(max_degree)

            for data in dataset:
                degs = degree(data.edge_index[0], dtype=torch.long)
                data.x = F.one_hot(degs, num_classes=max_degree+1).to(torch.float)
        else:
            deg = torch.cat(degs, dim=0).to(torch.float)
            mean, std = deg.mean().item(), deg.std().item()
            for data in dataset:
                degs = degree(data.edge_index[0], dtype=torch.long)
                data.x = ( (degs - mean) / std ).view( -1, 1 )
    return dataset



def prepare_dataset_onehot_y(dataset):

    y_set = set()
    for data in dataset:
        y_set.add(int(data.y))
    num_classes = len(y_set)

    for data in dataset:
        data.y = F.one_hot(data.y, num_classes=num_classes).to(torch.float)[0]
    return dataset


def mixup_cross_entropy_loss(input, target, size_average=True):
    """Origin: https://github.com/moskomule/mixup.pytorch
    in PyTorch's cross entropy, targets are expected to be labels
    so to predict probabilities this loss is needed
    suppose q is the target and p is the input
    loss(p, q) = -\sum_i q_i \log p_i
    """
    assert input.size() == target.size()
    assert isinstance(input, Variable) and isinstance(target, Variable)
    loss = - torch.sum(input * target)
    return loss / input.size()[0] if size_average else loss




def train(model, train_loader):
    model.train()
    loss_all = 0
    graph_all = 0
    for data in train_loader:
        # print( "data.y", data.y )
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data.x, data.edge_index, data.batch)
        y = data.y.view(-1, num_classes)
        loss = mixup_cross_entropy_loss(output, y)
        loss.backward()
        loss_all += loss.item() * data.num_graphs
        graph_all += data.num_graphs
        optimizer.step()
    loss = loss_all / graph_all
    return model, loss


def test(model, loader):
    model.eval()
    correct = 0
    total = 0
    loss = 0
    for data in loader:
        data = data.to(device)
        output = model(data.x, data.edge_index, data.batch)
        pred = output.max(dim=1)[1]
        y = data.y.view(-1, num_classes)
        loss += mixup_cross_entropy_loss(output, y).item() * data.num_graphs
        y = y.max(dim=1)[1]
        correct += pred.eq(y).sum().item()
        total += data.num_graphs
    acc = correct / total
    loss = loss / total
    return acc, loss

originaldataset=TUDataset(root="/home/sayan/g-mixup/data",name='REDDIT-BINARY')
dataset=list(originaldataset)
random.shuffle(dataset)
for graph in dataset:
        graph.y = graph.y.view(-1)

dataset = prepare_dataset_onehot_y(dataset)
dataset = prepare_dataset_x( dataset )
train_nums = int(len(dataset) * 0.8)
train_val_nums = int(len(dataset) * 0.9)

train_dataset = dataset[:train_nums]
val_dataset = dataset[train_nums:train_val_nums]
test_dataset = dataset[train_val_nums:]
batch_size=32
learning_rate=0.01
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
num_features = dataset[0].x.shape[1]
num_classes = dataset[0].y.shape[0] 

print("Num features",num_features)
print("num_classes",num_classes)
model = GCN(num_features=num_features, num_classes=num_classes, num_hidden=64).to(device)
    


optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4)
scheduler = StepLR(optimizer, step_size=100, gamma=0.5)




In [None]:
for epoch in range(1, 2000):
    model, train_loss = train(model, train_loader)
    train_acc = 0
    val_acc, val_loss = test(model, val_loader)
    test_acc, test_loss = test(model, test_loader)
    scheduler.step()
    print("Train loss",train_loss, "Epoch",epoch)
    print("Test Accuracy",test_acc,"Test_loss",test_loss)
    tensorboard_writer.add_scalar('Train Loss', train_loss, epoch)
    tensorboard_writer.add_scalar('Validation Loss', val_loss, epoch)
    tensorboard_writer.add_scalar('Test Loss', test_loss, epoch)
    tensorboard_writer.add_scalar('Validation Accuracy', val_acc, epoch)
    tensorboard_writer.add_scalar('Test Accuracy', test_acc, epoch)

    logger.info('Epoch: {:03d}, Train Loss: {:.6f}, Val Loss: {:.6f}, Test Loss: {:.6f},  Val Acc: {: .6f}, Test Acc: {: .6f}'.format(
        epoch, train_loss, val_loss, test_loss, val_acc, test_acc))

In [None]:
 save_path='/home/sayan/g-mixup/model/redditbinary.pth'
 torch.save(model.state_dict(), save_path)

In [None]:
load_path = '/home/sayan/g-mixup/model/redditbinary.pth'



# Initialize the model architecture

# Load the saved model weights
model.load_state_dict(torch.load(load_path))
model.eval()

In [None]:
dataset2=list(originaldataset)
random.shuffle(dataset2)
# for graph in dataset2:
#      graph.y = graph.y.view(-1)

# dataset = prepare_dataset_onehot_y(dataset2)
dataset2 = prepare_dataset_x( dataset2 )
num_features = dataset2[0].x.shape[1]
num_classes = dataset2[0].y.shape[0]
print(num_features)
print(num_classes)
#explain_loader= DataLoader(dataset2[:30], batch_size=1, shuffle=True)
newdataset=[]
model.to('cpu')
for data in dataset2: 
    #data=data.to(device)
    output = model(data.x, data.edge_index,None)
    #print("Output is",output)
    #print(output)
    #output=output.to("cpu")

    pred = output.max(dim=1)[1]
    #print(pred)
    if (pred==0):
        data.y=torch.zeros_like(data.y)
        #newdataset.append(data)
        
    if (pred==1):
        data.y=torch.ones_like(data.y)
   
    newdataset.append(data)
    #print("pred is",pred)
    #y = data.y.view(-1, num_classes)

In [None]:
print(newdataset)

In [None]:
print(newdataset[1].x)

In [20]:
def prepare_dataset_x(dataset):
    if dataset[0].x is None:
        max_degree = 0
        degs = []
        for data in dataset:
            degs += [degree(data.edge_index[0], dtype=torch.long)]
            print("Type of degs is ", type(degs))
            print("Type of _max_deg, degs[-1].max().item() is ", type(max_degree), type(degs[-1].max().item()))
            max_degree = max( max_degree, torch.tensor(degs[-1]).max())
            print("Type of max degree is ", type(max_degree))
            data.num_nodes = int( torch.max(data.edge_index) ) + 1

        if max_degree < 2000:
            # dataset.transform = T.OneHotDegree(max_degree)

            for data in dataset:
                degs = degree(data.edge_index[0], dtype=torch.long)
                data.x = F.one_hot(degs, num_classes=max_degree+1).to(torch.float)
        else:
            deg = torch.cat(degs, dim=0).to(torch.float)
            mean, std = deg.mean().item(), deg.std().item()
            for data in dataset:
                degs = degree(data.edge_index[0], dtype=torch.long)
                data.x = ( (degs - mean) / std ).view( -1, 1 )
    return dataset

In [None]:

classgraphs=split_class_graphs(newdataset)

avg_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density = stat_graph(newdataset)
resolution = int(median_num_nodes)
#print("resolution is",resolution)
graphons=[]
for label,graphs in classgraphs:
    #print("Label is",label)
    #print("graph is",graphs[0])
    align_graphs_list, normalized_node_degrees, max_num, min_num = align_graphs(
                    graphs, padding=True, N=resolution)
    #print("Aligned adj",align_graphs_list[8].shape,align_graphs_list[56].shape)
    graphon = universal_svd(align_graphs_list, threshold=0.2)
    #print("Graphon is ",graphon.shape)

    graphons.append((label, graphon))
#two_graphons = random.sample(graphons, 2)
print(graphons)
two_graphons= [graphons[0] , graphons[1]]
print(graphons[0][0], graphons[1][0])




# ng=two_graphons_mixup(two_graphons,la=1.0,num_sample=1)
# print(new_graph)
# print(ng)


In [15]:
edgeind=new_graph[1].edge_index[0]
print(degree(edgeind) , type(degree(edgeind)), type(degree(edgeind).max().item()))
edgeind2=dataset3[1].edge_index[0]
print(degree(edgeind2), type(degree(edgeind2)), type(degree(edgeind2).max().item()))

tensor([115.,  41.,  12.,  11.,   8.,   6.,   9.,   5.,   4.,   9.,   4.,   6.,
          5.,  10.,   5.,   6.,   5.,   4.,   3.,   7.,   6.,   4.,   3.,   4.,
          3.,   4.,   5.,   4.,   2.,   3.,   7.,   5.,   3.,   5.,   1.,   2.,
          2.,   3.,   3.,   3.,   4.,   1.,   2.,   2.,   4.,   6.,   2.,   2.,
          3.,   4.,   3.,   3.,   3.,   2.,   2.,   3.,   4.,   4.,   1.,   2.,
          5.,   5.,   2.,   3.,   1.,   1.,   2.,   2.,   1.,   2.,   1.,   2.,
          3.,   3.,   3.,   3.,   2.,   3.,   4.,   1.,   1.,   3.,   3.,   2.,
          3.,   1.,   2.,   2.,   1.,   2.,   1.,   5.,   2.,   2.,   2.,   3.,
          2.,   4.,   1.,   2.,   1.,   1.,   1.,   1.,   2.,   4.,   2.,   2.,
          2.,   2.,   3.,   4.,   3.,   3.,   2.,   3.,   2.,   3.,   2.,   4.,
          2.,   1.,   3.,   1.,   1.,   1.,   4.,   1.,   1.,   1.,   3.,   1.,
          2.,   4.,   1.,   1.,   1.,   3.,   1.,   1.,   1.,   2.,   1.,   3.,
          1.,   3.,   1.,   1.,   2.,   

In [21]:
#new_graph=[]

new_graph = two_graphons_mixup(two_graphons, la=1.0, num_sample=10)
print("Label of new graph is",new_graph[0].y)
dataset3=list(originaldataset)

newlist=list(originaldataset)+new_graph
print(len(newlist))
print(newlist[-1].x)
newlist=prepare_dataset_x(newlist)
explainergraph=newlist[len(dataset3):]
print(explainergraph[0].x)
#print(new_graph[0].x,new_graph[0].y,new_graph[0].edge_index)
max=0
for data in explainergraph:
    targetoutput=model(data.x,data.edge_index,None)
    print(targetoutput)
    soft=torch.nn.Softmax(dim=1)
    problitites=soft(targetoutput)

    
    targetpred = targetoutput.max(dim=1)[1]
    print("Label of explainer graph is",targetpred)
#newlist=[]



Label of new graph is tensor([0.])
2010
None
Type of degs is  <class 'list'>
Type of _max_deg, degs[-1].max().item() is  <class 'int'> <class 'int'>


  max_degree = max( max_degree, torch.tensor(degs[-1]).max())


TypeError: 'int' object is not callable

In [None]:
print(two_graphons)

In [None]:
from torch_geometric.data import Data
graph=Data()
label=np.array(23)
sample_graph_label = torch.from_numpy(label).type(torch.float32)
graph.y=sample_graph_label
print(graph)