# CNA | Final Project | P2
### Mohsen Ebadpour | 400131080 | m.ebadpour@aut.ac.ir

In [2]:
import torch 
from torch import nn
from torch import optim
from torch.utils.data import DataLoader,Dataset
from torch.nn import functional as F

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd 
from sklearn.metrics import confusion_matrix,accuracy_score
from tqdm import tqdm


import torch_geometric 
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric import transforms as T
from torch_geometric.nn import GCNConv,Linear,GATConv,GATv2Conv,SAGEConv, GATConv,ChebConv
from torch.utils.data import random_split
from torch_geometric.nn import GraphConv, TopKPooling
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
from torch_geometric.nn.pool.topk_pool import topk,filter_adj
LAYERS = {
    GCNConv:"GCNConv",
    GATConv: "GATConv",
    SAGEConv:"SAGEConv",
    ChebConv:"ChebConv"
}



  from .autonotebook import tqdm as notebook_tqdm


# Report Dataset

In [3]:
torch.cuda.manual_seed(777)
def ReportDataset(name="ENZYMES"):
    _dataset = TUDataset(name=name,root="./{0}".format(name))
    data = {}
    data["Node`s Featrue"] = _dataset.num_node_features
    data["Edge`s Feature"] = _dataset.num_edge_features
    data["Classes"] = _dataset.num_classes
    data["No. graphs"] = len(_dataset)
    #_dataset.transform = T.NormalizeFeatures()
    
    nodes,edges = 0,0
    for dataset in _dataset:
        nodes += dataset.num_nodes
        edges += dataset.num_edges 
    
    data["Mean Edges No."] = round(edges/len(_dataset),2)    
    data["Mean Nodes No."] = round(nodes/len(_dataset),2)  
    return data , _dataset
    
    
name = "MUTAG"
Names = [name]
data,dataset = ReportDataset(name)
pd.DataFrame([data]).set_axis(Names,axis=0)

Unnamed: 0,Node`s Featrue,Edge`s Feature,Classes,No. graphs,Mean Edges No.,Mean Nodes No.
MUTAG,7,4,2,188,39.59,17.93


# Dataset Split

In [4]:
def GetSets(dataset,train=0.8,valid=0.1):
    train_ratio = int(len(dataset)*train)
    validation_ratio = int(len(dataset)*valid)
    training_set,validation_set,test_set = random_split(dataset,[train_ratio , validation_ratio,len(dataset) - (train_ratio + validation_ratio)])
    return training_set,validation_set,test_set

TrainSet,ValidationSet,TestSet = GetSets(dataset,0.8,0.1)
BatchSize = 128 
TrainLoader = DataLoader(TrainSet, batch_size=BatchSize, shuffle=True)
ValidationLoader = DataLoader(ValidationSet,batch_size=BatchSize,shuffle=False)
TestLoader = DataLoader(TestSet,batch_size=1,shuffle=False)

# Train Function

In [5]:
def TestPerformance(model,loader):
    model.eval()
    correct = 0.
    loss = 0.
    for data in loader:
        data = data.to("cuda")
        model = model.to("cuda")
        out = model(data)
        pred = out.max(dim=1)[1]
        correct += pred.eq(data.y).sum().item()
        loss += F.cross_entropy(out,data.y).item()
    return correct / len(loader.dataset),loss / len(loader.dataset)


def Train(CLASS,args,TrainLoader,ValidationLoader,TestLoader,epoch:int,lr=0.01,
          path="",weight_decay=5e-4,show=True,dataset_name="MUTAG",coment=""):
    device = "cuda"
    
    model = CLASS(**args)
    model = model.to(device)
    opt = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    model.train()
    loss_train = []
    acc_train = []
    
    loss_val = []
    acc_val = []
    
    acc_test = []
    
    min_loss = 1e10
    patience = 0
    for ite in range(epoch):
        model.train()
        for i, data in enumerate(TrainLoader):
            data = data.to("cuda")
            model = model.to("cuda")
            out = model(data)
            loss = F.cross_entropy(out, data.y)
            #print("Training loss:{}".format(loss.item()))
            loss.backward()
            opt.step()
            opt.zero_grad()
            
        val_acc,val_loss = TestPerformance(model,ValidationLoader)
        train_acc,train_loss = TestPerformance(model,TrainLoader)
        test_acc,test_loss = TestPerformance(model,TestLoader)
        
        acc_test.append(test_acc)
        
        acc_val.append(val_acc)
        loss_val.append(val_acc)
        
        acc_train.append(train_acc)
        loss_train.append(train_loss)
        
        if val_loss < min_loss:
            torch.save(model.state_dict(),'latest.pth')
            min_loss = val_loss
            patience = 0
        else:
            patience += 1
        if patience > 50:
            break 

    model = CLASS(**args)
    model.load_state_dict(torch.load('latest.pth'))
    val_acc,val_loss = TestPerformance(model,ValidationLoader)
    test_acc,test_loss = TestPerformance(model,TestLoader) 
    
    
    if show:  
        sns.set_style("whitegrid")
        plt.rcParams['figure.figsize']= (14,5)
        h,w = 1,2
        plt.subplot(h,w,1)
        plt.plot(loss_train,label="Train loss")
        plt.plot(loss_val,label="Validation loss")
        plt.title("Loss Report | {0} | {1}".format(model.name,dataset_name))
        plt.xlabel("Epoch")
        plt.ylabel("Cross Entropy Loss")
        plt.legend()
        #plt.show()
        
        plt.subplot(h,w,2)
        plt.plot(acc_train,label="Train Accuracy")
        plt.plot(acc_val,label="Validation Accuracy")
        plt.title("Accuracy Report | Test Accuracy: {0}%".format(round(test_acc*100,2)))
        plt.xlabel("Epoch")
        plt.legend()
        
        plt.tight_layout()
        plt.savefig("{3}/{1} | {0} | {2}.jpg".format(model.name,dataset_name,coment,path))
        #plt.show()
        plt.clf()
        
    return round(test_acc*100,2)
    
    

In [6]:
class SAGPool(torch.nn.Module):
    def __init__(self,in_channels,ratio=0.5,non_linearity=torch.tanh,**karg):
        super(SAGPool,self).__init__()
        self.in_channels = in_channels
        self.ratio = ratio
        self.score_layer = karg["Conv"](in_channels,**karg)
        self.non_linearity = non_linearity
    
    def forward(self, x, edge_index, edge_attr=None, batch=None):
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))
        score = self.score_layer(x,edge_index).squeeze()
        perm = topk(score, self.ratio, batch)
        
        x = x[perm] * self.non_linearity(score[perm]).view(-1, 1)
        batch = batch[perm]
        edge_index, edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=score.size(0))

        return x, edge_index, edge_attr, batch, perm

In [7]:
class SAGPoolNet(torch.nn.Module):
    def __init__(self,dataset,is_hierarchical=True,pooling_ratio=0.5,p_dropout=0.5,hidden_features=128,use_w_for_concat=False,**karg):
        super(SAGPoolNet, self).__init__()
        from torch.nn.init import xavier_uniform_,zeros_
        self.num_features = dataset.num_features
        self.hidden_features = hidden_features
        self.num_classes = dataset.num_classes
        self.pooling_ratio = pooling_ratio
        self.p = p_dropout
        
        global LAYERS
        self.name = "SAGPool | "
        self.name += "Hierarchical | " if is_hierarchical else "Global | "
        self.name += "Pooling ratio: {0} | ".format(str(pooling_ratio))
        self.name += "Weighted Concat | " if use_w_for_concat else "Simple Concat | "
        self.name += LAYERS[karg["Conv"]]
        
        self.is_hierarchical = is_hierarchical
        self.use_w_for_concat = use_w_for_concat
        
        if self.use_w_for_concat:
            W = torch.Tensor(3,1)
            W = nn.Parameter(W) 
            self._att = W
            xavier_uniform_(self._att)
            
        if is_hierarchical:
            self.lin1 = torch.nn.Linear(self.hidden_features*2, self.hidden_features)
            self.conv1 = GCNConv(self.num_features, self.hidden_features)
            self.pool1 = SAGPool(self.hidden_features, ratio=self.pooling_ratio,**karg)
            self.conv2 = GCNConv(self.hidden_features, self.hidden_features)
            self.pool2 = SAGPool(self.hidden_features, ratio=self.pooling_ratio,**karg)
            self.conv3 = GCNConv(self.hidden_features, self.hidden_features)
            self.pool3 = SAGPool(self.hidden_features, ratio=self.pooling_ratio,**karg)
        else:
            self.conv1 = GCNConv(self.num_features, self.hidden_features)
            self.conv2 = GCNConv(self.hidden_features, self.hidden_features)
            self.conv3 = GCNConv(self.hidden_features, self.hidden_features)
            self.pool = SAGPool(self.hidden_features*3, ratio=self.pooling_ratio,**karg)
            self.lin1 = torch.nn.Linear(self.hidden_features*2*3, self.hidden_features)
            
        self.lin2 = torch.nn.Linear(self.hidden_features, self.hidden_features//2)
        self.lin3 = torch.nn.Linear(self.hidden_features//2, self. num_classes)

    def forward(self, data):
        if self.is_hierarchical:    
            x, edge_index, batch = data.x, data.edge_index, data.batch

            x = F.relu(self.conv1(x, edge_index))
            x, edge_index, _, batch, _ = self.pool1(x, edge_index, None, batch)
            x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

            x = F.relu(self.conv2(x, edge_index))
            x, edge_index, _, batch, _ = self.pool2(x, edge_index, None, batch)
            x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

            x = F.relu(self.conv3(x, edge_index))
            x, edge_index, _, batch, _ = self.pool3(x, edge_index, None, batch)
            x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
            
            if self.use_w_for_concat:
                x = self._att[0][0]*x1 + self._att[1][0]*x2 + self._att[2][0]*x3
            else:
                x = x1 + x2 + x3
            
        else:
            x, edge_index, batch = data.x, data.edge_index, data.batch
            x1 = F.relu(self.conv1(x, edge_index)) 
            x2 = F.relu(self.conv2(x1, edge_index))
            x3 = F.relu(self.conv3(x2, edge_index)) 
            
            if self.use_w_for_concat:
                x1 = self._att[0][0] * x1
                x2 = self._att[1][0] * x2
                x3 = self._att[2][0] * x3
            
            x = torch.concat([x1,x2,x3],dim=1)
            x, edge_index, _, batch, _ = self.pool(x, edge_index, None, batch)
            x = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=self.p, training=self.training)
        x = F.relu(self.lin2(x))
        x = F.log_softmax(self.lin3(x), dim=-1)

        return x

In [8]:
MAINargs = {
    "dataset":dataset,
    "out_channels":1,
    "is_hierarchical":True,
    "use_w_for_concat":False,
    "pooling_ratio":0.5,
    "p_dropout":0.5,
    "Conv":GCNConv
    
}


def TrainModels(glo,args):
    print("========\n")
    CMD = ""
    global LAYERS
    for layer in LAYERS:
        _args = args.copy()
        _args["Conv"] = layer 
        ACC = []
        if "H" in glo:
            _args["is_hierarchical"] = True
        else:
            _args["is_hierarchical"] = False
            
        if LAYERS[layer] == "GATConv":
            _args["heads"] = 6 
            _args["concat"] = False
            
        if LAYERS[layer] == "ChebConv":
            _args["K"] = 2
        
        path = "./P2/"
        for index in tqdm(range(15)):
            cmnt = str(index)
            acc = Train(SAGPoolNet,_args,TrainLoader=TrainLoader,ValidationLoader=ValidationLoader,TestLoader=TestLoader,
                epoch=500,lr=0.0005,weight_decay=0.0001,dataset_name=name,path=path,show=False,coment=cmnt)
            ACC.append(acc)
            
        is_hie =  "Hierarchical"  if _args["is_hierarchical"] else "Global" 
        cmd = "Mean: {0}, STD:{1}, MIN:{4}, MAX:{5}, {2}, {3}".format(round(np.mean(ACC),2),round(np.std(ACC),2),is_hie,LAYERS[layer],round(np.min(ACC),2),round(np.max(ACC),2))
        print(cmd)
        with open(path+"{0} | is_hierarchical {1} | use_w_for_concat {2}.npy".format(cmd,_args["is_hierarchical"],_args["use_w_for_concat"]),"wb") as f :
            np.save(f,np.array(ACC))
            
        CMD += cmd + "\n"
    return CMD

OUT = ""


In [9]:
OUT += TrainModels("H",MAINargs)




100%|███████████████████████████████████████████| 15/15 [05:44<00:00, 23.00s/it]


Mean: 72.0, STD:4.4, MIN:65.0, MAX:80.0, Hierarchical, GCNConv


100%|███████████████████████████████████████████| 15/15 [03:45<00:00, 15.06s/it]


Mean: 70.67, STD:3.59, MIN:65.0, MAX:80.0, Hierarchical, GATConv


100%|███████████████████████████████████████████| 15/15 [04:52<00:00, 19.47s/it]


Mean: 73.33, STD:5.06, MIN:70.0, MAX:85.0, Hierarchical, SAGEConv


100%|███████████████████████████████████████████| 15/15 [05:43<00:00, 22.92s/it]

Mean: 73.33, STD:2.98, MIN:70.0, MAX:80.0, Hierarchical, ChebConv





In [10]:
OUT += TrainModels("G",MAINargs)




100%|███████████████████████████████████████████| 15/15 [03:24<00:00, 13.65s/it]


Mean: 73.67, STD:5.62, MIN:65.0, MAX:80.0, Global, GCNConv


100%|███████████████████████████████████████████| 15/15 [03:32<00:00, 14.17s/it]


Mean: 72.67, STD:6.02, MIN:65.0, MAX:80.0, Global, GATConv


100%|███████████████████████████████████████████| 15/15 [02:41<00:00, 10.77s/it]


Mean: 74.0, STD:5.54, MIN:65.0, MAX:80.0, Global, SAGEConv


100%|███████████████████████████████████████████| 15/15 [03:47<00:00, 15.18s/it]

Mean: 74.33, STD:4.78, MIN:65.0, MAX:80.0, Global, ChebConv





In [11]:
tmp = MAINargs.copy()
tmp["use_w_for_concat"] = True
OUT += TrainModels("H",tmp)





100%|███████████████████████████████████████████| 15/15 [05:59<00:00, 23.99s/it]


Mean: 71.67, STD:5.06, MIN:65.0, MAX:85.0, Hierarchical, GCNConv


100%|███████████████████████████████████████████| 15/15 [04:48<00:00, 19.21s/it]


Mean: 71.0, STD:2.0, MIN:70.0, MAX:75.0, Hierarchical, GATConv


100%|███████████████████████████████████████████| 15/15 [04:52<00:00, 19.50s/it]


Mean: 71.33, STD:3.4, MIN:65.0, MAX:80.0, Hierarchical, SAGEConv


100%|███████████████████████████████████████████| 15/15 [05:24<00:00, 21.60s/it]

Mean: 70.0, STD:3.16, MIN:65.0, MAX:75.0, Hierarchical, ChebConv





In [12]:
OUT += TrainModels("G",tmp)




100%|███████████████████████████████████████████| 15/15 [03:28<00:00, 13.90s/it]


Mean: 72.33, STD:4.78, MIN:65.0, MAX:80.0, Global, GCNConv


100%|███████████████████████████████████████████| 15/15 [03:46<00:00, 15.11s/it]


Mean: 70.0, STD:3.16, MIN:65.0, MAX:75.0, Global, GATConv


100%|███████████████████████████████████████████| 15/15 [03:08<00:00, 12.58s/it]


Mean: 71.33, STD:4.99, MIN:65.0, MAX:80.0, Global, SAGEConv


100%|███████████████████████████████████████████| 15/15 [03:48<00:00, 15.26s/it]

Mean: 70.67, STD:5.12, MIN:65.0, MAX:80.0, Global, ChebConv



