In [1]:
%load_ext autoreload
%autoreload 2

In [2]:

# from src import Match, Icsr
from src.utils import get_matches
from questions import *

from datetime import datetime
import random
import datasets
from collections import defaultdict
import numpy as np
import tiktoken
from copy import deepcopy
import pandas as pd
import json

# import pandas as pd
# import matplotlib.pyplot as plt
# from sklearn.model_selection import train_test_split

import logging

# Set the logging level to INFO
logging.basicConfig(level=logging.INFO)

# Get the root logger
logger = logging.getLogger()

In [3]:
# load matches
dataset = datasets.load_dataset("FAERS-PubMed/raw_dataset")
matches = get_matches(dataset['train'])
print(len(matches))



  0%|          | 0/1 [00:00<?, ?it/s]

65648


In [4]:
# process full text
def remove_front(text):
    if '==== Body' in text:
        text = ('\n').join(text.split('==== Body')[1:])
    return text.strip()

def remove_refs(text):
    if '==== Refs' in text:
        text = ('\n').join(text.split('==== Refs')[:-1])
    return text.strip() 

def get_processed_fulltext(article):
    fulltext_filtered = remove_refs(remove_front(article.fulltext))
    data = [article.title,article.abstract, fulltext_filtered]
    return ('\n').join(data).strip()

## Filter

In [5]:
# arguments
report_cutoff = 10
fulltext_only = True
commercial_only = False
test_cutoff = datetime(year=2021, month=1, day=1)

In [6]:
# filter too many reports
matches = [m for m in matches if len(m.reports) <= report_cutoff]
print(f'Matches with <= {report_cutoff} reports: {len(matches):,}')

Matches with <= 10 reports: 62,168


In [7]:
# get articles with full text
if fulltext_only:
    matches = [m for m in matches if m.article.fulltext]
    print(f'Matches with full text: {len(matches):,}')

Matches with full text: 18,678


## Questions

In [8]:
questions = [
    WeightQuestion(),
    DrugsGivenReactionQuestion(),
    DrugAdministrationRouteQuestion(),
    DrugDosageQuestion(),
    DrugDosageTextQuestion(),
    ReactionOutcomeQuestion(),
    ReactionsQuestion(),
    PatientAgeGivenReactionQuestion(),
    PatientAgeGivenDrugQuestion()
]
conversation_flow = None

# questions specifically focussed on icsr
# questions = [
#     PatientWeightAndSexQuestion(),
#     DrugsQuestion(),
#     ReactionsQuestion()
# ]
# conversation_flow = [q.t for q in questions]



In [9]:
bad_answers = ['Unknown', 'UNK']
def get_questions(report):
    instances = []
    for question in questions:
        try:
            new_q = question.from_report(report)
        except Exception as ex:
            logger.warn(f'Error with question {question.t} on report {report.safetyreportid}: {type(ex).__name__}, {ex.args}')
        instances.extend(new_q)
    return instances

def filter_questions(questions):
    filtered_questions = []
    for q in questions:
        a = q[1]
        if 'Unknown' not in a and 'UNK' not in a:
            filtered_questions.append(q)
    return filtered_questions

In [10]:
# get questions associated with a report
report_to_question = {}
report_to_article = {}
reports = {}

for match in matches:
    for report in match.reports:
        if report.safetyreportid in report_to_question:
            raise KeyError('safetyreportid already in set')
        else:
            reports.update({
                report.safetyreportid: report
            })
            report_to_question.update({
                report.safetyreportid: filter_questions(get_questions(report))
            })
            report_to_article.update({
                report.safetyreportid: match.article.pmid
            })

articles = {}

for match in matches:
    articles.update({match.article.pmid: match.article})

  logger.warn(f'Error with question {question.t} on report {report.safetyreportid}: {type(ex).__name__}, {ex.args}')


## Conversation

In [11]:
def sample_conversation(questions, seed=4):
    random.seed(seed)
    # group questions per type
    questions_per_type = defaultdict(list)
    for q in questions:
        questions_per_type[q[2]].append(q)

    # if a conversation flow is not defined, uniformy sample some types and one question per type
    if not conversation_flow:
        types = questions_per_type.keys()
        types_sampled = random.sample(types, len(types))
    # else, follow the flow
    else:
        types_sampled = conversation_flow

    conversation = []
    for q_type in types_sampled:
        conversation.append(random.choice(questions_per_type[q_type]))

    return conversation

In [12]:
def conversation_to_chatml(conversation, article):
    # set system messagee
    chatml = [{
        'role': 'system',
        'content': 'You are a helpful assistant. You read biomedical texts and concisely answer user questions about adverse drug events. You give the most specific answer supported by the text.'
    }]
    # include article in the first full text
    processed_article = get_processed_fulltext(article)
    chatml.append({
                'role': 'user',
                'content': f'Text: {processed_article}'
            })
    chatml.append({
        'role': 'assistant',
        'content': "I'm ready for your questions!"
    })
    # QA over different turns
    for q,a, _ in conversation:
        chatml.append({
            'role': 'user',
            'content': q
        })
        chatml.append({
            'role': 'assistant',
            'content': a
        })
    return chatml

In [13]:
report_ids = []
article_ids = []
conversations = []

for report_id in reports.keys():
    article_id = report_to_article[report_id]
    article = articles[article_id]
    conversation = sample_conversation(report_to_question[report_id], seed = report_id)
    chatml = conversation_to_chatml(conversation, article)

    report_ids.append(report_id)
    article_ids.append(article_id)
    conversations.append(chatml)

since Python 3.9 and will be removed in a subsequent version.
  types_sampled = random.sample(types, len(types))


In [14]:
df = pd.DataFrame(data={
    "pmid": article_ids,
    "safetyreportid": report_ids,
    "conversation": conversations
})

# pick one report per article  for now
df = df.groupby(['pmid']).first().reset_index()

## Split and save
For now, sample one report per article. Sample the exact same one as in the ICSR-Extraction dataset

In [17]:
icsr_dataset = datasets.load_dataset('BioDEX/BioDEX-ICSR')



  0%|          | 0/3 [00:00<?, ?it/s]

In [18]:
train_df = df[df['pmid'].isin(icsr_dataset['train']['pmid'])]
validation_df = df[df['pmid'].isin(icsr_dataset['validation']['pmid'])]
test_df = df[df['pmid'].isin(icsr_dataset['test']['pmid'])]

print(len(train_df))
print(len(validation_df))
print(len(test_df))

9624
2407
3628


In [19]:
print(len(icsr_dataset['train']))
print(len(icsr_dataset['validation']))
print(len(icsr_dataset['test']))


9624
2407
3628


In [20]:
ds = datasets.DatasetDict({
    "train": datasets.Dataset.from_pandas(train_df).remove_columns(['__index_level_0__']),
    "validation": datasets.Dataset.from_pandas(validation_df).remove_columns(['__index_level_0__']),
    "test": datasets.Dataset.from_pandas(test_df).remove_columns(['__index_level_0__'])
})

In [21]:
ds.push_to_hub('BioDEX/BioDEX-Conv')



Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

Deleting unused files from dataset repository:   0%|          | 0/1 [00:00<?, ?it/s]



Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

Deleting unused files from dataset repository:   0%|          | 0/1 [00:00<?, ?it/s]



Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

Deleting unused files from dataset repository:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading metadata:   0%|          | 0.00/641 [00:00<?, ?B/s]

In [22]:
validation_df.to_json('biodex_conv_validation.jsonl', orient='records', lines=True)
train_df.to_json('biodex_conv_train.jsonl', orient='records', lines=True)
test_df.to_json('biodex_conv_test.jsonl', orient='records', lines=True)
