In [1]:
# CARDIO:DE data structure:
# data
# └── cardiode
#     ├── tsv
#     |   ├── CARDIODE100_heldout
#     |   │   └── *.tsv [x100]
#     |   └── CARDIODE400_main
#     |       └── *.tsv [x400]
#     └── txt
#         ├── CARDIODE100_heldout
#         │   └── *.txt [x100]
#         └── CARDIODE400_main 
#             └── *.txt [x400]
#
# Info:
# - CARDIO:DE has no nested entites, therefore there is no need to remove them.
# - The maximum length of a passage is 2195 characters, Llama2 can handle 4096 tokens (@Suteera: is this correct?), therefore this is fine.
# - As agreed the following nomenclature applies:
#    documents = dataset
#    document  = dataset[index]
#    entities  = dataset[index]["entities"]
#    entity    = dataset[index]["entities"][index2] 
#      e.g. [{'id': '1-49-1', 'type': 'ACTIVEING', 'text': ['Salbutamol'], 'offsets': [[6055, 6065]], 'normalized': []}]
#    passages  = dataset[index]["passages"]
#    passage   = dataset[index]["passages"][index3] 
#      e.g. [{'id': '1-1', 'type': 'sentence', 'text': ['Some text'], 'offsets': [[0, 8]]},
#    samples   = transformed data to a valid input structure suitable for Suteera's training script.
#      e.g. {'sample_id': 49, 'doc_id': '1.tsv', 'doc': 'Some text.', 'prompt': None, 'label': [('Ezetimib', 'ACTIVEING')]}
#      Note: all functions defined do not add the prompt to the samples, this needs to be done externally.

In [2]:
from datasets import load_dataset, Dataset, disable_caching
from scipy.special import rel_entr, kl_div
from collections import Counter
import numpy as np
import random

In [3]:
%%capture --no-stderr
def loadDataCardioDE() -> dict:
    # Warnings are printed out due to not optimal loading in load_dataset(...). Printing the warnings will result in higher execution time.
    dataset_name = "bigbio/cardiode"
    data_dir = "./data/cardiode"
    # Only train split available.
    return load_dataset(dataset_name, split = "train", data_dir = data_dir, trust_remote_code = True)

In [4]:
# Transforms the loaded dataset of CARDIO:DE into samples which are a valid input structure suitable for
# training Llama2. Every sample is one document which has all passages concatenated.
# 
# dataset: the dataset loaded with load_dataset.
#
# return: a list of samples, each contains one document. 
#
def transformToSingleDocumentsCardioDE(dataset: dict) -> dict:
    ret = [None] * len(dataset)

    for index, document in enumerate(dataset):
        ret[index] = {
            "sample_id" : index,
            "doc_id" : document["document_id"],
            "doc" : "",
            "prompt" : None,
            "label" : []
        }
        text = ""
        for passage in document["passages"]:
            text += passage["text"][0] + " "
            
        ret[index]["doc"] = text

        # Container for the text and labels of entites.
        text = []
        label = []
        
        # Check which entities belong to the current passage.
        for entity in document["entities"]:
            text.append(entity["text"][0])
            label.append(entity["type"])

        ret[index]["label"] = list(zip(text, label))
        
    return ret

In [5]:
# Transforms the loaded dataset of CARDIO:DE into samples which are a valid input structure suitable for
# training Llama2. Every sample is one single passage.
# 
# dataset: the dataset loaded with load_dataset.
#
# return: a list of samples, each contains one passage. 
#
def transformToSinglePassagesCardioDE(dataset: dict):
    ret = []

    # Each transformed passage gets it's own index.
    idx = 0
    for document in dataset:
        for passage in document["passages"]:

            # Container for the text and labels of entites.
            text = []
            label = []

            # Offset of the current passage.
            passageoffsets = passage["offsets"][0]

            # Check which entities belong to the current passage.
            for entity in document["entities"]:
                entityoffsets = entity["offsets"][0]

                # Check if entity is in the current passage.
                if entityoffsets[0] > passageoffsets[0] and entityoffsets[1] < passageoffsets[1]:
                    text.append(entity["text"][0])
                    label.append(entity["type"])
            
            ret.append({
                "sample_id": idx, 
                "doc_id":    document["document_id"], 
                "doc":       passage["text"][0], 
                "prompt":    None,
                "label":     list(zip(text, label))
            })
            
            idx += 1

    return ret

In [6]:
# Returns a dictionary containig all entity names which can be found in CARDIO:DE as keys of the dictionary. 
# Since this function is used mostly for counting purposes, the values of each key is set to zero.
# The dictionary is sorted.
#
# return: a dictionary containin all entity names as keys.
#
def getEntityNamesCardioDE() -> dict: 
    return dict(sorted({
        "ACTIVEING" : 0, 
        "DRUG" :      0, 
        "STRENGTH" :  0, 
        "FREQUENCY" : 0, 
        "DURATION" :  0, 
        "FORM" :      0
    }.items()))

In [7]:
# Counts the number of entities in the given sample. If an entity is not found at all the count is set to 
# zero. 
#
# return: a dictionary with the entity names as key and their counts as values. Nomenclature: sample_count.
# 
def getEntityCountsCardioDE(sample) -> dict:
    ret = getEntityNamesCardioDE()

    for label in sample["label"]:
        ret[label[1]] += 1
    
    return ret

In [8]:
# Converts the numbner of entities to probability distributions which can be used to calculate the KL divergence. 
# Samples which do not contain any entity will result in a dictionary with zero entries.
# Example:
#   {"ACTIVEING" : 1    , "DRUG" :      3    , "STRENGTH" :  0, "FREQUENCY" : 1    , "DURATION" :  2   , "FORM" :      1}
#   will be converted into:
#   {"ACTIVEING" : 0.125, "DRUG" :      0.375, "STRENGTH" :  0, "FREQUENCY" : 0.125, "DURATION" :  0.25, "FORM" :      0.125}
#
# return: a dictionary containing the probability distribution of the entites in the given sample. Nomenclature: sample_distribution
#
def toDistributionCardioDE(sample_count: dict) -> dict: 
    ret = sample_count.copy()

    # Total amount of entites found in the sample.
    total = sum(ret.values())

    # If the sample contains entites.
    if (total > 0):
        for entity in ret:
            ret[entity] = ret[entity] / total

    return ret

In [9]:
# Calculates the average probability distribution of entites in CARDIO:DE.
# 
# sample_counts: a list/array of "sample_count" which you can get with the function "getEntityCountsCardioDE".
#
# return: a dictionary with the average probability distribution of entites in CARDIO:DE.
#
def getAverageDistributionCardioDE(sample_counts: list) -> dict:
    ret = getEntityNamesCardioDE()

    # Summing up the amount of each entity.
    for sample_count in sample_counts:
        for entity in sample_count:
            ret[entity] += sample_count[entity]

    # Returning the probability distribution.
    return toDistributionCardioDE(ret)

In [10]:
# Returns the KL divegence of a sample distribution w.r.t. the average distribution.
# 
# sample_distribution: the entity distribution of the sample.
# average_distribution: the average distribution of the sample.
#
# return: the divergence as float number. This could be also infinity, if the sample has no 
#   entities.
#
def getKLDivergenceCardioDE(sample_distribution: dict, average_distribution: dict) -> float:
    divergence = kl_div(list(average_distribution.values()), list(sample_distribution.values()))
    divergence = list(filter(lambda x: x != float('inf'), divergence))

    if len(divergence) == 0:
        return float('inf')
    else:
        return sum(divergence)

In [11]:
# Calculates the divergence of the samples and sorts them in a new list based on it. 
#
def sortSamplesByDivergenceCardioDE(samples) -> list:

    # Calculating the count of each sample.
    sample_counts = []
    for sample in samples:
        sample_counts.append(getEntityCountsCardioDE(sample))

    average_distribution = getAverageDistributionCardioDE(sample_counts)

    # Calculating the divergences of each sample.
    sample_divergences = []
    
    for sample_count in sample_counts:
        sample_divergences.append(getKLDivergenceCardioDE(toDistributionCardioDE(sample_count), average_distribution))

    # Adding the divergence to the sample.
    for index, sample in enumerate(samples):
        sample["divergence"] = sample_divergences[index]
        
    return sorted(samples, key = lambda x: x['divergence'])

In [12]:
# Returns the amount of samples which are requested. If the amount is below 1 and greater than 0, it will
# be treated as percentage, if it is below 0, the whole dataset will be returned, if it is above 1, the given
# amount of samples are returned.
# 
# samples: the whole set of samples.
# amount: the amount of samples requested.
#
# return: a sorted list with the requested amount of samples which have the most similar distribution as the 
#     average sample. 
#
def getRankedSamplesCardioDE(samples: list, amount: float = -1) -> list:
    # Correcting amount if it is negative.
    if amount < 0:
        amount = len(samples)
    else:
        # Also percentages are acceptable.
        if amount < 1:
            amount = round(amount * len(samples))

    # Get the sorted samples.
    sorted_samples = sortSamplesByDivergenceCardioDE(samples)

    return sorted_samples[0:amount]

In [13]:
# Loads the dataset, shuffle the data based on the seed, ranks the data based on the average sample of the
# dataset and returns the requested amount of samples.
#
# amount: the amount of samples requested. Look at "getRankedSamplesCardioDE" for more information.
# seed: the seed used to shuffle the samples.
#
# return: a list of samples with "amount" length containing the samples  which have the most similar 
#     distribution as the average sample. 
def cardioDE(amount: float = -1, seed: int = 42, train_split: float = 0.7, test_split: float = 0.15) -> list:
    dataset = loadDataCardioDE()
    documents = transformToSingleDocumentsCardioDE(dataset)
    
    # Randomize the data.
    random.seed(seed)
    random.shuffle(documents)

    train_amount = round(len(documents) * train_split)
    test_amount = round(len(documents) * test_split)

    train = documents[0: train_amount]
    test = documents[train_amount: train_amount + test_amount]
    dev = documents[train_amount + test_amount:len(documents)]

    return {
        "train" : getRankedSamplesCardioDE(train, amount = amount), 
        "test" : test,
        "dev" : dev
    }

In [14]:
# An example:
samples = cardioDE(amount = -1, seed = 17)