In [1]:
from datasets import load_from_disk
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
meddialog_path = Path("../datasets/meddialog").resolve()
meddialog = load_from_disk(str(meddialog_path))
meddialog

DatasetDict({
    train: Dataset({
        features: ['description', 'utterances'],
        num_rows: 482
    })
    validation: Dataset({
        features: ['description', 'utterances'],
        num_rows: 60
    })
    test: Dataset({
        features: ['description', 'utterances'],
        num_rows: 61
    })
})

In [3]:
meddialog['train'].features


{'description': Value(dtype='string', id=None),
 'utterances': Sequence(feature={'speaker': ClassLabel(names=['patient', 'doctor'], id=None), 'utterance': Value(dtype='string', id=None)}, length=-1, id=None)}

In [13]:
meddialog['train'][0]

{'description': 'throat a bit sore and want to get a good imune booster, especially in light of the virus. please advise. have not been in contact with nyone with the virus.',
 'utterances': {'speaker': [0, 1],
  'utterance': ['throat a bit sore and want to get a good imune booster, especially in light of the virus. please advise. have not been in contact with nyone with the virus.',
   "during this pandemic. throat pain can be from a strep throat infection (antibiotics needed), a cold or influenza or other virus, or from some other cause such as allergies or irritants. usually, a person sees the doctor (call first) if the sore throat is bothersome, recurrent, or doesn't go away quickly. covid-19 infections tend to have cough, whereas strep throat usually lacks cough but has more throat pain. (3/21/20)"]}}

In [14]:
def is_valid_dialog(example):
    return (
        isinstance(example.get("utterances"), dict) and
        "utterance" in example["utterances"] and
        isinstance(example["utterances"]["utterance"], list) and
        len(example["utterances"]["utterance"]) > 1 and
        example.get("description")
    )

filtered = meddialog['train'].filter(is_valid_dialog)
print("Filtered rows:", len(filtered))

Filtered rows: 482


In [15]:
def flatten_utterances(example):
    return {
        "dialogue_text": " ".join(example["utterances"]["utterance"]),
        "label": example["description"]
    }

In [16]:
filtered = meddialog['train'].filter(is_valid_dialog)
formatted = filtered.map(flatten_utterances, remove_columns=filtered.column_names)
formatted

Map: 100%|██████████| 482/482 [00:00<00:00, 5881.81 examples/s]


Dataset({
    features: ['dialogue_text', 'label'],
    num_rows: 482
})

In [21]:
formatted.to_json("../data/cleaned_meddialog.json", orient="records", lines=True)


Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 124.17ba/s]


371522