In [2]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import yaml
import xopen
import json
import warnings
warnings.filterwarnings("ignore")
from os.path import dirname
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer
from data import infer_preprocess
from model import Classifier, BERTDiseaseClassifier
from utils import default_symps
from nltk.tokenize import sent_tokenize
import blingfire
import spacy

In [3]:
def cut_sentences(text, tokenizer, nlp):
    if tokenizer == 'blingfire':
        sents = blingfire.text_to_sentences(text.strip()).split("\n")
    if tokenizer == 'nltk':
        sents = sent_tokenize(text.strip())
    if tokenizer == 'spacysm':
        doc = nlp(text)
        sents = [sent.text.strip() for sent in doc.sents]
    if tokenizer == 'spacylg':
        doc = nlp(text)
        sents = [sent.text.strip() for sent in doc.sents]
    if tokenizer == 'spacytrf':
        doc = nlp(text)
        sents = [sent.text.strip() for sent in doc.sents]
    return sents

In [10]:
datastore = []
# options : blingfire, nltk, spacysm, spacylg, spacytrf
senttokenizer = 'spacysm'
# set spacy tokenizer
if senttokenizer == 'spacysm':
    nlp = spacy.load("en_core_web_sm")
if senttokenizer == 'spacylg':
    nlp = spacy.load("en_core_web_lg")
if senttokenizer == 'spacytrf':
    nlp = spacy.load("en_core_web_trf")    

symps = ["Anxious_Mood","Autonomic_symptoms","Cardiovascular_symptoms","Catatonic_behavior","Decreased_energy_tiredness_fatigue","Depressed_Mood","Gastrointestinal_symptoms","Genitourinary_symptoms","Hyperactivity_agitation","Impulsivity","Inattention","Indecisiveness","Respiratory_symptoms","Suicidal_ideas","Worthlessness_and_guilty","avoidance_of_stimuli","compensatory_behaviors_to_prevent_weight_gain","compulsions","diminished_emotional_expression","do_things_easily_get_painful_consequences","drastical_shift_in_mood_and_energy","fear_about_social_situations","fear_of_gaining_weight","fears_of_being_negatively_evaluated","flight_of_ideas","intrusion_symptoms","loss_of_interest_or_motivation","more_talktive","obsession","panic_fear","pessimism","poor_memory","sleep_disturbance","somatic_muscle","somatic_symptoms_others","somatic_symptoms_sensory","weight_and_appetite_change","Anger_Irritability"]

if __name__ == "__main__":
    batch_size = 64
    input_dir = "../../../data/postdatalines.json"
    ckpt_dir = "lightning_logs/version_0/checkpoints/epoch=0-step=720.ckpt"
    hparams_dir = os.path.join(dirname(dirname(ckpt_dir)), 'hparams.yaml')
    hparams = yaml.load(open(hparams_dir),Loader=yaml.Loader)
    max_len = hparams["max_len"]
    tokenizer = AutoTokenizer.from_pretrained(hparams["model_type"])
    clf = Classifier.load_from_checkpoint(ckpt_dir, symps=default_symps)
    clf.eval()
    clf.cuda()

    allPostSentences = []
    
    with xopen.xopen(input_dir) as fi:
        for i, line in tqdm(enumerate(fi)):
            record = json.loads(line)
            user_sents = []
            sent_bounds = [0]
            curr_sid = 0
            if record['text'] == None:
                break
            else:
                for post in record["text"]:
                    sents = cut_sentences(post, senttokenizer, nlp)
                    curr_sid += len(sents)
                    sent_bounds.append(curr_sid)
                    user_sents.extend(sents)
                all_probs = []
                all_feats = []
                for i in range(0, len(user_sents), batch_size):
                    curr_texts = user_sents[i:i+batch_size]
                    processed_batch = infer_preprocess(curr_texts, tokenizer, max_len)
                    for k, v in processed_batch.items():
                        processed_batch[k] = v.cuda()
                    with torch.no_grad():
                        feats, logits = clf.feat_extract_avg(processed_batch)
                        feats = feats.detach().cpu().numpy()
                        probs = logits.sigmoid().detach().cpu().numpy()
                    sentData = {"sentence": curr_texts, "probabilities": dict(zip(symps, probs))}
                    allPostSentences.append(sentData)
                
                
                

# df = pd.read_json("../../../data/postdatalines.json", lines=True)
# vector_df = pd.DataFrame(datastore)
# df = pd.concat([df,vector_df],axis=1)
# df.to_json(f"../../../data/vectorData/Test{senttokenizer}Vectors.json",lines=True, orient='records')

Some weights of BertModel were not initialized from the model checkpoint at mental/mental-bert-base-uncased and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


"accumulation":       1
"bal_sample":         True
"bs":                 64
"control_ratio":      0.5
"exp_name":           mbert_label_enhance_bal_sample_050_666
"gradient_clip_val":  0.1
"input_dir":          ../data/symp_data_w_control
"loss_mask":          True
"loss_type":          bce
"loss_weighting":     mean
"lr":                 0.0003
"max_len":            64
"model_type":         mental/mental-bert-base-uncased
"patience":           4
"pos_weight_setting": default
"seed":               666
"threshold":          0.5
"uncertain":          exclude
"write_result_dir":   ./lightning_logs/bal_sample_records.json


797it [04:17,  3.10it/s]


In [15]:
i = 0
print(len(allPostSentences[i]['sentence']))
print(allPostSentences[i]['probabilities']["Anxious_Mood"])

64
[7.9607414e-03 3.0320024e-04 1.0407658e-04 9.0106642e-03 1.0153279e-02
 3.4104589e-02 2.4430244e-04 1.5734080e-03 7.5825855e-02 2.9496774e-02
 4.8576433e-02 4.5706523e-03 3.1745611e-04 5.6496076e-03 8.3765751e-03
 2.6174184e-02 4.9148109e-03 6.5465295e-04 2.6951968e-03 4.7332618e-02
 7.9219081e-03 8.9806043e-02 2.1092191e-03 1.6790779e-02 1.7521903e-01
 5.5154562e-03 1.7779199e-02 4.0076017e-01 7.7681732e-04 5.2708038e-04
 4.8485263e-03 4.0203603e-03 9.6228626e-03 3.5640804e-04 3.9608913e-04
 3.9182062e-04 2.8982684e-03 1.0988766e-02]
