In [None]:

import torch
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
from transformers import BertTokenizer, BertModel
import torch.nn as nn
import torch.nn.functional as F

from bert import preprocessing, generate_node_embeddings

import json
from tqdm import tqdm
import pandas as pd
import time

from load_pubmed import get_pubmed_casestudy
from main_pubmed_gnn import GCN

BATCH_SIZE = 8


def train(lm, gnn, g, loader, optimizer_lm, optimizer_gnn,  device):
    node_embs = []

    for batch in tqdm(loader):
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask = batch
        output = lm(b_input_ids,
                    token_type_ids=None,
                    attention_mask=b_input_mask,
                    output_hidden_states=True)
        emb = output['hidden_states'][-1]  # outputs[0]=last hidden state
        cls_token_emb = emb.permute(1, 0, 2)[0]
        node_embs.append(cls_token_emb.detach().cpu())
    node_embs = torch.cat(node_embs, dim=0)
    torch.cuda.empty_cache() # PyTorch thing

    # train gnn
    
    X = node_embs.to(device)
    g = g.to(device)
    X.requires_grad = True
    X.retain_grad()
    gnn.train()
    optimizer_gnn.zero_grad()
    out = gnn(X, g.edge_index)[g.train_mask]
    loss = F.nll_loss(out, g.y[g.train_mask])
    loss.backward()
    optimizer_gnn.step()
    torch.cuda.empty_cache() # PyTorch thing

    grad = X.grad
    grad.requires_grad = True
    print(grad.shape)
    
    lm.train()
    for batch_idx, batch in enumerate(loader):
        optimizer_lm.zero_grad()
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask = batch
        # Forward pass
        output = lm(b_input_ids,
                    token_type_ids=None,
                    attention_mask=b_input_mask,
                    output_hidden_states=True)
        emb = output['hidden_states'][-1]  # outputs[0]=last hidden state
        cls_token_emb = emb.permute(1, 0, 2)[0]
        loss = grad[batch_idx*BATCH_SIZE:(batch_idx+1)*BATCH_SIZE].sum()
        loss.backward()
        optimizer_lm.step()
        torch.cuda.empty_cache() # PyTorch thing
    return loss.item()


@torch.no_grad()
def test(model, data):
    model.eval()

    out = model(data.x, data.edge_index)
    pred = out.argmax(dim=-1)
    correct = pred.eq(data.y)

    train_acc = correct[data.train_mask].sum().item() / \
        data.train_mask.sum().item()
    val_acc = correct[data.val_mask].sum().item() / data.val_mask.sum().item()
    test_acc = correct[data.test_mask].sum().item() / \
        data.test_mask.sum().item()

    return train_acc, val_acc, test_acc


In [None]:
# load data
print("[!] Loading dataset")
f = open('pubmed.json')
pubmed = json.load(f)
df_pubmed = pd.DataFrame.from_dict(pubmed)

# Preprocess
print("[!] Preprocessing")
start = time.time()
AB = df_pubmed['AB'].fillna("")
TI = df_pubmed['TI'].fillna("")
text = []
for ti, ab in zip(TI, AB):
    t = 'Title: ' + ti + '\n'+'Abstract: ' + ab
    # t = ti + ab
    text.append(t)
token_id = []
attention_masks = []
tokenizer = BertTokenizer.from_pretrained(
    'bert-base-uncased', do_lower_case=True)
for sample in tqdm(text):
    encoding_dict = preprocessing(sample, tokenizer)
    token_id.append(encoding_dict['input_ids'])
    attention_masks.append(encoding_dict['attention_mask'])
token_id = torch.cat(token_id, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)
print("Time: ", time.time()-start)

# Prepare DataLoader
batch_size = 16
dataset = TensorDataset(token_id, attention_masks)
dataloader = DataLoader(
    dataset,
    shuffle=False,
    sampler=SequentialSampler(dataset),
    batch_size=batch_size
)

# Load the BertForSequenceClassification model
bert = BertModel.from_pretrained(
    'bert-base-uncased',
    output_attentions=False,
    output_hidden_states=True,
)


In [None]:

# Run on GPU
print("[!] Generating node embeddings")
start = time.time()
bert.cuda()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
features = generate_node_embeddings(bert, dataloader, device)
print("Time: ", time.time()-start)


In [None]:
data, data_pubid = get_pubmed_casestudy()
data.x = features
gnn_model = GCN(
    in_channels=data.x.shape[1], hidden_channels=128, out_channels=3, num_layers=4, dropout=0)
gnn_model.cuda()


In [None]:

print("[!] Start training")
start = time.time()
data.cuda()
optimizer_gnn = torch.optim.Adam(gnn_model.parameters(), lr=0.001)
optimizer_lm = torch.optim.Adam(bert.parameters(), lr=0.001)
for epoch in range(1, 1000):
    loss = train(bert, gnn_model, data, dataloader,
                 optimizer_lm, optimizer_gnn, device)
    accs = test(gnn_model, data)
    print(
        f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train Acc: {accs[0]:.4f}, Val Acc: {accs[1]:.4f}, Test Acc: {accs[2]:.4f}')
print("Time: ", time.time()-start)


In [None]:

# import torch
# from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
# from transformers import BertTokenizer, BertModel
# import torch.nn as nn
# import torch.nn.functional as F

# from bert import preprocessing, generate_node_embeddings

# import json
# from tqdm import tqdm
# import pandas as pd
# import time

# from load_pubmed import get_pubmed_casestudy
# from main_pubmed_gnn import GCN

# data, data_pubid = get_pubmed_casestudy()


# a = torch.rand(data.x.shape[0], 768).detach()
# x = a.cuda()
# x.requires_grad = True
# x.retain_grad()

# data.cuda()
# gnn_model = GCN(
#     in_channels=x.shape[1], hidden_channels=128, out_channels=3, num_layers=4, dropout=0)
# gnn_model.cuda()
# optimizer = torch.optim.Adam(gnn_model.parameters(), lr=0.001)
# out = gnn_model(x, data.edge_index)[data.train_mask]
# loss = F.nll_loss(out, data.y[data.train_mask])
# loss.backward()
# optimizer.step()
# print(x.grad.shape)

# grad = x.grad.clone()
# grad.requires_grad = True
# grad.sum().backward()


In [2]:
from load_pubmed import get_pubmed_casestudy
data, data_pubid = get_pubmed_casestudy()

In [5]:
data


Data(x=[19717, 500], edge_index=[2, 88648], y=[19717], train_mask=[19717], val_mask=[19717], test_mask=[19717], train_id=[11830], val_id=[3943], test_id=[3944])