<a href="https://colab.research.google.com/github/MarioZZJ/data/blob/master/onclp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!unzip -q /content/drive/MyDrive/dataset.zip -d /content

In [3]:
!pip install torch_geometric

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torch_geometric
  Downloading torch_geometric-2.3.1.tar.gz (661 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m661.6/661.6 kB[0m [31m28.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: torch_geometric
  Building wheel for torch_geometric (pyproject.toml) ... [?25l[?25hdone
  Created wheel for torch_geometric: filename=torch_geometric-2.3.1-py3-none-any.whl size=910476 sha256=3e5929ea1fc8d7b66c809811e10cfcdb8cab8a667326ebe27b6efb71d4f01eda
  Stored in directory: /root/.cache/pip/wheels/ac/dc/30/e2874821ff308ee67dcd7a66dbde912411e19e35a1addda028
Successfully built torch_geometric
Installing collected packages: torch_geometric
Successfully installed torch_geomet

In [4]:
import torch
import pandas as pd
from torch_geometric.data import InMemoryDataset,Data
from torch_geometric.nn import GCNConv
from torch_geometric import seed_everything
import torch.nn.functional as F
import torch_geometric.transforms as T
from sklearn.metrics import roc_auc_score
from torch_geometric.utils import negative_sampling
import numpy as np
import gc

In [5]:
def make_deterministic(random_seed = 711):
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    seed_everything(random_seed)

make_deterministic(711)

In [6]:
class OncologyMeSH(InMemoryDataset):
    def __init__(self, root, year,transform=None, pre_transform=None):
        self.year = year
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])
    @property
    def processed_file_names(self):
        return [f'oncology{self.year}.pt']

In [7]:
class EarlyStopping:
    def __init__(self, patience=10, delta=0, path='checkpoint.pt'):
        """
        Early stopping 定义
        
        :param patience: 当验证集损失连续多少轮没有下降时停止训练
        :param delta: 验证集损失的最小变化，当变化小于 delta 时认为模型没有明显提升
        :param path: 记录模型权重的文件路径
        """
        self.patience = patience
        self.delta = delta
        self.path = path
        self.counter = 0  # 记录验证集损失连续没有下降的轮数
        self.best_score = None  # 记录最佳验证集损失
        self.early_stop = False  # 是否停止训练
        self.val_loss_min = np.Inf  # 记录最小验证集损失

    def __call__(self, val_loss, model):
        """
        Early stopping 逻辑

        :param val_loss: 当前轮次验证集损失
        :param model: 当前轮次模型
        :return: 如果需要停止训练，返回 True；否则返回 False
        """
        if self.best_score is None:
            self.best_score = val_loss
            self.save_checkpoint(val_loss, model)
        elif val_loss > self.best_score - self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = val_loss
            self.save_checkpoint(val_loss, model)
            self.counter = 0

        return self.early_stop

    def save_checkpoint(self, val_loss, model):
        """
        保存模型权重

        :param val_loss: 当前轮次验证集损失
        :param model: 当前轮次模型
        """
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss


In [8]:
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score
from torch_geometric.utils import negative_sampling


def train_link_predictor(
    model, train_data, val_data, optimizer, criterion, n_epochs=500, patience=20, delta=0.0003, ckp_path='checkpoint.pt'
):
    early_stopping = EarlyStopping(patience=patience, delta=delta, path=ckp_path)
    for epoch in range(1, n_epochs + 1):

        model.train()
        optimizer.zero_grad()
        z = model.encode(train_data.x, train_data.edge_index)

        # sampling training negatives for every training epoch
        neg_edge_index = negative_sampling(
            edge_index=train_data.edge_index, num_nodes=train_data.num_nodes,
            num_neg_samples=train_data.edge_label_index.size(1), method='sparse',force_undirected=True)

        edge_label_index = torch.cat(
            [train_data.edge_label_index, neg_edge_index],
            dim=-1,
        )
        edge_label = torch.cat([
            train_data.edge_label,
            train_data.edge_label.new_zeros(neg_edge_index.size(1))
        ], dim=0)

        out = model.decode(z, edge_label_index).view(-1)
        loss = criterion(out, edge_label)
        loss.backward()
        optimizer.step()

        # val_auc = eval_link_predictor(model, val_data)
        model.eval()
        with torch.no_grad():
            z = model.encode(val_data.x, val_data.edge_index)
            out = model.decode(z, val_data.edge_label_index).view(-1).sigmoid()
            val_loss = criterion(out, val_data.edge_label)
        val_auc = roc_auc_score(val_data.edge_label.cpu().numpy(), out.cpu().numpy())

        if early_stopping(val_loss, model):
            break

        # if epoch % 1 == 0:
            # print(f"Epoch: {epoch:03d}, Train Loss: {loss:.3f}, Val AUC: {val_auc:.3f}")

    return model


@torch.no_grad()
def eval_link_predictor(model, data):

    model.eval()
    z = model.encode(data.x, data.edge_index)
    out = model.decode(z, data.edge_label_index).view(-1).sigmoid()

    return [roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy()),
            f1_score(data.edge_label.cpu().numpy(), out.cpu().numpy() > 0.5),
            accuracy_score(data.edge_label.cpu().numpy(), out.cpu().numpy() > 0.5),
            precision_score(data.edge_label.cpu().numpy(), out.cpu().numpy() > 0.5),
            recall_score(data.edge_label.cpu().numpy(), out.cpu().numpy() > 0.5)]

In [9]:

class Net(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)

    def decode(self, z, edge_label_index):
        return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(
            dim=-1
        )  # product of a pair of nodes on each edge

    def decode_all(self, z):
        prob_adj = z @ z.t()
        return (prob_adj > 0).nonzero(as_tuple=False).t()

In [10]:
%%time
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
mesh_eval = []
eyes_eval = []
for year in range(2002,2022):
    dataset = OncologyMeSH(f"dataset/Oncology{year}",year)
    graph = dataset[0]
    graph.x = graph.x.float()
    graph = graph.to(device)
    # graph.x = torch.eye(graph.x.size(0),dtype=torch.float)
    split = T.RandomLinkSplit(
        num_val=0.05,
        num_test=0.1,
        is_undirected=True,
        add_negative_train_samples=False,
        neg_sampling_ratio=1.0,
    )
    train_data, val_data, test_data = split(graph)
    

    model = Net(dataset.num_features, 400, 64).to(device)
    # model = Net(graph.x.size(0),512,64).to(device)
    optimizer = torch.optim.Adam(params=model.parameters(), lr=0.005)
    criterion = torch.nn.BCEWithLogitsLoss()

    model = train_link_predictor(model, train_data, val_data, optimizer, criterion, ckp_path=f'mesh{year}_checkpoint.pt',n_epochs=500, patience=20, delta=0.0001)
    model.load_state_dict(torch.load(f'mesh{year}_checkpoint.pt'))
    test_eval = eval_link_predictor(model, test_data)
    mesh_eval.append(test_eval)

    print(f"{year} | [MeSH] AUC-{test_eval[0]:.4f} F1-{test_eval[1]:.4f} ;",end='')
    del model,optimizer,criterion
    torch.cuda.empty_cache()
    # comparable 
    mesh_num_features = graph.x.size(1)
    mesh_num_samples = graph.x.size(0)
    num_eyes = int(mesh_num_features // mesh_num_samples)
    baseline_x = torch.concat([torch.eye(mesh_num_samples,dtype=torch.float) for i in range(num_eyes)],dim=1)
    if mesh_num_features % mesh_num_samples != 0:
      baseline_x = torch.concat([baseline_x, torch.zeros(mesh_num_samples, mesh_num_features % mesh_num_samples)],dim=1)
    train_data.x = baseline_x.to(device)
    val_data.x = baseline_x.to(device)
    test_data.x = baseline_x.to(device)
    model = Net(mesh_num_features,400,64).to(device)
    optimizer = torch.optim.Adam(params=model.parameters(), lr=0.005)
    criterion = torch.nn.BCEWithLogitsLoss()
    model = train_link_predictor(model, train_data, val_data, optimizer, criterion,ckp_path=f'eyes{year}_checkpoint.pt',n_epochs=500, patience=20, delta=0.0001)
    
    model.load_state_dict(torch.load(f'eyes{year}_checkpoint.pt'))
    test_eval = eval_link_predictor(model, test_data)
    eyes_eval.append(test_eval)

    print(f" [Eyes] AUC-{test_eval[0]:.4f} F1-{test_eval[1]:.4f} .")
    
    del dataset,graph,model,optimizer,criterion,train_data, val_data, test_data
    torch.cuda.empty_cache()


2002 | [MeSH] AUC-0.9233 F1-0.7730 ; [Eyes] AUC-0.9072 F1-0.7463 .
2003 | [MeSH] AUC-0.9194 F1-0.7703 ; [Eyes] AUC-0.9036 F1-0.7474 .
2004 | [MeSH] AUC-0.9227 F1-0.7709 ; [Eyes] AUC-0.9111 F1-0.7428 .
2005 | [MeSH] AUC-0.9328 F1-0.7797 ; [Eyes] AUC-0.9151 F1-0.7356 .
2006 | [MeSH] AUC-0.9259 F1-0.7717 ; [Eyes] AUC-0.9147 F1-0.7298 .
2007 | [MeSH] AUC-0.9230 F1-0.7664 ; [Eyes] AUC-0.9147 F1-0.7321 .
2008 | [MeSH] AUC-0.9342 F1-0.7760 ; [Eyes] AUC-0.9174 F1-0.7297 .
2009 | [MeSH] AUC-0.9289 F1-0.7718 ; [Eyes] AUC-0.9225 F1-0.7273 .
2010 | [MeSH] AUC-0.9304 F1-0.7682 ; [Eyes] AUC-0.9205 F1-0.7299 .
2011 | [MeSH] AUC-0.9287 F1-0.7685 ; [Eyes] AUC-0.9224 F1-0.7264 .
2012 | [MeSH] AUC-0.9303 F1-0.7686 ; [Eyes] AUC-0.9230 F1-0.7264 .
2013 | [MeSH] AUC-0.9284 F1-0.7696 ; [Eyes] AUC-0.9213 F1-0.7307 .
2014 | [MeSH] AUC-0.9300 F1-0.7710 ; [Eyes] AUC-0.9212 F1-0.7340 .
2015 | [MeSH] AUC-0.9298 F1-0.7700 ; [Eyes] AUC-0.9199 F1-0.7355 .
2016 | [MeSH] AUC-0.9137 F1-0.7436 ; [Eyes] AUC-0.9221 F1-0.73

In [15]:
pd.DataFrame(mesh_eval,index=range(2002,2022),columns=['auc','f1','accuracy','precision','recall']).to_csv('0508_gcn_mesh_eval.csv',sep=',',index=True)
pd.DataFrame(eyes_eval,index=range(2002,2022),columns=['auc','f1','accuracy','precision','recall']).to_csv('0508_gcn_eyes_eval.csv',sep=',',index=True)