In [1]:
import torch.optim as optim
from torch_geometric.datasets import Planetoid
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.data import DataLoader
import random

torch.set_printoptions(edgeitems=500)

# seed for reproducibility
seed = 0
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

import warnings
warnings.filterwarnings('ignore')

### Load dataset

### Dataset info:
class 0: without autism associations
class 1: autism genes

In [2]:
import read_data

data = read_data.read()

In [3]:
data

Data(x=[23472, 23472], edge_index=[2, 811236], y=[23472], train_mask=[23472], test_mask=[23472], val_mask=[23472], num_classes=23)

In [4]:
print(f'Contains isolated nodes: {data.contains_isolated_nodes()}')
print(f'Contains self-loops: {data.contains_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
# dataset = data
# data.train_mask = data.y >= 0

Contains isolated nodes: False
Contains self-loops: False
Is undirected: True
Average node degree: 34.56


In [5]:
# total nodes
data.train_mask.shape

torch.Size([23472])

In [6]:
# number of training samples
data.train_mask.sum()

tensor(1284)

In [7]:
# number of testing samples
data.test_mask.sum()
# data.test_mask = data.y >= 0

tensor(276)

In [8]:
# number of validation samples
data.val_mask.sum()

tensor(275)

In [9]:
data.y[data.train_mask].shape

torch.Size([1284])

In [10]:
data.y[data.test_mask].shape

torch.Size([276])

##### Visualization of the Model Using Tensorboard Command
commandline run tensorboard
```
cd src
tensorboard --logdir log
```

In [11]:
# build model
from GCN import GCNStack

model = GCNStack(data.num_node_features, hidden_dim1=128, hidden_dim2=64, hidden_dim3=32, output_dim=data.num_classes, dropout=0.5)
print(model)


GCNStack(
  (convs): ModuleList(
    (0): GCNConv(23472, 128)
    (1): GCNConv(128, 64)
    (2): GCNConv(64, 32)
  )
  (lns): ModuleList(
    (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  )
  (post_mp): Sequential(
    (0): Linear(in_features=32, out_features=32, bias=True)
    (1): Linear(in_features=32, out_features=23, bias=True)
  )
)


In [12]:
# Running on GPU or CPU
use_GPU = True
device = torch.device('cuda' if torch.cuda.is_available() and use_GPU else 'cpu')
model, data = model.to(device), data.to(device)

In [13]:
device


device(type='cuda')

In [14]:
# torch.cuda.empty_cache()

### Model training

In [15]:
def model_test(loader, model, is_validation=False, is_training=False):
    ''' Testing Code of the Model '''
    model.eval()

    correct = 0
    for data in loader:
        with torch.no_grad():
            emb, pred = model(data.x, data.edge_index)
            pred = pred.argmax(dim=1)

        if is_training:
            mask = data.train_mask
        elif is_validation:
            mask = data.val_mask
        else: # testing
            mask = data.test_mask
        # node classification: only evaluate on nodes in test set
        pred = pred[mask]
        label = data.y[mask]
        # testing code
        # training_status = 'Training' if is_training else 'Testing'
        # print(training_status, '$$pred', pred)
        # print(training_status, '%%label', label)
        correct += pred.eq(label).sum().item()
    total = 0
    for data in loader.dataset:
        if is_training:
            total += torch.sum(data.train_mask).item()
        elif is_validation:
            total += torch.sum(data.val_mask).item()
        else:
            total += torch.sum(data.test_mask).item()
    return correct / total, pred, label

def model_train(dataset, writer, model, epoch_num, lr, weight_decay, momentum):
    ''' Training code of the model '''
    test_loader = loader = DataLoader(dataset, shuffle=False)

    # Optimizer
    # opt = optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay, momentum=momentum)
    opt = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    # visualize the model architecture in tensorboard
    # writer.add_graph(model, ( data.x, data.edge_index ))

    # Training:
    for epoch in range(epoch_num + 1):
        total_loss = 0
        model.train()
        for batch in loader:
            #print(batch.train_mask, '----')
            opt.zero_grad()
            embedding, pred = model(batch.x, batch.edge_index)
            label = batch.y
            pred = pred[batch.train_mask]
            label = label[batch.train_mask]
            loss = model.loss(pred, label)
            loss.backward()
            opt.step()
            total_loss += loss.item() * batch.num_graphs
        total_loss /= len(loader.dataset)
        writer.add_scalar("loss", total_loss, epoch)

        # if epoch % 5 == 0:
        train_acc, _, _ = model_test(test_loader, model, is_training=True)
        validation_acc, _, _= model_test(test_loader, model, is_training=False, is_validation=True)
        print("Epoch {}. Loss: {:.4f}. Train accuracy: {:.4f}. Validation accuracy: {:.4f}".format(
            epoch, total_loss, train_acc, validation_acc))
        writer.add_scalar("validation accuracy", validation_acc, epoch)

        if epoch % 20 == 0:
            name = 'epoch' + str(epoch)
            writer.add_embedding(embedding, global_step=epoch, tag=name, metadata=batch.y)

    return model

from datetime import datetime
from tensorboardX import SummaryWriter

writer = SummaryWriter("./log/" + datetime.now().strftime("%Y%m%d-%H%M%S"))

model_trained = model_train([data], writer, model, epoch_num=80, lr=0.0001, weight_decay=0.00001, momentum=0.9)

Epoch 0. Loss: 3.1088. Train accuracy: 0.7679. Validation accuracy: 0.7200
Epoch 1. Loss: 2.8451. Train accuracy: 0.8294. Validation accuracy: 0.8036
Epoch 2. Loss: 2.6472. Train accuracy: 0.8442. Validation accuracy: 0.8327
Epoch 3. Loss: 2.5435. Train accuracy: 0.8497. Validation accuracy: 0.8364
Epoch 4. Loss: 2.4562. Train accuracy: 0.8536. Validation accuracy: 0.8400
Epoch 5. Loss: 2.4181. Train accuracy: 0.8567. Validation accuracy: 0.8400
Epoch 6. Loss: 2.3826. Train accuracy: 0.8583. Validation accuracy: 0.8400
Epoch 7. Loss: 2.3652. Train accuracy: 0.8614. Validation accuracy: 0.8436
Epoch 8. Loss: 2.3265. Train accuracy: 0.8621. Validation accuracy: 0.8436
Epoch 9. Loss: 2.3005. Train accuracy: 0.8645. Validation accuracy: 0.8473
Epoch 10. Loss: 2.2810. Train accuracy: 0.8668. Validation accuracy: 0.8473
Epoch 11. Loss: 2.2707. Train accuracy: 0.8676. Validation accuracy: 0.8473
Epoch 12. Loss: 2.2506. Train accuracy: 0.8684. Validation accuracy: 0.8473
Epoch 13. Loss: 2.2539

### Model Evaluation

In [16]:
test_acc, pred, label = model_test( DataLoader([data], shuffle=False), model_trained, is_training=False, is_validation=False)

In [17]:
test_acc

0.8840579710144928

In [18]:
pred_np = pred.cpu().detach().numpy()
label_np = label.cpu().detach().numpy()

In [19]:
from sklearn.metrics import f1_score,  recall_score, precision_score

In [20]:
f1_score(pred_np, label_np, average='weighted')

0.9384615384615385

In [21]:
precision_score(pred_np, label_np, average='weighted')

1.0

In [22]:
recall_score(pred_np, label_np, average='weighted')

0.8840579710144928

### Save model

In [23]:
# save model
# torch.save(model_trained, f='pretrained_model.pth')

### Precision @ k plot

In [67]:
# change k here
num_of_k = 8000

def compute_top_k(loader, model):
    """ Testing Code of the Model """
    model.eval()

    for data in loader:
        with torch.no_grad():
            emb, pred = model(data.x, data.edge_index)
            prob = F.softmax(pred, dim=1).max(dim=1)[0]
            pred = pred.argmax(dim=1)
            # compute top k with the highest probability
            val, idx = torch.topk(prob, k=num_of_k, dim=0)

    return  pred, val, idx

pred, val, idx = compute_top_k( DataLoader([data], shuffle=False), model_trained)



In [68]:
positive_position = read_data.get_autism_position()

In [69]:
idx = idx.tolist()
positive_position = positive_position.tolist()

In [70]:
intersect = set(idx).intersection(positive_position)

In [71]:
# using the paper formula to calculate
len(intersect) / num_of_k

0.0175

In [72]:
# Coordinates:
(1000, 0.029), (2000, 0.0185) , (3000, 0.017), (4000, 0.01825) , (5000, 0.0168),  (6000, 0.017167), (7000, 0.017114), (8000, 0.0175)

((1000, 0.029),
 (2000, 0.0185),
 (3000, 0.017),
 (4000, 0.01825),
 (5000, 0.0168),
 (6000, 0.017167),
 (7000, 0.017114),
 (8000, 0.0175))