In [13]:
import torch

### Load Data

In [14]:
import os
from math import ceil

import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
#from torch_geometric.data import DenseDataLoader #To make use of this data loader, all graph attributes in the dataset need to have the same shape. In particular, this data loader should only be used when working with dense adjacency matrices.
from torch_geometric.nn import DenseGCNConv as GCNConv, dense_diff_pool

In [15]:
data_dir = 'c:/Users/david/MT_data/extracted_patches/mutant_graphs_diffpool/'

In [16]:
from c_PatchDataset_diffpool import PatchDataset
dataset = PatchDataset(data_dir = data_dir)
len(dataset)

100

In [17]:
print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')

data = dataset[0]  # Get the first graph object.
print()
print(data)
print('=============================================================')

# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of node features: {data.num_node_features}')
#print(f'Number of edges: {data.num_edges}')
#print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
#print(f'Contains self-loops: {data.has_self_loops()}')


Dataset: PatchDataset(100):
Number of graphs: 100

Data(x=[1300, 16], y=0, pos=[1300, 3], adj=[1300, 1300], w=[1300, 1300], mutant='AAAP')
Number of nodes: 1300
Number of node features: 16


In [18]:
data.adj

tensor([[1., 1., 1.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]])

In [19]:
from torch_geometric.loader import DenseDataLoader

#max_nodes = 150
#class MyFilter(object):
#    def __call__(self, data):
#        return data.num_nodes <= max_nodes
#dataset = TUDataset('data', name='PROTEINS', transform=T.ToDense(max_nodes),
#                    pre_filter=MyFilter())
#dataset = dataset.shuffle()
#n = (len(dataset) + 9) // 10
batch_size = 2 # cannot change this at the moment

n_train = ceil((4/6) * len(dataset))
n_val = ceil((len(dataset) - n_train)/2)
n_test = len(dataset) - n_train - n_val

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [n_train, n_val, n_test])
print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of validation graphs: {len(val_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')
#test_dataset = dataset[:n]
#val_dataset = dataset[n:2 * n]
#train_dataset = dataset[2 * n:]

train_loader = DenseDataLoader(dataset = train_dataset, batch_size= batch_size)
val_loader = DenseDataLoader(dataset = val_dataset, batch_size= batch_size)
test_loader = DenseDataLoader(dataset = test_dataset, batch_size= batch_size)


Number of training graphs: 67
Number of validation graphs: 17
Number of test graphs: 16


In [20]:
dataset.num_features

16

In [21]:
train_dataset[0]

Data(x=[1300, 16], y=1, pos=[1300, 3], adj=[1300, 1300], w=[1300, 1300], mutant='AYCA')

In [22]:
train_dataset[1]

Data(x=[1300, 16], y=1, pos=[1300, 3], adj=[1300, 1300], w=[1300, 1300], mutant='VSAA')

### Define Network

In [23]:
class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels,
                 normalize=False, lin=True):
        super(GNN, self).__init__()
        
        self.convs = torch.nn.ModuleList()

        # Each instance of this GNN will have 3 convolutional layers and three batch norm layers        
        self.convs.append(GCNConv(in_channels, hidden_channels, normalize))
        self.convs.append(GCNConv(hidden_channels, hidden_channels, normalize))
        self.convs.append(GCNConv(hidden_channels, out_channels, normalize))


    def forward(self, x, adj, mask=None):
        
        #Feed the feature matrix and the adjacency matrix through the 3 conv and 3 bns layers
        for step in range(len(self.convs)):
            #x = self.bns[step](F.relu(self.convs[step](x, adj, mask)))
            x = F.relu(self.convs[step](x, adj, mask))
        return x


class DiffPool(torch.nn.Module):
    def __init__(self):
        super(DiffPool, self).__init__()

        #PoolGNN
        num_nodes = [1095, 200, 100, 25, 1]
        self.gnn1_pool = GNN(dataset.num_features, 64, num_nodes[1]) # --> S1 num_nodes[0] x num_nodes[1] (1095x200)
        self.gnn2_pool = GNN(32, 64, num_nodes[2])                   # --> S2 num_nodes[1] x num_nodes[2] (200x100)
        self.gnn3_pool = GNN(64, 32, num_nodes[3])                   # --> S3 num_nodes[2] x num_nodes[3] (100x25)
        self.gnn4_pool = GNN(128, 64, num_nodes[4])                   # --> S4 num_nodes[3] x num_nodes[4] (25x1)

        #EmbedGNN
        self.num_features = [dataset.num_features, 32, 64, 128, 256]
        self.gnn1_embed = GNN(self.num_features[0], self.num_features[0], self.num_features[1])             # --> X1 num_nodes[0] x num_features[1] (1000x32)   --matmul--> (200x32)
        self.gnn2_embed = GNN(self.num_features[1], self.num_features[1], self.num_features[2], lin=False)  # --> X1 num_nodes[1] x num_features[2] (200x64)    --matmul--> (100x64)
        self.gnn3_embed = GNN(self.num_features[2], self.num_features[2], self.num_features[3], lin=False)  # --> X1 num_nodes[2] x num_features[3] (100x128)   --matmul--> (25x128)
        self.gnn4_embed = GNN(self.num_features[3], self.num_features[3], self.num_features[4], lin=False)  # --> X1 num_nodes[3] x num_features[4] (25x256)    --matmul--> (1x256)
        self.gnn5_embed = GNN(self.num_features[4], self.num_features[4], self.num_features[4], lin=False)  # 

        # Final Classifier
        self.lin1 = torch.nn.Linear(256, 64) 
        self.lin2 = torch.nn.Linear(64, 2)



    def forward(self, x, adj, mask=None):
        #print()
        #print('=======================')
        #print(f'input X: {x.shape}')
        #print(f'input adj: {adj.shape}')

        #Hierarchical Step #1
        #print()
        #print('First hierarchical step:')
        s = self.gnn1_pool(x, adj, mask) # cluster assignment matrix
        #print(f'S1 = {s.shape}')
        x = self.gnn1_embed(x, adj, mask) # node feature embedding
        #print(f'X1 = {x.shape}')
        x, adj, l1, e1 = dense_diff_pool(x, adj, s, mask) # does the necessary matrix multiplications
        #print(f'X Output 1 = {x.shape}')
        #print(f'Adj Output 1 = {adj.shape}')

        #Hierarchical Step #2
        #print()
        #print('Second hierarchical step:')
        s = self.gnn2_pool(x, adj, mask) # cluster assignment matrix
        #print(f'S2 = {s.shape}')
        x = self.gnn2_embed(x, adj, mask) # node feature embedding
        #print(f'X2 = {x.shape}')
        x, adj, l1, e1 = dense_diff_pool(x, adj, s, mask) # does the necessary matrix multiplications
        #print(f'X Output 2 = {x.shape}')
        #print(f'Adj Output 2 = {adj.shape}')
        
        # Hierarchical Step #3
        #print()
        #print('Third hierarchical step:')
        s = self.gnn3_pool(x, adj)
        #print(f'S3: {s.shape}')
        x = self.gnn3_embed(x, adj)
        #print(f'X3: {x.shape}')
        x, adj, l2, e2 = dense_diff_pool(x, adj, s)
        #print(f'X Output 3 = {x.shape}')
        #print(f'Adj Output 3 = {adj.shape}')

        # Hierarchical Step #4
        #print()
        #print('Fourth hierarchical step:')
        s = self.gnn4_pool(x, adj)
        #print(f'S4: {s.shape}')
        x = self.gnn4_embed(x, adj)
        #print(f'X4: {x.shape}')
        x, adj, l2, e2 = dense_diff_pool(x, adj, s)
        #print(f'X Output 4 = {x.shape}')
        #print(f'Adj Output 4 = {adj.shape}')

        # Final Classification
        #print()
        #print('Final Classification')
        x = x.mean(dim=1) # Pool the features of all nodes (global mean pool)  dim = 1 refers to columns
        #print(f'X Output after mean= {x.shape}')

        x = F.relu(self.lin1(x)) # Fully connected layer + relu
        #print(f'X Output after lin1= {x.shape}')

        x = self.lin2(x) # Reduction to num_classes
        #print(f'X Output after lin2= {x.shape}')

        not_log = F.softmax(x, dim=-1)
        log = F.log_softmax(x, dim=-1)
        #print(f'Softmax = {not_log}')
        #print(f'LogSoftmax = {log}')

        return log, l1 + l2, e1 + e2

In [24]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DiffPool().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

node_lengths = []

def train(epoch):
    model.train()
    loss_all = 0

    for data in train_loader:
        node_lengths.append(data.x.shape[0])
        data = data.to(device)
        optimizer.zero_grad()
        print(data.x.shape)
        print(data.adj.shape)
        output, _, _ = model(data.x, data.adj) #data.mask
        loss = F.nll_loss(output, data.y.view(-1))
        loss.backward()
        loss_all += data.y.size(0) * loss.item()
        optimizer.step()
    return loss_all / len(train_dataset)


@torch.no_grad()
def test(loader):
    model.eval()
    correct = 0

    for data in loader:
        data = data.to(device)
        pred = model(data.x, data.adj)[0].max(dim=1)[1] #, data.mask
        correct += pred.eq(data.y.view(-1)).sum().item()
    return correct / len(loader.dataset)


best_val_acc = test_acc = 0
for epoch in range(1, 2):
    train_loss = train(epoch)
    val_acc = test(val_loader)
    if val_acc > best_val_acc:
        test_acc = test(test_loader)
        best_val_acc = val_acc
    print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, '
          f'Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')

torch.Size([2, 1300, 16])
torch.Size([2, 1300, 1300])
torch.Size([2, 1300, 16])
torch.Size([2, 1300, 1300])
torch.Size([2, 1300, 16])
torch.Size([2, 1300, 1300])
torch.Size([2, 1300, 16])
torch.Size([2, 1300, 1300])
torch.Size([2, 1300, 16])
torch.Size([2, 1300, 1300])
torch.Size([2, 1300, 16])
torch.Size([2, 1300, 1300])
torch.Size([2, 1300, 16])
torch.Size([2, 1300, 1300])
torch.Size([2, 1300, 16])
torch.Size([2, 1300, 1300])
torch.Size([2, 1300, 16])
torch.Size([2, 1300, 1300])
torch.Size([2, 1300, 16])
torch.Size([2, 1300, 1300])
torch.Size([2, 1300, 16])
torch.Size([2, 1300, 1300])
torch.Size([2, 1300, 16])
torch.Size([2, 1300, 1300])
torch.Size([2, 1300, 16])
torch.Size([2, 1300, 1300])
torch.Size([2, 1300, 16])
torch.Size([2, 1300, 1300])
torch.Size([2, 1300, 16])
torch.Size([2, 1300, 1300])
torch.Size([2, 1300, 16])
torch.Size([2, 1300, 1300])
torch.Size([2, 1300, 16])
torch.Size([2, 1300, 1300])
torch.Size([2, 1300, 16])
torch.Size([2, 1300, 1300])
torch.Size([2, 1300, 16])
to

In [25]:
#import matplotlib.pyplot as plt
#len(node_lengths)
#plt.hist(node_lengths, bins = 20)

In [26]:
#min(node_lengths)

In [27]:
#max(node_lengths)