Use the trained model to generate embeddings for new sentences, from the 500k Reddit dataset.

---

Environment config. Note that the spacy model en_core_web_lg needs to first be downloaded separately.

In [7]:
from flair.models import SequenceTagger
from flair.data import Sentence 
import tqdm
import pandas as pd

In [None]:
# # included for convenience to help find correct paths
# import os
# os.getcwd()
# os.listdir("..")

Constants.

In [19]:
# top level project directory containing code, data, .gitignore, etc
MEDRED_REPRODUCIBLE_DIR = "../"
REDDIT_IN = MEDRED_REPRODUCIBLE_DIR + "data/validation/Reddit/all_sr.csv"
REDDIT_SYMPTOM_TAGGED_TOKENS_OUT = MEDRED_REPRODUCIBLE_DIR + "data/validation/Reddit/NER_Reddit_pred_dis.csv"
REDDIT_DRUG_TAGGED_TOKENS_OUT = MEDRED_REPRODUCIBLE_DIR + "data/validation/Reddit/NER_Reddit_pred_drug.csv"
EMBEDDING_MODEL_PATH = MEDRED_REPRODUCIBLE_DIR + "resources/taggers/FA_MedRed_glove_roberta/final-model.pt"

In [17]:
def predict(model_path):
	'''
	Makes predictions on the Reddit 500k corpus given a model.
	Outputs two tagged files by type:
	1. *_dis.csv for symptoms
	2. *_drug.csv for drugs
	'''

	# load the model you trained
	model = SequenceTagger.load(model_path)
	# load data
	data = pd.read_csv(REDDIT_IN) # cols: ,year,month,subreddit,body,clean_body,post_index
	# process and write...
	with open(REDDIT_DRUG_TAGGED_TOKENS_OUT, 'w', encoding="utf-8") as f_drug:
		with open(REDDIT_SYMPTOM_TAGGED_TOKENS_OUT, 'w', encoding="utf-8") as f_dis:
			header = "subreddit,post_index,matched,score,start_pos,end_pos\n"
			f_dis.write(header)
			f_drug.write(header)
			# iterate across rows (with the tqdm progress bar)
			for _, row in tqdm.tqdm(data.iterrows(), total=data.shape[0]):
				sentence = Sentence(str(row['body']))
				# predict tags (generates them into sentence object)
				model.predict(sentence) # generates prediction into sentence object
				# write each entity
				for el in sentence.get_spans("ner"):
					# write to appropriate file
					# values written are the subreddit (), post_index (an ID)
					# note for future adaptation: text writes token text, not sentence (for sentence, can el.sentence.text)
					if el.tag == 'DIS':
						f_dis.write(row['subreddit']+','+row['post_index']+',"'+\
							el.text.replace('\n', ' ').replace('\t', ' ')+'",'+str(el.score)+','+str(el.start_position)+','+str(el.end_position)+'\n')
					elif el.tag == 'DRUG':
						f_drug.write(row['subreddit']+','+row['post_index']+',"'+\
							el.text.replace('\n', ' ')+'",'+str(el.score)+','+str(el.start_position)+','+str(el.end_position)+'\n')

Run predictions. Takes ~6hrs 40min with a single RTX 2080Ti GPU and i7-8700k CPU.

On termination, will throw StopIteration error; _do not panic!_ The marked up files should generate anyway.

In [18]:
predict(EMBEDDING_MODEL_PATH)

2022-05-07 15:58:31,468 loading file ../resources/taggers/FA_MedRed_glove_roberta/final-model.pt
2022-05-07 15:58:33,785 SequenceTagger predicts: Dictionary with 11 tags: O, S-DIS, B-DIS, E-DIS, I-DIS, S-DRUG, B-DRUG, E-DRUG, I-DRUG, <START>, <STOP>


 99%|█████████▉| 493041/496958 [6:41:06<03:11, 20.49it/s]   


StopIteration: 