In [None]:
import os
import pyarrow.parquet as pq
import pyarrow as pa
import numpy as np
from graphstorm.gconstruct.file_io import read_data_parquet, write_data_hdf5

In [None]:
from transformers import BertTokenizer
from transformers import BertModel, BertConfig
import torch as th

Utility functions to generate tokens of text data and compute their BERT embeddings.

In [None]:
def compute_tokens(strs, tokenizer, max_seq_len):
    tokens = []
    att_masks = []
    type_ids = []
    for s in strs:
        t = tokenizer(s, max_length=max_seq_len,
                      truncation=True, padding='max_length', return_tensors='pt')
        tokens.append(t['input_ids'])
        att_masks.append(t['attention_mask'])
        type_ids.append(t['token_type_ids'])
    tokens = th.cat(tokens, dim=0)
    att_masks = th.cat(att_masks, dim=0)
    type_ids = th.cat(type_ids, dim=0)
    return tokens, att_masks, type_ids
    
def compute_bert_embed(tokens, att_masks, type_ids, lm_model, device, bert_batch_size):
    lm_model.eval()
    out_embeds = []
    lm_model = lm_model.to(device)
    with th.no_grad():
        tokens_list = th.split(tokens, bert_batch_size)
        att_masks_list = th.split(att_masks, bert_batch_size)
        token_types_list = th.split(type_ids, bert_batch_size)
        for tokens, att_masks, token_types in zip(tokens_list, att_masks_list, token_types_list):
            outputs = lm_model(tokens.to(device),
                               attention_mask=att_masks.to(device),
                               token_type_ids=token_types.to(device))
            out_embeds.append(outputs.pooler_output.cpu())
        out_embeds = th.cat(out_embeds)
    return out_embeds.numpy()

Process the paper nodes and compute their BERT embeddings

In [None]:
def process_papers(i):
    print(f'process file {i}')
    papers = read_data_parquet(f'mag_papers_{i}.parquet')
    max_seq_len = 128
    bert_model = "bert-base-uncased"
    tokenizer = BertTokenizer.from_pretrained(bert_model)
    config = BertConfig.from_pretrained(bert_model)
    lm_model = BertModel.from_pretrained(bert_model, config=config)
    
    tokens = compute_tokens(papers['title'], tokenizer, max_seq_len)
    gpu = int(os.environ['CUDA_VISIBLE_DEVICES'])
    device = f"cuda:{gpu}"
    embeds = compute_bert_embed(tokens[0], tokens[1], tokens[2], lm_model, device, 1024)
    res = {}
    res['paper'] = papers['paper']
    res['feat'] = embeds
    res['year'] = papers['year']
    write_data_hdf5(res, f'mag_papers_bert_{i}.hdf5')
    return None

In [None]:
from graphstorm.gconstruct.utils import multiprocessing_data_read
data = multiprocessing_data_read([i for i in range(51)], num_processes=8, user_parser=process_papers)

Process the fos nodes and compute their BERT embeddings

In [None]:
def process_fos():
    fos = read_data_parquet('mag_fos.parquet')
    max_seq_len = 16
    bert_model = "bert-base-uncased"
    tokenizer = BertTokenizer.from_pretrained(bert_model)
    config = BertConfig.from_pretrained(bert_model)
    lm_model = BertModel.from_pretrained(bert_model, config=config)
    
    tokens = compute_tokens(fos['id'], tokenizer, max_seq_len)
    device = "cuda:0"
    embeds = compute_bert_embed(tokens[0], tokens[1], tokens[2], lm_model, device, 1024)
    res = {}
    res['id'] = fos['id']
    res['feat'] = embeds
    write_data_hdf5(res, 'mag_fos_bert.hdf5')
    return None

In [None]:
process_fos()