In [1]:
import argparse
import os
import torch
import numpy as np
 

parser = argparse.ArgumentParser()

parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--device", type=int, default=0)

parser.add_argument("--dataspace_path", type=str, default="./data")
parser.add_argument("--SSL_emb_dim", type=int, default=256)
parser.add_argument("--max_seq_len", type=int, default=512)

args = parser.parse_args("")
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
from transformers import AutoModel, AutoTokenizer
pretrained_SciBERT_folder = os.path.join(args.dataspace_path, 'pretrained_SciBERT')
text_tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder)
text_model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder).to(device)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# This is for BERT
def padarray(A, size, value=0):
    t = size - len(A)
    return np.pad(A, pad_width=(0, t), mode='constant', constant_values = value)

def preprocess_each_sentence(sentence, tokenizer, max_seq_len):
    text_input = tokenizer(
        sentence, truncation=True, max_length=max_seq_len,
        padding='max_length', return_tensors='np')
    
    input_ids = text_input['input_ids'].squeeze()
    attention_mask = text_input['attention_mask'].squeeze()

    sentence_tokens_ids = padarray(input_ids, max_seq_len)
    sentence_masks = padarray(attention_mask, max_seq_len)
    return [sentence_tokens_ids, sentence_masks]


# This is for BERT
def prepare_text_tokens(device, description, tokenizer, max_seq_len):
    B = len(description)
    tokens_outputs = [preprocess_each_sentence(description[idx], tokenizer, max_seq_len) for idx in range(B)]
    tokens_ids = [o[0] for o in tokens_outputs]
    masks = [o[1] for o in tokens_outputs]
    tokens_ids = torch.Tensor(tokens_ids).long().to(device)
    masks = torch.Tensor(masks).bool().to(device)
    return tokens_ids, masks

In [16]:
#text_dim = 768
#text2latent = torch.nn.Linear(text_dim, args.SSL_emb_dim).to(device)

description_tokens_ids, description_masks = prepare_text_tokens(device, ['The molecule is an 11-oxo steroid that is corticosterone in which the hydroxy substituent at the 11beta position has been oxidised to give the corresponding ketone. It has a role as a human metabolite and a mouse metabolite. It is a 21-hydroxy steroid, a 3-oxo-Delta(4) steroid, a 20-oxo steroid, an 11-oxo steroid, a corticosteroid and a primary alpha-hydroxy ketone. It derives from a corticosterone.', 'The molecule is a steroid ester, a 20-oxo steroid, an acetate ester, a 17alpha-hydroxy steroid, an 11-oxo steroid, a 3-oxo-Delta(1),Delta(4)-steroid and a tertiary alpha-hydroxy ketone. It derives from a prednisone.', 'The molecule is an 11-oxo steroid that is corticosterone in which the hydroxy substituent at the 11beta position has been oxidised to give the corresponding ketone. It has a role as a human metabolite and a mouse metabolite. It is a 21-hydroxy steroid, a 3-oxo-Delta(4) steroid, a 20-oxo steroid, an 11-oxo steroid, a corticosteroid and a primary alpha-hydroxy ketone. It derives from a corticosterone. # The molecule is a steroid ester, a 20-oxo steroid, an acetate ester, a 17alpha-hydroxy steroid, an 11-oxo steroid, a 3-oxo-Delta(1),Delta(4)-steroid and a tertiary alpha-hydroxy ketone. It derives from a prednisone.'],text_tokenizer, 500) 

description_output = text_model(input_ids=description_tokens_ids, attention_mask=description_masks)
description_repr = description_output["pooler_output"]
text_dim = 768
#description_repr = text2latent(description_repr)

In [15]:
description_repr.shape

torch.Size([3, 768])