### Load Data

In [20]:
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.loader import DenseDataLoader #To make use of this data loader, all graph attributes in the dataset need to have the same shape. In particular, this data loader should only be used when working with dense adjacency matrices.
from torch_geometric.nn import DenseGCNConv, dense_diff_pool
from f_visualization_functions import visualize_points
from torch_geometric.utils import dense_to_sparse
import math

In [21]:
data_dir = 'c:/Users/david/MT_data/extracted_patches/mutant_graphs_diffpool/'

In [22]:
from c_PatchDataset_diffpool_pos import PatchDataset
dataset = PatchDataset(data_dir = data_dir)
len(dataset)

1500

In [23]:
print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')

data = dataset[1]  # Get the first graph object.
print()
print(data)
print('=============================================================')

# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of node features: {data.num_node_features}')


Dataset: PatchDataset(1500):
Number of graphs: 1500

Data(x=[1300, 19], y=0, pos=[1300, 3], adj=[1300, 1300])
Number of nodes: 1300
Number of node features: 19


In [24]:
data.x

tensor([[-6.4592,  2.0831,  9.3907,  ...,  0.6941,  0.7079,  0.7162],
        [-6.1470,  3.2396,  9.5020,  ...,  0.6885,  0.7668,  0.7423],
        [-5.5948,  2.1197,  9.2291,  ...,  0.7070,  0.7731,  0.7089],
        ...,
        [ 0.9024, -5.8300,  6.2016,  ...,  0.5465,  0.4768,  0.4432],
        [-1.4027, 10.7846,  3.9749,  ...,  0.4924,  0.3807,  0.3163],
        [ 6.6276,  1.4644,  3.7423,  ...,  0.7026,  0.5827,  0.5153]])

In [25]:
torch.max(data.x[:,:3])

tensor(11.4650)

In [26]:
torch.min(data.x[:,:3])

tensor(-6.4592)

In [None]:
data.edge_index, _ = dense_to_sparse(data.adj)
visualize_points(data.pos, data.edge_index)

In [None]:
batch_size = 5

n_train = math.ceil((4/6) * len(dataset))
n_val = math.ceil((len(dataset) - n_train)/2)
n_test = len(dataset) - n_train - n_val

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [n_train, n_val, n_test])
print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of validation graphs: {len(val_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')

train_loader = DenseDataLoader(dataset = train_dataset, batch_size= batch_size, shuffle=True)
val_loader = DenseDataLoader(dataset = val_dataset, batch_size= batch_size, shuffle=True)
test_loader = DenseDataLoader(dataset = test_dataset, batch_size= batch_size, shuffle=True)

In [None]:
databatch = next(iter(train_loader))
#databatch.w.dtype
new_adj = databatch.adj * databatch.w
new_adj


### Define Network

In [None]:
class GNN(torch.nn.Module):
    def __init__(self, in_nodes, in_channels, hidden_channels, out_channels,
                 normalize=False, lin=True):
        super(GNN, self).__init__()

        # Each instance of this GNN will have 3 convolutional layers and three batch norm layers        
        self.conv1 = DenseGCNConv(in_channels, hidden_channels, normalize)
        self.bns1 = torch.nn.BatchNorm1d(in_nodes)
        
        self.conv2 = DenseGCNConv(hidden_channels, hidden_channels, normalize)
        self.bns2 = torch.nn.BatchNorm1d(in_nodes)
        
        self.conv3 = DenseGCNConv(hidden_channels, out_channels, normalize)
        self.bns3 = torch.nn.BatchNorm1d(in_nodes)


    def forward(self, x, adj, mask=None):
        
        #Step 1
        x = self.conv1(x, adj, mask)
        #print(x.shape)
        x = self.bns1(x)
        
        #Step 2
        x = self.conv2(x, adj, mask)
        #print(x.shape)
        x = self.bns2(x)

        #Step 3
        x = self.conv3(x, adj, mask)
        #print(x.shape)
        if x.shape[2] != 1: 
            x = self.bns3(x)

        return x


class DiffPool(torch.nn.Module):
    def __init__(self, num_nodes):
        super(DiffPool, self).__init__()

        #Hierarchical Step #1
        in_nodes = num_nodes
        out_nodes = 250 # Number of clusters / nodes in the next layer
        self.gnn1_pool = GNN(in_nodes, dataset.num_features, 64, out_nodes, ) # PoolGNN --> Cluster Assignment Matrix to reduce to num_nodes
        self.gnn1_embed = GNN(in_nodes, dataset.num_features, 16, 32) # EmbGNN --> Convolutions to change node feature dim.

        # Hierarchical Step #2
        in_nodes = out_nodes
        out_nodes = 125
        self.gnn2_pool = GNN(in_nodes, 32, 64, out_nodes)
        self.gnn2_embed = GNN(in_nodes, 32, 32, 64, lin=False)

        # Hierarchical Step #3
        in_nodes = out_nodes
        out_nodes = 60
        self.gnn3_pool = GNN(in_nodes, 64, 64, out_nodes)
        self.gnn3_embed = GNN(in_nodes, 64, 64, 128, lin=False)

        # Final Classifier
        self.lin1 = torch.nn.Linear(128, 64) 
        self.lin2 = torch.nn.Linear(64, 2)



    def forward(self, x, adj, pos, batch, epoch, mask=None):
        
        #if batch == 0: print('Shape of input data batch:')
        #if batch == 0: print(f'Feature Matrix: {tuple(x.shape)}')
        #if batch == 0: print(f'Adjacency Matrix: {tuple(adj.shape)}')
        #if batch == 0: print(f'Coordinate Matrix: {tuple(pos.shape)}')

        with torch.no_grad():
             if epoch == 15 and batch == 199: torch.save((x.detach(), pos.detach(), adj.detach()), 'img0_data.pt')

        #Hierarchical Step #1
        #if batch == 0: print('Hierarchical Step #1')
        s = self.gnn1_pool(x, adj, mask) # cluster assignment matrix
        x1 = self.gnn1_embed(x, adj, mask) # node feature embedding
        #if batch == 0: print(f'X1 = {tuple(x1.shape)}    S1: {tuple(s.shape)}')

        pos = torch.matmul( torch.softmax( torch.softmax(s, dim=-1).transpose(1, 2), dim=-1), pos)
        x, adj, l1, e1 = dense_diff_pool(x1, adj, s, mask) # does the necessary matrix multiplications
        adj = torch.softmax(adj, dim=-1)

        #if batch == 0: print(f'---matmul---> New feature matrix (softmax(s_0.t()) @ z_0) = {tuple(x.shape)}')
        #if batch == 0: print(f'---matmul---> New adjacency matrix (s_0.t() @ adj_0 @ s_0) = {tuple(adj.shape)}')
        #if batch == 0: print(f'---matmul---> New coordinate matrix ((softmax(s_0.t()) @ pos) = {tuple(pos.shape)}')

        with torch.no_grad():
            if epoch == 15 and batch == 199: torch.save((x1.detach(), x.detach(), pos.detach(), adj.detach(), s.detach()), 'img1_data.pt')

        # Hierarchical Step #2
        #if batch == 0: print('Hierarchical Step #2')
        s = self.gnn2_pool(x, adj)
        x2 = self.gnn2_embed(x, adj)
        #if batch == 0: print(f'X2: {tuple(x2.shape)}    S2: {tuple(s.shape)}')
        
        pos = torch.matmul( torch.softmax( torch.softmax(s, dim=-1).transpose(1, 2), dim=-1), pos)    
        x, adj, l2, e2 = dense_diff_pool(x2, adj, s)
        adj = torch.softmax(adj, dim=-1)

        #if batch == 0: print(f'---matmul---> New feature matrix (softmax(s_0.t()) @ z_0) = {tuple(x.shape)}')
        #if batch == 0: print(f'---matmul---> New adjacency matrix (s_0.t() @ adj_0 @ s_0) = {tuple(adj.shape)}')
        #if batch == 0: print(f'---matmul---> New coordinate matrix ((softmax(s_0.t()) @ pos) = {tuple(pos.shape)}')
        
        with torch.no_grad():
            if epoch == 15 and batch == 199: torch.save((x2.detach(), x.detach(), pos.detach(), adj.detach(), s.detach()), 'img2_data.pt')

        # Hierarchical Step #3
        #if batch == 0: print('Hierarchical Step #3')
        s = self.gnn3_pool(x, adj)
        x3 = self.gnn3_embed(x, adj)
        #if batch == 0: print(f'X3: {tuple(x3.shape)}    S3: {tuple(s.shape)}')

        pos = torch.matmul( torch.softmax( torch.softmax(s, dim=-1).transpose(1, 2), dim=-1), pos)        
        x, adj, l3, e3 = dense_diff_pool(x3, adj, s)
        adj = torch.softmax(adj, dim=-1)

        #if batch == 0: print(f'---matmul---> New feature matrix (softmax(s_0.t()) @ z_0) = {tuple(x.shape)}')
        #if batch == 0: print(f'---matmul---> New adjacency matrix (s_0.t() @ adj_0 @ s_0) = {tuple(adj.shape)}')
        #if batch == 0: print(f'---matmul---> New coordinate matrix ((softmax(s_0.t()) @ pos) = {tuple(pos.shape)}')
        
        with torch.no_grad():
            if epoch == 15 and batch == 199: torch.save((x3.detach(), x.detach(), pos.detach(), adj.detach(), s.detach()), 'img3_data.pt')

        # Final Classification
        #if batch == 0: print('Final Classification:')
        x = x.mean(dim=1) # Pool the features of all nodes (global mean pool)  dim = 1 refers to columns
        #if batch == 0: print(f'---X Output 3 after mean= {tuple(x.shape)}')

        x = F.relu(self.lin1(x)) # Fully connected layer + relu
        #if batch == 0: print(f'------ X Output 3 after lin1= {tuple(x.shape)}')

        x = self.lin2(x) # Reduction to num_classes
        #if batch == 0: print(f'--------- X Output 3 after lin2= {tuple(x.shape)}')

        log = F.log_softmax(x, dim=-1)
        #if batch == 0: print(f'------------ LogSoftmax = {log}')
        
        return log, l1 + l2 + l3, e1 + e2 + e3

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DiffPool(num_nodes = 1300).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


def train(epoch):
    model.train()
    loss_all = 0

    batch = 0
    for data in train_loader:
        #print(f'BATCH NUMBER {batch}')
        data = data.to(device)
        optimizer.zero_grad()
        data.adj = data.adj * data.w #multiply the adjacency matrix with the weights
        output, _, _ = model(data.x, data.adj, data.pos, batch, epoch) #data.mask
        loss = F.nll_loss(output, data.y.view(-1))
        loss.backward()
        loss_all += data.y.size(0) * loss.item()
        optimizer.step()
        batch +=1
    return loss_all / len(train_dataset)


@torch.no_grad()
def test(loader):
    model.eval()
    correct = 0

    for data in loader:
        data = data.to(device)
        pred = model(data.x, data.adj, data.pos, batch = None, epoch = None)[0].max(dim=1)[1] #, data.mask
        correct += pred.eq(data.y.view(-1)).sum().item()
    return correct / len(loader.dataset)


best_val_acc = test_acc = 0
for epoch in range(1,16):
    train_loss = train(epoch)
    train_acc = test(train_loader)
    val_acc = test(val_loader)
    test_acc = test(test_loader)
    #if val_acc > best_val_acc:
    #    test_acc = test(test_loader)
    #    best_val_acc = val_acc # mehr ausgeben

    print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.3f}, Train Acc: {train_acc:.3f} '
          f'Val Acc: {val_acc:.3f}, Test Acc: {test_acc:.3f}')
    


In [None]:
dataset_fraction = [0,0]

train_fraction = [0,0]
val_fraction = [0,0]
test_fraction = [0,0]

for grph in train_dataset: 
    if grph.y == 1: 
        train_fraction[1] +=1
        dataset_fraction[1] +=1 
    else: 
        train_fraction[0] +=1
        dataset_fraction[0] +=1 

for grph in val_dataset: 
    if grph.y == 1:
         val_fraction[1] +=1
         dataset_fraction[1] +=1  
    else:
         val_fraction[0] +=1
         dataset_fraction[0] +=1

for grph in test_dataset: 
    if grph.y == 1:
         test_fraction[1] +=1
         dataset_fraction[1] +=1 
    else:
         test_fraction[0] +=1
         dataset_fraction[0] +=1

print(f'Overall dataset percentage of label 1 = {dataset_fraction[1]/len(dataset)})')
print(f'Training dataset percentage of label 1 = {train_fraction} = {train_fraction[1]/len(train_dataset)}')
print(f'Validation dataset percentage of label 1 = {val_fraction} = {val_fraction[1]/len(val_dataset)}')
print(f'Test dataset percentage of label 1 = {test_fraction} = {test_fraction[1]/len(test_dataset)}')

### Input Graph: 

In [None]:
x0, pos0, adj0 = torch.load('img0_data.pt')

In [None]:
# Output of Embedding GNN
print(x0[0].shape)
x0[0]

In [None]:
print(pos0[0].shape)
pos0[0]

In [None]:
print(adj0[0].shape)
adj0[0]

In [None]:
edge_index, _ = dense_to_sparse(adj0[0])
visualize_points(pos0[0], edge_index)

### Graph After 1st Reduction

In [None]:
x1_emb, x1_pool, pos1, adj1, s1= torch.load('img1_data.pt')

In [None]:
# Output of Embedding GNN (adj0 @ x_0 @ w_gnn_emb)
print(x1_emb[0].shape)
x1_emb[0]

In [None]:
# Output of Pooling GNN: adj_0 @ x_0 @ w_gnn_pool
print(s1[0].shape)
s1[0]

In [None]:
# Output Coordinate Matrix (pos_out = softmax(s).t() @ pos_in)
print(pos1[0].shape)
pos1[0]

In [None]:
# Output Feature Matrix (x_out = softmax(s).t() @ x_in)
print(x1_pool[0].shape)
x1_pool[0]

In [None]:
# Output Adjacency Matrix = softmax(adj_out = softmax(s.t()) @ adj_in @ softmax(s))
print(adj1[0].shape)
adj1[0]

In [None]:
edge_index, _ = dense_to_sparse(adj1[0])
visualize_points(pos1[0], edge_index)

### Graph after 2nd reduction

In [None]:
x2_emb, x2_pool, pos2, adj2, s2 = torch.load('img2_data.pt')

In [None]:
# Output of Embedding GNN (adj1 @ x1_pool @ w_gnn_emb)
print(x2_emb[0].shape)
x2_emb[0]

In [None]:
# Output of Pooling GNN: adj1 @ x1_pool @ w_gnn_pool), dim=1
print(s2[0].shape)
s2[0]

In [None]:
# Output Coordinate Matrix (pos_out = softmax(s.t()) @ pos_in)
print(pos2[0].shape)
pos2[0]

In [None]:
# Output Feature Matrix (x_out = softmax(s2).t() @ x2_emb)
print(x2_pool[0].shape)
x2_pool[0]

In [None]:
# Output Adjacency Matrix (adj = softmax(s).T @ adj @ softmax(s)
print(adj2[0].shape)
adj2[0]

In [None]:
edge_index, _ = dense_to_sparse(adj2[0])
visualize_points(pos2[0], edge_index)

### Graph after 3rd reduction

In [None]:
x3_emb, x3_pool, pos3, adj3, s3 = torch.load('img3_data.pt')

In [None]:
# Output of Embedding GNN (adj_0 @ x_0 @ w_gnn_emb)
print(x3_emb[0].shape)
x3_emb[0]

In [None]:
# Output of Pooling GNN: torch.softmax(adj_0 @ x_0 @ w_gnn_pool), dim=1)
print(s3[0].shape)
s3[0]

In [None]:
# Output Coordinate Matrix (pos_out = softmax(s.t()) @ pos_in)
print(pos3[0].shape)
pos3[0]

In [None]:
# Output Feature Matrix (x_out = softmax(s.t()) @ x_0)
print(x3_pool[0].shape)
x3_pool[0]

In [None]:
# Output Adjacency Matrix (adj = softmax(s.t()) @ adj @ softmax(s)
print(adj3[0].shape)
adj3[0]

In [None]:
edge_index, _ = dense_to_sparse(adj3[0])
visualize_points(pos3[0], edge_index)