In [1]:
import pickle
import pandas as pd
import numpy as np
import os
import sys
import time
from sklearn.model_selection import train_test_split
from codes.GraphDPA import GCN_Model
from codes.utils import *
from prefetch_generator import BackgroundGenerator 
import torch.nn.functional as F

In [2]:
print('Start Time: {}'.format(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))))
embed_dim = 11

if not os.path.exists('example_saved_data/model_tmp'):
    os.mkdir('example_saved_data/model_tmp')

with open('example_saved_data/entities2id.pkl', 'rb') as file:
    entities2id = pickle.load(file) 
    
with open('example_saved_data/graphs/graph_map.pkl', 'rb') as file:
    graph_map = pickle.load(file)
train_indexs, test_indexs = train_test_split(range(len(graph_map)), test_size = 0.1)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = GCN_Model(len(entities2id), embed_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
best_model_file = 'example_saved_data/model_tmp/{}+{}+GCNConv.pt'.format(embed_dim, 'None')
early_stopping = EarlyStopping(file=best_model_file, patience=5)

train_dataset = MyOwnDataset('example_saved_data', train_indexs, graph_map)
test_dataset = MyOwnDataset('example_saved_data', test_indexs, graph_map)
train_loader = DataLoader(train_dataset, batch_size=1000, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000)
if not os.path.exists(best_model_file):
    model.train()
    for epoch in range(1000):
        model.train()

        train_loss = 0
        for data in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
            logits = model(data)
            loss = F.cross_entropy(logits, data.y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            break

        if epoch%10 == 0:
            model.eval()
            test_loss = 0
            for data in test_loader:
                data = data.to(device)
                logits = model(data)
                test_loss = roc_auc_score(data.y.cpu().tolist(), logits.cpu()[:,1].tolist())
                break

            print('{} Epoch {}: train loss {}, test ROC {}'.format(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())), epoch, train_loss, test_loss))
            early_stopping(-test_loss, model)
            if early_stopping.early_stop:
                print('Early Stopping')
                break

Start Time: 2021-10-17 15:11:30
2021-10-17 15:11:48 Epoch 0: train loss 0.7228423357009888, test ROC 0.6169596964019792
2021-10-17 15:13:32 Epoch 10: train loss 0.6871052980422974, test ROC 0.5704872620134186
EarlyStopping counter: 1 out of 5
2021-10-17 15:15:14 Epoch 20: train loss 0.675652027130127, test ROC 0.6707498678964308
2021-10-17 15:16:57 Epoch 30: train loss 0.6377252340316772, test ROC 0.7087916926870667
2021-10-17 15:18:42 Epoch 40: train loss 0.6212900876998901, test ROC 0.7159893356391411
2021-10-17 15:20:25 Epoch 50: train loss 0.629585862159729, test ROC 0.7304486717586588
2021-10-17 15:22:08 Epoch 60: train loss 0.6148384809494019, test ROC 0.7352044002497958
2021-10-17 15:23:51 Epoch 70: train loss 0.6126474738121033, test ROC 0.7416414468943652
2021-10-17 15:25:34 Epoch 80: train loss 0.5787592530250549, test ROC 0.7407367376022802
EarlyStopping counter: 1 out of 5
2021-10-17 15:27:16 Epoch 90: train loss 0.5815764665603638, test ROC 0.7439872860322493
2021-10-17 15